Skip to content
Snippets Groups Projects
Commit 54d5cf3a authored by Nikolai's avatar Nikolai
Browse files

Adding option to train on batches with equal number of events of both classes

parent 4bd7de4a
No related branches found
No related tags found
No related merge requests found
......@@ -4,6 +4,9 @@ from sys import version_info
if version_info[0] > 2:
raw_input = input
izip = zip
else:
from itertools import izip
import os
import json
......@@ -108,6 +111,13 @@ class ClassificationProject(object):
:param earlystopping_opts: options for the keras EarlyStopping callback
:param use_modelcheckpoint: save model weights after each epoch and don't save after no validation loss improvement
:param balance_dataset: if True, balance the dataset instead of
applying class weights. Only a fraction of the overrepresented
class will be used in each epoch, but different subsets of the
overrepresented class will be used in each epoch.
:param random_seed: use this seed value when initialising the model and produce consistent results. Note:
random data is also used for shuffling the training data, so results may vary still. To
produce consistent results, set the numpy random seed before training.
......@@ -158,7 +168,8 @@ class ClassificationProject(object):
use_earlystopping=True,
earlystopping_opts=None,
use_modelcheckpoint=True,
random_seed=1234):
random_seed=1234,
balance_dataset=False):
self.name = name
self.signal_trees = signal_trees
......@@ -186,6 +197,8 @@ class ClassificationProject(object):
if earlystopping_opts is None:
earlystopping_opts = dict()
self.earlystopping_opts = earlystopping_opts
self.random_seed = random_seed
self.balance_dataset = balance_dataset
self.project_dir = project_dir
if self.project_dir is None:
......@@ -194,8 +207,6 @@ class ClassificationProject(object):
if not os.path.exists(self.project_dir):
os.mkdir(self.project_dir)
self.random_seed = random_seed
self.s_train = None
self.b_train = None
self.s_test = None
......@@ -210,6 +221,9 @@ class ClassificationProject(object):
self._scores_train = None
self._scores_test = None
# class weighted validation data
self._w_validation = None
self._s_eventlist_train = None
self._b_eventlist_train = None
......@@ -550,6 +564,54 @@ class ClassificationProject(object):
np.random.shuffle(self._scores_train)
@property
def w_validation(self):
"class weighted validation data"
split_index = int((1-self.validation_split)*len(self.x_train))
if self._w_validation is None:
self._w_validation = np.array(self.w_train[split_index:])
self._w_validation[self.y_train[split_index:]==0] *= self.class_weight[0]
self._w_validation[self.y_train[split_index:]==1] *= self.class_weight[1]
return self._w_validation
@property
def class_weighted_validation_data(self):
split_index = int((1-self.validation_split)*len(self.x_train))
return self.x_train[split_index:], self.y_train[split_index:], self.w_validation
@property
def training_data(self):
"training data with validation data split off"
split_index = int((1-self.validation_split)*len(self.x_train))
return self.x_train[:split_index], self.y_train[:split_index], self.w_train[:split_index]
def yield_batch(self, class_label):
while True:
x_train, y_train, w_train = self.training_data
# shuffle the entries for this class label
rn_state = np.random.get_state()
x_train[y_train==class_label] = np.random.permutation(x_train[y_train==class_label])
np.random.set_state(rn_state)
w_train[y_train==class_label] = np.random.permutation(w_train[y_train==class_label])
# yield them batch wise
for start in range(0, len(x_train[y_train==class_label]), self.batch_size):
yield (x_train[y_train==class_label][start:start+self.batch_size],
y_train[y_train==class_label][start:start+self.batch_size],
w_train[y_train==class_label][start:start+self.batch_size])
# restart
def yield_balanced_batch(self):
"generate batches with equal amounts of both classes"
for batch_0, batch_1 in izip(self.yield_batch(0), self.yield_batch(1)):
yield (np.concatenate((batch_0[0], batch_1[0])),
np.concatenate((batch_0[1], batch_1[1])),
np.concatenate((batch_0[2], batch_1[2])))
def train(self, epochs=10):
self.load()
......@@ -560,22 +622,36 @@ class ClassificationProject(object):
self.total_epochs = self._read_info("epochs", 0)
logger.info("Train model")
try:
self.shuffle_training_data()
self.is_training = True
self.model.fit(self.x_train,
# the reshape might be unnescessary here
self.y_train.reshape(-1, 1),
epochs=epochs,
validation_split = self.validation_split,
class_weight=self.class_weight,
sample_weight=self.w_train,
shuffle=True,
batch_size=self.batch_size,
callbacks=self.callbacks_list)
self.is_training = False
except KeyboardInterrupt:
logger.info("Interrupt training - continue with rest")
if not self.balance_dataset:
try:
self.shuffle_training_data()
self.is_training = True
self.model.fit(self.x_train,
# the reshape might be unnescessary here
self.y_train.reshape(-1, 1),
epochs=epochs,
validation_split = self.validation_split,
class_weight=self.class_weight,
sample_weight=self.w_train,
shuffle=True,
batch_size=self.batch_size,
callbacks=self.callbacks_list)
self.is_training = False
except KeyboardInterrupt:
logger.info("Interrupt training - continue with rest")
else:
try:
self.is_training = True
labels, label_counts = np.unique(self.y_train, return_counts=True)
logger.info("Training on balanced batches")
self.model.fit_generator(self.yield_balanced_batch(),
steps_per_epoch=int(min(label_counts)/self.batch_size),
epochs=epochs,
validation_data=self.class_weighted_validation_data,
callbacks=self.callbacks_list)
self.is_training = False
except KeyboardInterrupt:
logger.info("Interrupt training - continue with rest")
logger.info("Save history")
self._dump_history()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment