From 54d5cf3a4a15bdfacc68651c55376e2afb8202b5 Mon Sep 17 00:00:00 2001
From: Nikolai <osterei33@gmx.de>
Date: Thu, 17 May 2018 07:55:36 +0200
Subject: [PATCH] Adding option to train on batches with equal number of events
 of both classes

---
 toolkit.py | 114 ++++++++++++++++++++++++++++++++++++++++++++---------
 1 file changed, 95 insertions(+), 19 deletions(-)

diff --git a/toolkit.py b/toolkit.py
index dbec934..5593a29 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -4,6 +4,9 @@ from sys import version_info
 
 if version_info[0] > 2:
     raw_input = input
+    izip = zip
+else:
+    from itertools import izip
 
 import os
 import json
@@ -108,6 +111,13 @@ class ClassificationProject(object):
 
     :param earlystopping_opts: options for the keras EarlyStopping callback
 
+    :param use_modelcheckpoint: save model weights after each epoch and don't save after no validation loss improvement
+
+    :param balance_dataset: if True, balance the dataset instead of
+                            applying class weights. Only a fraction of the overrepresented
+                            class will be used in each epoch, but different subsets of the
+                            overrepresented class will be used in each epoch.
+
     :param random_seed: use this seed value when initialising the model and produce consistent results. Note:
                         random data is also used for shuffling the training data, so results may vary still. To
                         produce consistent results, set the numpy random seed before training.
@@ -158,7 +168,8 @@ class ClassificationProject(object):
                         use_earlystopping=True,
                         earlystopping_opts=None,
                         use_modelcheckpoint=True,
-                        random_seed=1234):
+                        random_seed=1234,
+                        balance_dataset=False):
 
         self.name = name
         self.signal_trees = signal_trees
@@ -186,6 +197,8 @@ class ClassificationProject(object):
         if earlystopping_opts is None:
             earlystopping_opts = dict()
         self.earlystopping_opts = earlystopping_opts
+        self.random_seed = random_seed
+        self.balance_dataset = balance_dataset
 
         self.project_dir = project_dir
         if self.project_dir is None:
@@ -194,8 +207,6 @@ class ClassificationProject(object):
         if not os.path.exists(self.project_dir):
             os.mkdir(self.project_dir)
 
-        self.random_seed = random_seed
-
         self.s_train = None
         self.b_train = None
         self.s_test = None
@@ -210,6 +221,9 @@ class ClassificationProject(object):
         self._scores_train = None
         self._scores_test = None
 
+        # class weighted validation data
+        self._w_validation = None
+
         self._s_eventlist_train = None
         self._b_eventlist_train = None
 
@@ -550,6 +564,54 @@ class ClassificationProject(object):
             np.random.shuffle(self._scores_train)
 
 
+    @property
+    def w_validation(self):
+        "class weighted validation data"
+        split_index = int((1-self.validation_split)*len(self.x_train))
+        if self._w_validation is None:
+            self._w_validation = np.array(self.w_train[split_index:])
+            self._w_validation[self.y_train[split_index:]==0] *= self.class_weight[0]
+            self._w_validation[self.y_train[split_index:]==1] *= self.class_weight[1]
+        return self._w_validation
+
+
+    @property
+    def class_weighted_validation_data(self):
+        split_index = int((1-self.validation_split)*len(self.x_train))
+        return self.x_train[split_index:], self.y_train[split_index:], self.w_validation
+
+
+    @property
+    def training_data(self):
+        "training data with validation data split off"
+        split_index = int((1-self.validation_split)*len(self.x_train))
+        return self.x_train[:split_index], self.y_train[:split_index], self.w_train[:split_index]
+
+
+    def yield_batch(self, class_label):
+        while True:
+            x_train, y_train, w_train = self.training_data
+            # shuffle the entries for this class label
+            rn_state = np.random.get_state()
+            x_train[y_train==class_label] = np.random.permutation(x_train[y_train==class_label])
+            np.random.set_state(rn_state)
+            w_train[y_train==class_label] = np.random.permutation(w_train[y_train==class_label])
+            # yield them batch wise
+            for start in range(0, len(x_train[y_train==class_label]), self.batch_size):
+                yield (x_train[y_train==class_label][start:start+self.batch_size],
+                       y_train[y_train==class_label][start:start+self.batch_size],
+                       w_train[y_train==class_label][start:start+self.batch_size])
+            # restart
+
+
+    def yield_balanced_batch(self):
+        "generate batches with equal amounts of both classes"
+        for batch_0, batch_1 in izip(self.yield_batch(0), self.yield_batch(1)):
+            yield (np.concatenate((batch_0[0], batch_1[0])),
+                   np.concatenate((batch_0[1], batch_1[1])),
+                   np.concatenate((batch_0[2], batch_1[2])))
+
+
     def train(self, epochs=10):
 
         self.load()
@@ -560,22 +622,36 @@ class ClassificationProject(object):
         self.total_epochs = self._read_info("epochs", 0)
 
         logger.info("Train model")
-        try:
-            self.shuffle_training_data()
-            self.is_training = True
-            self.model.fit(self.x_train,
-                           # the reshape might be unnescessary here
-                           self.y_train.reshape(-1, 1),
-                           epochs=epochs,
-                           validation_split = self.validation_split,
-                           class_weight=self.class_weight,
-                           sample_weight=self.w_train,
-                           shuffle=True,
-                           batch_size=self.batch_size,
-                           callbacks=self.callbacks_list)
-            self.is_training = False
-        except KeyboardInterrupt:
-            logger.info("Interrupt training - continue with rest")
+        if not self.balance_dataset:
+            try:
+                self.shuffle_training_data()
+                self.is_training = True
+                self.model.fit(self.x_train,
+                               # the reshape might be unnescessary here
+                               self.y_train.reshape(-1, 1),
+                               epochs=epochs,
+                               validation_split = self.validation_split,
+                               class_weight=self.class_weight,
+                               sample_weight=self.w_train,
+                               shuffle=True,
+                               batch_size=self.batch_size,
+                               callbacks=self.callbacks_list)
+                self.is_training = False
+            except KeyboardInterrupt:
+                logger.info("Interrupt training - continue with rest")
+        else:
+            try:
+                self.is_training = True
+                labels, label_counts = np.unique(self.y_train, return_counts=True)
+                logger.info("Training on balanced batches")
+                self.model.fit_generator(self.yield_balanced_batch(),
+                                         steps_per_epoch=int(min(label_counts)/self.batch_size),
+                                         epochs=epochs,
+                                         validation_data=self.class_weighted_validation_data,
+                                         callbacks=self.callbacks_list)
+                self.is_training = False
+            except KeyboardInterrupt:
+                logger.info("Interrupt training - continue with rest")
 
         logger.info("Save history")
         self._dump_history()
-- 
GitLab