Skip to content
Snippets Groups Projects
Commit d09ab3d6 authored by Nikolai.Hartmann's avatar Nikolai.Hartmann
Browse files

model and training for RNN working

parent fe84f351
No related branches found
No related tags found
No related merge requests found
......@@ -31,9 +31,8 @@ import h5py
from sklearn.preprocessing import StandardScaler, RobustScaler
from sklearn.externals import joblib
from sklearn.metrics import roc_curve, auc
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.models import model_from_json
from keras.models import Sequential, Model, model_from_json
from keras.layers import Dense, Dropout, Input, Masking, GRU, concatenate
from keras.callbacks import History, EarlyStopping, CSVLogger, ModelCheckpoint, TensorBoard
from keras.optimizers import SGD
import keras.optimizers
......@@ -649,37 +648,41 @@ class ClassificationProject(object):
# last layer is one neuron (binary classification)
self._model.add(Dense(1, activation=self.activation_function_output))
logger.info("Using {}(**{}) as Optimizer".format(self.optimizer, self.optimizer_opts))
Optimizer = getattr(keras.optimizers, self.optimizer)
optimizer = Optimizer(**self.optimizer_opts)
logger.info("Compile model")
rn_state = np.random.get_state()
np.random.seed(self.random_seed)
self._model.compile(optimizer=optimizer,
loss=self.loss,
weighted_metrics=['accuracy']
)
np.random.set_state(rn_state)
self._compile_or_load_model()
if os.path.exists(os.path.join(self.project_dir, "weights.h5")):
if self.is_training:
continue_training = self.query_yn("Found previously trained weights - "
"continue training (choosing N will restart)? (Y/N) ")
else:
continue_training = True
if continue_training:
self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
logger.info("Found and loaded previously trained weights")
else:
logger.info("Starting completely new model")
else:
logger.info("No weights found, starting completely new model")
return self._model
# dump to json for documentation
with open(os.path.join(self.project_dir, "model.json"), "w") as of:
of.write(self._model.to_json())
return self._model
def _compile_or_load_model(self):
logger.info("Using {}(**{}) as Optimizer".format(self.optimizer, self.optimizer_opts))
Optimizer = getattr(keras.optimizers, self.optimizer)
optimizer = Optimizer(**self.optimizer_opts)
logger.info("Compile model")
rn_state = np.random.get_state()
np.random.seed(self.random_seed)
self._model.compile(optimizer=optimizer,
loss=self.loss,
weighted_metrics=['accuracy']
)
np.random.set_state(rn_state)
if os.path.exists(os.path.join(self.project_dir, "weights.h5")):
if self.is_training:
continue_training = self.query_yn("Found previously trained weights - "
"continue training (choosing N will restart)? (Y/N) ")
else:
continue_training = True
if continue_training:
self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
logger.info("Found and loaded previously trained weights")
else:
logger.info("Starting completely new model")
else:
logger.info("No weights found, starting completely new model")
# dump to json for documentation
with open(os.path.join(self.project_dir, "model.json"), "w") as of:
of.write(self._model.to_json())
@property
......@@ -1413,18 +1416,63 @@ class ClassificationProjectRNN(ClassificationProject):
@property
def model():
pass
def model(self):
if self._model is None:
# following the setup from the tutorial:
# https://github.com/YaleATLAS/CERNDeepLearningTutorial
rnn_inputs = []
rnn_channels = []
for field_idx in self.recurrent_field_idx:
chan_inp = Input(field_idx.shape[1:])
channel = Masking(mask_value=self.mask_value)(chan_inp)
channel = GRU(32)(channel)
# TODO: configure dropout for recurrent layers
#channel = Dropout(0.3)(channel)
rnn_inputs.append(chan_inp)
rnn_channels.append(channel)
flat_input = Input((len(self.flat_fields),))
if self.dropout_input is None:
flat_channel = Dropout(rate=self.dropout_input)(flat_input)
else:
flat_channel = flat_input
combined = concatenate(rnn_channels+[flat_channel])
for node_count, dropout_fraction in zip(self.nodes, self.dropout):
combined = Dense(node_count, activation=self.activation_function)(combined)
if (dropout_fraction is not None) and (dropout_fraction > 0):
combined = Dropout(rate=dropout_fraction)(combined)
combined = Dense(1, activation=self.activation_function_output)(combined)
self._model = Model(inputs=rnn_inputs+[flat_input], outputs=combined)
self._compile_or_load_model()
else:
return self._model
def train(self, epochs=10):
try:
self.shuffle_training_data() # needed here too, in order to get correct validation data
self.is_training = True
logger.info("Training on batches for RNN")
# note: the batches have class_weight already applied
self.model.fit_generator(self.yield_batch(),
steps_per_epoch=int(len(self.training_data[0])/self.batch_size),
epochs=epochs,
validation_data=self.class_weighted_validation_data,
callbacks=self.callbacks_list)
self.is_training = False
except KeyboardInterrupt:
logger.info("Interrupt training - continue with rest")
logger.info("Save history")
self._dump_history()
def get_input_list(self, x):
"Format the input starting from flat ntuple"
x_input = []
x_flat = x[:,[self.fields.index(field_name) for field_name in self.flat_fields]]
x_input.append(x_flat)
for field_idx in self.recurrent_field_idx:
x_recurrent = x[:,field_idx.reshape(-1)].reshape(-1, *field_idx.shape)
x_recurrent = x[:,field_idx.reshape(-1)].reshape(-1, *field_idx.shape[1:])
x_input.append(x_recurrent)
x_flat = x[:,[self.fields.index(field_name) for field_name in self.flat_fields]]
x_input.append(x_flat)
return x_input
......@@ -1442,6 +1490,14 @@ class ClassificationProjectRNN(ClassificationProject):
w_batch*np.array(self.class_weight)[y_batch.astype(int)])
@property
def class_weighted_validation_data(self):
"class weighted validation data. Attention: Shuffle training data before using this!"
x_val, y_val, w_val = super(ClassificationProjectRNN, self).class_weighted_validation_data
x_val_input = self.get_input_list(x_val)
return x_val_input, y_val, w_val
if __name__ == "__main__":
logging.basicConfig()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment