Skip to content
Snippets Groups Projects
Commit 489b934d authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

Merge remote-tracking branch 'origin/master'

parents b9e191be d23f0440
No related branches found
No related tags found
No related merge requests found
......@@ -619,7 +619,8 @@ class ClassificationProject(object):
np.random.seed(self.random_seed)
self._model.compile(optimizer=optimizer,
loss=self.loss,
metrics=['accuracy'])
weighted_metrics=['accuracy']
)
np.random.set_state(rn_state)
if os.path.exists(os.path.join(self.project_dir, "weights.h5")):
......@@ -1106,7 +1107,7 @@ class ClassificationProject(object):
plt.clf()
def plot_accuracy(self, all_trainings=False, log=False):
def plot_accuracy(self, all_trainings=False, log=False, acc_suffix="weighted_acc"):
"""
Plot the value of the accuracy metric for each epoch
......@@ -1118,14 +1119,14 @@ class ClassificationProject(object):
else:
hist_dict = self.history.history
if (not 'acc' in hist_dict) or (not 'val_acc' in hist_dict):
if (not acc_suffix in hist_dict) or (not 'val_'+acc_suffix in hist_dict):
logger.warning("No previous history found for plotting, try global history")
hist_dict = self.csv_hist
logger.info("Plot accuracy")
plt.plot(hist_dict['acc'])
plt.plot(hist_dict['val_acc'])
plt.plot(hist_dict[acc_suffix])
plt.plot(hist_dict['val_'+acc_suffix])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
......@@ -1138,11 +1139,11 @@ class ClassificationProject(object):
def plot_all(self):
self.plot_ROC()
self.plot_accuracy()
# self.plot_accuracy()
self.plot_loss()
self.plot_score()
self.plot_weights()
self.plot_significance()
# self.plot_significance()
def create_getter(dataset_name):
......@@ -1181,8 +1182,8 @@ if __name__ == "__main__":
optimizer="Adam",
#optimizer="SGD",
#optimizer_opts=dict(lr=100., decay=1e-6, momentum=0.9),
earlystopping_opts=dict(monitor='val_loss',
min_delta=0, patience=2, verbose=0, mode='auto'),
earlystopping_opts=dict(monitor='val_loss',
min_delta=0, patience=2, verbose=0, mode='auto'),
selection="1",
branches = ["met", "mt"],
weight_expr = "eventWeight*genWeight",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment