From 64b906462009446fa141b6656f56c80b901f3096 Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann <Nikolai.Hartmann@physik.uni-muenchen.de> Date: Thu, 26 Apr 2018 14:00:00 +0200 Subject: [PATCH] Loading and saving of StandardScaler --- toolkit.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/toolkit.py b/toolkit.py index 0240e19..425a95b 100755 --- a/toolkit.py +++ b/toolkit.py @@ -10,13 +10,17 @@ from root_numpy import tree2array, rec2array import numpy as np import pandas as pd import h5py +from sklearn.preprocessing import StandardScaler +from sklearn.externals import joblib import ROOT class KerasROOTClassification: + dataset_names = ["x_train", "x_test", "y_train", "y_test", "w_train", "w_test"] + def __init__(self, name, signal_trees, bkg_trees, branches, weight_expr, identifiers, layers=3, nodes=64, out_dir="./outputs"): @@ -50,6 +54,9 @@ class KerasROOTClassification: self.s_eventlist_train = None self.b_eventlist_train = None + self._scaler = None + + def load_data(self): try: @@ -125,6 +132,29 @@ class KerasROOTClassification: logger.info("Data loaded") + @property + def scaler(self): + # create the scaler (and fit to training data) if not existent + if self._scaler is None: + filename = os.path.join(self.project_dir, "scaler.pkl") + try: + self._scaler = joblib.load(filename) + logger.info("Loaded existing StandardScaler from {}".format(filename)) + except IOError: + logger.info("Creating new StandardScaler") + self._scaler = StandardScaler() + logger.info("Fitting StandardScaler to training data") + self._scaler.fit(self.x_train) + joblib.dump(self._scaler, filename) + return self._scaler + + + def _transform_data(self): + pass + + def _create_model(self): + pass + def train(self): pass @@ -161,3 +191,5 @@ if __name__ == "__main__": c.load_data() print(c.x_train) print(len(c.x_train)) + + print(c.scaler) -- GitLab