Skip to content
Snippets Groups Projects
Unverified Commit 595d0af5 authored by Eric Schanet's avatar Eric Schanet
Browse files

Merge branch 'master' of gitlab.physik.uni-muenchen.de:Nikolai.Hartmann/KerasROOTClassification

* 'master' of gitlab.physik.uni-muenchen.de:Nikolai.Hartmann/KerasROOTClassification:
  normalize weights for default training as well
  try normalising weights
  put model checkpoint/reload weights into separate function
parents 17746465 4ac1c6b5
No related branches found
No related tags found
No related merge requests found
......@@ -338,6 +338,8 @@ class ClassificationProject(object):
self._fields = None
self._mean_train_weight = None
@property
def fields(self):
......@@ -741,6 +743,13 @@ class ClassificationProject(object):
np.random.shuffle(self._scores_train)
@property
def mean_train_weight(self):
if self._mean_train_weight is None:
self._mean_train_weight = np.mean(self.w_train*np.array(self.class_weight)[self.y_train.astype(int)])
return self._mean_train_weight
@property
def w_validation(self):
"class weighted validation data weights"
......@@ -749,7 +758,7 @@ class ClassificationProject(object):
self._w_validation = np.array(self.w_train[split_index:])
self._w_validation[self.y_train[split_index:]==0] *= self.class_weight[0]
self._w_validation[self.y_train[split_index:]==1] *= self.class_weight[1]
return self._w_validation
return self._w_validation/self.mean_train_weight
@property
......@@ -819,7 +828,7 @@ class ClassificationProject(object):
validation_split = self.validation_split,
# we have to multiply by class weight since keras ignores class weight if sample weight is given
# see https://github.com/keras-team/keras/issues/497
sample_weight=self.w_train*np.array(self.class_weight)[self.y_train.astype(int)],
sample_weight=self.w_train*np.array(self.class_weight)[self.y_train.astype(int)]/self.mean_train_weight,
shuffle=True,
batch_size=self.batch_size,
callbacks=self.callbacks_list)
......@@ -842,6 +851,11 @@ class ClassificationProject(object):
except KeyboardInterrupt:
logger.info("Interrupt training - continue with rest")
self.checkpoint_model(epochs)
def checkpoint_model(self, epochs):
logger.info("Save history")
self._dump_history()
......@@ -1455,6 +1469,8 @@ class ClassificationProjectRNN(ClassificationProject):
for branch_index, branch in enumerate(self.fields):
self.plot_input(branch_index)
self.total_epochs = self._read_info("epochs", 0)
try:
self.shuffle_training_data() # needed here too, in order to get correct validation data
self.is_training = True
......@@ -1468,8 +1484,8 @@ class ClassificationProjectRNN(ClassificationProject):
self.is_training = False
except KeyboardInterrupt:
logger.info("Interrupt training - continue with rest")
logger.info("Save history")
self._dump_history()
self.checkpoint_model(epochs)
def get_input_list(self, x):
......@@ -1494,7 +1510,7 @@ class ClassificationProjectRNN(ClassificationProject):
x_input = self.get_input_list(x_batch)
yield (x_input,
y_train[shuffled_idx[start:start+int(self.batch_size)]],
w_batch*np.array(self.class_weight)[y_batch.astype(int)])
w_batch*np.array(self.class_weight)[y_batch.astype(int)]/self.mean_train_weight)
@property
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment