Skip to content
Snippets Groups Projects
Commit 77dd63a5 authored by Thomas Weber's avatar Thomas Weber
Browse files

Added plot functions for losses and accuracy

parent 1bb1f5f8
No related branches found
No related tags found
No related merge requests found
......@@ -44,7 +44,7 @@ class KerasROOTClassification:
def __init__(self, name,
signal_trees, bkg_trees, branches, weight_expr, identifiers,
layers=3, nodes=64, batch_size=128, activation_function='relu', out_dir="./outputs"):
layers=3, nodes=64, batch_size=128, validation_split=0.33, activation_function='relu', out_dir="./outputs"):
self.name = name
self.signal_trees = signal_trees
self.bkg_trees = bkg_trees
......@@ -54,6 +54,7 @@ class KerasROOTClassification:
self.layers = layers
self.nodes = nodes
self.batch_size = batch_size
self.validation_split = validation_split
self.activation_function = activation_function
self.out_dir = out_dir
......@@ -82,6 +83,7 @@ class KerasROOTClassification:
self._bkg_weights = None
self._sig_weights = None
self._model = None
self._history = None
self.score_train = None
self.score_test = None
......@@ -280,10 +282,11 @@ class KerasROOTClassification:
self.total_epochs = self._read_info("epochs", 0)
logger.info("Train model")
self.model.fit(self.x_train,
self._history = self.model.fit(self.x_train,
# the reshape might be unnescessary here
self.y_train.reshape(-1, 1),
epochs=epochs,
validation_split = self.validation_split,
class_weight=self.class_weight,
shuffle=True,
batch_size=self.batch_size)
......@@ -376,7 +379,28 @@ class KerasROOTClassification:
def plot_score(self):
pass
def plot_loss(self):
logger.info("Plot losses")
plt.plot(self._history.history['loss'])
plt.plot(self._history.history['val_loss'])
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train','test'], loc='upper left')
plt.savefig(os.path.join(self.project_dir, "losses.pdf"))
def plot_accuracy(self):
logger.info("Plot accuracy")
plt.plot(self._history.history['acc'])
plt.plot(self._history.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.savefig(os.path.join(self.project_dir, "accuracy.pdf"))
if __name__ == "__main__":
......@@ -395,5 +419,7 @@ if __name__ == "__main__":
weight_expr = "eventWeight*genWeight",
identifiers = ["DatasetNumber", "EventNumber"])
c.train(epochs=1)
c.train(epochs=20)
c.plot_ROC()
c.plot_loss()
c.plot_accuracy()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment