From bf50acf015a42b412dc479cc861eff629104e0d7 Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Thu, 16 Aug 2018 14:15:40 +0200 Subject: [PATCH] Improve loading from dir * Call _init_from_args instead of init in subclasses * Store project_type variable in ClassificationProjectRNN * Introduce load_from_dir function to automatically choose correct class --- browse.py | 4 ++-- toolkit.py | 62 +++++++++++++++++++++++++++++++++++------------------- 2 files changed, 42 insertions(+), 24 deletions(-) diff --git a/browse.py b/browse.py index 643624e..4021782 100755 --- a/browse.py +++ b/browse.py @@ -9,11 +9,11 @@ from KerasROOTClassification import * logging.basicConfig() logging.getLogger("KerasROOTClassification").setLevel(logging.INFO) -c = ClassificationProject(sys.argv[1]) +c = load_from_dir(sys.argv[1]) cs = [] cs.append(c) if len(sys.argv) > 2: for project_name in sys.argv[2:]: - cs.append(ClassificationProject(project_name)) + cs.append(load_from_dir(project_name)) diff --git a/toolkit.py b/toolkit.py index e912881..9f0ccef 100755 --- a/toolkit.py +++ b/toolkit.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -__all__ = ["ClassificationProject", "ClassificationProjectDataFrame", "ClassificationProjectRNN"] +__all__ = ["load_from_dir", "ClassificationProject", "ClassificationProjectDataFrame", "ClassificationProjectRNN"] from sys import version_info @@ -71,6 +71,19 @@ if version_info[0] > 2: byteify = lambda input : input +def load_from_dir(path): + "Load a project and the options from a directory" + try: + with open(os.path.join(path, "info.json")) as f: + info = json.load(f) + project_type = info["project_type"] + if project_type == "ClassificationProjectRNN": + return ClassificationProjectRNN(path) + except KeyError, IOError: + pass + return ClassificationProject(path) + + class ClassificationProject(object): """Simple framework to load data from ROOT TTrees and train Keras @@ -607,6 +620,9 @@ class ClassificationProject(object): def _write_info(self, key, value): filename = os.path.join(self.project_dir, "info.json") + if not os.path.exists(filename): + with open(filename, "w") as of: + json.dump({}, of) with open(filename) as f: info = json.load(f) info[key] = value @@ -851,10 +867,10 @@ class ClassificationProject(object): except KeyboardInterrupt: logger.info("Interrupt training - continue with rest") - self.checkpoint_model(epochs) + self.checkpoint_model() - def checkpoint_model(self, epochs): + def checkpoint_model(self): logger.info("Save history") self._dump_history() @@ -870,7 +886,7 @@ class ClassificationProject(object): logger.info("Reloading weights file since we are using model checkpoint!") self.model.load_weights(os.path.join(self.project_dir, "weights.h5")) - self.total_epochs += epochs + self.total_epochs += self.history.epoch[-1]+1 self._write_info("epochs", self.total_epochs) @@ -1255,17 +1271,17 @@ class ClassificationProjectDataFrame(ClassificationProject): A little hack to initialize a ClassificationProject from a pandas DataFrame instead of ROOT TTrees """ - def __init__(self, - name, - df, - input_columns, - weight_column="weights", - label_column="labels", - signal_label="signal", - background_label="background", - split_mode="split_column", - split_column="is_train", - **kwargs): + def _init_from_args(self, + name, + df, + input_columns, + weight_column="weights", + label_column="labels", + signal_label="signal", + background_label="background", + split_mode="split_column", + split_column="is_train", + **kwargs): self.df = df self.input_columns = input_columns @@ -1374,11 +1390,11 @@ class ClassificationProjectRNN(ClassificationProject): A little wrapper to use recurrent units for things like jet collections """ - def __init__(self, name, - recurrent_field_names=None, - rnn_layer_nodes=32, - mask_value=-999, - **kwargs): + def _init_from_args(self, name, + recurrent_field_names=None, + rnn_layer_nodes=32, + mask_value=-999, + **kwargs): """ recurrent_field_names example: [["jet1Pt", "jet1Eta", "jet1Phi"], @@ -1387,7 +1403,9 @@ class ClassificationProjectRNN(ClassificationProject): [["lep1Pt", "lep1Eta", "lep1Phi", "lep1flav"], ["lep2Pt", "lep2Eta", "lep2Phi", "lep2flav"]], """ - super(ClassificationProjectRNN, self).__init__(name, **kwargs) + super(ClassificationProjectRNN, self)._init_from_args(name, **kwargs) + + self._write_info("project_type", "ClassificationProjectRNN") self.recurrent_field_names = recurrent_field_names if self.recurrent_field_names is None: @@ -1485,7 +1503,7 @@ class ClassificationProjectRNN(ClassificationProject): except KeyboardInterrupt: logger.info("Interrupt training - continue with rest") - self.checkpoint_model(epochs) + self.checkpoint_model() def get_input_list(self, x): -- GitLab