Skip to content
Snippets Groups Projects
Commit 660b4beb authored by Thomas Weber's avatar Thomas Weber
Browse files

Added plotROC function and scaler function call

parent 34930398
No related branches found
No related tags found
No related merge requests found
...@@ -13,11 +13,14 @@ import pandas as pd ...@@ -13,11 +13,14 @@ import pandas as pd
import h5py import h5py
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
from sklearn.externals import joblib from sklearn.externals import joblib
from sklearn.metrics import roc_curve
from keras.models import Sequential from keras.models import Sequential
from keras.layers import Dense from keras.layers import Dense
from keras.models import model_from_json from keras.models import model_from_json
import matplotlib.pyplot as plt
# configure number of cores # configure number of cores
# this doesn't seem to work, but at least with these settings keras only uses 4 processes # this doesn't seem to work, but at least with these settings keras only uses 4 processes
import tensorflow as tf import tensorflow as tf
...@@ -77,6 +80,9 @@ class KerasROOTClassification: ...@@ -77,6 +80,9 @@ class KerasROOTClassification:
self._class_weight = None self._class_weight = None
self._model = None self._model = None
self.score_train = None
self.score_test = None
# track the number of epochs this model has been trained # track the number of epochs this model has been trained
self.total_epochs = 0 self.total_epochs = 0
...@@ -173,6 +179,8 @@ class KerasROOTClassification: ...@@ -173,6 +179,8 @@ class KerasROOTClassification:
self._scaler = StandardScaler() self._scaler = StandardScaler()
logger.info("Fitting StandardScaler to training data") logger.info("Fitting StandardScaler to training data")
self._scaler.fit(self.x_train) self._scaler.fit(self.x_train)
logger.info("Fitting StandardScaler to test data")
self._scaler.fit(self.x_test)
joblib.dump(self._scaler, filename) joblib.dump(self._scaler, filename)
return self._scaler return self._scaler
...@@ -211,7 +219,8 @@ class KerasROOTClassification: ...@@ -211,7 +219,8 @@ class KerasROOTClassification:
self._model.add(Dense(self.nodes, activation=self.activation_function)) self._model.add(Dense(self.nodes, activation=self.activation_function))
# last layer is one neuron (binary classification) # last layer is one neuron (binary classification)
self._model.add(Dense(1, activation='sigmoid')) self._model.add(Dense(1, activation='sigmoid'))
logger.info("Compile model")
self._model.compile(optimizer='SGD', self._model.compile(optimizer='SGD',
loss='binary_crossentropy', loss='binary_crossentropy',
metrics=['accuracy']) metrics=['accuracy'])
...@@ -234,6 +243,8 @@ class KerasROOTClassification: ...@@ -234,6 +243,8 @@ class KerasROOTClassification:
if not self.data_loaded: if not self.data_loaded:
self._load_data() self._load_data()
self.scaler
try: try:
self.model.load_weights(os.path.join(self.project_dir, "weights.h5")) self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
...@@ -243,17 +254,25 @@ class KerasROOTClassification: ...@@ -243,17 +254,25 @@ class KerasROOTClassification:
logger.info("No weights found, starting completely new training") logger.info("No weights found, starting completely new training")
self.total_epochs = self._read_info("epochs", 0) self.total_epochs = self._read_info("epochs", 0)
logger.info("Train model")
self.model.fit(self.x_train, self.y_train, self.model.fit(self.x_train, self.y_train,
epochs=epochs, epochs=epochs,
class_weight=self.class_weight, class_weight=self.class_weight,
shuffle=True, shuffle=True,
batch_size=self.batch_size) batch_size=self.batch_size)
logger.info("Save weights")
self.model.save_weights(os.path.join(self.project_dir, "weights.h5")) self.model.save_weights(os.path.join(self.project_dir, "weights.h5"))
self.total_epochs += epochs self.total_epochs += epochs
self._write_info("epochs", self.total_epochs) self._write_info("epochs", self.total_epochs)
logger.info("Create scores for ROC curve")
self.scores_test = self.model.predict(self.x_test)
self.scores_train = self.model.predict(self.x_train)
def evaluate(self): def evaluate(self):
pass pass
...@@ -262,7 +281,25 @@ class KerasROOTClassification: ...@@ -262,7 +281,25 @@ class KerasROOTClassification:
pass pass
def plotROC(self): def plotROC(self):
pass
logger.info("Plot ROC curve")
fpr, tpr, threshold = roc_curve(self.y_test, self.scores_test, sample_weight = self.w_test)
plt.grid(color='gray', linestyle='--', linewidth=1)
plt.plot(fpr, tpr, label='NN')
plt.plot([0,1],[0,1], linestyle='--', color='black', label='Luck')
plt.xlabel("False positive rate (background rejection)")
plt.ylabel("True positive rate (signal efficiency)")
plt.title('Receiver operating characteristic')
plt.xlim(0,1)
plt.ylim(0,1)
plt.xticks(np.arange(0,1,0.1))
plt.yticks(np.arange(0,1,0.1))
plt.legend(loc='lower left', framealpha=1.0)
plt.savefig(os.path.join(self.project_dir, "ROC.pdf"))
plt.clf()
def plotScore(self): def plotScore(self):
pass pass
...@@ -286,3 +323,4 @@ if __name__ == "__main__": ...@@ -286,3 +323,4 @@ if __name__ == "__main__":
identifiers = ["DatasetNumber", "EventNumber"]) identifiers = ["DatasetNumber", "EventNumber"])
c.train(epochs=1) c.train(epochs=1)
c.plotROC()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment