diff --git a/toolkit.py b/toolkit.py index 3db07b0d43521e67499ccd62e0df2162ddbf7322..789467549887f87d25c988713db4fc3429e294a5 100755 --- a/toolkit.py +++ b/toolkit.py @@ -2,6 +2,7 @@ import os import json +import pickle import logging logger = logging.getLogger("KerasROOTClassification") @@ -18,6 +19,7 @@ from sklearn.metrics import roc_curve, auc from keras.models import Sequential from keras.layers import Dense from keras.models import model_from_json +from keras.callbacks import History import matplotlib.pyplot as plt import matplotlib.pyplot as plt @@ -256,6 +258,33 @@ class KerasROOTClassification(object): return self._scaler + @property + def history(self): + params_file = os.path.join(self.project_dir, "history_params.json") + history_file = os.path.join(self.project_dir, "history_history.json") + if self._history is None: + self._history = History() + with open(params_file) as f: + self._history.params = json.load(f) + with open(history_file) as f: + self._history.history = json.load(f) + return self._history + + + @history.setter + def history(self, value): + self._history = value + + + def _dump_history(self): + params_file = os.path.join(self.project_dir, "history_params.json") + history_file = os.path.join(self.project_dir, "history_history.json") + with open(params_file, "wb") as of: + json.dump(self.history.params, of) + with open(history_file, "wb") as of: + json.dump(self.history.history, of) + + def _transform_data(self): if not self.data_transformed: # todo: what to do about the outliers? Where do they come from? @@ -324,14 +353,28 @@ class KerasROOTClassification(object): return self._class_weight - def train(self, epochs=10): - + def load(self): + "Load all data needed for plotting and training" if not self.data_loaded: self._load_data() if not self.data_transformed: self._transform_data() + + def shuffle_training_data(self): + rn_state = np.random.get_state() + np.random.shuffle(self.x_train) + np.random.set_state(rn_state) + np.random.shuffle(self.y_train) + np.random.set_state(rn_state) + np.random.shuffle(self.w_train) + + + def train(self, epochs=10): + + self.load() + for branch_index, branch in enumerate(self.branches): self.plot_input(branch_index) @@ -345,14 +388,26 @@ class KerasROOTClassification(object): self.total_epochs = self._read_info("epochs", 0) logger.info("Train model") - self._history = self.model.fit(self.x_train, - # the reshape might be unnescessary here - self.y_train.reshape(-1, 1), - epochs=epochs, - validation_split = self.validation_split, - class_weight=self.class_weight, - shuffle=True, - batch_size=self.batch_size) + try: + self.history = History() + self.shuffle_training_data() + self.model.fit(self.x_train, + # the reshape might be unnescessary here + self.y_train.reshape(-1, 1), + epochs=epochs, + validation_split = self.validation_split, + class_weight=self.class_weight, + sample_weight=self.w_train, + shuffle=True, + batch_size=self.batch_size, + callbacks=[self.history]) + except KeyboardInterrupt: + logger.info("Interrupt training - continue with rest") + + print(self.history) + + logger.info("Save history") + self._dump_history() logger.info("Save weights") self.model.save_weights(os.path.join(self.project_dir, "weights.h5")) @@ -467,8 +522,8 @@ class KerasROOTClassification(object): def plot_loss(self): logger.info("Plot losses") - plt.plot(self._history.history['loss']) - plt.plot(self._history.history['val_loss']) + plt.plot(self.history.history['loss']) + plt.plot(self.history.history['val_loss']) plt.ylabel('loss') plt.xlabel('epoch') plt.legend(['train','test'], loc='upper left') @@ -479,8 +534,8 @@ class KerasROOTClassification(object): def plot_accuracy(self): logger.info("Plot accuracy") - plt.plot(self._history.history['acc']) - plt.plot(self._history.history['val_acc']) + plt.plot(self.history.history['acc']) + plt.plot(self.history.history['val_acc']) plt.title('model accuracy') plt.ylabel('accuracy') plt.xlabel('epoch') @@ -507,7 +562,8 @@ if __name__ == "__main__": identifiers = ["DatasetNumber", "EventNumber"], step_bkg = 100) -# c.train(epochs=10) + #c.load() + c.train(epochs=10) c.plot_ROC() c.plot_loss() c.plot_accuracy()