diff --git a/toolkit.py b/toolkit.py index 5a312a204a434545032310cf87ea7d9fda706914..e801ef6fc3f9d41b6ad42f9dc2eb74eb2d79a4ae 100755 --- a/toolkit.py +++ b/toolkit.py @@ -150,6 +150,8 @@ class ClassificationProject(object): random data is also used for shuffling the training data, so results may vary still. To produce consistent results, set the numpy random seed before training. + :param loss: loss function name + """ @@ -205,7 +207,8 @@ class ClassificationProject(object): use_modelcheckpoint=True, modelcheckpoint_opts=None, random_seed=1234, - balance_dataset=False): + balance_dataset=False, + loss='binary_crossentropy'): self.name = name self.signal_trees = signal_trees @@ -253,6 +256,7 @@ class ClassificationProject(object): self.modelcheckpoint_opts = modelcheckpoint_opts self.random_seed = random_seed self.balance_dataset = balance_dataset + self.loss = loss self.s_train = None self.b_train = None @@ -562,7 +566,7 @@ class ClassificationProject(object): rn_state = np.random.get_state() np.random.seed(self.random_seed) self._model.compile(optimizer=optimizer, - loss='binary_crossentropy', + loss=self.loss, metrics=['accuracy']) np.random.set_state(rn_state) if os.path.exists(os.path.join(self.project_dir, "weights.h5")): @@ -1031,7 +1035,7 @@ class ClassificationProject(object): plt.plot(hist_dict['val_loss']) plt.ylabel('loss') plt.xlabel('epoch') - plt.legend(['train','test'], loc='upper left') + plt.legend(['training data','validation data'], loc='upper left') if log: plt.yscale("log") if xlim is not None: @@ -1065,7 +1069,7 @@ class ClassificationProject(object): plt.title('model accuracy') plt.ylabel('accuracy') plt.xlabel('epoch') - plt.legend(['train', 'test'], loc='upper left') + plt.legend(['training data', 'validation data'], loc='upper left') if log: plt.yscale("log") plt.savefig(os.path.join(self.project_dir, "accuracy.pdf"))