Skip to content
Snippets Groups Projects
toolkit.py 71.3 KiB
Newer Older

        # calculate percentiles to get a heuristic for the range to be plotted
        # (should in principle also be done with weights, but for now do it unweighted)
        # range_sig = np.percentile(sig, [1, 99])
        # range_bkg = np.percentile(sig, [1, 99])
        # plot_range = (min(range_sig[0], range_bkg[0]), max(range_sig[1], range_sig[1]))
        plot_range = weighted_quantile(
            self.x_train[:,var_index], [0.01, 0.99],
            sample_weight=self.w_train*np.array(self.class_weight)[self.y_train.astype(int)]
        )

        logger.debug("Calculated range based on percentiles: {}".format(plot_range))

        try:
            centers_sig, hist_sig, _ = self.get_bin_centered_hist(sig, scale_factor=self.class_weight[1], bins=50, range=plot_range, weights=sig_weights)
            centers_bkg, hist_bkg, _ = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights)
        except ValueError:
            # weird, probably not always working workaround for a numpy bug
            plot_range = (float("{:.3f}".format(plot_range[0])), float("{:.3f}".format(plot_range[1])))
            logger.warn("Got a value error during plotting, maybe this is due to a numpy bug - changing range to {}".format(plot_range))
            centers_sig, hist_sig, _ = self.get_bin_centered_hist(sig, scale_factor=self.class_weight[1], bins=50, range=plot_range, weights=sig_weights)
            centers_bkg, hist_bkg, _ = self.get_bin_centered_hist(bkg, scale_factor=self.class_weight[0], bins=50, range=plot_range, weights=bkg_weights)

        width = centers_sig[1]-centers_sig[0]
        ax.bar(centers_bkg, hist_bkg, color="b", alpha=0.5, width=width)
        ax.bar(centers_sig, hist_sig, color="r", alpha=0.5, width=width)

        ax.set_xlabel(branch+" (transformed)")
        plot_dir = os.path.join(self.project_dir, "plots")
        if not os.path.exists(plot_dir):
            os.mkdir(plot_dir)
        fig.savefig(os.path.join(plot_dir, "var_{}.pdf".format(var_index)))
        plt.close(fig)

    def plot_weights(self):
        fig, ax = plt.subplots()
        bkg = self.w_train[self.y_train == 0]
        sig = self.w_train[self.y_train == 1]
        ax.hist(bkg, bins=100, color="b", alpha=0.5)
        fig.savefig(os.path.join(self.project_dir, "eventweights_bkg.pdf"))
        plt.close(fig)
        fig, ax = plt.subplots()
        ax.hist(sig, bins=100, color="r", alpha=0.5)
        fig.savefig(os.path.join(self.project_dir, "eventweights_sig.pdf"))
        plt.close(fig)
    def plot_ROC(self, xlim=(0,1), ylim=(0,1)):

        logger.info("Plot ROC curve")
        plt.grid(color='gray', linestyle='--', linewidth=1)

        for y, scores, weight, label in [
                (self.y_train, self.scores_train, self.w_train, "train"),
                (self.y_test, self.scores_test, self.w_test, "test")
        ]:
            fpr, tpr, threshold = roc_curve(y, scores, sample_weight = weight)
            fpr = 1.0 - fpr # background rejection
            try:
                roc_auc = auc(tpr, fpr)
            except ValueError:
                logger.warning("Got a value error from auc - trying to rerun with reorder=True")
                roc_auc = auc(tpr, fpr, reorder=True)
            plt.plot(tpr,  fpr, label=str(self.name + " {} (AUC = {:.3f})".format(label, roc_auc)))

Eric Schanet's avatar
Eric Schanet committed
        plt.plot([0,1],[1,0], linestyle='--', color='black', label='Luck')
