From cb4b04fcf3cffd19a6cc5a2dacdb56b164ca11e5 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Tue, 8 May 2018 13:17:49 +0200
Subject: [PATCH] Rename project class to ClassificationProject

---
 __init__.py |  3 +++
 compare.py  | 16 ++++++++--------
 toolkit.py  | 36 ++++++++++++++++++------------------
 3 files changed, 29 insertions(+), 26 deletions(-)
 create mode 100644 __init__.py

diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000..2702f3d
--- /dev/null
+++ b/__init__.py
@@ -0,0 +1,3 @@
+from toolkit import ClassificationProject
+from compare import overlay_ROC, overlay_loss
+
diff --git a/compare.py b/compare.py
index 4f2332d..f0a2a3d 100755
--- a/compare.py
+++ b/compare.py
@@ -7,7 +7,7 @@ import numpy as np
 import matplotlib.pyplot as plt
 from sklearn.metrics import roc_curve, auc
 
-from toolkit import KerasROOTClassification
+from toolkit import ClassificationProject
 
 """
 A few functions to compare different setups
@@ -76,14 +76,14 @@ if __name__ == "__main__":
                         identifiers = ["DatasetNumber", "EventNumber"],
                         step_bkg = 100)
 
-    example1 = KerasROOTClassification("test_sgd",
-                                       optimizer="SGD",
-                                       optimizer_opts=dict(lr=1000., decay=1e-6, momentum=0.9),
-                                       **data_options)
+    example1 = ClassificationProject("test_sgd",
+                                     optimizer="SGD",
+                                     optimizer_opts=dict(lr=1000., decay=1e-6, momentum=0.9),
+                                     **data_options)
 
-    example2 = KerasROOTClassification("test_adam",
-                                       optimizer="Adam",
-                                       **data_options)
+    example2 = ClassificationProject("test_adam",
+                                     optimizer="Adam",
+                                     **data_options)
 
 
     if not os.path.exists("outputs/test_sgd/scores_test.h5"):
diff --git a/toolkit.py b/toolkit.py
index 7dc4d20..271989d 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -41,7 +41,7 @@ K.set_session(session)
 
 import ROOT
 
-class KerasROOTClassification(object):
+class ClassificationProject(object):
 
     """Simple framework to load data from ROOT TTrees and train Keras
     neural networks for classification according to some global settings.
@@ -751,9 +751,9 @@ def create_setter(dataset_name):
     return setx
 
 # define getters and setters for all datasets
-for dataset_name in KerasROOTClassification.dataset_names:
-    setattr(KerasROOTClassification, dataset_name, property(create_getter(dataset_name),
-                                                            create_setter(dataset_name)))
+for dataset_name in ClassificationProject.dataset_names:
+    setattr(ClassificationProject, dataset_name, property(create_getter(dataset_name),
+                                                          create_setter(dataset_name)))
 
 
 if __name__ == "__main__":
@@ -764,21 +764,21 @@ if __name__ == "__main__":
 
     filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root"
 
-    c = KerasROOTClassification("test4",
-                                signal_trees = [(filename, "GG_oneStep_1705_1105_505_NoSys")],
-                                bkg_trees = [(filename, "ttbar_NoSys"),
-                                             (filename, "wjets_Sherpa221_NoSys")
-                                ],
-                                optimizer="Adam",
-                                #optimizer="SGD",
-                                #optimizer_opts=dict(lr=100., decay=1e-6, momentum=0.9),
+    c = ClassificationProject("test4",
+                              signal_trees = [(filename, "GG_oneStep_1705_1105_505_NoSys")],
+                              bkg_trees = [(filename, "ttbar_NoSys"),
+                                           (filename, "wjets_Sherpa221_NoSys")
+                              ],
+                              optimizer="Adam",
+                              #optimizer="SGD",
+                              #optimizer_opts=dict(lr=100., decay=1e-6, momentum=0.9),
                                 earlystopping_opts=dict(monitor='val_loss',
-                                    min_delta=0, patience=2, verbose=0, mode='auto'),
-                                selection="lep1Pt<5000", # cut out a few very weird outliers
-                                branches = ["met", "mt"],
-                                weight_expr = "eventWeight*genWeight",
-                                identifiers = ["DatasetNumber", "EventNumber"],
-                                step_bkg = 100)
+                                                        min_delta=0, patience=2, verbose=0, mode='auto'),
+                              selection="lep1Pt<5000", # cut out a few very weird outliers
+                              branches = ["met", "mt"],
+                              weight_expr = "eventWeight*genWeight",
+                              identifiers = ["DatasetNumber", "EventNumber"],
+                              step_bkg = 100)
 
     np.random.seed(42)
     c.train(epochs=20)
-- 
GitLab