Skip to content
Snippets Groups Projects
toolkit.py 79.6 KiB
Newer Older
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed


    def write_friend_tree(self, score_name,
                          source_filename, source_treename,
                          target_filename, target_treename,
                          batch_size=100000):
        f = ROOT.TFile.Open(source_filename)
        tree = f.Get(source_treename)
        entries = tree.GetEntries()
        logger.info("Write friend tree for {} in {}".format(source_treename, source_filename))
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
        if os.path.exists(target_filename):
            raise IOError("{} already exists, if you want to recreate it, delete it first".format(target_filename))
        for start in range(0, entries, batch_size):
Nikolai's avatar
Nikolai committed
            logger.info("Evaluating score for entry {}/{}".format(start, entries))
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
            logger.debug("Loading next batch")
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
                                     branches=self.branches+self.identifiers,
                                     start=start, stop=start+batch_size)
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
            x_eval = rec2array(x_from_tree[self.branches])
            if len(self.identifiers) > 0:
                # create list of booleans that indicate which events where used for training
                df_identifiers = pd.DataFrame(x_from_tree[self.identifiers])
                total_train_list = self.s_eventlist_train
                total_train_list = np.concatenate((total_train_list, self.b_eventlist_train))
                merged = df_identifiers.merge(pd.DataFrame(total_train_list), on=tuple(self.identifiers), indicator=True, how="left")
                is_train = np.array(merged["_merge"] == "both")
            else:
                is_train = np.zeros(len(x_eval))

            # join scores and is_train array
            scores = self.evaluate(x_eval).reshape(-1)
            friend_df = pd.DataFrame(np.array(scores, dtype=[(score_name, np.float64)]))
            friend_df[score_name+"_is_train"] = is_train
            friend_tree = friend_df.to_records()[[score_name, score_name+"_is_train"]]
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
            if start == 0:
                mode = "recreate"
            else:
                mode = "update"
            logger.debug("Write to root file")
            array2root(friend_tree, target_filename, treename=target_treename, mode=mode)
