diff --git a/toolkit.py b/toolkit.py index be90cc62c13bd99b7da8c37c78e773c41187b6d9..b6ec1f929a8eab5235c7893a174384ba2fff84ff 100755 --- a/toolkit.py +++ b/toolkit.py @@ -1844,6 +1844,9 @@ class ClassificationProjectRNN(ClassificationProject): x_train, y_train, w_train = self.x_train, self.y_train, self.w_train_tot train_idx = list(self.train_val_idx[0]) np.random.seed(self.shuffle_seed+1) + logger.info("Generating training batches from {} signal and {} background events" + .format(len(np.where(self.y_train[train_idx]==1)[0]), + len(np.where(self.y_train[train_idx]==0)[0]))) while True: shuffled_idx = np.random.permutation(train_idx) for start in range(0, len(shuffled_idx), int(self.batch_size)):