Skip to content
Snippets Groups Projects
Commit 1d453dc1 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

Merge remote-tracking branch 'origin/dev-organisation' into dev-organisation

parents b8017999 a7a38241
Branches dev-organisation
No related tags found
No related merge requests found
from toolkit import ClassificationProject from .toolkit import ClassificationProject
from compare import overlay_ROC, overlay_loss from .compare import overlay_ROC, overlay_loss
...@@ -3,6 +3,6 @@ import sys ...@@ -3,6 +3,6 @@ import sys
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from KerasROOTClassification import ClassificationProject from KerasROOTClassification import *
c = ClassificationProject(sys.argv[1]) c = ClassificationProject(sys.argv[1])
#!/usr/bin/env python #!/usr/bin/env python
from sys import version_info
if version_info[0] > 2:
raw_input = input
import os import os
import json import json
import pickle import pickle
...@@ -19,7 +24,7 @@ from sklearn.externals import joblib ...@@ -19,7 +24,7 @@ from sklearn.externals import joblib
from sklearn.metrics import roc_curve, auc from sklearn.metrics import roc_curve, auc
from keras.models import Sequential from keras.models import Sequential
from keras.layers import Dense from keras.layers import Dense, Dropout
from keras.models import model_from_json from keras.models import model_from_json
from keras.callbacks import History, EarlyStopping, CSVLogger from keras.callbacks import History, EarlyStopping, CSVLogger
from keras.optimizers import SGD from keras.optimizers import SGD
...@@ -41,6 +46,7 @@ K.set_session(session) ...@@ -41,6 +46,7 @@ K.set_session(session)
import ROOT import ROOT
class ClassificationProject(object): class ClassificationProject(object):
"""Simple framework to load data from ROOT TTrees and train Keras """Simple framework to load data from ROOT TTrees and train Keras
...@@ -75,12 +81,16 @@ class ClassificationProject(object): ...@@ -75,12 +81,16 @@ class ClassificationProject(object):
:param nodes: number of nodes in each layer :param nodes: number of nodes in each layer
:param dropout: dropout fraction after each hidden layer. Set to None for no Dropout
:param batch_size: size of the training batches :param batch_size: size of the training batches
:param validation_split: split off this fraction of training events for loss evaluation :param validation_split: split off this fraction of training events for loss evaluation
:param activation_function: activation function in the hidden layers :param activation_function: activation function in the hidden layers
:param activation_function_output: activation function in the output layer
:param out_dir: base directory in which the project directories should be stored :param out_dir: base directory in which the project directories should be stored
:param scaler_type: sklearn scaler class name to transform the data before training (options: "StandardScaler", "RobustScaler") :param scaler_type: sklearn scaler class name to transform the data before training (options: "StandardScaler", "RobustScaler")
...@@ -133,9 +143,11 @@ class ClassificationProject(object): ...@@ -133,9 +143,11 @@ class ClassificationProject(object):
selection=None, selection=None,
layers=3, layers=3,
nodes=64, nodes=64,
dropout=None,
batch_size=128, batch_size=128,
validation_split=0.33, validation_split=0.33,
activation_function='relu', activation_function='relu',
activation_function_output='sigmoid',
project_dir=None, project_dir=None,
scaler_type="RobustScaler", scaler_type="RobustScaler",
step_signal=2, step_signal=2,
...@@ -155,9 +167,11 @@ class ClassificationProject(object): ...@@ -155,9 +167,11 @@ class ClassificationProject(object):
self.identifiers = identifiers self.identifiers = identifiers
self.layers = layers self.layers = layers
self.nodes = nodes self.nodes = nodes
self.dropout = dropout
self.batch_size = batch_size self.batch_size = batch_size
self.validation_split = validation_split self.validation_split = validation_split
self.activation_function = activation_function self.activation_function = activation_function
self.activation_function_output = activation_function_output
self.scaler_type = scaler_type self.scaler_type = scaler_type
self.step_signal = step_signal self.step_signal = step_signal
self.step_bkg = step_bkg self.step_bkg = step_bkg
...@@ -426,6 +440,19 @@ class ClassificationProject(object): ...@@ -426,6 +440,19 @@ class ClassificationProject(object):
json.dump(info, of) json.dump(info, of)
@staticmethod
def query_yn(text):
result = None
while result is None:
input_text = raw_input(text)
if len(input_text) > 0:
if input_text.upper()[0] == "Y":
result = True
elif input_text.upper()[0] == "N":
result = False
return result
@property @property
def model(self): def model(self):
"Simple MLP" "Simple MLP"
...@@ -439,8 +466,10 @@ class ClassificationProject(object): ...@@ -439,8 +466,10 @@ class ClassificationProject(object):
# the other hidden layers # the other hidden layers
for layer_number in range(self.layers-1): for layer_number in range(self.layers-1):
self._model.add(Dense(self.nodes, activation=self.activation_function)) self._model.add(Dense(self.nodes, activation=self.activation_function))
if self.dropout is not None:
self._model.add(Dropout(rate=self.dropout))
# last layer is one neuron (binary classification) # last layer is one neuron (binary classification)
self._model.add(Dense(1, activation='sigmoid')) self._model.add(Dense(1, activation=self.activation_function_output))
logger.info("Using {}(**{}) as Optimizer".format(self.optimizer, self.optimizer_opts)) logger.info("Using {}(**{}) as Optimizer".format(self.optimizer, self.optimizer_opts))
Optimizer = getattr(keras.optimizers, self.optimizer) Optimizer = getattr(keras.optimizers, self.optimizer)
optimizer = Optimizer(**self.optimizer_opts) optimizer = Optimizer(**self.optimizer_opts)
...@@ -451,10 +480,14 @@ class ClassificationProject(object): ...@@ -451,10 +480,14 @@ class ClassificationProject(object):
loss='binary_crossentropy', loss='binary_crossentropy',
metrics=['accuracy']) metrics=['accuracy'])
np.random.set_state(rn_state) np.random.set_state(rn_state)
try: if os.path.exists(os.path.join(self.project_dir, "weights.h5")):
self.model.load_weights(os.path.join(self.project_dir, "weights.h5")) continue_training = self.query_yn("Found previously trained weights - continue training? (Y/N) ")
logger.info("Found and loaded previously trained weights") if continue_training:
except IOError: self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
logger.info("Found and loaded previously trained weights")
else:
logger.info("Starting completely new model")
else:
logger.info("No weights found, starting completely new model") logger.info("No weights found, starting completely new model")
# dump to json for documentation # dump to json for documentation
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment