diff --git a/browse.py b/browse.py index ed91b8681e49c16653bde554903c0cb62ee7a2ea..b641459fb75196a2ed350fff40570c4cf1b422d4 100755 --- a/browse.py +++ b/browse.py @@ -1,8 +1,12 @@ import sys +import logging import numpy as np import matplotlib.pyplot as plt from KerasROOTClassification import * +logging.basicConfig() +logging.getLogger("KerasROOTClassification").setLevel(logging.INFO) + c = ClassificationProject(sys.argv[1]) diff --git a/toolkit.py b/toolkit.py index 75cdbc1ec01c3a26c30ec3bc2bdcfda07477bb0a..b8ae99305951737f50d3d0712a333bb1449a7f19 100755 --- a/toolkit.py +++ b/toolkit.py @@ -52,6 +52,19 @@ K.set_session(session) import ROOT +def byteify(input): + "From stackoverflow https://stackoverflow.com/a/13105359" + if isinstance(input, dict): + return {byteify(key): byteify(value) + for key, value in input.iteritems()} + elif isinstance(input, list): + return [byteify(element) for element in input] + elif isinstance(input, unicode): + return input.encode('utf-8') + else: + return input + + class ClassificationProject(object): """Simple framework to load data from ROOT TTrees and train Keras @@ -144,13 +157,18 @@ class ClassificationProject(object): else: # otherwise initialise new project self._init_from_args(name, *args, **kwargs) - with open(os.path.join(self.project_dir, "options.json"), "w") as of: - json.dump(dict(args=args, kwargs=kwargs), of) + with open(os.path.join(self.project_dir, "options.pickle"), "w") as of: + pickle.dump(dict(args=args, kwargs=kwargs), of) def _init_from_dir(self, dirname): - with open(os.path.join(dirname, "options.json")) as f: - options = yaml.safe_load(f) + if not os.path.exists(os.path.join(dirname, "options.pickle")): + # for backward compatibility + with open(os.path.join(dirname, "options.json")) as f: + options = byteify(json.load(f)) + else: + with open(os.path.join(dirname, "options.pickle")) as f: + options = pickle.load(f) options["kwargs"]["project_dir"] = dirname self._init_from_args(os.path.basename(dirname), *options["args"], **options["kwargs"])