Skip to content
Snippets Groups Projects
Commit 35d84c65 authored by Nikolai's avatar Nikolai
Browse files

Transform and plot data before training

parent 34930398
No related branches found
No related tags found
No related merge requests found
......@@ -13,10 +13,10 @@ import pandas as pd
import h5py
from sklearn.preprocessing import StandardScaler
from sklearn.externals import joblib
from keras.models import Sequential
from keras.layers import Dense
from keras.models import model_from_json
import matplotlib.pyplot as plt
# configure number of cores
# this doesn't seem to work, but at least with these settings keras only uses 4 processes
......@@ -75,12 +75,15 @@ class KerasROOTClassification:
self._scaler = None
self._class_weight = None
self._bkg_weights = None
self._sig_weights = None
self._model = None
# track the number of epochs this model has been trained
self.total_epochs = 0
self.data_loaded = False
self.data_transformed = False
def _load_data(self):
......@@ -177,6 +180,18 @@ class KerasROOTClassification:
return self._scaler
def _transform_data(self):
if not self.data_transformed:
# todo: what to do about the outliers? Where do they come from?
logger.debug("training data before transformation: {}".format(self.x_train))
logger.debug("minimum values: {}".format([np.min(self.x_train[:,i]) for i in range(self.x_train.shape[1])]))
logger.debug("maximum values: {}".format([np.max(self.x_train[:,i]) for i in range(self.x_train.shape[1])]))
self.x_train = self.scaler.transform(self.x_train)
logger.debug("training data after transformation: {}".format(self.x_train))
self.x_test = self.scaler.transform(self.x_test)
self.data_transformed = True
def _read_info(self, key, default):
filename = os.path.join(self.project_dir, "info.json")
if not os.path.exists(filename):
......@@ -222,6 +237,7 @@ class KerasROOTClassification:
return self._model
@property
def class_weight(self):
if self._class_weight is None:
......@@ -230,11 +246,18 @@ class KerasROOTClassification:
self._class_weight = [(sumw_sig+sumw_bkg)/(2*sumw_bkg), (sumw_sig+sumw_bkg)/(2*sumw_sig)]
return self._class_weight
def train(self, epochs=10):
if not self.data_loaded:
self._load_data()
if not self.data_transformed:
self._transform_data()
for branch_index, branch in enumerate(self.branches):
self.plot_input(branch_index)
try:
self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
logger.info("Weights found and loaded")
......@@ -244,7 +267,9 @@ class KerasROOTClassification:
self.total_epochs = self._read_info("epochs", 0)
self.model.fit(self.x_train, self.y_train,
self.model.fit(self.x_train,
# the reshape might be unnescessary here
self.y_train.reshape(-1, 1),
epochs=epochs,
class_weight=self.class_weight,
shuffle=True,
......@@ -261,6 +286,52 @@ class KerasROOTClassification:
def writeFriendTree(self):
pass
@property
def bkg_weights(self):
"""
class weights multiplied by event weights (for plotting)
TODO: find a better way to do this
"""
if self._bkg_weights is None:
logger.debug("Calculating background weights for plotting")
self._bkg_weights = np.empty(sum(self.y_train == 0))
self._bkg_weights.fill(self.class_weight[0])
self._bkg_weights *= self.w_train[self.y_train == 0]
return self._bkg_weights
@property
def sig_weights(self):
"""
class weights multiplied by event weights (for plotting)
TODO: find a better way to do this
"""
if self._sig_weights is None:
logger.debug("Calculating signal weights for plotting")
self._sig_weights = np.empty(sum(self.y_train == 1))
self._sig_weights.fill(self.class_weight[1])
self._sig_weights *= self.w_train[self.y_train == 1]
return self._sig_weights
def plot_input(self, var_index):
"plot a single input variable"
branch = self.branches[var_index]
fig, ax = plt.subplots()
bkg = self.x_train[:,var_index][self.y_train == 0]
sig = self.x_train[:,var_index][self.y_train == 1]
logger.debug("Plotting bkg (min={}, max={}) from {}".format(np.min(bkg), np.max(bkg), bkg))
logger.debug("Plotting sig (min={}, max={}) from {}".format(np.min(sig), np.max(sig), sig))
ax.hist(bkg, color="b", alpha=0.5, bins=50, weights=self.bkg_weights)
ax.hist(sig, color="r", alpha=0.5, bins=50, weights=self.sig_weights)
ax.set_xlabel(branch+" (transformed)")
plot_dir = os.path.join(self.project_dir, "plots")
if not os.path.exists(plot_dir):
os.mkdir(plot_dir)
fig.savefig(os.path.join(plot_dir, "var_{}.pdf".format(var_index)))
def plotROC(self):
pass
......@@ -272,7 +343,8 @@ class KerasROOTClassification:
if __name__ == "__main__":
logging.basicConfig()
logging.getLogger("KerasROOTClassification").setLevel(logging.INFO)
#logging.getLogger("KerasROOTClassification").setLevel(logging.INFO)
logging.getLogger("KerasROOTClassification").setLevel(logging.DEBUG)
filename = "/project/etp4/nhartmann/trees/allTrees_m1.8_NoSys.root"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment