Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
K
KerasROOTClassification
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container Registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Eric.Schanet
KerasROOTClassification
Commits
54d5cf3a
Commit
54d5cf3a
authored
6 years ago
by
Nikolai
Browse files
Options
Downloads
Patches
Plain Diff
Adding option to train on batches with equal number of events of both classes
parent
4bd7de4a
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
toolkit.py
+95
-19
95 additions, 19 deletions
toolkit.py
with
95 additions
and
19 deletions
toolkit.py
+
95
−
19
View file @
54d5cf3a
...
...
@@ -4,6 +4,9 @@ from sys import version_info
if
version_info
[
0
]
>
2
:
raw_input
=
input
izip
=
zip
else
:
from
itertools
import
izip
import
os
import
json
...
...
@@ -108,6 +111,13 @@ class ClassificationProject(object):
:param earlystopping_opts: options for the keras EarlyStopping callback
:param use_modelcheckpoint: save model weights after each epoch and don
'
t save after no validation loss improvement
:param balance_dataset: if True, balance the dataset instead of
applying class weights. Only a fraction of the overrepresented
class will be used in each epoch, but different subsets of the
overrepresented class will be used in each epoch.
:param random_seed: use this seed value when initialising the model and produce consistent results. Note:
random data is also used for shuffling the training data, so results may vary still. To
produce consistent results, set the numpy random seed before training.
...
...
@@ -158,7 +168,8 @@ class ClassificationProject(object):
use_earlystopping
=
True
,
earlystopping_opts
=
None
,
use_modelcheckpoint
=
True
,
random_seed
=
1234
):
random_seed
=
1234
,
balance_dataset
=
False
):
self
.
name
=
name
self
.
signal_trees
=
signal_trees
...
...
@@ -186,6 +197,8 @@ class ClassificationProject(object):
if
earlystopping_opts
is
None
:
earlystopping_opts
=
dict
()
self
.
earlystopping_opts
=
earlystopping_opts
self
.
random_seed
=
random_seed
self
.
balance_dataset
=
balance_dataset
self
.
project_dir
=
project_dir
if
self
.
project_dir
is
None
:
...
...
@@ -194,8 +207,6 @@ class ClassificationProject(object):
if
not
os
.
path
.
exists
(
self
.
project_dir
):
os
.
mkdir
(
self
.
project_dir
)
self
.
random_seed
=
random_seed
self
.
s_train
=
None
self
.
b_train
=
None
self
.
s_test
=
None
...
...
@@ -210,6 +221,9 @@ class ClassificationProject(object):
self
.
_scores_train
=
None
self
.
_scores_test
=
None
# class weighted validation data
self
.
_w_validation
=
None
self
.
_s_eventlist_train
=
None
self
.
_b_eventlist_train
=
None
...
...
@@ -550,6 +564,54 @@ class ClassificationProject(object):
np
.
random
.
shuffle
(
self
.
_scores_train
)
@property
def
w_validation
(
self
):
"
class weighted validation data
"
split_index
=
int
((
1
-
self
.
validation_split
)
*
len
(
self
.
x_train
))
if
self
.
_w_validation
is
None
:
self
.
_w_validation
=
np
.
array
(
self
.
w_train
[
split_index
:])
self
.
_w_validation
[
self
.
y_train
[
split_index
:]
==
0
]
*=
self
.
class_weight
[
0
]
self
.
_w_validation
[
self
.
y_train
[
split_index
:]
==
1
]
*=
self
.
class_weight
[
1
]
return
self
.
_w_validation
@property
def
class_weighted_validation_data
(
self
):
split_index
=
int
((
1
-
self
.
validation_split
)
*
len
(
self
.
x_train
))
return
self
.
x_train
[
split_index
:],
self
.
y_train
[
split_index
:],
self
.
w_validation
@property
def
training_data
(
self
):
"
training data with validation data split off
"
split_index
=
int
((
1
-
self
.
validation_split
)
*
len
(
self
.
x_train
))
return
self
.
x_train
[:
split_index
],
self
.
y_train
[:
split_index
],
self
.
w_train
[:
split_index
]
def
yield_batch
(
self
,
class_label
):
while
True
:
x_train
,
y_train
,
w_train
=
self
.
training_data
# shuffle the entries for this class label
rn_state
=
np
.
random
.
get_state
()
x_train
[
y_train
==
class_label
]
=
np
.
random
.
permutation
(
x_train
[
y_train
==
class_label
])
np
.
random
.
set_state
(
rn_state
)
w_train
[
y_train
==
class_label
]
=
np
.
random
.
permutation
(
w_train
[
y_train
==
class_label
])
# yield them batch wise
for
start
in
range
(
0
,
len
(
x_train
[
y_train
==
class_label
]),
self
.
batch_size
):
yield
(
x_train
[
y_train
==
class_label
][
start
:
start
+
self
.
batch_size
],
y_train
[
y_train
==
class_label
][
start
:
start
+
self
.
batch_size
],
w_train
[
y_train
==
class_label
][
start
:
start
+
self
.
batch_size
])
# restart
def
yield_balanced_batch
(
self
):
"
generate batches with equal amounts of both classes
"
for
batch_0
,
batch_1
in
izip
(
self
.
yield_batch
(
0
),
self
.
yield_batch
(
1
)):
yield
(
np
.
concatenate
((
batch_0
[
0
],
batch_1
[
0
])),
np
.
concatenate
((
batch_0
[
1
],
batch_1
[
1
])),
np
.
concatenate
((
batch_0
[
2
],
batch_1
[
2
])))
def
train
(
self
,
epochs
=
10
):
self
.
load
()
...
...
@@ -560,22 +622,36 @@ class ClassificationProject(object):
self
.
total_epochs
=
self
.
_read_info
(
"
epochs
"
,
0
)
logger
.
info
(
"
Train model
"
)
try
:
self
.
shuffle_training_data
()
self
.
is_training
=
True
self
.
model
.
fit
(
self
.
x_train
,
# the reshape might be unnescessary here
self
.
y_train
.
reshape
(
-
1
,
1
),
epochs
=
epochs
,
validation_split
=
self
.
validation_split
,
class_weight
=
self
.
class_weight
,
sample_weight
=
self
.
w_train
,
shuffle
=
True
,
batch_size
=
self
.
batch_size
,
callbacks
=
self
.
callbacks_list
)
self
.
is_training
=
False
except
KeyboardInterrupt
:
logger
.
info
(
"
Interrupt training - continue with rest
"
)
if
not
self
.
balance_dataset
:
try
:
self
.
shuffle_training_data
()
self
.
is_training
=
True
self
.
model
.
fit
(
self
.
x_train
,
# the reshape might be unnescessary here
self
.
y_train
.
reshape
(
-
1
,
1
),
epochs
=
epochs
,
validation_split
=
self
.
validation_split
,
class_weight
=
self
.
class_weight
,
sample_weight
=
self
.
w_train
,
shuffle
=
True
,
batch_size
=
self
.
batch_size
,
callbacks
=
self
.
callbacks_list
)
self
.
is_training
=
False
except
KeyboardInterrupt
:
logger
.
info
(
"
Interrupt training - continue with rest
"
)
else
:
try
:
self
.
is_training
=
True
labels
,
label_counts
=
np
.
unique
(
self
.
y_train
,
return_counts
=
True
)
logger
.
info
(
"
Training on balanced batches
"
)
self
.
model
.
fit_generator
(
self
.
yield_balanced_batch
(),
steps_per_epoch
=
int
(
min
(
label_counts
)
/
self
.
batch_size
),
epochs
=
epochs
,
validation_data
=
self
.
class_weighted_validation_data
,
callbacks
=
self
.
callbacks_list
)
self
.
is_training
=
False
except
KeyboardInterrupt
:
logger
.
info
(
"
Interrupt training - continue with rest
"
)
logger
.
info
(
"
Save history
"
)
self
.
_dump_history
()
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment