Skip to content
Snippets Groups Projects
Commit d9f78c70 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

loss function configurable

parent a027593b
No related branches found
No related tags found
No related merge requests found
...@@ -150,6 +150,8 @@ class ClassificationProject(object): ...@@ -150,6 +150,8 @@ class ClassificationProject(object):
random data is also used for shuffling the training data, so results may vary still. To random data is also used for shuffling the training data, so results may vary still. To
produce consistent results, set the numpy random seed before training. produce consistent results, set the numpy random seed before training.
:param loss: loss function name
""" """
...@@ -205,7 +207,8 @@ class ClassificationProject(object): ...@@ -205,7 +207,8 @@ class ClassificationProject(object):
use_modelcheckpoint=True, use_modelcheckpoint=True,
modelcheckpoint_opts=None, modelcheckpoint_opts=None,
random_seed=1234, random_seed=1234,
balance_dataset=False): balance_dataset=False,
loss='binary_crossentropy'):
self.name = name self.name = name
self.signal_trees = signal_trees self.signal_trees = signal_trees
...@@ -253,6 +256,7 @@ class ClassificationProject(object): ...@@ -253,6 +256,7 @@ class ClassificationProject(object):
self.modelcheckpoint_opts = modelcheckpoint_opts self.modelcheckpoint_opts = modelcheckpoint_opts
self.random_seed = random_seed self.random_seed = random_seed
self.balance_dataset = balance_dataset self.balance_dataset = balance_dataset
self.loss = loss
self.s_train = None self.s_train = None
self.b_train = None self.b_train = None
...@@ -562,7 +566,7 @@ class ClassificationProject(object): ...@@ -562,7 +566,7 @@ class ClassificationProject(object):
rn_state = np.random.get_state() rn_state = np.random.get_state()
np.random.seed(self.random_seed) np.random.seed(self.random_seed)
self._model.compile(optimizer=optimizer, self._model.compile(optimizer=optimizer,
loss='binary_crossentropy', loss=self.loss,
metrics=['accuracy']) metrics=['accuracy'])
np.random.set_state(rn_state) np.random.set_state(rn_state)
if os.path.exists(os.path.join(self.project_dir, "weights.h5")): if os.path.exists(os.path.join(self.project_dir, "weights.h5")):
...@@ -1031,7 +1035,7 @@ class ClassificationProject(object): ...@@ -1031,7 +1035,7 @@ class ClassificationProject(object):
plt.plot(hist_dict['val_loss']) plt.plot(hist_dict['val_loss'])
plt.ylabel('loss') plt.ylabel('loss')
plt.xlabel('epoch') plt.xlabel('epoch')
plt.legend(['train','test'], loc='upper left') plt.legend(['training data','validation data'], loc='upper left')
if log: if log:
plt.yscale("log") plt.yscale("log")
if xlim is not None: if xlim is not None:
...@@ -1065,7 +1069,7 @@ class ClassificationProject(object): ...@@ -1065,7 +1069,7 @@ class ClassificationProject(object):
plt.title('model accuracy') plt.title('model accuracy')
plt.ylabel('accuracy') plt.ylabel('accuracy')
plt.xlabel('epoch') plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left') plt.legend(['training data', 'validation data'], loc='upper left')
if log: if log:
plt.yscale("log") plt.yscale("log")
plt.savefig(os.path.join(self.project_dir, "accuracy.pdf")) plt.savefig(os.path.join(self.project_dir, "accuracy.pdf"))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment