Skip to content
Snippets Groups Projects
Commit 81dfd04c authored by Nikolai's avatar Nikolai
Browse files

Random seed before initialising model

parent 8831b35e
No related branches found
No related tags found
No related merge requests found
...@@ -94,6 +94,10 @@ class KerasROOTClassification(object): ...@@ -94,6 +94,10 @@ class KerasROOTClassification(object):
:param earlystopping_opts: options for the keras EarlyStopping callback :param earlystopping_opts: options for the keras EarlyStopping callback
:param random_seed: use this seed value when initialising the model and produce consistent results. Note:
random data is also used for shuffling the training data, so results may vary still. To
produce consistent results, set the numpy random seed before training.
""" """
...@@ -134,7 +138,8 @@ class KerasROOTClassification(object): ...@@ -134,7 +138,8 @@ class KerasROOTClassification(object):
step_bkg=2, step_bkg=2,
optimizer="SGD", optimizer="SGD",
optimizer_opts=None, optimizer_opts=None,
earlystopping_opts=None): earlystopping_opts=None,
random_seed=1234):
self.name = name self.name = name
self.signal_trees = signal_trees self.signal_trees = signal_trees
...@@ -168,6 +173,8 @@ class KerasROOTClassification(object): ...@@ -168,6 +173,8 @@ class KerasROOTClassification(object):
if not os.path.exists(self.project_dir): if not os.path.exists(self.project_dir):
os.mkdir(self.project_dir) os.mkdir(self.project_dir)
self.random_seed = random_seed
self.s_train = None self.s_train = None
self.b_train = None self.b_train = None
self.s_test = None self.s_test = None
...@@ -434,10 +441,12 @@ class KerasROOTClassification(object): ...@@ -434,10 +441,12 @@ class KerasROOTClassification(object):
Optimizer = getattr(keras.optimizers, self.optimizer) Optimizer = getattr(keras.optimizers, self.optimizer)
optimizer = Optimizer(**self.optimizer_opts) optimizer = Optimizer(**self.optimizer_opts)
logger.info("Compile model") logger.info("Compile model")
rn_state = np.random.get_state()
np.random.seed(self.random_seed)
self._model.compile(optimizer=optimizer, self._model.compile(optimizer=optimizer,
loss='binary_crossentropy', loss='binary_crossentropy',
metrics=['accuracy']) metrics=['accuracy'])
np.random.set_state(rn_state)
try: try:
self.model.load_weights(os.path.join(self.project_dir, "weights.h5")) self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
logger.info("Found and loaded previously trained weights") logger.info("Found and loaded previously trained weights")
...@@ -495,7 +504,6 @@ class KerasROOTClassification(object): ...@@ -495,7 +504,6 @@ class KerasROOTClassification(object):
try: try:
self.history = History() self.history = History()
self.shuffle_training_data() self.shuffle_training_data()
self.model.fit(self.x_train, self.model.fit(self.x_train,
# the reshape might be unnescessary here # the reshape might be unnescessary here
self.y_train.reshape(-1, 1), self.y_train.reshape(-1, 1),
...@@ -732,27 +740,26 @@ if __name__ == "__main__": ...@@ -732,27 +740,26 @@ if __name__ == "__main__":
(filename, "wjets_Sherpa221_NoSys") (filename, "wjets_Sherpa221_NoSys")
], ],
optimizer="Adam", optimizer="Adam",
#optimizer="SGD",
#optimizer_opts=dict(lr=100., decay=1e-6, momentum=0.9), #optimizer_opts=dict(lr=100., decay=1e-6, momentum=0.9),
earlystopping_opts=dict(monitor='val_loss', earlystopping_opts=dict(monitor='val_loss',
min_delta=0, patience=2, verbose=0, mode='auto'), min_delta=0, patience=2, verbose=0, mode='auto'),
# optimizer="Adam",
selection="lep1Pt<5000", # cut out a few very weird outliers selection="lep1Pt<5000", # cut out a few very weird outliers
branches = ["met", "mt"], branches = ["met", "mt"],
weight_expr = "eventWeight*genWeight", weight_expr = "eventWeight*genWeight",
identifiers = ["DatasetNumber", "EventNumber"], identifiers = ["DatasetNumber", "EventNumber"],
step_bkg = 100) step_bkg = 100)
np.random.seed(42)
c.train(epochs=20) c.train(epochs=20)
c.plot_ROC() c.plot_ROC()
c.plot_loss() c.plot_loss()
c.plot_accuracy() c.plot_accuracy()
c.write_friend_tree("test4_score", # c.write_friend_tree("test4_score",
source_filename=filename, source_treename="GG_oneStep_1705_1105_505_NoSys", # source_filename=filename, source_treename="GG_oneStep_1705_1105_505_NoSys",
target_filename="friend.root", target_treename="test4_score") # target_filename="friend.root", target_treename="test4_score")
np.random.seed(1234)
c.write_friend_tree("test4_score", # c.write_friend_tree("test4_score",
source_filename=filename, source_treename="ttbar_NoSys", # source_filename=filename, source_treename="ttbar_NoSys",
target_filename="friend_ttbar_NoSys.root", target_treename="test4_score") # target_filename="friend_ttbar_NoSys.root", target_treename="test4_score")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment