Skip to content
Snippets Groups Projects
Commit 660b4beb authored by Thomas Weber's avatar Thomas Weber
Browse files

Added plotROC function and scaler function call

parent 34930398
No related branches found
No related tags found
No related merge requests found
......@@ -13,11 +13,14 @@ import pandas as pd
import h5py
from sklearn.preprocessing import StandardScaler
from sklearn.externals import joblib
from sklearn.metrics import roc_curve
from keras.models import Sequential
from keras.layers import Dense
from keras.models import model_from_json
import matplotlib.pyplot as plt
# configure number of cores
# this doesn't seem to work, but at least with these settings keras only uses 4 processes
import tensorflow as tf
......@@ -77,6 +80,9 @@ class KerasROOTClassification:
self._class_weight = None
self._model = None
self.score_train = None
self.score_test = None
# track the number of epochs this model has been trained
self.total_epochs = 0
......@@ -173,6 +179,8 @@ class KerasROOTClassification:
self._scaler = StandardScaler()
logger.info("Fitting StandardScaler to training data")
self._scaler.fit(self.x_train)
logger.info("Fitting StandardScaler to test data")
self._scaler.fit(self.x_test)
joblib.dump(self._scaler, filename)
return self._scaler
......@@ -211,7 +219,8 @@ class KerasROOTClassification:
self._model.add(Dense(self.nodes, activation=self.activation_function))
# last layer is one neuron (binary classification)
self._model.add(Dense(1, activation='sigmoid'))
logger.info("Compile model")
self._model.compile(optimizer='SGD',
loss='binary_crossentropy',
metrics=['accuracy'])
......@@ -234,6 +243,8 @@ class KerasROOTClassification:
if not self.data_loaded:
self._load_data()
self.scaler
try:
self.model.load_weights(os.path.join(self.project_dir, "weights.h5"))
......@@ -243,17 +254,25 @@ class KerasROOTClassification:
logger.info("No weights found, starting completely new training")
self.total_epochs = self._read_info("epochs", 0)
logger.info("Train model")
self.model.fit(self.x_train, self.y_train,
epochs=epochs,
class_weight=self.class_weight,
shuffle=True,
batch_size=self.batch_size)
logger.info("Save weights")
self.model.save_weights(os.path.join(self.project_dir, "weights.h5"))
self.total_epochs += epochs
self._write_info("epochs", self.total_epochs)
logger.info("Create scores for ROC curve")
self.scores_test = self.model.predict(self.x_test)
self.scores_train = self.model.predict(self.x_train)
def evaluate(self):
pass
......@@ -262,7 +281,25 @@ class KerasROOTClassification:
pass
def plotROC(self):
pass
logger.info("Plot ROC curve")
fpr, tpr, threshold = roc_curve(self.y_test, self.scores_test, sample_weight = self.w_test)
plt.grid(color='gray', linestyle='--', linewidth=1)
plt.plot(fpr, tpr, label='NN')
plt.plot([0,1],[0,1], linestyle='--', color='black', label='Luck')
plt.xlabel("False positive rate (background rejection)")
plt.ylabel("True positive rate (signal efficiency)")
plt.title('Receiver operating characteristic')
plt.xlim(0,1)
plt.ylim(0,1)
plt.xticks(np.arange(0,1,0.1))
plt.yticks(np.arange(0,1,0.1))
plt.legend(loc='lower left', framealpha=1.0)
plt.savefig(os.path.join(self.project_dir, "ROC.pdf"))
plt.clf()
def plotScore(self):
pass
......@@ -286,3 +323,4 @@ if __name__ == "__main__":
identifiers = ["DatasetNumber", "EventNumber"])
c.train(epochs=1)
c.plotROC()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment