Skip to content
Snippets Groups Projects
toolkit.py 44.7 KiB
Newer Older
def create_getter(dataset_name):
    def getx(self):
        if getattr(self, "_"+dataset_name) is None:
            self._load_from_hdf5(dataset_name)
        return getattr(self, "_"+dataset_name)
    return getx

def create_setter(dataset_name):
    def setx(self, value):
        setattr(self, "_"+dataset_name, value)
    return setx

# define getters and setters for all datasets
for dataset_name in ClassificationProject.dataset_names:
    setattr(ClassificationProject, dataset_name, property(create_getter(dataset_name),
                                                          create_setter(dataset_name)))
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
if __name__ == "__main__":

    logging.basicConfig()
    logging.getLogger("KerasROOTClassification").setLevel(logging.INFO)
    #logging.getLogger("KerasROOTClassification").setLevel(logging.DEBUG)
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed
    filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root"

    c = ClassificationProject("test4",
                              signal_trees = [(filename, "GG_oneStep_1705_1105_505_NoSys")],
                              bkg_trees = [(filename, "ttbar_NoSys"),
                                           (filename, "wjets_Sherpa221_NoSys")
                              ],
                              optimizer="Adam",
                              #optimizer="SGD",
                              #optimizer_opts=dict(lr=100., decay=1e-6, momentum=0.9),
Thomas Weber's avatar
Thomas Weber committed
                                earlystopping_opts=dict(monitor='val_loss',
                                                        min_delta=0, patience=2, verbose=0, mode='auto'),
                              selection="lep1Pt<5000", # cut out a few very weird outliers
                              branches = ["met", "mt"],
                              weight_expr = "eventWeight*genWeight",
                              identifiers = ["DatasetNumber", "EventNumber"],
                              step_bkg = 100)
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed

    np.random.seed(42)
    c.train(epochs=20)
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed

    # c.write_friend_tree("test4_score",
    #                     source_filename=filename, source_treename="GG_oneStep_1705_1105_505_NoSys",
    #                     target_filename="friend.root", target_treename="test4_score")
Nikolai.Hartmann's avatar
Nikolai.Hartmann committed

    # c.write_friend_tree("test4_score",
    #                     source_filename=filename, source_treename="ttbar_NoSys",
    #                     target_filename="friend_ttbar_NoSys.root", target_treename="test4_score")