From 54d5cf3a4a15bdfacc68651c55376e2afb8202b5 Mon Sep 17 00:00:00 2001 From: Nikolai <osterei33@gmx.de> Date: Thu, 17 May 2018 07:55:36 +0200 Subject: [PATCH] Adding option to train on batches with equal number of events of both classes --- toolkit.py | 114 ++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 95 insertions(+), 19 deletions(-) diff --git a/toolkit.py b/toolkit.py index dbec934..5593a29 100755 --- a/toolkit.py +++ b/toolkit.py @@ -4,6 +4,9 @@ from sys import version_info if version_info[0] > 2: raw_input = input + izip = zip +else: + from itertools import izip import os import json @@ -108,6 +111,13 @@ class ClassificationProject(object): :param earlystopping_opts: options for the keras EarlyStopping callback + :param use_modelcheckpoint: save model weights after each epoch and don't save after no validation loss improvement + + :param balance_dataset: if True, balance the dataset instead of + applying class weights. Only a fraction of the overrepresented + class will be used in each epoch, but different subsets of the + overrepresented class will be used in each epoch. + :param random_seed: use this seed value when initialising the model and produce consistent results. Note: random data is also used for shuffling the training data, so results may vary still. To produce consistent results, set the numpy random seed before training. @@ -158,7 +168,8 @@ class ClassificationProject(object): use_earlystopping=True, earlystopping_opts=None, use_modelcheckpoint=True, - random_seed=1234): + random_seed=1234, + balance_dataset=False): self.name = name self.signal_trees = signal_trees @@ -186,6 +197,8 @@ class ClassificationProject(object): if earlystopping_opts is None: earlystopping_opts = dict() self.earlystopping_opts = earlystopping_opts + self.random_seed = random_seed + self.balance_dataset = balance_dataset self.project_dir = project_dir if self.project_dir is None: @@ -194,8 +207,6 @@ class ClassificationProject(object): if not os.path.exists(self.project_dir): os.mkdir(self.project_dir) - self.random_seed = random_seed - self.s_train = None self.b_train = None self.s_test = None @@ -210,6 +221,9 @@ class ClassificationProject(object): self._scores_train = None self._scores_test = None + # class weighted validation data + self._w_validation = None + self._s_eventlist_train = None self._b_eventlist_train = None @@ -550,6 +564,54 @@ class ClassificationProject(object): np.random.shuffle(self._scores_train) + @property + def w_validation(self): + "class weighted validation data" + split_index = int((1-self.validation_split)*len(self.x_train)) + if self._w_validation is None: + self._w_validation = np.array(self.w_train[split_index:]) + self._w_validation[self.y_train[split_index:]==0] *= self.class_weight[0] + self._w_validation[self.y_train[split_index:]==1] *= self.class_weight[1] + return self._w_validation + + + @property + def class_weighted_validation_data(self): + split_index = int((1-self.validation_split)*len(self.x_train)) + return self.x_train[split_index:], self.y_train[split_index:], self.w_validation + + + @property + def training_data(self): + "training data with validation data split off" + split_index = int((1-self.validation_split)*len(self.x_train)) + return self.x_train[:split_index], self.y_train[:split_index], self.w_train[:split_index] + + + def yield_batch(self, class_label): + while True: + x_train, y_train, w_train = self.training_data + # shuffle the entries for this class label + rn_state = np.random.get_state() + x_train[y_train==class_label] = np.random.permutation(x_train[y_train==class_label]) + np.random.set_state(rn_state) + w_train[y_train==class_label] = np.random.permutation(w_train[y_train==class_label]) + # yield them batch wise + for start in range(0, len(x_train[y_train==class_label]), self.batch_size): + yield (x_train[y_train==class_label][start:start+self.batch_size], + y_train[y_train==class_label][start:start+self.batch_size], + w_train[y_train==class_label][start:start+self.batch_size]) + # restart + + + def yield_balanced_batch(self): + "generate batches with equal amounts of both classes" + for batch_0, batch_1 in izip(self.yield_batch(0), self.yield_batch(1)): + yield (np.concatenate((batch_0[0], batch_1[0])), + np.concatenate((batch_0[1], batch_1[1])), + np.concatenate((batch_0[2], batch_1[2]))) + + def train(self, epochs=10): self.load() @@ -560,22 +622,36 @@ class ClassificationProject(object): self.total_epochs = self._read_info("epochs", 0) logger.info("Train model") - try: - self.shuffle_training_data() - self.is_training = True - self.model.fit(self.x_train, - # the reshape might be unnescessary here - self.y_train.reshape(-1, 1), - epochs=epochs, - validation_split = self.validation_split, - class_weight=self.class_weight, - sample_weight=self.w_train, - shuffle=True, - batch_size=self.batch_size, - callbacks=self.callbacks_list) - self.is_training = False - except KeyboardInterrupt: - logger.info("Interrupt training - continue with rest") + if not self.balance_dataset: + try: + self.shuffle_training_data() + self.is_training = True + self.model.fit(self.x_train, + # the reshape might be unnescessary here + self.y_train.reshape(-1, 1), + epochs=epochs, + validation_split = self.validation_split, + class_weight=self.class_weight, + sample_weight=self.w_train, + shuffle=True, + batch_size=self.batch_size, + callbacks=self.callbacks_list) + self.is_training = False + except KeyboardInterrupt: + logger.info("Interrupt training - continue with rest") + else: + try: + self.is_training = True + labels, label_counts = np.unique(self.y_train, return_counts=True) + logger.info("Training on balanced batches") + self.model.fit_generator(self.yield_balanced_batch(), + steps_per_epoch=int(min(label_counts)/self.batch_size), + epochs=epochs, + validation_data=self.class_weighted_validation_data, + callbacks=self.callbacks_list) + self.is_training = False + except KeyboardInterrupt: + logger.info("Interrupt training - continue with rest") logger.info("Save history") self._dump_history() -- GitLab