From bf50acf015a42b412dc479cc861eff629104e0d7 Mon Sep 17 00:00:00 2001
From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de>
Date: Thu, 16 Aug 2018 14:15:40 +0200
Subject: [PATCH] Improve loading from dir

* Call _init_from_args instead of init in subclasses
* Store project_type variable in ClassificationProjectRNN
* Introduce load_from_dir function to automatically choose correct class
---
 browse.py  |  4 ++--
 toolkit.py | 62 +++++++++++++++++++++++++++++++++++-------------------
 2 files changed, 42 insertions(+), 24 deletions(-)

diff --git a/browse.py b/browse.py
index 643624e..4021782 100755
--- a/browse.py
+++ b/browse.py
@@ -9,11 +9,11 @@ from KerasROOTClassification import *
 logging.basicConfig()
 logging.getLogger("KerasROOTClassification").setLevel(logging.INFO)
 
-c = ClassificationProject(sys.argv[1])
+c = load_from_dir(sys.argv[1])
 
 cs = []
 cs.append(c)
 
 if len(sys.argv) > 2:
     for project_name in sys.argv[2:]:
-        cs.append(ClassificationProject(project_name))
+        cs.append(load_from_dir(project_name))
diff --git a/toolkit.py b/toolkit.py
index e912881..9f0ccef 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -1,6 +1,6 @@
 #!/usr/bin/env python
 
-__all__ = ["ClassificationProject", "ClassificationProjectDataFrame", "ClassificationProjectRNN"]
+__all__ = ["load_from_dir", "ClassificationProject", "ClassificationProjectDataFrame", "ClassificationProjectRNN"]
 
 from sys import version_info
 
@@ -71,6 +71,19 @@ if version_info[0] > 2:
     byteify = lambda input : input
 
 
+def load_from_dir(path):
+    "Load a project and the options from a directory"
+    try:
+        with open(os.path.join(path, "info.json")) as f:
+            info = json.load(f)
+        project_type = info["project_type"]
+        if project_type == "ClassificationProjectRNN":
+            return ClassificationProjectRNN(path)
+    except KeyError, IOError:
+        pass
+    return ClassificationProject(path)
+
+
 class ClassificationProject(object):
 
     """Simple framework to load data from ROOT TTrees and train Keras
@@ -607,6 +620,9 @@ class ClassificationProject(object):
 
     def _write_info(self, key, value):
         filename = os.path.join(self.project_dir, "info.json")
+        if not os.path.exists(filename):
+            with open(filename, "w") as of:
+                json.dump({}, of)
         with open(filename) as f:
             info = json.load(f)
         info[key] = value
@@ -851,10 +867,10 @@ class ClassificationProject(object):
             except KeyboardInterrupt:
                 logger.info("Interrupt training - continue with rest")
 
-        self.checkpoint_model(epochs)
+        self.checkpoint_model()
 
 
-    def checkpoint_model(self, epochs):
+    def checkpoint_model(self):
 
         logger.info("Save history")
         self._dump_history()
@@ -870,7 +886,7 @@ class ClassificationProject(object):
             logger.info("Reloading weights file since we are using model checkpoint!")
             self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
 
-        self.total_epochs += epochs
+        self.total_epochs += self.history.epoch[-1]+1
         self._write_info("epochs", self.total_epochs)
 
 
@@ -1255,17 +1271,17 @@ class ClassificationProjectDataFrame(ClassificationProject):
     A little hack to initialize a ClassificationProject from a pandas DataFrame instead of ROOT TTrees
     """
 
-    def __init__(self,
-                 name,
-                 df,
-                 input_columns,
-                 weight_column="weights",
-                 label_column="labels",
-                 signal_label="signal",
-                 background_label="background",
-                 split_mode="split_column",
-                 split_column="is_train",
-                 **kwargs):
+    def _init_from_args(self,
+                        name,
+                        df,
+                        input_columns,
+                        weight_column="weights",
+                        label_column="labels",
+                        signal_label="signal",
+                        background_label="background",
+                        split_mode="split_column",
+                        split_column="is_train",
+                        **kwargs):
 
         self.df = df
         self.input_columns = input_columns
@@ -1374,11 +1390,11 @@ class ClassificationProjectRNN(ClassificationProject):
     A little wrapper to use recurrent units for things like jet collections
     """
 
-    def __init__(self, name,
-                 recurrent_field_names=None,
-                 rnn_layer_nodes=32,
-                 mask_value=-999,
-                 **kwargs):
+    def _init_from_args(self, name,
+                        recurrent_field_names=None,
+                        rnn_layer_nodes=32,
+                        mask_value=-999,
+                        **kwargs):
         """
         recurrent_field_names example:
         [["jet1Pt", "jet1Eta", "jet1Phi"],
@@ -1387,7 +1403,9 @@ class ClassificationProjectRNN(ClassificationProject):
         [["lep1Pt", "lep1Eta", "lep1Phi", "lep1flav"],
          ["lep2Pt", "lep2Eta", "lep2Phi", "lep2flav"]],
         """
-        super(ClassificationProjectRNN, self).__init__(name, **kwargs)
+        super(ClassificationProjectRNN, self)._init_from_args(name, **kwargs)
+
+        self._write_info("project_type", "ClassificationProjectRNN")
 
         self.recurrent_field_names = recurrent_field_names
         if self.recurrent_field_names is None:
@@ -1485,7 +1503,7 @@ class ClassificationProjectRNN(ClassificationProject):
         except KeyboardInterrupt:
             logger.info("Interrupt training - continue with rest")
 
-        self.checkpoint_model(epochs)
+        self.checkpoint_model()
 
 
     def get_input_list(self, x):
-- 
GitLab