Nikolai's avatar
Nikolai committed
            logger.debug("Done")
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed

    def write_all_friend_trees(self):
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
        pass

    @staticmethod
    def get_bin_centered_hist(x, scale_factor=None, **np_kwargs):
        "Return bin centers, histogram and relative (!) errors"
        hist, bins = np.histogram(x, **np_kwargs)
        centers = (bins[:-1] + bins[1:]) / 2
        if "weights" in np_kwargs:
            bin_indices = np.digitize(x, bins)
            sumw2 = np.array([np.sum(np_kwargs["weights"][bin_indices==i]**2)
                              for i in range(1, len(bins)+1)])
            sumw = np.array([np.sum(np_kwargs["weights"][bin_indices==i])
                             for i in range(1, len(bins)+1)])
            # move overflow to last bin
            # (since thats what np.histogram gives us)
            sumw2[-2] += sumw2[-1]
            sumw2 = sumw2[:-1]
            sumw[-2] += sumw[-1]
            sumw = sumw[:-1]
            # calculate relative error
            errors = np.sqrt(sumw2)/sumw
        else:
            errors = np.sqrt(hist)/hist
        if scale_factor is not None:
            hist *= scale_factor
        return centers, hist, errors
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
    def plot_input(self, var_index, ax=None):
        "plot a single input variable"
        branch = self.fields[var_index]
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
        if ax is None:
            fig, ax = plt.subplots()
        else:
            fig = None
        bkg = self.x_train[:,var_index][self.y_train == 0]
        sig = self.x_train[:,var_index][self.y_train == 1]
        bkg_weights = self.w_train_tot[self.y_train == 0]
        sig_weights = self.w_train_tot[self.y_train == 1]

Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
        if hasattr(self, "mask_value"):
            bkg_not_masked = np.where(bkg != self.mask_value)[0]
            sig_not_masked = np.where(sig != self.mask_value)[0]
            bkg = bkg[bkg_not_masked]
            sig = sig[sig_not_masked]
            bkg_weights = bkg_weights[bkg_not_masked]
            sig_weights = sig_weights[sig_not_masked]

        if self.balance_dataset:
            if len(sig) < len(bkg):
                logger.warning("Plotting only up to {} bkg events, since we use balance_dataset".format(len(sig)))
                bkg = bkg[0:len(sig)]
                bkg_weights = bkg_weights[0:len(sig)]
            else:
                logger.warning("Plotting only up to {} sig events, since we use balance_dataset".format(len(bkg)))
                sig = sig[0:len(bkg)]
                sig_weights = sig_weights[0:len(bkg)]
        logger.debug("Plotting bkg (min={}, max={}) from {}".format(np.min(bkg), np.max(bkg), bkg))
        logger.debug("Plotting sig (min={}, max={}) from {}".format(np.min(sig), np.max(sig), sig))

        # calculate percentiles to get a heuristic for the range to be plotted
        x_total = np.concatenate([bkg, sig])
        w_total = np.concatenate([bkg_weights, sig_weights])
        plot_range = weighted_quantile(
            x_total,
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
            [0.01, 0.99],
            sample_weight=w_total,
        logger.debug("Calculated range based on percentiles: {}".format(plot_range))

        bins = 50

        # check if we have a distribution of integer numbers (e.g. njet or something categorical)
        # in that case we want to have a bin for each number
        if (x_total == x_total.astype(int)).all():
            plot_range = (math.floor(plot_range[0])-0.5, math.ceil(plot_range[1])+0.5)
            bins = int(plot_range[1]-plot_range[0])

            centers_sig, hist_sig, _ = self.get_bin_centered_hist(sig, bins=bins, range=plot_range, weights=sig_weights)
            centers_bkg, hist_bkg, _ = self.get_bin_centered_hist(bkg, bins=bins, range=plot_range, weights=bkg_weights)
        except ValueError:
            # weird, probably not always working workaround for a numpy bug
            plot_range = (float("{:.3f}".format(plot_range[0])), float("{:.3f}".format(plot_range[1])))
            logger.warn("Got a value error during plotting, maybe this is due to a numpy bug - changing range to {}".format(plot_range))
            centers_sig, hist_sig, _ = self.get_bin_centered_hist(sig, bins=bins, range=plot_range, weights=sig_weights)
            centers_bkg, hist_bkg, _ = self.get_bin_centered_hist(bkg, bins=bins, range=plot_range, weights=bkg_weights)

        width = centers_sig[1]-centers_sig[0]
        ax.bar(centers_bkg, hist_bkg, color="b", alpha=0.5, width=width)
        ax.bar(centers_sig, hist_sig, color="r", alpha=0.5, width=width)

Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
        label = branch
        if self.data_transformed:
            label += " (transformed)"
        ax.set_xlabel(label)
        if fig is not None:
            plot_dir = os.path.join(self.project_dir, "plots")
            if not os.path.exists(plot_dir):
                os.mkdir(plot_dir)
            return save_show(plt, fig, os.path.join(plot_dir, "var_{}.pdf".format(var_index)))


    def plot_all_inputs(self):
        nrows = math.ceil(math.sqrt(len(self.fields)))
        fig, axes = plt.subplots(nrows=int(nrows), ncols=int(nrows),
                                 figsize=(3*nrows, 3*nrows),
                                 gridspec_kw=dict(wspace=0.4, hspace=0.4))
        for i in range(len(self.fields)):
            self.plot_input(i, ax=axes.reshape(-1)[i])
        return save_show(plt, fig, os.path.join(self.project_dir, "all_inputs.pdf"))
    def plot_weights(self, bins=100, range=None):
        fig, ax = plt.subplots()
        bkg = self.w_train_tot[self.y_train == 0]
        sig = self.w_train_tot[self.y_train == 1]
        ax.hist(bkg, bins=bins, range=range, color="b", alpha=0.5)
        ax.set_yscale("log")
        save_show(plt, fig, os.path.join(self.project_dir, "eventweights_bkg.pdf"))
        fig, ax = plt.subplots()
        ax.hist(sig, bins=bins, range=range, color="r", alpha=0.5)
        ax.set_yscale("log")
        save_show(plt, fig, os.path.join(self.project_dir, "eventweights_sig.pdf"))
    def plot_ROC(self, xlim=(0,1), ylim=(0,1)):

        logger.info("Plot ROC curve")
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed

        fig, ax = plt.subplots()
        ax.grid(color='gray', linestyle='--', linewidth=1)

        for y, scores, weight, label in [
                (self.y_train, self.scores_train, self.w_train, "train"),
                (self.y_test, self.scores_test, self.w_test, "test")
        ]:
            fpr, tpr, threshold = roc_curve(y, scores, sample_weight = weight)
            fpr = 1.0 - fpr # background rejection
            try:
                roc_auc = auc(tpr, fpr)
            except ValueError:
                logger.warning("Got a value error from auc - trying to rerun with reorder=True")
                roc_auc = auc(tpr, fpr, reorder=True)
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
            ax.plot(tpr,  fpr, label=str(self.name + " {} (AUC = {:.3f})".format(label, roc_auc)))

        ax.plot([0,1],[1,0], linestyle='--', color='black', label='Luck')
        ax.set_ylabel("Background rejection")
        ax.set_xlabel("Signal efficiency")
        ax.set_title('Receiver operating characteristic')
        ax.set_xlim(*xlim)
        ax.set_ylim(*ylim)
        # plt.xticks(np.arange(0,1,0.1))
        # plt.yticks(np.arange(0,1,0.1))
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
        ax.legend(loc='lower left', framealpha=1.0)
        return save_show(plt, fig, os.path.join(self.project_dir, "ROC.pdf"))
    def plot_score(self, log=True, plot_opts=dict(bins=50, range=(0,1)),
                   ylim=None, xlim=None, density=True,
                   lumifactor=None, apply_class_weight=True,
                   invert_activation=False):
        if invert_activation:
            trf = self.get_inverse_act_fn()
        else:
            trf = lambda y : y
        for scores, weights, y, class_label, fn, opts in [
                (self.scores_train, self.w_train, self.y_train, 1, ax.bar, dict(color="r", label="signal train")),
                (self.scores_train, self.w_train, self.y_train, 0, ax.bar, dict(color="b", label="background train")),
                (self.scores_test, self.w_test, self.y_test, 1, ax.errorbar, dict(fmt="ro", label="signal test")),
                (self.scores_test, self.w_test, self.y_test, 0, ax.errorbar, dict(fmt="bo", label="background test")),
        ]:
            weights = weights[y==class_label]
            if apply_class_weight is True and (lumifactor is not None):
                logger.warning("not applying class weight, since lumifactor given")
            if apply_class_weight and (lumifactor is None):
                weights = weights*self.class_weight[class_label]
            if lumifactor is not None:
                weights = weights*lumifactor
            centers, hist, rel_errors = self.get_bin_centered_hist(
                trf(scores[y==class_label].reshape(-1)),
                weights=weights,
                **plot_opts
            )
            width = centers[1]-centers[0]
            if density:
                hist = hist/width
            if fn == ax.errorbar:
                errors = rel_errors*hist
                opts.update(yerr=errors)
            else:
                opts.update(width=width, alpha=0.5)
            fn(centers, hist, **opts)
        if log:
            ax.set_yscale("log")
        if ylim is not None:
            ax.set_ylim(*ylim)
        if xlim is not None:
            ax.set_xlim(*xlim)
        if density:
            ax.set_ylabel("dN / d(NN output)")
        else:
            ax.set_ylabel("Events / {:.2f}".format(width))
        if apply_class_weight:
            ax.set_title("Class weights applied")
        ax.legend(loc='upper center', framealpha=0.5)
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
        return save_show(plt, fig, os.path.join(self.project_dir, "scores.pdf"))
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed

    def plot_significance_hist(self, lumifactor=1., significance_function=None, plot_opts=dict(bins=50, range=(0, 1)), invert_activation=False):

        """
        Plot significances based on a histogram of scores
        """

        logger.info("Plot significances")

        if invert_activation:
            trf = self.get_inverse_act_fn()
        else:
            trf = lambda y : y

        centers_sig_train, hist_sig_train, rel_errors_sig_train = self.get_bin_centered_hist(trf(self.scores_train[self.y_train==1].reshape(-1)), weights=self.w_train[self.y_train==1], **plot_opts)
        centers_bkg_train, hist_bkg_train, rel_errors_bkg_train = self.get_bin_centered_hist(trf(self.scores_train[self.y_train==0].reshape(-1)), weights=self.w_train[self.y_train==0], **plot_opts)
        centers_sig_test, hist_sig_test, rel_errors_sig_test = self.get_bin_centered_hist(trf(self.scores_test[self.y_test==1].reshape(-1)), weights=self.w_test[self.y_test==1], **plot_opts)
        centers_bkg_test, hist_bkg_test, rel_errors_bkg_test = self.get_bin_centered_hist(trf(self.scores_test[self.y_test==0].reshape(-1)), weights=self.w_test[self.y_test==0], **plot_opts)

        significances_train = []
        significances_test = []
        for hist_sig, hist_bkg, rel_errors_sig, rel_errors_bkg, significances, w, y in [
                (hist_sig_train, hist_bkg_train, rel_errors_sig_train, rel_errors_bkg_train, significances_train, self.w_train, self.y_train),
                (hist_sig_test, hist_bkg_test, rel_errors_sig_test, rel_errors_bkg_test, significances_test, self.w_test, self.y_test),
            # factor to rescale due to using only a fraction of events (training and test samples)
            # normfactor_sig = (np.sum(self.w_train[self.y_train==1])+np.sum(self.w_test[self.y_test==1]))/np.sum(w[y==1])
            # normfactor_bkg = (np.sum(self.w_train[self.y_train==0])+np.sum(self.w_test[self.y_test==0]))/np.sum(w[y==0])
            normfactor_sig = self.step_signal
            normfactor_bkg = self.step_bkg
            # first set nan values to 0 and multiply by lumi
            for arr in hist_sig, hist_bkg, rel_errors_bkg:
                arr[np.isnan(arr)] = 0
            hist_sig *= lumifactor*normfactor_sig
            hist_bkg *= lumifactor*normfactor_bkg
            for i in range(len(hist_sig)):
                s = sum(hist_sig[i:])
                b = sum(hist_bkg[i:])
                db = math.sqrt(sum((rel_errors_bkg[i:]*hist_bkg[i:])**2))
                if significance_function is None:
                        z = poisson_asimov_significance(s, 0, b, db)
                    except (ZeroDivisionError, ValueError) as e:
                        z = 0
                else:
                    z = significance_function(s, b, db)
                logger.debug("s, b, db, z = {}, {}, {}, {}".format(s, b, db, z))
                significances.append(z)

        fig, ax = plt.subplots()
        width = centers_sig_train[1]-centers_sig_train[0]
        ax.plot(centers_bkg_train, significances_train, label="train, Z_max={}".format(np.amax(significances_train)))
        ax.plot(centers_bkg_test, significances_test, label="test, Z_max={}".format(np.amax(significances_test)))
        ax.set_xlabel("Cut on NN score")
        ax.set_ylabel("Significance")
        ax.legend(loc='lower center', framealpha=0.5)
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
        return save_show(plt, fig, os.path.join(self.project_dir, "significances_hist.pdf"))


    @staticmethod
    def calc_s_ds_b_db(scores, y, w):

        """
        Calculate the sum of weights of signal (s), background (b) and the
        sqrt of the squared sum of weights for all possible threshold
        of the output score.
        Following the implementation from sklearn.metrics.ranking._binary_clf_curve
        """

        desc_score_indices = np.argsort(scores, kind="mergesort")[::-1]
        scores_sorted = scores[desc_score_indices]
        y_sorted = y[desc_score_indices]
        w_sorted = w[desc_score_indices]
        distinct_value_indices = np.where(np.diff(scores_sorted))[0]
        threshold_idxs = np.r_[distinct_value_indices, y_sorted - 1]
        s_sumw = stable_cumsum(y_sorted * w_sorted)[threshold_idxs]
        s_sumw2 = stable_cumsum(y_sorted * (w_sorted**2))[threshold_idxs]
        b_sumw = stable_cumsum(np.logical_not(y_sorted) * w_sorted)[threshold_idxs]
        b_sumw2 = stable_cumsum(np.logical_not(y_sorted) * (w_sorted**2))[threshold_idxs]

        return s_sumw, np.sqrt(s_sumw2), b_sumw, np.sqrt(b_sumw2), scores_sorted[threshold_idxs]


    def get_inverse_act_fn(self):
        if not self.activation_function_output == "sigmoid":
            raise NotImplementedError("Inverse function of {} not supported yet - "
                                      "currently only sigmoid"
                                      .format(self.activation_function_output))
        return lambda y : np.log(y/(1-y))


    def plot_significance(self, significance_function=None, maxsteps=None, lumifactor=1., vectorized=False, invert_activation=False):
        """
        Plot the significance when cutting on all posible thresholds and plot against signal efficiency.
        """

        if significance_function is None:
            vectorized = True
            significance_function = poisson_asimov_significance

        if invert_activation:
            trf = self.get_inverse_act_fn()
        else:
            trf = lambda y : y

        fig, ax = plt.subplots()
        ax2 = ax.twinx()
        prop_cycle = plt.rcParams['axes.prop_cycle']
        colors = prop_cycle.by_key()['color']
        for (scores, y, w, label), col in zip(
                [(self.scores_train, self.y_train, self.w_train, "train"),
                 (self.scores_test, self.y_test, self.w_test, "test")],
                colors
        ):
            scores = trf(scores)
            s_sumws, s_errs, b_sumws, b_errs, thresholds = self.calc_s_ds_b_db(scores, y, w)
            if maxsteps is not None:
                stepsize = int(len(s_sumws))/int(maxsteps)
            else:
                stepsize = 1
            if stepsize == 0:
                stepsize = 1
            s_sumws = s_sumws[::stepsize]*lumifactor*self.step_signal
            s_errs = s_errs[::stepsize]*lumifactor*self.step_signal
            b_sumws = b_sumws[::stepsize]*lumifactor*self.step_bkg
            b_errs = b_errs[::stepsize]*lumifactor*self.step_bkg
            nonzero_b = np.where(b_sumws!=0)[0]
            s_sumws = s_sumws[nonzero_b]
            s_errs = s_errs[nonzero_b]
            b_sumws = b_sumws[nonzero_b]
            b_errs = b_errs[nonzero_b]
            thresholds = thresholds[nonzero_b]
            if not vectorized:
                zs = []
                for s, ds, b, db in zip(s_sumws, s_errs, b_sumws, b_errs):
                    zs.append(significance_function(s, ds, b, db))
            else:
                zs = significance_function(s_sumws, s_errs, b_sumws, b_errs)
            ax.plot(s_sumws/s_sumws[-1], zs, label=label, color=col)
            ax2.plot(s_sumws/s_sumws[-1], thresholds, "--", color=col)
        ax.set_xlabel("Signal efficiency")
        ax.set_ylabel("Significance")
        ax.set_xlim(0, 1)
        ax2.set_ylabel("Threshold")
        ax.legend()
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
        return save_show(plt, fig, os.path.join(self.project_dir, "significances.pdf"))
    @property
    def csv_hist(self):
        with open(os.path.join(self.project_dir, "training.log")) as f:
            reader = csv.reader(f)
            history_list = list(reader)
        hist_dict = {}
        for hist_key_index, hist_key in enumerate(history_list[0]):
            hist_dict[hist_key] = [float(line[hist_key_index]) for line in history_list[1:]]
        return hist_dict

    def plot_loss(self, all_trainings=False, log=False, ylim=None, xlim=None):
        """
        Plot the value of the loss function for each epoch

        :param all_trainings: set to true if you want to plot all trainings (otherwise the previous history is used)
        """

        if all_trainings:
            hist_dict = self.csv_hist
        else:
            hist_dict = self.history.history
        if (not 'loss' in hist_dict) or (not 'val_loss' in hist_dict):
            logger.warning("No previous history found for plotting, try global history")
            hist_dict = self.csv_hist

        logger.info("Plot losses")
        plt.plot(hist_dict['loss'])
        plt.plot(hist_dict['val_loss'])
        plt.ylabel('loss')
        plt.xlabel('epoch')
        plt.legend(['training data','validation data'], loc='upper left')
        if log:
            plt.yscale("log")
        if ylim is not None:
            plt.ylim(*ylim)
        plt.savefig(os.path.join(self.project_dir, "losses.pdf"))
Thomas Weber's avatar
Thomas Weber committed
        plt.clf()
Nikolai's avatar
Nikolai committed
    def plot_accuracy(self, all_trainings=False, log=False, acc_suffix="weighted_acc"):
        """
        Plot the value of the accuracy metric for each epoch

        :param all_trainings: set to true if you want to plot all trainings (otherwise the previous history is used)
        """

        if all_trainings:
            hist_dict = self.csv_hist
        else:
            hist_dict = self.history.history
Nikolai's avatar
Nikolai committed
        if (not acc_suffix in hist_dict) or (not 'val_'+acc_suffix in hist_dict):
            logger.warning("No previous history found for plotting, try global history")
            hist_dict = self.csv_hist

        logger.info("Plot accuracy")
Nikolai's avatar
Nikolai committed
        plt.plot(hist_dict[acc_suffix])
        plt.plot(hist_dict['val_'+acc_suffix])
        plt.title('model accuracy')
        plt.ylabel('accuracy')
        plt.xlabel('epoch')
        plt.legend(['training data', 'validation data'], loc='upper left')
        if log:
            plt.yscale("log")
        plt.savefig(os.path.join(self.project_dir, "accuracy.pdf"))
Thomas Weber's avatar
Thomas Weber committed
        plt.clf()
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed

Nikolai's avatar
Nikolai committed
        # self.plot_accuracy()
        self.plot_loss()
        self.plot_score()
        self.plot_weights()
Nikolai's avatar
Nikolai committed
        # self.plot_significance()
Nikolai's avatar
Nikolai committed
    def to_DataFrame(self):
        df = pd.DataFrame(np.concatenate([self.x_train, self.x_test]), columns=self.fields)
        df["weight"] = np.concatenate([self.w_train, self.w_test])
        df["labels"] = pd.Categorical.from_codes(
            np.concatenate([self.y_train, self.y_test]),
            categories=["background", "signal"]
        )
        for identifier in self.identifiers:
            try:
                df[identifier] = np.concatenate([self.s_eventlist_train[identifier],
                                                 self.b_eventlist_train[identifier],
                                                 -1*np.ones(len(self.x_test), dtype="i8")])
            except IOError:
                logger.warning("Can't find eventlist - DataFrame won't contain identifiers")
        df["is_train"] = np.concatenate([np.ones(len(self.x_train), dtype=np.bool),
                                         np.zeros(len(self.x_test), dtype=np.bool)])
def create_getter(dataset_name):
    def getx(self):
        if getattr(self, "_"+dataset_name) is None:
            self._load_from_hdf5(dataset_name)
        return getattr(self, "_"+dataset_name)
    return getx

def create_setter(dataset_name):
    def setx(self, value):
        setattr(self, "_"+dataset_name, value)
    return setx

# define getters and setters for all datasets
for dataset_name in ClassificationProject.dataset_names:
    setattr(ClassificationProject, dataset_name, property(create_getter(dataset_name),
                                                          create_setter(dataset_name)))
class ClassificationProjectDataFrame(ClassificationProject):
    """
    A little hack to initialize a ClassificationProject from a pandas DataFrame instead of ROOT TTrees
    """
    def __init__(self, name, *args, **kwargs):
        if len(args) < 1 and len(kwargs) < 1:
            # if no further arguments given, interpret as directory name
            self._init_from_dir(name)
        else:
            # otherwise initialise new project
            self._init_from_args(name, *args, **kwargs)
            with open(os.path.join(self.project_dir, "options.pickle"), "wb") as of:
                # don't put the dataframe into options.pickle!
                if len(args) > 1:
                    args = args[1:]
                else:
                    args = []
                pickle.dump(dict(args=args, kwargs=kwargs), of)


    def _init_from_args(self,
                        name,
                        df,
                        input_columns,
                        weight_column="weights",
                        label_column="labels",
                        signal_label="signal",
                        background_label="background",
                        split_mode="split_column",
                        split_column="is_train",
                        **kwargs):

        self.df = df
        self.input_columns = input_columns
        self.weight_column = weight_column
        self.label_column = label_column
        self.signal_label = signal_label
        self.background_label = background_label
        if split_mode != "split_column":
            raise NotImplementedError("'split_column' is the only currently supported split mode")
        self.split_mode = split_mode
        self.split_column = split_column
        super(ClassificationProjectDataFrame, self)._init_from_args(
            name,
            signal_trees=[], bkg_trees=[], branches=[], weight_expr="1",
            **kwargs
        )
        self._x_train = None
        self._x_test = None
        self._y_train = None
        self._y_test = None
        self._w_train = None
        self._w_test = None

    @property
    def x_train(self):
        if self._x_train is None:
            self._x_train = self.df[self.df[self.split_column]][self.input_columns].values
        return self._x_train

    @x_train.setter
    def x_train(self, value):
        self._x_train = value

    @property
    def x_test(self):
        if self._x_test is None:
            self._x_test = self.df[~self.df[self.split_column]][self.input_columns].values
        return self._x_test

    @x_test.setter
    def x_test(self, value):
        self._x_test = value

    @property
    def y_train(self):
        if self._y_train is None:
            self._y_train = (self.df[self.df[self.split_column]][self.label_column] == self.signal_label).values
        return self._y_train

    @y_train.setter
    def y_train(self, value):
        self._y_train = value

    @property
    def y_test(self):
        if self._y_test is None:
            self._y_test = (self.df[~self.df[self.split_column]][self.label_column] == self.signal_label).values
        return self._y_test

    @y_test.setter
    def y_test(self, value):
        self._y_test = value

    @property
    def w_train(self):
        if self._w_train is None:
            self._w_train = self.df[self.df[self.split_column]][self.weight_column].values
        return self._w_train

    @w_train.setter
    def w_train(self, value):
        self._w_train = value

    @property
    def w_test(self):
        if self._w_test is None:
            self._w_test = self.df[~self.df[self.split_column]][self.weight_column].values
        return self._w_test

    @w_test.setter
    def w_test(self, value):
        self._w_test = value

    @property
    def fields(self):
        return self.input_columns


    def load(self, reload=False):

        if reload:
            self.data_loaded = False
            self.data_transformed = False
            self._x_train = None
            self._x_test = None
            self._y_train = None
            self._y_test = None
            self._w_train = None
            self._w_test = None
            self._w_train_tot = None
Nikolai's avatar
Nikolai committed
class ClassificationProjectRNN(ClassificationProject):

    """
    A little wrapper to use recurrent units for things like jet collections
    """

    def _init_from_args(self, name,
                        recurrent_field_names=None,
                        rnn_layer_nodes=32,
                        mask_value=-999,
                        recurrent_unit_type="GRU",
                        **kwargs):
        recurrent_field_names example:
        [["jet1Pt", "jet1Eta", "jet1Phi"],
         ["jet2Pt", "jet2Eta", "jet2Phi"],
         ["jet3Pt", "jet3Eta", "jet3Phi"]],
        [["lep1Pt", "lep1Eta", "lep1Phi", "lep1flav"],
         ["lep2Pt", "lep2Eta", "lep2Phi", "lep2flav"]],
        """
        super(ClassificationProjectRNN, self)._init_from_args(name, **kwargs)

        self._write_info("project_type", "ClassificationProjectRNN")

        self.recurrent_field_names = recurrent_field_names
        if self.recurrent_field_names is None:
            self.recurrent_field_names = []
        self.rnn_layer_nodes = rnn_layer_nodes
        self.mask_value = mask_value
        self.recurrent_unit_type = recurrent_unit_type

        # convert to  of indices
        self.recurrent_field_idx = []
        for field_name_list in self.recurrent_field_names:
            field_names = np.array([field_name_list])
            if field_names.dtype == np.object:
                raise ValueError(
                    "Invalid entry for recurrent fields: {} - "
                    "please ensure that the length for all elements in the list is equal"
                    .format(field_names)
            field_idx = (
                np.array([self.fields.index(field_name)
                          for field_name in field_names.reshape(-1)])
                .reshape(field_names.shape)
            )
            self.recurrent_field_idx.append(field_idx)
        self.flat_fields = []
        for field in self.fields:
            if any(self.fields.index(field) in field_idx.reshape(-1) for field_idx in self.recurrent_field_idx):
                continue
            self.flat_fields.append(field)

        if self.scaler_type != "WeightedRobustScaler":
            raise NotImplementedError(
                "Invalid scaler '{}' - only WeightedRobustScaler is currently supported for RNN"
                .format(self.scaler_type)
            )


Nikolai's avatar
Nikolai committed
    @property
    def model(self):
        if self._model is None:
            # following the setup from the tutorial:
            # https://github.com/YaleATLAS/CERNDeepLearningTutorial
            rnn_inputs = []
            rnn_channels = []
            for field_idx in self.recurrent_field_idx:
                chan_inp = Input(field_idx.shape[1:])
                channel = Masking(mask_value=self.mask_value)(chan_inp)
                if self.recurrent_unit_type == "GRU":
                    channel = GRU(self.rnn_layer_nodes)(channel)
                elif self.recurrent_unit_type == "SimpleRNN":
                    channel = SimpleRNN(self.rnn_layer_nodes)(channel)
                else:
                    raise NotImplementedError("{} not implemented".format(self.recurrent_unit_type))
                logger.info("Added {} unit".format(self.recurrent_unit_type))
                # TODO: configure dropout for recurrent layers
                #channel = Dropout(0.3)(channel)
                rnn_inputs.append(chan_inp)
                rnn_channels.append(channel)
            flat_input = Input((len(self.flat_fields),))
Nikolai's avatar
Nikolai committed
            if self.dropout_input is not None:
                flat_channel = Dropout(rate=self.dropout_input)(flat_input)
            else:
                flat_channel = flat_input
            combined = concatenate(rnn_channels+[flat_channel])
            for node_count, dropout_fraction in zip(self.nodes, self.dropout):
                combined = Dense(node_count, activation=self.activation_function)(combined)
                if (dropout_fraction is not None) and (dropout_fraction > 0):
                    combined = Dropout(rate=dropout_fraction)(combined)
            combined = Dense(1, activation=self.activation_function_output)(combined)
            self._model = Model(inputs=rnn_inputs+[flat_input], outputs=combined)
            self._compile_or_load_model()
        return self._model


    def train(self, epochs=10):
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
        self.shuffle_training_data()

        for branch_index, branch in enumerate(self.fields):
            self.plot_input(branch_index)

        self.total_epochs = self._read_info("epochs", 0)

        try:
            self.is_training = True
            logger.info("Training on batches for RNN")
            # note: the batches have class_weight already applied
            self.model.fit_generator(self.yield_batch(),
                                     steps_per_epoch=int(len(self.training_data[0])/self.batch_size),
                                     epochs=epochs,
                                     validation_data=self.validation_data,
                                     callbacks=self.callbacks_list)
            self.is_training = False
        except KeyboardInterrupt:
            logger.info("Interrupt training - continue with rest")
        self.checkpoint_model()
    def clean_mask(self, x):
        """
        Mask recurrent fields such that once a masked value occurs,
        all values corresponding to the same and following objects are
        masked as well. Works in place.
        """
        for recurrent_field_idx in self.recurrent_field_idx:
            for evt in x:
                masked = False
                for line_idx in recurrent_field_idx.reshape(*recurrent_field_idx.shape[1:]):
                    if (evt[line_idx] == self.mask_value).any():
                        masked=True
                    if masked:
                        evt[line_idx] = self.mask_value


    def mask_uniform(self, x):
        """
        Mask recurrent fields with a random (uniform) number of objects. Works in place.
        """
        for recurrent_field_idx in self.recurrent_field_idx:
            for evt in x:
                masked = False
                nobj = int(random.random()*(recurrent_field_idx.shape[1]+1))
                for obj_number, line_idx in enumerate(recurrent_field_idx.reshape(*recurrent_field_idx.shape[1:])):
                    if obj_number == nobj:
                        masked=True
                    if masked:
                        evt[line_idx] = self.mask_value


    def get_input_list(self, x):
        "Format the input starting from flat ntuple"
        x_input = []
        for field_idx in self.recurrent_field_idx:
            x_recurrent = x[:,field_idx.reshape(-1)].reshape(-1, *field_idx.shape[1:])
            x_input.append(x_recurrent)
        x_flat = x[:,[self.fields.index(field_name) for field_name in self.flat_fields]]
        x_input.append(x_flat)
    def get_input_flat(self, x):
        "Transform input back to flat ntuple"
        nevent = x[0].shape[0]
        x_flat = np.empty((nevent, len(self.fields)), dtype=np.float)
        # recurrent fields
        for rec_ar, idx in zip(x, self.recurrent_field_idx):
            idx = idx.reshape(-1)
            for source_idx, target_idx in enumerate(idx):
                x_flat[:,target_idx] = rec_ar.reshape(nevent, -1)[:,source_idx]
        # flat fields
        for source_idx, field_name in enumerate(self.flat_fields):
            target_idx = self.fields.index(field_name)
            x_flat[:,target_idx] = x[-1][:,source_idx]
        return x_flat


Nikolai's avatar
Nikolai committed
    def yield_batch(self):
        x_train, y_train, w_train = self.training_data
Nikolai's avatar
Nikolai committed
        while True:
            shuffled_idx = np.random.permutation(len(x_train))
            for start in range(0, len(shuffled_idx), int(self.batch_size)):
                x_batch = x_train[shuffled_idx[start:start+int(self.batch_size)]]
                y_batch = y_train[shuffled_idx[start:start+int(self.batch_size)]]
                w_batch = w_train[shuffled_idx[start:start+int(self.batch_size)]]
                x_input = self.get_input_list(x_batch)
                yield (x_input, y_batch, w_batch)
    def validation_data(self):
        "class weighted validation data. Attention: Shuffle training data before using this!"
        x_val, y_val, w_val = super(ClassificationProjectRNN, self).validation_data
        x_val_input = self.get_input_list(x_val)
        return x_val_input, y_val, w_val


    def evaluate_train_test(self, do_train=True, do_test=True, batch_size=10000, mode=None):
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
        logger.info("Reloading (and re-transforming) unshuffled training data")
        self.load(reload=True)

        if mode is not None:
            self._write_info("scores_mode", mode)

Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
        def eval_score(data_name):
            logger.info("Create/Update scores for {} sample".format(data_name))
            n_events = len(getattr(self, "x_"+data_name))
            setattr(self, "scores_"+data_name, np.empty(n_events))
            for start in range(0, n_events, batch_size):
                stop = start+batch_size
                getattr(self, "scores_"+data_name)[start:stop] = (
                    self.predict(
                        self.get_input_list(getattr(self, "x_"+data_name)[start:stop]),
                        mode=mode
                    ).reshape(-1)
                )
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
            self._dump_to_hdf5("scores_"+data_name)

        if do_test:
            eval_score("test")
        if do_train:
            eval_score("train")


    def evaluate(self, x_eval, mode=None):
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
        logger.debug("Evaluate score for {}".format(x_eval))
        x_eval = self.transform(x_eval)
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
        logger.debug("Evaluate for transformed array: {}".format(x_eval))
        return self.predict(self.get_input_list(x_eval), mode=mode)
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
if __name__ == "__main__":

    logging.basicConfig()
    logging.getLogger("KerasROOTClassification").setLevel(logging.INFO)
    #logging.getLogger("KerasROOTClassification").setLevel(logging.DEBUG)
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
    filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root"

    c = ClassificationProject("test4",
                              signal_trees = [(filename, "GG_oneStep_1705_1105_505_NoSys")],
                              bkg_trees = [(filename, "ttbar_NoSys"),
                                           (filename, "wjets_Sherpa221_NoSys")
                              ],
                              optimizer="Adam",
                              #optimizer="SGD",
                              #optimizer_opts=dict(lr=100., decay=1e-6, momentum=0.9),
Nikolai's avatar
Nikolai committed
                              earlystopping_opts=dict(monitor='val_loss',
                                                      min_delta=0, patience=2, verbose=0, mode='auto'),
                              selection="1",
                              branches = ["met", "mt"],
                              weight_expr = "eventWeight*genWeight",
                              identifiers = ["DatasetNumber", "EventNumber"],
                              step_bkg = 100)
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed

    np.random.seed(42)
    c.train(epochs=20)
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed

    # c.write_friend_tree("test4_score",
    #                     source_filename=filename, source_treename="GG_oneStep_1705_1105_505_NoSys",
    #                     target_filename="friend.root", target_treename="test4_score")
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed

    # c.write_friend_tree("test4_score",
    #                     source_filename=filename, source_treename="ttbar_NoSys",
    #                     target_filename="friend_ttbar_NoSys.root", target_treename="test4_score")