diff --git a/toolkit.py b/toolkit.py index 3d862f5d27f7f36062635a5d906bff5e15bb0a07..7dc4d201a94a9d322428e75237bb5a68a7a7712b 100755 --- a/toolkit.py +++ b/toolkit.py @@ -4,6 +4,7 @@ import os import json import pickle import importlib +import csv import logging logger = logging.getLogger("KerasROOTClassification") @@ -20,7 +21,7 @@ from sklearn.metrics import roc_curve, auc from keras.models import Sequential from keras.layers import Dense from keras.models import model_from_json -from keras.callbacks import History, EarlyStopping +from keras.callbacks import History, EarlyStopping, CSVLogger from keras.optimizers import SGD import keras.optimizers @@ -121,6 +122,7 @@ class KerasROOTClassification(object): def _init_from_dir(self, dirname): with open(os.path.join(dirname, "options.json")) as f: options = json.load(f) + options["kwargs"]["project_dir"] = dirname self._init_from_args(os.path.basename(dirname), *options["args"], **options["kwargs"]) @@ -132,7 +134,7 @@ class KerasROOTClassification(object): batch_size=128, validation_split=0.33, activation_function='relu', - out_dir="./outputs", + project_dir=None, scaler_type="RobustScaler", step_signal=2, step_bkg=2, @@ -153,7 +155,6 @@ class KerasROOTClassification(object): self.batch_size = batch_size self.validation_split = validation_split self.activation_function = activation_function - self.out_dir = out_dir self.scaler_type = scaler_type self.step_signal = step_signal self.step_bkg = step_bkg @@ -165,10 +166,9 @@ class KerasROOTClassification(object): earlystopping_opts = dict() self.earlystopping_opts = earlystopping_opts - self.project_dir = os.path.join(self.out_dir, name) - - if not os.path.exists(self.out_dir): - os.mkdir(self.out_dir) + self.project_dir = project_dir + if self.project_dir is None: + self.project_dir = name if not os.path.exists(self.project_dir): os.mkdir(self.project_dir) @@ -330,10 +330,10 @@ class KerasROOTClassification(object): @property def callbacks_list(self): - if not self._callbacks_list: - self._callbacks_list.append(self.history) - self._callbacks_list.append(EarlyStopping(**self.earlystopping_opts)) - + self._callbacks_list = [] + self._callbacks_list.append(self.history) + self._callbacks_list.append(EarlyStopping(**self.earlystopping_opts)) + self._callbacks_list.append(CSVLogger(os.path.join(self.project_dir, "training.log"), append=True)) return self._callbacks_list @@ -369,10 +369,11 @@ class KerasROOTClassification(object): history_file = os.path.join(self.project_dir, "history_history.json") if self._history is None: self._history = History() - with open(params_file) as f: - self._history.params = json.load(f) - with open(history_file) as f: - self._history.history = json.load(f) + if os.path.exists(params_file) and os.path.exists(history_file): + with open(params_file) as f: + self._history.params = json.load(f) + with open(history_file) as f: + self._history.history = json.load(f) return self._history @@ -502,7 +503,6 @@ class KerasROOTClassification(object): logger.info("Train model") try: - self.history = History() self.shuffle_training_data() self.model.fit(self.x_train, # the reshape might be unnescessary here @@ -684,11 +684,31 @@ class KerasROOTClassification(object): pass - def plot_loss(self): + @property + def csv_hist(self): + with open(os.path.join(self.project_dir, "training.log")) as f: + reader = csv.reader(f) + history_list = list(reader) + hist_dict = {} + for hist_key_index, hist_key in enumerate(history_list[0]): + hist_dict[hist_key] = [float(line[hist_key_index]) for line in history_list[1:]] + return hist_dict + + def plot_loss(self, all_trainings=False): + """ + Plot the value of the loss function for each epoch + + :param all_trainings: set to true if you want to plot all trainings (otherwise the previous history is used) + """ + + if all_trainings: + hist_dict = self.csv_hist + else: + hist_dict = self.history.history logger.info("Plot losses") - plt.plot(self.history.history['loss']) - plt.plot(self.history.history['val_loss']) + plt.plot(hist_dict['loss']) + plt.plot(hist_dict['val_loss']) plt.ylabel('loss') plt.xlabel('epoch') plt.legend(['train','test'], loc='upper left') @@ -696,11 +716,21 @@ class KerasROOTClassification(object): plt.clf() - def plot_accuracy(self): + def plot_accuracy(self, all_trainings=False): + """ + Plot the value of the accuracy metric for each epoch + + :param all_trainings: set to true if you want to plot all trainings (otherwise the previous history is used) + """ + + if all_trainings: + hist_dict = self.csv_hist + else: + hist_dict = self.history.history logger.info("Plot accuracy") - plt.plot(self.history.history['acc']) - plt.plot(self.history.history['val_acc']) + plt.plot(hist_dict['acc']) + plt.plot(hist_dict['val_acc']) plt.title('model accuracy') plt.ylabel('accuracy') plt.xlabel('epoch')