Skip to content
Snippets Groups Projects
Commit a35af907 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

Plot score for training and test data, signal and background

parent f205e49d
No related branches found
No related tags found
No related merge requests found
...@@ -472,8 +472,13 @@ class KerasROOTClassification(object): ...@@ -472,8 +472,13 @@ class KerasROOTClassification(object):
return self._class_weight return self._class_weight
def load(self): def load(self, reload=False):
"Load all data needed for plotting and training" "Load all data needed for plotting and training"
if reload:
self.data_loaded = False
self.data_transformed = False
if not self.data_loaded: if not self.data_loaded:
self._load_data() self._load_data()
...@@ -489,6 +494,7 @@ class KerasROOTClassification(object): ...@@ -489,6 +494,7 @@ class KerasROOTClassification(object):
np.random.set_state(rn_state) np.random.set_state(rn_state)
np.random.shuffle(self.w_train) np.random.shuffle(self.w_train)
if self._scores_train is not None: if self._scores_train is not None:
logger.info("Shuffling scores, since they are also there")
np.random.set_state(rn_state) np.random.set_state(rn_state)
np.random.shuffle(self._scores_train) np.random.shuffle(self._scores_train)
...@@ -528,12 +534,18 @@ class KerasROOTClassification(object): ...@@ -528,12 +534,18 @@ class KerasROOTClassification(object):
self.total_epochs += epochs self.total_epochs += epochs
self._write_info("epochs", self.total_epochs) self._write_info("epochs", self.total_epochs)
logger.info("Reloading (and re-transforming) unshuffled training data")
self.load(reload=True)
logger.info("Create/Update scores for ROC curve") logger.info("Create/Update scores for ROC curve")
self.scores_test = self.model.predict(self.x_test) self.scores_test = self.model.predict(self.x_test)
self.scores_train = self.model.predict(self.x_train) self.scores_train = self.model.predict(self.x_train)
self._dump_to_hdf5("scores_train", "scores_test") self._dump_to_hdf5("scores_train", "scores_test")
logger.info("Creating all validation plots")
self.plot_all()
def evaluate(self, x_eval): def evaluate(self, x_eval):
...@@ -588,7 +600,8 @@ class KerasROOTClassification(object): ...@@ -588,7 +600,8 @@ class KerasROOTClassification(object):
def get_bin_centered_hist(x, scale_factor=None, **np_kwargs): def get_bin_centered_hist(x, scale_factor=None, **np_kwargs):
hist, bins = np.histogram(x, **np_kwargs) hist, bins = np.histogram(x, **np_kwargs)
centers = (bins[:-1] + bins[1:]) / 2 centers = (bins[:-1] + bins[1:]) / 2
hist *= scale_factor if scale_factor is not None:
hist *= scale_factor
return centers, hist return centers, hist
...@@ -667,8 +680,24 @@ class KerasROOTClassification(object): ...@@ -667,8 +680,24 @@ class KerasROOTClassification(object):
plt.savefig(os.path.join(self.project_dir, "ROC.pdf")) plt.savefig(os.path.join(self.project_dir, "ROC.pdf"))
plt.clf() plt.clf()
def plot_score(self): def plot_score(self):
pass plot_opts = dict(bins=50, range=(0, 1))
centers_sig_train, hist_sig_train = self.get_bin_centered_hist(self.scores_train[self.y_train==1].reshape(-1), density=True, weights=self.w_train[self.y_train==1], **plot_opts)
centers_bkg_train, hist_bkg_train = self.get_bin_centered_hist(self.scores_train[self.y_train==0].reshape(-1), density=True, weights=self.w_train[self.y_train==0], **plot_opts)
centers_sig_test, hist_sig_test = self.get_bin_centered_hist(self.scores_test[self.y_test==1].reshape(-1), density=True, weights=self.w_test[self.y_test==1], **plot_opts)
centers_bkg_test, hist_bkg_test = self.get_bin_centered_hist(self.scores_test[self.y_test==0].reshape(-1), density=True, weights=self.w_test[self.y_test==0], **plot_opts)
fig, ax = plt.subplots()
width = centers_sig_train[1]-centers_sig_train[0]
ax.bar(centers_bkg_train, hist_bkg_train, color="b", alpha=0.5, width=width, label="background train")
ax.bar(centers_sig_train, hist_sig_train, color="r", alpha=0.5, width=width, label="signal train")
ax.scatter(centers_bkg_test, hist_bkg_test, color="b", label="background test")
ax.scatter(centers_sig_test, hist_sig_test, color="r", label="signal test")
ax.set_yscale("log")
ax.set_xlabel("NN output")
plt.legend(loc='upper right', framealpha=1.0)
fig.savefig(os.path.join(self.project_dir, "scores.pdf"))
def plot_loss(self): def plot_loss(self):
...@@ -695,6 +724,15 @@ class KerasROOTClassification(object): ...@@ -695,6 +724,15 @@ class KerasROOTClassification(object):
plt.savefig(os.path.join(self.project_dir, "accuracy.pdf")) plt.savefig(os.path.join(self.project_dir, "accuracy.pdf"))
plt.clf() plt.clf()
def plot_all(self):
self.plot_ROC()
self.plot_accuracy()
self.plot_loss()
self.plot_score()
self.plot_weights()
def create_getter(dataset_name): def create_getter(dataset_name):
def getx(self): def getx(self):
if getattr(self, "_"+dataset_name) is None: if getattr(self, "_"+dataset_name) is None:
...@@ -739,9 +777,7 @@ if __name__ == "__main__": ...@@ -739,9 +777,7 @@ if __name__ == "__main__":
np.random.seed(42) np.random.seed(42)
c.train(epochs=20) c.train(epochs=20)
c.plot_ROC() #c.plot_all()
c.plot_loss()
c.plot_accuracy()
# c.write_friend_tree("test4_score", # c.write_friend_tree("test4_score",
# source_filename=filename, source_treename="GG_oneStep_1705_1105_505_NoSys", # source_filename=filename, source_treename="GG_oneStep_1705_1105_505_NoSys",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment