Skip to content
Snippets Groups Projects
Commit ce7b4c4f authored by Nikolai's avatar Nikolai
Browse files

correct project dir initialisation and csvlogger

parent 81dfd04c
No related branches found
No related tags found
No related merge requests found
......@@ -4,6 +4,7 @@ import os
import json
import pickle
import importlib
import csv
import logging
logger = logging.getLogger("KerasROOTClassification")
......@@ -20,7 +21,7 @@ from sklearn.metrics import roc_curve, auc
from keras.models import Sequential
from keras.layers import Dense
from keras.models import model_from_json
from keras.callbacks import History, EarlyStopping
from keras.callbacks import History, EarlyStopping, CSVLogger
from keras.optimizers import SGD
import keras.optimizers
......@@ -121,6 +122,7 @@ class KerasROOTClassification(object):
def _init_from_dir(self, dirname):
with open(os.path.join(dirname, "options.json")) as f:
options = json.load(f)
options["kwargs"]["project_dir"] = dirname
self._init_from_args(os.path.basename(dirname), *options["args"], **options["kwargs"])
......@@ -132,7 +134,7 @@ class KerasROOTClassification(object):
batch_size=128,
validation_split=0.33,
activation_function='relu',
out_dir="./outputs",
project_dir=None,
scaler_type="RobustScaler",
step_signal=2,
step_bkg=2,
......@@ -153,7 +155,6 @@ class KerasROOTClassification(object):
self.batch_size = batch_size
self.validation_split = validation_split
self.activation_function = activation_function
self.out_dir = out_dir
self.scaler_type = scaler_type
self.step_signal = step_signal
self.step_bkg = step_bkg
......@@ -165,10 +166,9 @@ class KerasROOTClassification(object):
earlystopping_opts = dict()
self.earlystopping_opts = earlystopping_opts
self.project_dir = os.path.join(self.out_dir, name)
if not os.path.exists(self.out_dir):
os.mkdir(self.out_dir)
self.project_dir = project_dir
if self.project_dir is None:
self.project_dir = name
if not os.path.exists(self.project_dir):
os.mkdir(self.project_dir)
......@@ -330,10 +330,10 @@ class KerasROOTClassification(object):
@property
def callbacks_list(self):
if not self._callbacks_list:
self._callbacks_list.append(self.history)
self._callbacks_list.append(EarlyStopping(**self.earlystopping_opts))
self._callbacks_list = []
self._callbacks_list.append(self.history)
self._callbacks_list.append(EarlyStopping(**self.earlystopping_opts))
self._callbacks_list.append(CSVLogger(os.path.join(self.project_dir, "training.log"), append=True))
return self._callbacks_list
......@@ -369,10 +369,11 @@ class KerasROOTClassification(object):
history_file = os.path.join(self.project_dir, "history_history.json")
if self._history is None:
self._history = History()
with open(params_file) as f:
self._history.params = json.load(f)
with open(history_file) as f:
self._history.history = json.load(f)
if os.path.exists(params_file) and os.path.exists(history_file):
with open(params_file) as f:
self._history.params = json.load(f)
with open(history_file) as f:
self._history.history = json.load(f)
return self._history
......@@ -502,7 +503,6 @@ class KerasROOTClassification(object):
logger.info("Train model")
try:
self.history = History()
self.shuffle_training_data()
self.model.fit(self.x_train,
# the reshape might be unnescessary here
......@@ -684,11 +684,31 @@ class KerasROOTClassification(object):
pass
def plot_loss(self):
@property
def csv_hist(self):
with open(os.path.join(self.project_dir, "training.log")) as f:
reader = csv.reader(f)
history_list = list(reader)
hist_dict = {}
for hist_key_index, hist_key in enumerate(history_list[0]):
hist_dict[hist_key] = [float(line[hist_key_index]) for line in history_list[1:]]
return hist_dict
def plot_loss(self, all_trainings=False):
"""
Plot the value of the loss function for each epoch
:param all_trainings: set to true if you want to plot all trainings (otherwise the previous history is used)
"""
if all_trainings:
hist_dict = self.csv_hist
else:
hist_dict = self.history.history
logger.info("Plot losses")
plt.plot(self.history.history['loss'])
plt.plot(self.history.history['val_loss'])
plt.plot(hist_dict['loss'])
plt.plot(hist_dict['val_loss'])
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train','test'], loc='upper left')
......@@ -696,11 +716,21 @@ class KerasROOTClassification(object):
plt.clf()
def plot_accuracy(self):
def plot_accuracy(self, all_trainings=False):
"""
Plot the value of the accuracy metric for each epoch
:param all_trainings: set to true if you want to plot all trainings (otherwise the previous history is used)
"""
if all_trainings:
hist_dict = self.csv_hist
else:
hist_dict = self.history.history
logger.info("Plot accuracy")
plt.plot(self.history.history['acc'])
plt.plot(self.history.history['val_acc'])
plt.plot(hist_dict['acc'])
plt.plot(hist_dict['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment