Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
K
KerasROOTClassification
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container Registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Eric.Schanet
KerasROOTClassification
Commits
0a377f2a
Commit
0a377f2a
authored
6 years ago
by
Nikolai.Hartmann
Browse files
Options
Downloads
Patches
Plain Diff
adding BaseLogger and History callbacks
parent
37280c34
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
toolkit.py
+20
-21
20 additions, 21 deletions
toolkit.py
with
20 additions
and
21 deletions
toolkit.py
+
20
−
21
View file @
0a377f2a
...
...
@@ -37,7 +37,7 @@ from sklearn.utils.extmath import stable_cumsum
from
sklearn.model_selection
import
KFold
from
keras.models
import
Sequential
,
Model
,
model_from_json
from
keras.layers
import
Dense
,
Dropout
,
Input
,
Masking
,
GRU
,
LSTM
,
concatenate
,
SimpleRNN
from
keras.callbacks
import
History
,
EarlyStopping
,
CSVLogger
,
ModelCheckpoint
,
TensorBoard
,
CallbackList
from
keras.callbacks
import
History
,
EarlyStopping
,
CSVLogger
,
ModelCheckpoint
,
TensorBoard
,
CallbackList
,
BaseLogger
from
keras.optimizers
import
SGD
import
keras.optimizers
from
keras.utils.vis_utils
import
model_to_dot
...
...
@@ -1604,7 +1604,7 @@ class ClassificationProject(object):
hist_dict
[
hist_key
]
=
[
float
(
line
[
hist_key_index
])
for
line
in
history_list
[
1
:]]
return
hist_dict
def
plot_loss
(
self
,
all_trainings
=
False
,
log
=
False
,
ylim
=
None
,
xlim
=
None
):
def
plot_loss
(
self
,
all_trainings
=
False
,
log
=
False
,
ylim
=
None
,
xlim
=
None
,
loss_key
=
"
loss
"
):
"""
Plot the value of the loss function for each epoch
...
...
@@ -1616,14 +1616,14 @@ class ClassificationProject(object):
else
:
hist_dict
=
self
.
history
.
history
if
(
not
'
loss
'
in
hist_dict
)
or
(
not
'
val_loss
'
in
hist_dict
):
if
(
not
loss
_key
in
hist_dict
)
or
(
not
'
val_
'
+
loss
_key
in
hist_dict
):
logger
.
warning
(
"
No previous history found for plotting, try global history
"
)
hist_dict
=
self
.
csv_hist
logger
.
info
(
"
Plot losses
"
)
plt
.
plot
(
hist_dict
[
'
loss
'
])
plt
.
plot
(
hist_dict
[
'
val_loss
'
])
plt
.
ylabel
(
'
loss
'
)
plt
.
plot
(
hist_dict
[
loss
_key
])
plt
.
plot
(
hist_dict
[
'
val_
'
+
loss
_key
])
plt
.
ylabel
(
loss
_key
)
plt
.
xlabel
(
'
epoch
'
)
plt
.
legend
([
'
training data
'
,
'
validation data
'
],
loc
=
'
upper left
'
)
if
log
:
...
...
@@ -2219,7 +2219,7 @@ class ClassificationProjectDecorr(ClassificationProject):
return
self
.
_model_adv
def
train
(
self
,
epochs
=
10
):
def
train
(
self
,
epochs
=
10
,
skip_checkpoint
=
False
):
"""
Train classifier and adversary concurrently. Most of the garbage in this
code block is just organising stuff to get all the keras callbacks
...
...
@@ -2227,17 +2227,20 @@ class ClassificationProjectDecorr(ClassificationProject):
"""
batch_generator
=
self
.
yield_batch
()
metric_list
=
[]
out_labels
=
self
.
model
.
metrics_names
self
.
model
.
history
=
History
()
callback_metrics
=
out_labels
+
[
'
val_
'
+
n
for
n
in
out_labels
]
callbacks
=
CallbackList
(
self
.
callbacks_list
)
callbacks
=
CallbackList
(
[
BaseLogger
()]
+
self
.
callbacks_list
+
[
self
.
model
.
history
])
callbacks
.
set_model
(
self
.
model
)
callbacks
.
set_params
({
'
epochs
'
:
epochs
,
'
steps
'
:
self
.
steps_per_epoch
,
'
verbose
'
:
self
.
verbose
,
#'do_validation': do_validation,
'
do_validation
'
:
Fals
e
,
'
do_validation
'
:
Tru
e
,
'
metrics
'
:
callback_metrics
,
})
self
.
model
.
stop_training
=
False
...
...
@@ -2264,27 +2267,23 @@ class ClassificationProjectDecorr(ClassificationProject):
self
.
model_adv
.
train_on_batch
(
x
,
y
[
1
:],
sample_weight
=
w
[
1
:]
)
batch_metrics
=
np
.
array
(
batch_metrics
).
reshape
(
1
,
len
(
batch_metrics
))
if
metrics
is
None
:
metrics
=
batch_metrics
else
:
metrics
=
np
.
concatenate
([
metrics
,
batch_metrics
])
avg_metrics
=
np
.
mean
(
metrics
,
axis
=
0
)
outs
=
list
(
batch_metrics
)
for
l
,
o
in
zip
(
out_labels
,
outs
):
batch_logs
[
l
]
=
o
batch_logs
[
l
]
=
float
(
o
)
callbacks
.
on_batch_end
(
batch_id
,
batch_logs
)
metric_list
.
append
(
avg_metrics
)
val_metrics
=
self
.
model
.
test_on_batch
(
*
self
.
validation_data
)
val_outs
=
list
(
val_metrics
)
for
l
,
o
in
zip
(
out_labels
,
val_outs
):
epoch_logs
[
'
val_
'
+
l
]
=
o
epoch_logs
[
'
val_
'
+
l
]
=
float
(
o
)
callbacks
.
on_epoch_end
(
epoch
,
epoch_logs
)
if
self
.
model
.
stop_training
:
break
callbacks
.
on_train_end
()
return
metric_list
if
not
skip_checkpoint
:
self
.
checkpoint_model
()
return
self
.
model
.
history
if
__name__
==
"
__main__
"
:
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment