Skip to content
Snippets Groups Projects
Commit 01681528 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

Switch saving options of new projects to pickle, but maintain support for json

parent 8b1f3c2e
No related branches found
No related tags found
No related merge requests found
import sys import sys
import logging
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from KerasROOTClassification import * from KerasROOTClassification import *
logging.basicConfig()
logging.getLogger("KerasROOTClassification").setLevel(logging.INFO)
c = ClassificationProject(sys.argv[1]) c = ClassificationProject(sys.argv[1])
...@@ -52,6 +52,19 @@ K.set_session(session) ...@@ -52,6 +52,19 @@ K.set_session(session)
import ROOT import ROOT
def byteify(input):
"From stackoverflow https://stackoverflow.com/a/13105359"
if isinstance(input, dict):
return {byteify(key): byteify(value)
for key, value in input.iteritems()}
elif isinstance(input, list):
return [byteify(element) for element in input]
elif isinstance(input, unicode):
return input.encode('utf-8')
else:
return input
class ClassificationProject(object): class ClassificationProject(object):
"""Simple framework to load data from ROOT TTrees and train Keras """Simple framework to load data from ROOT TTrees and train Keras
...@@ -144,13 +157,18 @@ class ClassificationProject(object): ...@@ -144,13 +157,18 @@ class ClassificationProject(object):
else: else:
# otherwise initialise new project # otherwise initialise new project
self._init_from_args(name, *args, **kwargs) self._init_from_args(name, *args, **kwargs)
with open(os.path.join(self.project_dir, "options.json"), "w") as of: with open(os.path.join(self.project_dir, "options.pickle"), "w") as of:
json.dump(dict(args=args, kwargs=kwargs), of) pickle.dump(dict(args=args, kwargs=kwargs), of)
def _init_from_dir(self, dirname): def _init_from_dir(self, dirname):
with open(os.path.join(dirname, "options.json")) as f: if not os.path.exists(os.path.join(dirname, "options.pickle")):
options = yaml.safe_load(f) # for backward compatibility
with open(os.path.join(dirname, "options.json")) as f:
options = byteify(json.load(f))
else:
with open(os.path.join(dirname, "options.pickle")) as f:
options = pickle.load(f)
options["kwargs"]["project_dir"] = dirname options["kwargs"]["project_dir"] = dirname
self._init_from_args(os.path.basename(dirname), *options["args"], **options["kwargs"]) self._init_from_args(os.path.basename(dirname), *options["args"], **options["kwargs"])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment