diff --git a/compare.py b/compare.py
index 7e4c9a86870dd1b01b4e733d0b9164929b049758..9ebcb65a56c001bededa325b20a411a6712e9f40 100755
--- a/compare.py
+++ b/compare.py
@@ -8,6 +8,7 @@ import matplotlib.pyplot as plt
 from sklearn.metrics import roc_curve, auc
 
 from .toolkit import ClassificationProject
+from .plotting import save_show
 
 """
 A few functions to compare different setups
@@ -62,8 +63,7 @@ def overlay_ROC(filename, *projects, **kwargs):
     if plot_thresholds:
         # to fit right y-axis description
         fig.tight_layout()
-    fig.savefig(filename)
-    plt.close(fig)
+    return save_show(plt, fig, filename)
 
 def overlay_loss(filename, *projects, **kwargs):
 
@@ -78,22 +78,23 @@ def overlay_loss(filename, *projects, **kwargs):
     prop_cycle = plt.rcParams['axes.prop_cycle']
     colors = prop_cycle.by_key()['color']
 
+    fig, ax = plt.subplots()
+
     for p,color in zip(projects,colors):
         hist_dict = p.csv_hist
-        plt.plot(hist_dict['loss'], linestyle='--', label="Training Loss "+p.name, color=color)
-        plt.plot(hist_dict['val_loss'], label="Validation Loss "+p.name, color=color)
+        ax.plot(hist_dict['loss'], linestyle='--', label="Training Loss "+p.name, color=color)
+        ax.plot(hist_dict['val_loss'], label="Validation Loss "+p.name, color=color)
 
-    plt.ylabel('loss')
-    plt.xlabel('epoch')
+    ax.set_ylabel('loss')
+    ax.set_xlabel('epoch')
     if log:
-        plt.yscale("log")
+        ax.set_yscale("log")
     if xlim is not None:
-        plt.xlim(*xlim)
+        ax.set_xlim(*xlim)
     if ylim is not None:
-        plt.ylim(*ylim)
-    plt.legend(loc='upper right')
-    plt.savefig(filename)
-    plt.clf()
+        ax.set_ylim(*ylim)
+    ax.legend(loc='upper right')
+    return save_show(plt, fig, filename)
 
 
 
diff --git a/keras_visualize_activations/read_activations.py b/keras_visualize_activations/read_activations.py
index 0e4641ac32bfcd48aa42cefb676d021e27697310..053439b84ebb50be638acb9d158694e2f3c1d9aa 100644
--- a/keras_visualize_activations/read_activations.py
+++ b/keras_visualize_activations/read_activations.py
@@ -1,5 +1,7 @@
 import keras.backend as K
 
+from keras.engine.input_layer import InputLayer
+from keras.layers.core import Masking
 
 def get_activations(model, model_inputs, print_shape_only=False, layer_name=None):
     print('----- activations -----')
@@ -12,8 +14,12 @@ def get_activations(model, model_inputs, print_shape_only=False, layer_name=None
         inp = [inp]
         model_multi_inputs_cond = False
 
+    # all layer outputs
+    # skip input and masking layers
     outputs = [layer.output for layer in model.layers if
-               layer.name == layer_name or layer_name is None]  # all layer outputs
+               (layer.name == layer_name or layer_name is None)
+               and not isinstance(layer, InputLayer)
+               and not isinstance(layer, Masking)]
 
     funcs = [K.function(inp + [K.learning_phase()], [out]) for out in outputs]  # evaluation functions
 
diff --git a/plotting.py b/plotting.py
index e79b0d52d4a9e5fc3ba3a9c5e5cc338bdbc3cfe8..749d2a9d6690e4859df8c55dc1cad79a9f6778fa 100644
--- a/plotting.py
+++ b/plotting.py
@@ -20,8 +20,27 @@ logger.addHandler(logging.NullHandler())
 Some further plotting functions
 """
 
-def get_mean_event(x, y, class_label):
-    return [np.mean(x[y==class_label][:,var_index]) for var_index in range(x.shape[1])]
+def save_show(plt, fig, filename):
+    "Save a figure and show it in case we are in ipython or jupyter notebook."
+    fig.savefig(filename)
+    try:
+        get_ipython
+        plt.show()
+        return fig
+    except NameError:
+        plt.close(fig)
+        return None
+
+
+def get_mean_event(x, y, class_label, mask_value=None):
+    means = []
+    for var_index in range(x.shape[1]):
+        if mask_value is not None:
+            masked_values = np.where(x[:,var_index] == mask_value)[0]
+            x = x[masked_values]
+            y = y[masked_values]
+        means.append(np.mean(x[y==class_label][:,var_index]))
+    return means
 
 
 def plot_NN_vs_var_1D(plotname, means, scorefun, var_index, var_range, var_label=None):
diff --git a/scripts/plot_NN_2D.py b/scripts/plot_NN_2D.py
index 58dbc7d92413f95ae25a0bb2d738aa5705580d97..460934daca75730d6f3d0df7b2d881af6c1595af 100755
--- a/scripts/plot_NN_2D.py
+++ b/scripts/plot_NN_2D.py
@@ -11,7 +11,7 @@ import ROOT
 ROOT.gROOT.SetBatch()
 ROOT.PyConfig.IgnoreCommandLineOptions = True
 
-from KerasROOTClassification import ClassificationProject, load_from_dir
+from KerasROOTClassification import load_from_dir
 from KerasROOTClassification.plotting import (
     get_mean_event,
     plot_NN_vs_var_2D,
@@ -73,13 +73,21 @@ else:
 varx_label = args.varx
 vary_label = args.vary
 
-# percentilesx = np.percentile(c.x_test[:,varx_index], [1,99])
-# percentilesy = np.percentile(c.x_test[:,vary_index], [1,99])
-
 total_weights = c.w_test*np.array(c.class_weight)[c.y_test.astype(int)]
 
-percentilesx = weighted_quantile(c.x_test[:,varx_index], [0.1, 0.99], sample_weight=total_weights)
-percentilesy = weighted_quantile(c.x_test[:,vary_index], [0.1, 0.99], sample_weight=total_weights)
+try:
+    mask_value = c.mask_value
+except NameError:
+    mask_value = None
+
+varx_test = c.x_test[:,varx_index]
+vary_test = c.x_test[:,vary_index]
+
+x_not_masked = np.where(varx_test != mask_value)[0]
+y_not_masked = np.where(vary_test != mask_value)[0]
+
+percentilesx = weighted_quantile(varx_test[x_not_masked], [0.1, 0.99], sample_weight=total_weights[x_not_masked])
+percentilesy = weighted_quantile(vary_test[y_not_masked], [0.1, 0.99], sample_weight=total_weights[y_not_masked])
 
 if args.xrange is not None:
     if len(args.xrange) < 3:
@@ -100,9 +108,11 @@ else:
 if args.mode.startswith("mean"):
 
     if args.mode == "mean_sig":
-        means = get_mean_event(c.x_test, c.y_test, 1)
+        means = get_mean_event(c.x_test, c.y_test, 1, mask_value=mask_value)
     elif args.mode == "mean_bkg":
-        means = get_mean_event(c.x_test, c.y_test, 0)
+        means = get_mean_event(c.x_test, c.y_test, 0, mask_value=mask_value)
+
+    print(means)
 
     if hasattr(c, "get_input_list"):
         input_transform = c.get_input_list
@@ -126,11 +136,15 @@ if args.mode.startswith("mean"):
             logscale=args.log, only_pixels=(not args.contour)
         )
     else:
+        if hasattr(c, "get_input_list"):
+            transform_function = lambda inp : c.get_input_list(c.scaler.transform(inp))
+        else:
+            transform_function = c.scaler.transform
         plot_NN_vs_var_2D_all(
             args.output_filename,
             means=means,
             model=c.model,
-            transform_function=c.scaler.transform,
+            transform_function=transform_function,
             varx_index=varx_index,
             vary_index=vary_index,
             xmin=varx_range[0], xmax=varx_range[1], nbinsx=varx_range[2],
diff --git a/toolkit.py b/toolkit.py
index 29adbfb84cef516a9dc6eb3259b5a60690f75655..8779db4810fd803a0811bdb0e2de1f4f1ca5ae57 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -42,6 +42,7 @@ from keras import backend as K
 import matplotlib.pyplot as plt
 
 from .utils import WeightedRobustScaler, weighted_quantile, poisson_asimov_significance
+from .plotting import save_show
 
 # configure number of cores
 # this doesn't seem to work, but at least with these settings keras only uses 4 processes
@@ -348,6 +349,7 @@ class ClassificationProject(object):
 
         self.data_loaded = False
         self.data_transformed = False
+        self.data_shuffled = False
 
         # track if we are currently training
         self.is_training = False
@@ -447,6 +449,7 @@ class ClassificationProject(object):
             self._dump_to_hdf5(*self.dataset_names_tree)
 
         self.data_loaded = True
+        self.data_shuffled = False
 
 
     def _dump_training_list(self):
@@ -771,6 +774,7 @@ class ClassificationProject(object):
             logger.info("Shuffling scores, since they are also there")
             np.random.set_state(rn_state)
             np.random.shuffle(self._scores_train)
+        self.data_shuffled = True
 
 
     @property
@@ -789,6 +793,8 @@ class ClassificationProject(object):
     @property
     def validation_data(self):
         "Validation data. Attention: Shuffle training data before using this!"
+        if not self.data_shuffled:
+            raise ValueError("Training data isn't shuffled, can't split of validation data")
         split_index = int((1-self.validation_split)*len(self.x_train))
         return self.x_train[split_index:], self.y_train[split_index:], self.w_train_tot[split_index:]
 
@@ -796,6 +802,8 @@ class ClassificationProject(object):
     @property
     def training_data(self):
         "Training data with validation data split off. Attention: Shuffle training data before using this!"
+        if not self.data_shuffled:
+            raise ValueError("Training data isn't shuffled, can't split of validation data")
         split_index = int((1-self.validation_split)*len(self.x_train))
         return self.x_train[:split_index], self.y_train[:split_index], self.w_train_tot[:split_index]
 
@@ -908,10 +916,10 @@ class ClassificationProject(object):
 
         logger.info("Create/Update scores for train/test sample")
         if do_test:
-            self.scores_test = self.predict(self.x_test, mode=mode)
+            self.scores_test = self.predict(self.x_test, mode=mode).reshape(-1)
             self._dump_to_hdf5("scores_test")
         if do_train:
-            self.scores_train = self.predict(self.x_train, mode=mode)
+            self.scores_train = self.predict(self.x_train, mode=mode).reshape(-1)
             self._dump_to_hdf5("scores_train")
 
 
@@ -1012,15 +1020,26 @@ class ClassificationProject(object):
         return centers, hist, errors
 
 
-    def plot_input(self, var_index):
+    def plot_input(self, var_index, ax=None):
         "plot a single input variable"
         branch = self.fields[var_index]
-        fig, ax = plt.subplots()
+        if ax is None:
+            fig, ax = plt.subplots()
+        else:
+            fig = None
         bkg = self.x_train[:,var_index][self.y_train == 0]
         sig = self.x_train[:,var_index][self.y_train == 1]
         bkg_weights = self.w_train_tot[self.y_train == 0]
         sig_weights = self.w_train_tot[self.y_train == 1]
 
+        if hasattr(self, "mask_value"):
+            bkg_not_masked = np.where(bkg != self.mask_value)[0]
+            sig_not_masked = np.where(sig != self.mask_value)[0]
+            bkg = bkg[bkg_not_masked]
+            sig = sig[sig_not_masked]
+            bkg_weights = bkg_weights[bkg_not_masked]
+            sig_weights = sig_weights[sig_not_masked]
+
         if self.balance_dataset:
             if len(sig) < len(bkg):
                 logger.warning("Plotting only up to {} bkg events, since we use balance_dataset".format(len(sig)))
@@ -1035,33 +1054,56 @@ class ClassificationProject(object):
         logger.debug("Plotting sig (min={}, max={}) from {}".format(np.min(sig), np.max(sig), sig))
 
         # calculate percentiles to get a heuristic for the range to be plotted
+        x_total = np.concatenate([bkg, sig])
+        w_total = np.concatenate([bkg_weights, sig_weights])
         plot_range = weighted_quantile(
-            self.x_train[:,var_index], [0.01, 0.99],
-            sample_weight=self.w_train_tot
+            x_total,
+            [0.01, 0.99],
+            sample_weight=w_total,
         )
-
         logger.debug("Calculated range based on percentiles: {}".format(plot_range))
 
+        bins = 50
+
+        # check if we have a distribution of integer numbers (e.g. njet or something categorical)
+        # in that case we want to have a bin for each number
+        if (x_total == x_total.astype(int)).all():
+            plot_range = (math.floor(plot_range[0])-0.5, math.ceil(plot_range[1])+0.5)
+            bins = int(plot_range[1]-plot_range[0])
+
         try:
-            centers_sig, hist_sig, _ = self.get_bin_centered_hist(sig, bins=50, range=plot_range, weights=sig_weights)
-            centers_bkg, hist_bkg, _ = self.get_bin_centered_hist(bkg, bins=50, range=plot_range, weights=bkg_weights)
+            centers_sig, hist_sig, _ = self.get_bin_centered_hist(sig, bins=bins, range=plot_range, weights=sig_weights)
+            centers_bkg, hist_bkg, _ = self.get_bin_centered_hist(bkg, bins=bins, range=plot_range, weights=bkg_weights)
         except ValueError:
             # weird, probably not always working workaround for a numpy bug
             plot_range = (float("{:.3f}".format(plot_range[0])), float("{:.3f}".format(plot_range[1])))
             logger.warn("Got a value error during plotting, maybe this is due to a numpy bug - changing range to {}".format(plot_range))
-            centers_sig, hist_sig, _ = self.get_bin_centered_hist(sig, bins=50, range=plot_range, weights=sig_weights)
-            centers_bkg, hist_bkg, _ = self.get_bin_centered_hist(bkg, bins=50, range=plot_range, weights=bkg_weights)
+            centers_sig, hist_sig, _ = self.get_bin_centered_hist(sig, bins=bins, range=plot_range, weights=sig_weights)
+            centers_bkg, hist_bkg, _ = self.get_bin_centered_hist(bkg, bins=bins, range=plot_range, weights=bkg_weights)
 
         width = centers_sig[1]-centers_sig[0]
         ax.bar(centers_bkg, hist_bkg, color="b", alpha=0.5, width=width)
         ax.bar(centers_sig, hist_sig, color="r", alpha=0.5, width=width)
 
-        ax.set_xlabel(branch+" (transformed)")
-        plot_dir = os.path.join(self.project_dir, "plots")
-        if not os.path.exists(plot_dir):
-            os.mkdir(plot_dir)
-        fig.savefig(os.path.join(plot_dir, "var_{}.pdf".format(var_index)))
-        plt.close(fig)
+        label = branch
+        if self.data_transformed:
+            label += " (transformed)"
+        ax.set_xlabel(label)
+        if fig is not None:
+            plot_dir = os.path.join(self.project_dir, "plots")
+            if not os.path.exists(plot_dir):
+                os.mkdir(plot_dir)
+            return save_show(plt, fig, os.path.join(plot_dir, "var_{}.pdf".format(var_index)))
+
+
+    def plot_all_inputs(self):
+        nrows = math.ceil(math.sqrt(len(self.fields)))
+        fig, axes = plt.subplots(nrows=int(nrows), ncols=int(nrows),
+                                 figsize=(3*nrows, 3*nrows),
+                                 gridspec_kw=dict(wspace=0.4, hspace=0.4))
+        for i in range(len(self.fields)):
+            self.plot_input(i, ax=axes.reshape(-1)[i])
+        return save_show(plt, fig, os.path.join(self.project_dir, "all_inputs.pdf"))
 
 
     def plot_weights(self, bins=100, range=None):
@@ -1070,19 +1112,19 @@ class ClassificationProject(object):
         sig = self.w_train_tot[self.y_train == 1]
         ax.hist(bkg, bins=bins, range=range, color="b", alpha=0.5)
         ax.set_yscale("log")
-        fig.savefig(os.path.join(self.project_dir, "eventweights_bkg.pdf"))
-        plt.close(fig)
+        return save_show(plt, fig, os.path.join(self.project_dir, "eventweights_bkg.pdf"))
         fig, ax = plt.subplots()
         ax.hist(sig, bins=bins, range=range, color="r", alpha=0.5)
         ax.set_yscale("log")
-        fig.savefig(os.path.join(self.project_dir, "eventweights_sig.pdf"))
-        plt.close(fig)
+        return save_show(plt, fig, os.path.join(self.project_dir, "eventweights_sig.pdf"))
 
 
     def plot_ROC(self, xlim=(0,1), ylim=(0,1)):
 
         logger.info("Plot ROC curve")
-        plt.grid(color='gray', linestyle='--', linewidth=1)
+
+        fig, ax = plt.subplots()
+        ax.grid(color='gray', linestyle='--', linewidth=1)
 
         for y, scores, weight, label in [
                 (self.y_train, self.scores_train, self.w_train, "train"),
@@ -1095,23 +1137,28 @@ class ClassificationProject(object):
             except ValueError:
                 logger.warning("Got a value error from auc - trying to rerun with reorder=True")
                 roc_auc = auc(tpr, fpr, reorder=True)
-            plt.plot(tpr,  fpr, label=str(self.name + " {} (AUC = {:.3f})".format(label, roc_auc)))
-
-        plt.plot([0,1],[1,0], linestyle='--', color='black', label='Luck')
-        plt.ylabel("Background rejection")
-        plt.xlabel("Signal efficiency")
-        plt.title('Receiver operating characteristic')
-        plt.xlim(*xlim)
-        plt.ylim(*ylim)
+            ax.plot(tpr,  fpr, label=str(self.name + " {} (AUC = {:.3f})".format(label, roc_auc)))
+
+        ax.plot([0,1],[1,0], linestyle='--', color='black', label='Luck')
+        ax.set_ylabel("Background rejection")
+        ax.set_xlabel("Signal efficiency")
+        ax.set_title('Receiver operating characteristic')
+        ax.set_xlim(*xlim)
+        ax.set_ylim(*ylim)
         # plt.xticks(np.arange(0,1,0.1))
         # plt.yticks(np.arange(0,1,0.1))
-        plt.legend(loc='lower left', framealpha=1.0)
-        plt.savefig(os.path.join(self.project_dir, "ROC.pdf"))
-        plt.clf()
+        ax.legend(loc='lower left', framealpha=1.0)
+        return save_show(plt, fig, os.path.join(self.project_dir, "ROC.pdf"))
 
 
-    def plot_score(self, log=True, plot_opts=dict(bins=50, range=(0, 1)),
-                   ylim=None, xlim=None, density=True, lumifactor=None, apply_class_weight=True):
+    def plot_score(self, log=True, plot_opts=dict(bins=50, range=(0,1)),
+                   ylim=None, xlim=None, density=True,
+                   lumifactor=None, apply_class_weight=True,
+                   invert_activation=False):
+        if invert_activation:
+            trf = self.get_inverse_act_fn()
+        else:
+            trf = lambda y : y
         fig, ax = plt.subplots()
         for scores, weights, y, class_label, fn, opts in [
                 (self.scores_train, self.w_train, self.y_train, 1, ax.bar, dict(color="r", label="signal train")),
@@ -1127,7 +1174,7 @@ class ClassificationProject(object):
             if lumifactor is not None:
                 weights = weights*lumifactor
             centers, hist, rel_errors = self.get_bin_centered_hist(
-                scores[y==class_label].reshape(-1),
+                trf(scores[y==class_label].reshape(-1)),
                 weights=weights,
                 **plot_opts
             )
@@ -1154,11 +1201,10 @@ class ClassificationProject(object):
         if apply_class_weight:
             ax.set_title("Class weights applied")
         ax.legend(loc='upper center', framealpha=0.5)
-        fig.savefig(os.path.join(self.project_dir, "scores.pdf"))
-        plt.close(fig)
+        return save_show(plt, fig, os.path.join(self.project_dir, "scores.pdf"))
 
 
-    def plot_significance_hist(self, lumifactor=1., significance_function=None, plot_opts=dict(bins=50, range=(0, 1))):
+    def plot_significance_hist(self, lumifactor=1., significance_function=None, plot_opts=dict(bins=50, range=(0, 1)), invert_activation=False):
 
         """
         Plot significances based on a histogram of scores
@@ -1166,10 +1212,15 @@ class ClassificationProject(object):
 
         logger.info("Plot significances")
 
-        centers_sig_train, hist_sig_train, rel_errors_sig_train = self.get_bin_centered_hist(self.scores_train[self.y_train==1].reshape(-1), weights=self.w_train[self.y_train==1], **plot_opts)
-        centers_bkg_train, hist_bkg_train, rel_errors_bkg_train = self.get_bin_centered_hist(self.scores_train[self.y_train==0].reshape(-1), weights=self.w_train[self.y_train==0], **plot_opts)
-        centers_sig_test, hist_sig_test, rel_errors_sig_test = self.get_bin_centered_hist(self.scores_test[self.y_test==1].reshape(-1), weights=self.w_test[self.y_test==1], **plot_opts)
-        centers_bkg_test, hist_bkg_test, rel_errors_bkg_test = self.get_bin_centered_hist(self.scores_test[self.y_test==0].reshape(-1), weights=self.w_test[self.y_test==0], **plot_opts)
+        if invert_activation:
+            trf = self.get_inverse_act_fn()
+        else:
+            trf = lambda y : y
+
+        centers_sig_train, hist_sig_train, rel_errors_sig_train = self.get_bin_centered_hist(trf(self.scores_train[self.y_train==1].reshape(-1)), weights=self.w_train[self.y_train==1], **plot_opts)
+        centers_bkg_train, hist_bkg_train, rel_errors_bkg_train = self.get_bin_centered_hist(trf(self.scores_train[self.y_train==0].reshape(-1)), weights=self.w_train[self.y_train==0], **plot_opts)
+        centers_sig_test, hist_sig_test, rel_errors_sig_test = self.get_bin_centered_hist(trf(self.scores_test[self.y_test==1].reshape(-1)), weights=self.w_test[self.y_test==1], **plot_opts)
+        centers_bkg_test, hist_bkg_test, rel_errors_bkg_test = self.get_bin_centered_hist(trf(self.scores_test[self.y_test==0].reshape(-1)), weights=self.w_test[self.y_test==0], **plot_opts)
 
         significances_train = []
         significances_test = []
@@ -1210,8 +1261,7 @@ class ClassificationProject(object):
         ax.set_xlabel("Cut on NN score")
         ax.set_ylabel("Significance")
         ax.legend(loc='lower center', framealpha=0.5)
-        fig.savefig(os.path.join(self.project_dir, "significances_hist.pdf"))
-        plt.close(fig)
+        return save_show(plt, fig, os.path.join(self.project_dir, "significances_hist.pdf"))
 
 
     @staticmethod
@@ -1238,7 +1288,15 @@ class ClassificationProject(object):
         return s_sumw, np.sqrt(s_sumw2), b_sumw, np.sqrt(b_sumw2), scores_sorted[threshold_idxs]
 
 
-    def plot_significance(self, significance_function=None, maxsteps=1000, lumifactor=1., vectorized=False):
+    def get_inverse_act_fn(self):
+        if not self.activation_function_output == "sigmoid":
+            raise NotImplementedError("Inverse function of {} not supported yet - "
+                                      "currently only sigmoid"
+                                      .format(self.activation_function_output))
+        return lambda y : np.log(y/(1-y))
+
+
+    def plot_significance(self, significance_function=None, maxsteps=1000, lumifactor=1., vectorized=False, invert_activation=False):
         """
         Plot the significance when cutting on all posible thresholds and plot against signal efficiency.
         """
@@ -1247,6 +1305,11 @@ class ClassificationProject(object):
             vectorized = True
             significance_function = poisson_asimov_significance
 
+        if invert_activation:
+            trf = self.get_inverse_act_fn()
+        else:
+            trf = lambda y : y
+
         fig, ax = plt.subplots()
         ax2 = ax.twinx()
         prop_cycle = plt.rcParams['axes.prop_cycle']
@@ -1256,6 +1319,7 @@ class ClassificationProject(object):
                  (self.scores_test, self.y_test, self.w_test, "test")],
                 colors
         ):
+            scores = trf(scores)
             s_sumws, s_errs, b_sumws, b_errs, thresholds = self.calc_s_ds_b_db(scores, y, w)
             stepsize = int(len(s_sumws))/int(maxsteps)
             if stepsize == 0:
@@ -1283,8 +1347,7 @@ class ClassificationProject(object):
         ax.set_xlim(0, 1)
         ax2.set_ylabel("Threshold")
         ax.legend()
-        fig.savefig(os.path.join(self.project_dir, "significances.pdf"))
-        plt.close(fig)
+        return save_show(plt, fig, os.path.join(self.project_dir, "significances.pdf"))
 
 
     @property
@@ -1624,13 +1687,14 @@ class ClassificationProjectRNN(ClassificationProject):
     def train(self, epochs=10):
         self.load()
 
+        self.shuffle_training_data()
+
         for branch_index, branch in enumerate(self.fields):
             self.plot_input(branch_index)
 
         self.total_epochs = self._read_info("epochs", 0)
 
         try:
-            self.shuffle_training_data() # needed here too, in order to get correct validation data
             self.is_training = True
             logger.info("Training on batches for RNN")
             # note: the batches have class_weight already applied
diff --git a/utils.py b/utils.py
index 8bdee4db37f6c960cda30459369cf61eec1c7894..39ccc4865990a98c6afcf75fd086915849049f13 100644
--- a/utils.py
+++ b/utils.py
@@ -15,7 +15,11 @@ logger.addHandler(logging.NullHandler())
 
 def get_single_neuron_function(model, layer, neuron, input_transform=None):
 
-    f = K.function([model.input]+[K.learning_phase()], [model.layers[layer].output[:,neuron]])
+    inp = model.input
+    if not isinstance(inp, list):
+        inp = [inp]
+
+    f = K.function(inp+[K.learning_phase()], [model.layers[layer].output[:,neuron]])
 
     def eval_single_neuron(x):
         if input_transform is not None: