diff --git a/toolkit.py b/toolkit.py index 9573f262c5e005f15e578e0d84ec9be621994143..191991f6c4a6b2fcdff6299187d5d40d65971917 100755 --- a/toolkit.py +++ b/toolkit.py @@ -198,8 +198,6 @@ class ClassificationProject(object): self._scaler = None self._class_weight = None - self._bkg_weights = None - self._sig_weights = None self._model = None self._history = None self._callbacks_list = [] @@ -476,8 +474,13 @@ class ClassificationProject(object): return self._class_weight - def load(self): + def load(self, reload=False): "Load all data needed for plotting and training" + + if reload: + self.data_loaded = False + self.data_transformed = False + if not self.data_loaded: self._load_data() @@ -492,9 +495,10 @@ class ClassificationProject(object): np.random.shuffle(self.y_train) np.random.set_state(rn_state) np.random.shuffle(self.w_train) - if self._scores_test is not None: + if self._scores_train is not None: + logger.info("Shuffling scores, since they are also there") np.random.set_state(rn_state) - np.random.shuffle(self._scores_test) + np.random.shuffle(self._scores_train) def train(self, epochs=10): @@ -531,12 +535,18 @@ class ClassificationProject(object): self.total_epochs += epochs self._write_info("epochs", self.total_epochs) + logger.info("Reloading (and re-transforming) unshuffled training data") + self.load(reload=True) + logger.info("Create/Update scores for ROC curve") self.scores_test = self.model.predict(self.x_test) self.scores_train = self.model.predict(self.x_train) self._dump_to_hdf5("scores_train", "scores_test") + logger.info("Creating all validation plots") + self.plot_all() + def evaluate(self, x_eval): @@ -587,34 +597,13 @@ class ClassificationProject(object): pass - @property - def bkg_weights(self): - """ - class weights multiplied by event weights (for plotting) - TODO: find a better way to do this - """ - if self._bkg_weights is None: - logger.debug("Calculating background weights for plotting") - self._bkg_weights = np.empty(sum(self.y_train == 0)) - self._bkg_weights.fill(self.class_weight[0]) - self._bkg_weights *= self.w_train[self.y_train == 0] - logger.debug("Background weights: {}".format(self._bkg_weights)) - return self._bkg_weights - - - @property - def sig_weights(self): - """ - class weights multiplied by event weights (for plotting) - TODO: find a better way to do this - """ - if self._sig_weights is None: - logger.debug("Calculating signal weights for plotting") - self._sig_weights = np.empty(sum(self.y_train == 1)) - self._sig_weights.fill(self.class_weight[1]) - self._sig_weights *= self.w_train[self.y_train == 1] - logger.debug("Signal weights: {}".format(self._sig_weights)) - return self._sig_weights + @staticmethod + def get_bin_centered_hist(x, scale_factor=None, **np_kwargs): + hist, bins = np.histogram(x, **np_kwargs) + centers = (bins[:-1] + bins[1:]) / 2 + if scale_factor is not None: + hist *= scale_factor + return centers, hist def plot_input(self, var_index): @@ -623,6 +612,8 @@ class ClassificationProject(object): fig, ax = plt.subplots() bkg = self.x_train[:,var_index][self.y_train == 0] sig = self.x_train[:,var_index][self.y_train == 1] + bkg_weights = self.w_train[self.y_train == 0] + sig_weights = self.w_train[self.y_train == 1] logger.debug("Plotting bkg (min={}, max={}) from {}".format(np.min(bkg), np.max(bkg), bkg)) logger.debug("Plotting sig (min={}, max={}) from {}".format(np.min(sig), np.max(sig), sig)) @@ -636,14 +627,19 @@ class ClassificationProject(object): logger.debug("Calculated range based on percentiles: {}".format(plot_range)) try: - ax.hist(bkg, color="b", alpha=0.5, bins=50, range=plot_range, weights=self.bkg_weights) - ax.hist(sig, color="r", alpha=0.5, bins=50, range=plot_range, weights=self.sig_weights) + centers_sig, hist_sig = self.get_bin_centered_hist(sig, scale_factor=self.class_weight[1], bins=50, range=plot_range, weights=sig_weights) + centers_bkg, hist_bkg = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights) except ValueError: # weird, probably not always working workaround for a numpy bug plot_range = (float("{:.2f}".format(plot_range[0])), float("{:.2f}".format(plot_range[1]))) logger.warn("Got a value error during plotting, maybe this is due to a numpy bug - changing range to {}".format(plot_range)) - ax.hist(bkg, color="b", alpha=0.5, bins=50, range=plot_range, weights=self.bkg_weights) - ax.hist(sig, color="r", alpha=0.5, bins=50, range=plot_range, weights=self.sig_weights) + centers_sig, hist_sig = self.get_bin_centered_hist(sig, scale_factor=self.class_weight[1], bins=50, range=plot_range, weights=sig_weights) + centers_bkg, hist_bkg = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights) + + width = centers_sig[1]-centers_sig[0] + ax.bar(centers_bkg, hist_bkg, color="b", alpha=0.5, width=width) + ax.bar(centers_sig, hist_sig, color="r", alpha=0.5, width=width) + ax.set_xlabel(branch+" (transformed)") plot_dir = os.path.join(self.project_dir, "plots") if not os.path.exists(plot_dir): @@ -685,8 +681,24 @@ class ClassificationProject(object): plt.savefig(os.path.join(self.project_dir, "ROC.pdf")) plt.clf() + def plot_score(self): - pass + plot_opts = dict(bins=50, range=(0, 1)) + centers_sig_train, hist_sig_train = self.get_bin_centered_hist(self.scores_train[self.y_train==1].reshape(-1), density=True, weights=self.w_train[self.y_train==1], **plot_opts) + centers_bkg_train, hist_bkg_train = self.get_bin_centered_hist(self.scores_train[self.y_train==0].reshape(-1), density=True, weights=self.w_train[self.y_train==0], **plot_opts) + centers_sig_test, hist_sig_test = self.get_bin_centered_hist(self.scores_test[self.y_test==1].reshape(-1), density=True, weights=self.w_test[self.y_test==1], **plot_opts) + centers_bkg_test, hist_bkg_test = self.get_bin_centered_hist(self.scores_test[self.y_test==0].reshape(-1), density=True, weights=self.w_test[self.y_test==0], **plot_opts) + fig, ax = plt.subplots() + width = centers_sig_train[1]-centers_sig_train[0] + ax.bar(centers_bkg_train, hist_bkg_train, color="b", alpha=0.5, width=width, label="background train") + ax.bar(centers_sig_train, hist_sig_train, color="r", alpha=0.5, width=width, label="signal train") + ax.scatter(centers_bkg_test, hist_bkg_test, color="b", label="background test") + ax.scatter(centers_sig_test, hist_sig_test, color="r", label="signal test") + ax.set_yscale("log") + ax.set_xlabel("NN output") + plt.legend(loc='upper right', framealpha=1.0) + fig.savefig(os.path.join(self.project_dir, "scores.pdf")) + @property @@ -743,6 +755,15 @@ class ClassificationProject(object): plt.savefig(os.path.join(self.project_dir, "accuracy.pdf")) plt.clf() + + def plot_all(self): + self.plot_ROC() + self.plot_accuracy() + self.plot_loss() + self.plot_score() + self.plot_weights() + + def create_getter(dataset_name): def getx(self): if getattr(self, "_"+dataset_name) is None: @@ -789,9 +810,7 @@ if __name__ == "__main__": np.random.seed(42) c.train(epochs=20) - c.plot_ROC() - c.plot_loss() - c.plot_accuracy() + #c.plot_all() # c.write_friend_tree("test4_score", # source_filename=filename, source_treename="GG_oneStep_1705_1105_505_NoSys",