diff --git a/toolkit.py b/toolkit.py index f679f69b1e7e56a2bf0eac46cda2649f44006f9a..1f86123c8489763d5c135c000f4c101e85367618 100755 --- a/toolkit.py +++ b/toolkit.py @@ -44,7 +44,7 @@ class KerasROOTClassification: def __init__(self, name, signal_trees, bkg_trees, branches, weight_expr, identifiers, - layers=3, nodes=64, batch_size=128, activation_function='relu', out_dir="./outputs"): + layers=3, nodes=64, batch_size=128, validation_split=0.33, activation_function='relu', out_dir="./outputs"): self.name = name self.signal_trees = signal_trees self.bkg_trees = bkg_trees @@ -54,6 +54,7 @@ class KerasROOTClassification: self.layers = layers self.nodes = nodes self.batch_size = batch_size + self.validation_split = validation_split self.activation_function = activation_function self.out_dir = out_dir @@ -82,6 +83,7 @@ class KerasROOTClassification: self._bkg_weights = None self._sig_weights = None self._model = None + self._history = None self.score_train = None self.score_test = None @@ -280,10 +282,11 @@ class KerasROOTClassification: self.total_epochs = self._read_info("epochs", 0) logger.info("Train model") - self.model.fit(self.x_train, + self._history = self.model.fit(self.x_train, # the reshape might be unnescessary here self.y_train.reshape(-1, 1), epochs=epochs, + validation_split = self.validation_split, class_weight=self.class_weight, shuffle=True, batch_size=self.batch_size) @@ -376,7 +379,28 @@ class KerasROOTClassification: def plot_score(self): pass - + + def plot_loss(self): + + logger.info("Plot losses") + plt.plot(self._history.history['loss']) + plt.plot(self._history.history['val_loss']) + plt.ylabel('loss') + plt.xlabel('epoch') + plt.legend(['train','test'], loc='upper left') + plt.savefig(os.path.join(self.project_dir, "losses.pdf")) + + + def plot_accuracy(self): + + logger.info("Plot accuracy") + plt.plot(self._history.history['acc']) + plt.plot(self._history.history['val_acc']) + plt.title('model accuracy') + plt.ylabel('accuracy') + plt.xlabel('epoch') + plt.legend(['train', 'test'], loc='upper left') + plt.savefig(os.path.join(self.project_dir, "accuracy.pdf")) if __name__ == "__main__": @@ -395,5 +419,7 @@ if __name__ == "__main__": weight_expr = "eventWeight*genWeight", identifiers = ["DatasetNumber", "EventNumber"]) - c.train(epochs=1) + c.train(epochs=20) c.plot_ROC() + c.plot_loss() + c.plot_accuracy()