diff --git a/toolkit.py b/toolkit.py index eb5b02c4a0c94cac86d944257baab8f75f977b0a..d18fe5f021da2b3ebaed0cbd9de006b50e7dd7ac 100755 --- a/toolkit.py +++ b/toolkit.py @@ -100,10 +100,13 @@ class KerasROOTClassification(object): self.b_train = None self.s_test = None self.b_test = None - self.x_train = None - self.x_test = None - self.y_train = None - self.y_test = None + + self._x_train = None + self._x_test = None + self._y_train = None + self._y_test = None + self._w_train = None + self._w_test = None self.s_eventlist_train = None self.b_eventlist_train = None @@ -560,6 +563,25 @@ class KerasROOTClassification(object): plt.savefig(os.path.join(self.project_dir, "accuracy.pdf")) plt.clf() +def create_getter(dataset_name): + def getx(self): + if getattr(self, "_"+dataset_name) is None: + self._load_from_hdf5([dataset_name]) + return getattr(self, "_"+dataset_name) + return getx + +def create_setter(dataset_name): + def setx(self, value): + # maybe change this at some point to also dump into hdf + setattr(self, "_"+dataset_name, value) + return setx + +# define getters and setters for all datasets +for dataset_name in KerasROOTClassification.dataset_names: + setattr(KerasROOTClassification, dataset_name, property(create_getter(dataset_name), + create_setter(dataset_name))) + + if __name__ == "__main__": logging.basicConfig()