Eric Schanet's avatar
Eric Schanet committed
        plt.ylabel("Background rejection")
        plt.xlabel("Signal efficiency")
        plt.title('Receiver operating characteristic')
        plt.xlim(*xlim)
        plt.ylim(*ylim)
        # plt.xticks(np.arange(0,1,0.1))
        # plt.yticks(np.arange(0,1,0.1))
        plt.legend(loc='lower left', framealpha=1.0)
        plt.savefig(os.path.join(self.project_dir, "ROC.pdf"))
        plt.clf()

    def plot_score(self, log=True, plot_opts=dict(bins=50, range=(0, 1)), ylim=None, xlim=None):
        centers_sig_train, hist_sig_train, _ = self.get_bin_centered_hist(self.scores_train[self.y_train==1].reshape(-1), density=True, weights=self.w_train[self.y_train==1], **plot_opts)
        centers_bkg_train, hist_bkg_train, _ = self.get_bin_centered_hist(self.scores_train[self.y_train==0].reshape(-1), density=True, weights=self.w_train[self.y_train==0], **plot_opts)
        centers_sig_test, hist_sig_test, rel_errors_sig_test = self.get_bin_centered_hist(self.scores_test[self.y_test==1].reshape(-1), density=True, weights=self.w_test[self.y_test==1], **plot_opts)
        centers_bkg_test, hist_bkg_test, rel_errors_bkg_test = self.get_bin_centered_hist(self.scores_test[self.y_test==0].reshape(-1), density=True, weights=self.w_test[self.y_test==0], **plot_opts)
        errors_sig_test = hist_sig_test*rel_errors_sig_test
        errors_bkg_test = hist_bkg_test*rel_errors_bkg_test
        fig, ax = plt.subplots()
        width = centers_sig_train[1]-centers_sig_train[0]
        ax.bar(centers_bkg_train, hist_bkg_train, color="b", alpha=0.5, width=width, label="background train")
        ax.bar(centers_sig_train, hist_sig_train, color="r", alpha=0.5, width=width, label="signal train")
        ax.errorbar(centers_bkg_test, hist_bkg_test, fmt="bo", yerr=errors_bkg_test, label="background test")
        ax.errorbar(centers_sig_test, hist_sig_test, fmt="ro", yerr=errors_sig_test, label="signal test")
        if log:
            ax.set_yscale("log")
        if ylim is not None:
            ax.set_ylim(*ylim)
        if xlim is not None:
            ax.set_xlim(*xlim)
        fig.legend(loc='upper center', framealpha=0.5)
        fig.savefig(os.path.join(self.project_dir, "scores.pdf"))
        plt.close(fig)
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed

    def plot_significance_hist(self, lumifactor=1., significance_function=None, plot_opts=dict(bins=50, range=(0, 1))):

        """
        Plot significances based on a histogram of scores
        """

        logger.info("Plot significances")

        centers_sig_train, hist_sig_train, rel_errors_sig_train = self.get_bin_centered_hist(self.scores_train[self.y_train==1].reshape(-1), weights=self.w_train[self.y_train==1], **plot_opts)
        centers_bkg_train, hist_bkg_train, rel_errors_bkg_train = self.get_bin_centered_hist(self.scores_train[self.y_train==0].reshape(-1), weights=self.w_train[self.y_train==0], **plot_opts)
        centers_sig_test, hist_sig_test, rel_errors_sig_test = self.get_bin_centered_hist(self.scores_test[self.y_test==1].reshape(-1), weights=self.w_test[self.y_test==1], **plot_opts)
        centers_bkg_test, hist_bkg_test, rel_errors_bkg_test = self.get_bin_centered_hist(self.scores_test[self.y_test==0].reshape(-1), weights=self.w_test[self.y_test==0], **plot_opts)

        significances_train = []
        significances_test = []
        for hist_sig, hist_bkg, rel_errors_sig, rel_errors_bkg, significances, w, y in [
                (hist_sig_train, hist_bkg_train, rel_errors_sig_train, rel_errors_bkg_train, significances_train, self.w_train, self.y_train),
                (hist_sig_test, hist_bkg_test, rel_errors_sig_test, rel_errors_bkg_test, significances_test, self.w_test, self.y_test),
            # factor to rescale due to using only a fraction of events (training and test samples)
            # normfactor_sig = (np.sum(self.w_train[self.y_train==1])+np.sum(self.w_test[self.y_test==1]))/np.sum(w[y==1])
            # normfactor_bkg = (np.sum(self.w_train[self.y_train==0])+np.sum(self.w_test[self.y_test==0]))/np.sum(w[y==0])
            normfactor_sig = self.step_signal
            normfactor_bkg = self.step_bkg
            # first set nan values to 0 and multiply by lumi
            for arr in hist_sig, hist_bkg, rel_errors_bkg:
                arr[np.isnan(arr)] = 0
            hist_sig *= lumifactor*normfactor_sig
            hist_bkg *= lumifactor*normfactor_bkg
            for i in range(len(hist_sig)):
                s = sum(hist_sig[i:])
                b = sum(hist_bkg[i:])
                db = math.sqrt(sum((rel_errors_bkg[i:]*hist_bkg[i:])**2))
                if significance_function is None:
                        z = poisson_asimov_significance(s, 0, b, db)
                    except (ZeroDivisionError, ValueError) as e:
                        z = 0
                else:
                    z = significance_function(s, b, db)
                logger.debug("s, b, db, z = {}, {}, {}, {}".format(s, b, db, z))
                significances.append(z)

        fig, ax = plt.subplots()
        width = centers_sig_train[1]-centers_sig_train[0]
        ax.plot(centers_bkg_train, significances_train, label="train, Z_max={}".format(np.amax(significances_train)))
        ax.plot(centers_bkg_test, significances_test, label="test, Z_max={}".format(np.amax(significances_test)))
        ax.set_xlabel("Cut on NN score")
        ax.set_ylabel("Significance")
        ax.legend(loc='lower center', framealpha=0.5)
        fig.savefig(os.path.join(self.project_dir, "significances_hist.pdf"))
        plt.close(fig)


    @staticmethod
    def calc_s_ds_b_db(scores, y, w):

        """
        Calculate the sum of weights of signal (s), background (b) and the
        sqrt of the squared sum of weights for all possible threshold
        of the output score.
        Following the implementation from sklearn.metrics.ranking._binary_clf_curve
        """

        desc_score_indices = np.argsort(scores, kind="mergesort")[::-1]
        scores_sorted = scores[desc_score_indices]
        y_sorted = y[desc_score_indices]
        w_sorted = w[desc_score_indices]
        distinct_value_indices = np.where(np.diff(scores_sorted))[0]
        threshold_idxs = np.r_[distinct_value_indices, y_sorted - 1]
        s_sumw = stable_cumsum(y_sorted * w_sorted)[threshold_idxs]
        s_sumw2 = stable_cumsum(y_sorted * (w_sorted**2))[threshold_idxs]
        b_sumw = stable_cumsum(np.logical_not(y_sorted) * w_sorted)[threshold_idxs]
        b_sumw2 = stable_cumsum(np.logical_not(y_sorted) * (w_sorted**2))[threshold_idxs]

        return s_sumw, np.sqrt(s_sumw2), b_sumw, np.sqrt(b_sumw2), scores_sorted[threshold_idxs]


    def plot_significance(self, significance_function=None, maxsteps=1000, lumifactor=1., vectorized=False):
        """
        Plot the significance when cutting on all posible thresholds and plot against signal efficiency.
        """

        if significance_function is None:
            vectorized = True
            significance_function = poisson_asimov_significance

        fig, ax = plt.subplots()
        ax2 = ax.twinx()
        prop_cycle = plt.rcParams['axes.prop_cycle']
        colors = prop_cycle.by_key()['color']
        for (scores, y, w, label), col in zip(
                [(self.scores_train, self.y_train, self.w_train, "train"),
                 (self.scores_test, self.y_test, self.w_test, "test")],
                colors
        ):
            s_sumws, s_errs, b_sumws, b_errs, thresholds = self.calc_s_ds_b_db(scores, y, w)
            stepsize = int(len(s_sumws))/int(maxsteps)
            if stepsize == 0:
                stepsize = 1
            s_sumws = s_sumws[::stepsize]*lumifactor*self.step_signal
            s_errs = s_errs[::stepsize]*lumifactor*self.step_signal
            b_sumws = b_sumws[::stepsize]*lumifactor*self.step_bkg
            b_errs = b_errs[::stepsize]*lumifactor*self.step_bkg
            nonzero_b = np.where(b_sumws!=0)[0]
            s_sumws = s_sumws[nonzero_b]
            s_errs = s_errs[nonzero_b]
            b_sumws = b_sumws[nonzero_b]
            b_errs = b_errs[nonzero_b]
            thresholds = thresholds[nonzero_b]
            if not vectorized:
                zs = []
                for s, ds, b, db in zip(s_sumws, s_errs, b_sumws, b_errs):
                    zs.append(significance_function(s, ds, b, db))
            else:
                zs = significance_function(s_sumws, s_errs, b_sumws, b_errs)
            ax.plot(s_sumws/s_sumws[-1], zs, label=label, color=col)
            ax2.plot(s_sumws/s_sumws[-1], thresholds, "--", color=col)
        ax.set_xlabel("Signal efficiency")
        ax.set_ylabel("Significance")
        ax.set_xlim(0, 1)
        ax2.set_ylabel("Threshold")
        ax.legend()
        fig.savefig(os.path.join(self.project_dir, "significances.pdf"))
        plt.close(fig)
    @property
    def csv_hist(self):
        with open(os.path.join(self.project_dir, "training.log")) as f:
            reader = csv.reader(f)
            history_list = list(reader)
        hist_dict = {}
        for hist_key_index, hist_key in enumerate(history_list[0]):
            hist_dict[hist_key] = [float(line[hist_key_index]) for line in history_list[1:]]
        return hist_dict

    def plot_loss(self, all_trainings=False, log=False, ylim=None, xlim=None):
        """
        Plot the value of the loss function for each epoch

        :param all_trainings: set to true if you want to plot all trainings (otherwise the previous history is used)
        """

        if all_trainings:
            hist_dict = self.csv_hist
        else:
            hist_dict = self.history.history
        if (not 'loss' in hist_dict) or (not 'val_loss' in hist_dict):
            logger.warning("No previous history found for plotting, try global history")
            hist_dict = self.csv_hist

        logger.info("Plot losses")
        plt.plot(hist_dict['loss'])
        plt.plot(hist_dict['val_loss'])
        plt.ylabel('loss')
        plt.xlabel('epoch')
        plt.legend(['training data','validation data'], loc='upper left')
        if log:
            plt.yscale("log")
        if ylim is not None:
            plt.ylim(*ylim)
        plt.savefig(os.path.join(self.project_dir, "losses.pdf"))
Thomas Weber's avatar
Thomas Weber committed
        plt.clf()
Nikolai's avatar
Nikolai committed
    def plot_accuracy(self, all_trainings=False, log=False, acc_suffix="weighted_acc"):
        """
        Plot the value of the accuracy metric for each epoch

        :param all_trainings: set to true if you want to plot all trainings (otherwise the previous history is used)
        """

        if all_trainings:
            hist_dict = self.csv_hist
        else:
            hist_dict = self.history.history
Nikolai's avatar
Nikolai committed
        if (not acc_suffix in hist_dict) or (not 'val_'+acc_suffix in hist_dict):
            logger.warning("No previous history found for plotting, try global history")
            hist_dict = self.csv_hist

        logger.info("Plot accuracy")
Nikolai's avatar
Nikolai committed
        plt.plot(hist_dict[acc_suffix])
        plt.plot(hist_dict['val_'+acc_suffix])
        plt.title('model accuracy')
        plt.ylabel('accuracy')
        plt.xlabel('epoch')
        plt.legend(['training data', 'validation data'], loc='upper left')
        if log:
            plt.yscale("log")
        plt.savefig(os.path.join(self.project_dir, "accuracy.pdf"))
Thomas Weber's avatar
Thomas Weber committed
        plt.clf()
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed

Nikolai's avatar
Nikolai committed
        # self.plot_accuracy()
        self.plot_loss()
        self.plot_score()
        self.plot_weights()
Nikolai's avatar
Nikolai committed
        # self.plot_significance()
Nikolai's avatar
Nikolai committed
    def to_DataFrame(self):
        df = pd.DataFrame(np.concatenate([self.x_train, self.x_test]), columns=self.fields)
        df["weight"] = np.concatenate([self.w_train, self.w_test])
        df["labels"] = pd.Categorical.from_codes(
            np.concatenate([self.y_train, self.y_test]),
            categories=["background", "signal"]
        )
        for identifier in self.identifiers:
            try:
                df[identifier] = np.concatenate([self.s_eventlist_train[identifier],
                                                 self.b_eventlist_train[identifier],
                                                 -1*np.ones(len(self.x_test), dtype="i8")])
            except IOError:
                logger.warning("Can't find eventlist - DataFrame won't contain identifiers")
        df["is_train"] = np.concatenate([np.ones(len(self.x_train), dtype=np.bool),
                                         np.zeros(len(self.x_test), dtype=np.bool)])
def create_getter(dataset_name):
    def getx(self):
        if getattr(self, "_"+dataset_name) is None:
            self._load_from_hdf5(dataset_name)
        return getattr(self, "_"+dataset_name)
    return getx

def create_setter(dataset_name):
    def setx(self, value):
        setattr(self, "_"+dataset_name, value)
    return setx

# define getters and setters for all datasets
for dataset_name in ClassificationProject.dataset_names:
    setattr(ClassificationProject, dataset_name, property(create_getter(dataset_name),
                                                          create_setter(dataset_name)))
class ClassificationProjectDataFrame(ClassificationProject):
    """
    A little hack to initialize a ClassificationProject from a pandas DataFrame instead of ROOT TTrees
    """
    def _init_from_args(self,
                        name,
                        df,
                        input_columns,
                        weight_column="weights",
                        label_column="labels",
                        signal_label="signal",
                        background_label="background",
                        split_mode="split_column",
                        split_column="is_train",
                        **kwargs):

        self.df = df
        self.input_columns = input_columns
        self.weight_column = weight_column
        self.label_column = label_column
        self.signal_label = signal_label
        self.background_label = background_label
        if split_mode != "split_column":
            raise NotImplementedError("'split_column' is the only currently supported split mode")
        self.split_mode = split_mode
        self.split_column = split_column
        super(ClassificationProjectDataFrame, self).__init__(name,
                                                             signal_trees=[], bkg_trees=[], branches=[], weight_expr="1",
                                                             **kwargs)
        self._x_train = None
        self._x_test = None
        self._y_train = None
        self._y_test = None
        self._w_train = None
        self._w_test = None

    @property
    def x_train(self):
        if self._x_train is None:
            self._x_train = self.df[self.df[self.split_column]][self.input_columns].values
        return self._x_train

    @x_train.setter
    def x_train(self, value):
        self._x_train = value

    @property
    def x_test(self):
        if self._x_test is None:
            self._x_test = self.df[~self.df[self.split_column]][self.input_columns].values
        return self._x_test

    @x_test.setter
    def x_test(self, value):
        self._x_test = value

    @property
    def y_train(self):
        if self._y_train is None:
            self._y_train = (self.df[self.df[self.split_column]][self.label_column] == self.signal_label).values
        return self._y_train

    @y_train.setter
    def y_train(self, value):
        self._y_train = value

    @property
    def y_test(self):
        if self._y_test is None:
            self._y_test = (self.df[~self.df[self.split_column]][self.label_column] == self.signal_label).values
        return self._y_test

    @y_test.setter
    def y_test(self, value):
        self._y_test = value

    @property
    def w_train(self):
        if self._w_train is None:
            self._w_train = self.df[self.df[self.split_column]][self.weight_column].values
        return self._w_train

    @w_train.setter
    def w_train(self, value):
        self._w_train = value

    @property
    def w_test(self):
        if self._w_test is None:
            self._w_test = self.df[~self.df[self.split_column]][self.weight_column].values
        return self._w_test

    @w_test.setter
    def w_test(self, value):
        self._w_test = value

    @property
    def fields(self):
        return self.input_columns


    def load(self, reload=False):

        if reload:
            self.data_loaded = False
            self.data_transformed = False
            self._x_train = None
            self._x_test = None
            self._y_train = None
            self._y_test = None
            self._w_train = None
            self._w_test = None

        if not self.data_transformed:
            self._transform_data()
Nikolai's avatar
Nikolai committed
class ClassificationProjectRNN(ClassificationProject):

    """
    A little wrapper to use recurrent units for things like jet collections
    """

    def _init_from_args(self, name,
                        recurrent_field_names=None,
                        rnn_layer_nodes=32,
                        mask_value=-999,
                        **kwargs):
        recurrent_field_names example:
        [["jet1Pt", "jet1Eta", "jet1Phi"],
         ["jet2Pt", "jet2Eta", "jet2Phi"],
         ["jet3Pt", "jet3Eta", "jet3Phi"]],
        [["lep1Pt", "lep1Eta", "lep1Phi", "lep1flav"],
         ["lep2Pt", "lep2Eta", "lep2Phi", "lep2flav"]],
        """
        super(ClassificationProjectRNN, self)._init_from_args(name, **kwargs)

        self._write_info("project_type", "ClassificationProjectRNN")

        self.recurrent_field_names = recurrent_field_names
        if self.recurrent_field_names is None:
            self.recurrent_field_names = []
        self.rnn_layer_nodes = rnn_layer_nodes
        self.mask_value = mask_value

        # convert to  of indices
        self.recurrent_field_idx = []
        for field_name_list in self.recurrent_field_names:
            field_names = np.array([field_name_list])
            if field_names.dtype == np.object:
                raise ValueError(
                    "Invalid entry for recurrent fields: {} - "
                    "please ensure that the length for all elements in the list is equal"
                    .format(field_names)
            field_idx = (
                np.array([self.fields.index(field_name)
                          for field_name in field_names.reshape(-1)])
                .reshape(field_names.shape)
            )
            self.recurrent_field_idx.append(field_idx)
        self.flat_fields = []
        for field in self.fields:
            if any(self.fields.index(field) in field_idx.reshape(-1) for field_idx in self.recurrent_field_idx):
                continue
            self.flat_fields.append(field)

        if self.scaler_type != "WeightedRobustScaler":
            raise NotImplementedError(
                "Invalid scaler '{}' - only WeightedRobustScaler is currently supported for RNN"
                .format(self.scaler_type)
            )


    def _transform_data(self):
        self.x_train[self.x_train == self.mask_value] = np.nan
        self.x_test[self.x_test == self.mask_value] = np.nan
        super(ClassificationProjectRNN, self)._transform_data()
        self.x_train[np.isnan(self.x_train)] = self.mask_value
        self.x_test[np.isnan(self.x_test)] = self.mask_value
Nikolai's avatar
Nikolai committed


    @property
    def model(self):
        if self._model is None:
            # following the setup from the tutorial:
            # https://github.com/YaleATLAS/CERNDeepLearningTutorial
            rnn_inputs = []
            rnn_channels = []
            for field_idx in self.recurrent_field_idx:
                chan_inp = Input(field_idx.shape[1:])
                channel = Masking(mask_value=self.mask_value)(chan_inp)
                channel = GRU(self.rnn_layer_nodes)(channel)
                # TODO: configure dropout for recurrent layers
                #channel = Dropout(0.3)(channel)
                rnn_inputs.append(chan_inp)
                rnn_channels.append(channel)
            flat_input = Input((len(self.flat_fields),))
Nikolai's avatar
Nikolai committed
            if self.dropout_input is not None:
                flat_channel = Dropout(rate=self.dropout_input)(flat_input)
            else:
                flat_channel = flat_input
            combined = concatenate(rnn_channels+[flat_channel])
            for node_count, dropout_fraction in zip(self.nodes, self.dropout):
                combined = Dense(node_count, activation=self.activation_function)(combined)
                if (dropout_fraction is not None) and (dropout_fraction > 0):
                    combined = Dropout(rate=dropout_fraction)(combined)
            combined = Dense(1, activation=self.activation_function_output)(combined)
            self._model = Model(inputs=rnn_inputs+[flat_input], outputs=combined)
            self._compile_or_load_model()
        return self._model


    def train(self, epochs=10):
        self.load()

        for branch_index, branch in enumerate(self.fields):
            self.plot_input(branch_index)

        self.total_epochs = self._read_info("epochs", 0)

        try:
            self.shuffle_training_data() # needed here too, in order to get correct validation data
            self.is_training = True
            logger.info("Training on batches for RNN")
            # note: the batches have class_weight already applied
            self.model.fit_generator(self.yield_batch(),
                                     steps_per_epoch=int(len(self.training_data[0])/self.batch_size),
                                     epochs=epochs,
                                     validation_data=self.class_weighted_validation_data,
                                     callbacks=self.callbacks_list)
            self.is_training = False
        except KeyboardInterrupt:
            logger.info("Interrupt training - continue with rest")
        self.checkpoint_model()
    def get_input_list(self, x):
        "Format the input starting from flat ntuple"
        x_input = []
        for field_idx in self.recurrent_field_idx:
            x_recurrent = x[:,field_idx.reshape(-1)].reshape(-1, *field_idx.shape[1:])
            x_input.append(x_recurrent)
        x_flat = x[:,[self.fields.index(field_name) for field_name in self.flat_fields]]
        x_input.append(x_flat)
Nikolai's avatar
Nikolai committed
    def yield_batch(self):
        x_train, y_train, w_train = self.training_data
Nikolai's avatar
Nikolai committed
        while True:
            shuffled_idx = np.random.permutation(len(x_train))
            for start in range(0, len(shuffled_idx), int(self.batch_size)):
                x_batch = x_train[shuffled_idx[start:start+int(self.batch_size)]]
                y_batch = y_train[shuffled_idx[start:start+int(self.batch_size)]]
                w_batch = w_train[shuffled_idx[start:start+int(self.batch_size)]]
                x_input = self.get_input_list(x_batch)
                yield (x_input,
                       y_train[shuffled_idx[start:start+int(self.batch_size)]],
Nikolai's avatar
Nikolai committed
                       w_batch*np.array(self.class_weight)[y_batch.astype(int)]/self.mean_train_weight)
    @property
    def class_weighted_validation_data(self):
        "class weighted validation data. Attention: Shuffle training data before using this!"
        x_val, y_val, w_val = super(ClassificationProjectRNN, self).class_weighted_validation_data
        x_val_input = self.get_input_list(x_val)
        return x_val_input, y_val, w_val


Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
    def evaluate_train_test(self, do_train=True, do_test=True, batch_size=10000):
        logger.info("Reloading (and re-transforming) unshuffled training data")
        self.load(reload=True)

        def eval_score(data_name):
            logger.info("Create/Update scores for {} sample".format(data_name))
            n_events = len(getattr(self, "x_"+data_name))
            setattr(self, "scores_"+data_name, np.empty(n_events))
            for start in range(0, n_events, batch_size):
                stop = start+batch_size
                getattr(self, "scores_"+data_name)[start:stop] = self.model.predict(self.get_input_list(getattr(self, "x_"+data_name)[start:stop])).reshape(-1)
            self._dump_to_hdf5("scores_"+data_name)

        if do_test:
            eval_score("test")
        if do_train:
            eval_score("train")


    def evaluate(self, x_eval):
        logger.debug("Evaluate score for {}".format(x_eval))
        x_eval = np.array(x_eval) # copy
        x_eval[x_eval==self.mask_value] = np.nan
        x_eval = self.scaler.transform(x_eval)
        x_eval[np.isnan(x_eval)] = self.mask_value
        logger.debug("Evaluate for transformed array: {}".format(x_eval))
        return self.model.predict(self.get_input_list(x_eval))


Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
if __name__ == "__main__":

    logging.basicConfig()
    logging.getLogger("KerasROOTClassification").setLevel(logging.INFO)
    #logging.getLogger("KerasROOTClassification").setLevel(logging.DEBUG)
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
    filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root"

    c = ClassificationProject("test4",
                              signal_trees = [(filename, "GG_oneStep_1705_1105_505_NoSys")],
                              bkg_trees = [(filename, "ttbar_NoSys"),
                                           (filename, "wjets_Sherpa221_NoSys")
                              ],
                              optimizer="Adam",
                              #optimizer="SGD",
                              #optimizer_opts=dict(lr=100., decay=1e-6, momentum=0.9),
Nikolai's avatar
Nikolai committed
                              earlystopping_opts=dict(monitor='val_loss',
                                                      min_delta=0, patience=2, verbose=0, mode='auto'),
                              selection="1",
                              branches = ["met", "mt"],
                              weight_expr = "eventWeight*genWeight",
                              identifiers = ["DatasetNumber", "EventNumber"],
                              step_bkg = 100)
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed

    np.random.seed(42)
    c.train(epochs=20)
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed

    # c.write_friend_tree("test4_score",
    #                     source_filename=filename, source_treename="GG_oneStep_1705_1105_505_NoSys",
    #                     target_filename="friend.root", target_treename="test4_score")
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed

    # c.write_friend_tree("test4_score",
    #                     source_filename=filename, source_treename="ttbar_NoSys",
    #                     target_filename="friend_ttbar_NoSys.root", target_treename="test4_score")