diff --git a/toolkit.py b/toolkit.py
index ec074bdf0c9e06e32de108a40fbb0e9641177488..a97a84671840ffb675f8962dab6784cc77f004c4 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -1164,17 +1164,51 @@ class ClassificationProject(object):
         return centers, hist, errors
 
 
-    def plot_input(self, var_index, ax=None):
-        "plot a single input variable"
+    def plot_input(self, var_index, ax=None, from_training_batches=False, max_n_batches=None):
+        """
+        plot a single input variable as a histogram (signal vs background)
+
+        :param from_training_batches: use data from training batch generator
+        :param max_n_batches: if training batch generator is used, just use
+                              this number of batches (otherwise steps_per_epoch is used)
+        """
         branch = self.fields[var_index]
         if ax is None:
             fig, ax = plt.subplots()
         else:
             fig = None
-        bkg = self.x_train[:,var_index][self.y_train == 0]
-        sig = self.x_train[:,var_index][self.y_train == 1]
-        bkg_weights = self.w_train_tot[self.y_train == 0]
-        sig_weights = self.w_train_tot[self.y_train == 1]
+
+        if not from_training_batches:
+            bkg = self.x_train[:,var_index][self.y_train == 0]
+            sig = self.x_train[:,var_index][self.y_train == 1]
+            bkg_weights = self.w_train_tot[self.y_train == 0]
+            sig_weights = self.w_train_tot[self.y_train == 1]
+        else:
+            bkg = None
+            sig = None
+            bkg_weights = None
+            sig_weights = None
+            if max_n_batches is not None:
+                n_batches = max_n_batches
+            else:
+                n_batches = self.steps_per_epoch
+            for i_batch, (x, y, w) in enumerate(self.yield_batch()):
+                if i_batch > n_batches:
+                    break
+                bkg_batch = x[:,var_index][y==0]
+                sig_batch = x[:,var_index][y==1]
+                bkg_weights_batch = w[y==0]
+                sig_weights_batch = w[y==1]
+                if bkg is None:
+                    bkg = bkg_batch
+                    sig = sig_batch
+                    bkg_weights = bkg_weights_batch
+                    sig_weights = sig_weights_batch
+                else:
+                    bkg = np.concatenate([bkg, bkg_batch])
+                    sig = np.concatenate([sig, sig_batch])
+                    bkg_weights = np.concatenate([bkg_weights, bkg_weights_batch])
+                    sig_weights = np.concatenate([sig_weights, sig_weights_batch])
 
         if hasattr(self, "mask_value"):
             bkg_not_masked = np.where(bkg != self.mask_value)[0]
@@ -1238,13 +1272,13 @@ class ClassificationProject(object):
             return save_show(plt, fig, os.path.join(plot_dir, "var_{}.pdf".format(var_index)))
 
 
-    def plot_all_inputs(self):
+    def plot_all_inputs(self, **kwargs):
         nrows = math.ceil(math.sqrt(len(self.fields)))
         fig, axes = plt.subplots(nrows=int(nrows), ncols=int(nrows),
                                  figsize=(3*nrows, 3*nrows),
                                  gridspec_kw=dict(wspace=0.4, hspace=0.4))
         for i in range(len(self.fields)):
-            self.plot_input(i, ax=axes.reshape(-1)[i])
+            self.plot_input(i, ax=axes.reshape(-1)[i], **kwargs)
         return save_show(plt, fig, os.path.join(self.project_dir, "all_inputs.pdf"))