Skip to content
Snippets Groups Projects
Commit b8017999 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

Merge branch 'master' into dev-organisation

parents cc9ee035 a35af907
No related branches found
No related tags found
No related merge requests found
......@@ -198,8 +198,6 @@ class ClassificationProject(object):
self._scaler = None
self._class_weight = None
self._bkg_weights = None
self._sig_weights = None
self._model = None
self._history = None
self._callbacks_list = []
......@@ -476,8 +474,13 @@ class ClassificationProject(object):
return self._class_weight
def load(self):
def load(self, reload=False):
"Load all data needed for plotting and training"
if reload:
self.data_loaded = False
self.data_transformed = False
if not self.data_loaded:
self._load_data()
......@@ -492,9 +495,10 @@ class ClassificationProject(object):
np.random.shuffle(self.y_train)
np.random.set_state(rn_state)
np.random.shuffle(self.w_train)
if self._scores_test is not None:
if self._scores_train is not None:
logger.info("Shuffling scores, since they are also there")
np.random.set_state(rn_state)
np.random.shuffle(self._scores_test)
np.random.shuffle(self._scores_train)
def train(self, epochs=10):
......@@ -531,12 +535,18 @@ class ClassificationProject(object):
self.total_epochs += epochs
self._write_info("epochs", self.total_epochs)
logger.info("Reloading (and re-transforming) unshuffled training data")
self.load(reload=True)
logger.info("Create/Update scores for ROC curve")
self.scores_test = self.model.predict(self.x_test)
self.scores_train = self.model.predict(self.x_train)
self._dump_to_hdf5("scores_train", "scores_test")
logger.info("Creating all validation plots")
self.plot_all()
def evaluate(self, x_eval):
......@@ -587,34 +597,13 @@ class ClassificationProject(object):
pass
@property
def bkg_weights(self):
"""
class weights multiplied by event weights (for plotting)
TODO: find a better way to do this
"""
if self._bkg_weights is None:
logger.debug("Calculating background weights for plotting")
self._bkg_weights = np.empty(sum(self.y_train == 0))
self._bkg_weights.fill(self.class_weight[0])
self._bkg_weights *= self.w_train[self.y_train == 0]
logger.debug("Background weights: {}".format(self._bkg_weights))
return self._bkg_weights
@property
def sig_weights(self):
"""
class weights multiplied by event weights (for plotting)
TODO: find a better way to do this
"""
if self._sig_weights is None:
logger.debug("Calculating signal weights for plotting")
self._sig_weights = np.empty(sum(self.y_train == 1))
self._sig_weights.fill(self.class_weight[1])
self._sig_weights *= self.w_train[self.y_train == 1]
logger.debug("Signal weights: {}".format(self._sig_weights))
return self._sig_weights
@staticmethod
def get_bin_centered_hist(x, scale_factor=None, **np_kwargs):
hist, bins = np.histogram(x, **np_kwargs)
centers = (bins[:-1] + bins[1:]) / 2
if scale_factor is not None:
hist *= scale_factor
return centers, hist
def plot_input(self, var_index):
......@@ -623,6 +612,8 @@ class ClassificationProject(object):
fig, ax = plt.subplots()
bkg = self.x_train[:,var_index][self.y_train == 0]
sig = self.x_train[:,var_index][self.y_train == 1]
bkg_weights = self.w_train[self.y_train == 0]
sig_weights = self.w_train[self.y_train == 1]
logger.debug("Plotting bkg (min={}, max={}) from {}".format(np.min(bkg), np.max(bkg), bkg))
logger.debug("Plotting sig (min={}, max={}) from {}".format(np.min(sig), np.max(sig), sig))
......@@ -636,14 +627,19 @@ class ClassificationProject(object):
logger.debug("Calculated range based on percentiles: {}".format(plot_range))
try:
ax.hist(bkg, color="b", alpha=0.5, bins=50, range=plot_range, weights=self.bkg_weights)
ax.hist(sig, color="r", alpha=0.5, bins=50, range=plot_range, weights=self.sig_weights)
centers_sig, hist_sig = self.get_bin_centered_hist(sig, scale_factor=self.class_weight[1], bins=50, range=plot_range, weights=sig_weights)
centers_bkg, hist_bkg = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights)
except ValueError:
# weird, probably not always working workaround for a numpy bug
plot_range = (float("{:.2f}".format(plot_range[0])), float("{:.2f}".format(plot_range[1])))
logger.warn("Got a value error during plotting, maybe this is due to a numpy bug - changing range to {}".format(plot_range))
ax.hist(bkg, color="b", alpha=0.5, bins=50, range=plot_range, weights=self.bkg_weights)
ax.hist(sig, color="r", alpha=0.5, bins=50, range=plot_range, weights=self.sig_weights)
centers_sig, hist_sig = self.get_bin_centered_hist(sig, scale_factor=self.class_weight[1], bins=50, range=plot_range, weights=sig_weights)
centers_bkg, hist_bkg = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights)
width = centers_sig[1]-centers_sig[0]
ax.bar(centers_bkg, hist_bkg, color="b", alpha=0.5, width=width)
ax.bar(centers_sig, hist_sig, color="r", alpha=0.5, width=width)
ax.set_xlabel(branch+" (transformed)")
plot_dir = os.path.join(self.project_dir, "plots")
if not os.path.exists(plot_dir):
......@@ -685,8 +681,24 @@ class ClassificationProject(object):
plt.savefig(os.path.join(self.project_dir, "ROC.pdf"))
plt.clf()
def plot_score(self):
pass
plot_opts = dict(bins=50, range=(0, 1))
centers_sig_train, hist_sig_train = self.get_bin_centered_hist(self.scores_train[self.y_train==1].reshape(-1), density=True, weights=self.w_train[self.y_train==1], **plot_opts)
centers_bkg_train, hist_bkg_train = self.get_bin_centered_hist(self.scores_train[self.y_train==0].reshape(-1), density=True, weights=self.w_train[self.y_train==0], **plot_opts)
centers_sig_test, hist_sig_test = self.get_bin_centered_hist(self.scores_test[self.y_test==1].reshape(-1), density=True, weights=self.w_test[self.y_test==1], **plot_opts)
centers_bkg_test, hist_bkg_test = self.get_bin_centered_hist(self.scores_test[self.y_test==0].reshape(-1), density=True, weights=self.w_test[self.y_test==0], **plot_opts)
fig, ax = plt.subplots()
width = centers_sig_train[1]-centers_sig_train[0]
ax.bar(centers_bkg_train, hist_bkg_train, color="b", alpha=0.5, width=width, label="background train")
ax.bar(centers_sig_train, hist_sig_train, color="r", alpha=0.5, width=width, label="signal train")
ax.scatter(centers_bkg_test, hist_bkg_test, color="b", label="background test")
ax.scatter(centers_sig_test, hist_sig_test, color="r", label="signal test")
ax.set_yscale("log")
ax.set_xlabel("NN output")
plt.legend(loc='upper right', framealpha=1.0)
fig.savefig(os.path.join(self.project_dir, "scores.pdf"))
@property
......@@ -743,6 +755,15 @@ class ClassificationProject(object):
plt.savefig(os.path.join(self.project_dir, "accuracy.pdf"))
plt.clf()
def plot_all(self):
self.plot_ROC()
self.plot_accuracy()
self.plot_loss()
self.plot_score()
self.plot_weights()
def create_getter(dataset_name):
def getx(self):
if getattr(self, "_"+dataset_name) is None:
......@@ -789,9 +810,7 @@ if __name__ == "__main__":
np.random.seed(42)
c.train(epochs=20)
c.plot_ROC()
c.plot_loss()
c.plot_accuracy()
#c.plot_all()
# c.write_friend_tree("test4_score",
# source_filename=filename, source_treename="GG_oneStep_1705_1105_505_NoSys",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment