Skip to content
Snippets Groups Projects
toolkit.py 48 KiB
Newer Older

        :param all_trainings: set to true if you want to plot all trainings (otherwise the previous history is used)
        """

        if all_trainings:
            hist_dict = self.csv_hist
        else:
            hist_dict = self.history.history
        if (not 'loss' in hist_dict) or (not 'val_loss' in hist_dict):
            logger.warning("No previous history found for plotting, try global history")
            hist_dict = self.csv_hist

        logger.info("Plot losses")
        plt.plot(hist_dict['loss'])
        plt.plot(hist_dict['val_loss'])
        plt.ylabel('loss')
        plt.xlabel('epoch')
        plt.legend(['train','test'], loc='upper left')
        if log:
            plt.yscale("log")
        plt.savefig(os.path.join(self.project_dir, "losses.pdf"))
Thomas Weber's avatar
Thomas Weber committed
        plt.clf()
    def plot_accuracy(self, all_trainings=False, log=False):
        """
        Plot the value of the accuracy metric for each epoch

        :param all_trainings: set to true if you want to plot all trainings (otherwise the previous history is used)
        """

        if all_trainings:
            hist_dict = self.csv_hist
        else:
            hist_dict = self.history.history
        if (not 'acc' in hist_dict) or (not 'val_acc' in hist_dict):
            logger.warning("No previous history found for plotting, try global history")
            hist_dict = self.csv_hist

        logger.info("Plot accuracy")
        plt.plot(hist_dict['acc'])
        plt.plot(hist_dict['val_acc'])
        plt.title('model accuracy')
        plt.ylabel('accuracy')
        plt.xlabel('epoch')
        plt.legend(['train', 'test'], loc='upper left')
        if log:
            plt.yscale("log")
        plt.savefig(os.path.join(self.project_dir, "accuracy.pdf"))
Thomas Weber's avatar
Thomas Weber committed
        plt.clf()
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed


    def plot_all(self):
        self.plot_ROC()
        self.plot_accuracy()
        self.plot_loss()
        self.plot_score()
        self.plot_weights()
        self.plot_significance()
def create_getter(dataset_name):
    def getx(self):
        if getattr(self, "_"+dataset_name) is None:
            self._load_from_hdf5(dataset_name)
        return getattr(self, "_"+dataset_name)
    return getx

def create_setter(dataset_name):
    def setx(self, value):
        setattr(self, "_"+dataset_name, value)
    return setx

# define getters and setters for all datasets
for dataset_name in ClassificationProject.dataset_names:
    setattr(ClassificationProject, dataset_name, property(create_getter(dataset_name),
                                                          create_setter(dataset_name)))
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
if __name__ == "__main__":

    logging.basicConfig()
    logging.getLogger("KerasROOTClassification").setLevel(logging.INFO)
    #logging.getLogger("KerasROOTClassification").setLevel(logging.DEBUG)
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
    filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root"

    c = ClassificationProject("test4",
                              signal_trees = [(filename, "GG_oneStep_1705_1105_505_NoSys")],
                              bkg_trees = [(filename, "ttbar_NoSys"),
                                           (filename, "wjets_Sherpa221_NoSys")
                              ],
                              optimizer="Adam",
                              #optimizer="SGD",
                              #optimizer_opts=dict(lr=100., decay=1e-6, momentum=0.9),
Thomas Weber's avatar
Thomas Weber committed
                                earlystopping_opts=dict(monitor='val_loss',
                                                        min_delta=0, patience=2, verbose=0, mode='auto'),
                              selection="1",
                              branches = ["met", "mt"],
                              weight_expr = "eventWeight*genWeight",
                              identifiers = ["DatasetNumber", "EventNumber"],
                              step_bkg = 100)
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed

    np.random.seed(42)
    c.train(epochs=20)
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed

    # c.write_friend_tree("test4_score",
    #                     source_filename=filename, source_treename="GG_oneStep_1705_1105_505_NoSys",
    #                     target_filename="friend.root", target_treename="test4_score")
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed

    # c.write_friend_tree("test4_score",
    #                     source_filename=filename, source_treename="ttbar_NoSys",
    #                     target_filename="friend_ttbar_NoSys.root", target_treename="test4_score")