Skip to content
Snippets Groups Projects
Commit 35d84c65 authored by Nikolai's avatar Nikolai
Browse files

Transform and plot data before training

parent 34930398
No related branches found
No related tags found
No related merge requests found
...@@ -13,10 +13,10 @@ import pandas as pd ...@@ -13,10 +13,10 @@ import pandas as pd
import h5py import h5py
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
from sklearn.externals import joblib from sklearn.externals import joblib
from keras.models import Sequential from keras.models import Sequential
from keras.layers import Dense from keras.layers import Dense
from keras.models import model_from_json from keras.models import model_from_json
import matplotlib.pyplot as plt
# configure number of cores # configure number of cores
# this doesn't seem to work, but at least with these settings keras only uses 4 processes # this doesn't seem to work, but at least with these settings keras only uses 4 processes
...@@ -75,12 +75,15 @@ class KerasROOTClassification: ...@@ -75,12 +75,15 @@ class KerasROOTClassification:
self._scaler = None self._scaler = None
self._class_weight = None self._class_weight = None
self._bkg_weights = None
self._sig_weights = None
self._model = None self._model = None
# track the number of epochs this model has been trained # track the number of epochs this model has been trained
self.total_epochs = 0 self.total_epochs = 0
self.data_loaded = False self.data_loaded = False
self.data_transformed = False
def _load_data(self): def _load_data(self):
...@@ -177,6 +180,18 @@ class KerasROOTClassification: ...@@ -177,6 +180,18 @@ class KerasROOTClassification:
return self._scaler return self._scaler
def _transform_data(self):
if not self.data_transformed:
# todo: what to do about the outliers? Where do they come from?
logger.debug("training data before transformation: {}".format(self.x_train))
logger.debug("minimum values: {}".format([np.min(self.x_train[:,i]) for i in range(self.x_train.shape[1])]))
logger.debug("maximum values: {}".format([np.max(self.x_train[:,i]) for i in range(self.x_train.shape[1])]))
self.x_train = self.scaler.transform(self.x_train)
logger.debug("training data after transformation: {}".format(self.x_train))
self.x_test = self.scaler.transform(self.x_test)
self.data_transformed = True
def _read_info(self, key, default): def _read_info(self, key, default):
filename = os.path.join(self.project_dir, "info.json") filename = os.path.join(self.project_dir, "info.json")
if not os.path.exists(filename): if not os.path.exists(filename):
...@@ -222,6 +237,7 @@ class KerasROOTClassification: ...@@ -222,6 +237,7 @@ class KerasROOTClassification:
return self._model return self._model
@property @property
def class_weight(self): def class_weight(self):
if self._class_weight is None: if self._class_weight is None:
...@@ -230,11 +246,18 @@ class KerasROOTClassification: ...@@ -230,11 +246,18 @@ class KerasROOTClassification:
self._class_weight = [(sumw_sig+sumw_bkg)/(2*sumw_bkg), (sumw_sig+sumw_bkg)/(2*sumw_sig)] self._class_weight = [(sumw_sig+sumw_bkg)/(2*sumw_bkg), (sumw_sig+sumw_bkg)/(2*sumw_sig)]
return self._class_weight return self._class_weight
def train(self, epochs=10): def train(self, epochs=10):
if not self.data_loaded: if not self.data_loaded:
self._load_data() self._load_data()
if not self.data_transformed:
self._transform_data()
for branch_index, branch in enumerate(self.branches):
self.plot_input(branch_index)
try: try:
self.model.load_weights(os.path.join(self.project_dir, "weights.h5")) self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
logger.info("Weights found and loaded") logger.info("Weights found and loaded")
...@@ -244,7 +267,9 @@ class KerasROOTClassification: ...@@ -244,7 +267,9 @@ class KerasROOTClassification:
self.total_epochs = self._read_info("epochs", 0) self.total_epochs = self._read_info("epochs", 0)
self.model.fit(self.x_train, self.y_train, self.model.fit(self.x_train,
# the reshape might be unnescessary here
self.y_train.reshape(-1, 1),
epochs=epochs, epochs=epochs,
class_weight=self.class_weight, class_weight=self.class_weight,
shuffle=True, shuffle=True,
...@@ -261,6 +286,52 @@ class KerasROOTClassification: ...@@ -261,6 +286,52 @@ class KerasROOTClassification:
def writeFriendTree(self): def writeFriendTree(self):
pass pass
@property
def bkg_weights(self):
"""
class weights multiplied by event weights (for plotting)
TODO: find a better way to do this
"""
if self._bkg_weights is None:
logger.debug("Calculating background weights for plotting")
self._bkg_weights = np.empty(sum(self.y_train == 0))
self._bkg_weights.fill(self.class_weight[0])
self._bkg_weights *= self.w_train[self.y_train == 0]
return self._bkg_weights
@property
def sig_weights(self):
"""
class weights multiplied by event weights (for plotting)
TODO: find a better way to do this
"""
if self._sig_weights is None:
logger.debug("Calculating signal weights for plotting")
self._sig_weights = np.empty(sum(self.y_train == 1))
self._sig_weights.fill(self.class_weight[1])
self._sig_weights *= self.w_train[self.y_train == 1]
return self._sig_weights
def plot_input(self, var_index):
"plot a single input variable"
branch = self.branches[var_index]
fig, ax = plt.subplots()
bkg = self.x_train[:,var_index][self.y_train == 0]
sig = self.x_train[:,var_index][self.y_train == 1]
logger.debug("Plotting bkg (min={}, max={}) from {}".format(np.min(bkg), np.max(bkg), bkg))
logger.debug("Plotting sig (min={}, max={}) from {}".format(np.min(sig), np.max(sig), sig))
ax.hist(bkg, color="b", alpha=0.5, bins=50, weights=self.bkg_weights)
ax.hist(sig, color="r", alpha=0.5, bins=50, weights=self.sig_weights)
ax.set_xlabel(branch+" (transformed)")
plot_dir = os.path.join(self.project_dir, "plots")
if not os.path.exists(plot_dir):
os.mkdir(plot_dir)
fig.savefig(os.path.join(plot_dir, "var_{}.pdf".format(var_index)))
def plotROC(self): def plotROC(self):
pass pass
...@@ -272,7 +343,8 @@ class KerasROOTClassification: ...@@ -272,7 +343,8 @@ class KerasROOTClassification:
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig() logging.basicConfig()
logging.getLogger("KerasROOTClassification").setLevel(logging.INFO) #logging.getLogger("KerasROOTClassification").setLevel(logging.INFO)
logging.getLogger("KerasROOTClassification").setLevel(logging.DEBUG)
filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root" filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment