Skip to content
Snippets Groups Projects
Commit 1d453dc1 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

Merge remote-tracking branch 'origin/dev-organisation' into dev-organisation

parents b8017999 a7a38241
No related branches found
No related tags found
No related merge requests found
from toolkit import ClassificationProject
from compare import overlay_ROC, overlay_loss
from .toolkit import ClassificationProject
from .compare import overlay_ROC, overlay_loss
......@@ -3,6 +3,6 @@ import sys
import numpy as np
import matplotlib.pyplot as plt
from KerasROOTClassification import ClassificationProject
from KerasROOTClassification import *
c = ClassificationProject(sys.argv[1])
#!/usr/bin/env python
from sys import version_info
if version_info[0] > 2:
raw_input = input
import os
import json
import pickle
......@@ -19,7 +24,7 @@ from sklearn.externals import joblib
from sklearn.metrics import roc_curve, auc
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dense, Dropout
from keras.models import model_from_json
from keras.callbacks import History, EarlyStopping, CSVLogger
from keras.optimizers import SGD
......@@ -41,6 +46,7 @@ K.set_session(session)
import ROOT
class ClassificationProject(object):
"""Simple framework to load data from ROOT TTrees and train Keras
......@@ -75,12 +81,16 @@ class ClassificationProject(object):
:param nodes: number of nodes in each layer
:param dropout: dropout fraction after each hidden layer. Set to None for no Dropout
:param batch_size: size of the training batches
:param validation_split: split off this fraction of training events for loss evaluation
:param activation_function: activation function in the hidden layers
:param activation_function_output: activation function in the output layer
:param out_dir: base directory in which the project directories should be stored
:param scaler_type: sklearn scaler class name to transform the data before training (options: "StandardScaler", "RobustScaler")
......@@ -133,9 +143,11 @@ class ClassificationProject(object):
selection=None,
layers=3,
nodes=64,
dropout=None,
batch_size=128,
validation_split=0.33,
activation_function='relu',
activation_function_output='sigmoid',
project_dir=None,
scaler_type="RobustScaler",
step_signal=2,
......@@ -155,9 +167,11 @@ class ClassificationProject(object):
self.identifiers = identifiers
self.layers = layers
self.nodes = nodes
self.dropout = dropout
self.batch_size = batch_size
self.validation_split = validation_split
self.activation_function = activation_function
self.activation_function_output = activation_function_output
self.scaler_type = scaler_type
self.step_signal = step_signal
self.step_bkg = step_bkg
......@@ -426,6 +440,19 @@ class ClassificationProject(object):
json.dump(info, of)
@staticmethod
def query_yn(text):
result = None
while result is None:
input_text = raw_input(text)
if len(input_text) > 0:
if input_text.upper()[0] == "Y":
result = True
elif input_text.upper()[0] == "N":
result = False
return result
@property
def model(self):
"Simple MLP"
......@@ -439,8 +466,10 @@ class ClassificationProject(object):
# the other hidden layers
for layer_number in range(self.layers-1):
self._model.add(Dense(self.nodes, activation=self.activation_function))
if self.dropout is not None:
self._model.add(Dropout(rate=self.dropout))
# last layer is one neuron (binary classification)
self._model.add(Dense(1, activation='sigmoid'))
self._model.add(Dense(1, activation=self.activation_function_output))
logger.info("Using {}(**{}) as Optimizer".format(self.optimizer, self.optimizer_opts))
Optimizer = getattr(keras.optimizers, self.optimizer)
optimizer = Optimizer(**self.optimizer_opts)
......@@ -451,10 +480,14 @@ class ClassificationProject(object):
loss='binary_crossentropy',
metrics=['accuracy'])
np.random.set_state(rn_state)
try:
self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
logger.info("Found and loaded previously trained weights")
except IOError:
if os.path.exists(os.path.join(self.project_dir, "weights.h5")):
continue_training = self.query_yn("Found previously trained weights - continue training? (Y/N) ")
if continue_training:
self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
logger.info("Found and loaded previously trained weights")
else:
logger.info("Starting completely new model")
else:
logger.info("No weights found, starting completely new model")
# dump to json for documentation
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment