diff --git a/toolkit.py b/toolkit.py index 3d862f5d27f7f36062635a5d906bff5e15bb0a07..d708ee11ac90375624ffc3b7fb0218562d73d677 100755 --- a/toolkit.py +++ b/toolkit.py @@ -92,6 +92,8 @@ class KerasROOTClassification(object): :param optimizer_opts: dictionary of options for the optimizer + :param use_earlystopping: set true to use the keras EarlyStopping callback + :param earlystopping_opts: options for the keras EarlyStopping callback :param random_seed: use this seed value when initialising the model and produce consistent results. Note: @@ -138,6 +140,7 @@ class KerasROOTClassification(object): step_bkg=2, optimizer="SGD", optimizer_opts=None, + use_earlystopping=True, earlystopping_opts=None, random_seed=1234): @@ -158,6 +161,7 @@ class KerasROOTClassification(object): self.step_signal = step_signal self.step_bkg = step_bkg self.optimizer = optimizer + self.use_earlystopping = use_earlystopping if optimizer_opts is None: optimizer_opts = dict() self.optimizer_opts = optimizer_opts @@ -332,8 +336,8 @@ class KerasROOTClassification(object): def callbacks_list(self): if not self._callbacks_list: self._callbacks_list.append(self.history) - self._callbacks_list.append(EarlyStopping(**self.earlystopping_opts)) - + if self.use_earlystopping: + self._callbacks_list.append(EarlyStopping(**self.earlystopping_opts)) return self._callbacks_list