diff --git a/toolkit.py b/toolkit.py index a0d3a1c04c798f6722385772088f4684f74bfb81..9573f262c5e005f15e578e0d84ec9be621994143 100755 --- a/toolkit.py +++ b/toolkit.py @@ -93,6 +93,8 @@ class ClassificationProject(object): :param optimizer_opts: dictionary of options for the optimizer + :param use_earlystopping: set true to use the keras EarlyStopping callback + :param earlystopping_opts: options for the keras EarlyStopping callback :param random_seed: use this seed value when initialising the model and produce consistent results. Note: @@ -140,6 +142,7 @@ class ClassificationProject(object): step_bkg=2, optimizer="SGD", optimizer_opts=None, + use_earlystopping=True, earlystopping_opts=None, random_seed=1234): @@ -159,6 +162,7 @@ class ClassificationProject(object): self.step_signal = step_signal self.step_bkg = step_bkg self.optimizer = optimizer + self.use_earlystopping = use_earlystopping if optimizer_opts is None: optimizer_opts = dict() self.optimizer_opts = optimizer_opts @@ -332,7 +336,8 @@ class ClassificationProject(object): def callbacks_list(self): self._callbacks_list = [] self._callbacks_list.append(self.history) - self._callbacks_list.append(EarlyStopping(**self.earlystopping_opts)) + if self.use_earlystopping: + self._callbacks_list.append(EarlyStopping(**self.earlystopping_opts)) self._callbacks_list.append(CSVLogger(os.path.join(self.project_dir, "training.log"), append=True)) return self._callbacks_list