Skip to content
Snippets Groups Projects
Unverified Commit 2e2344e7 authored by Eric Schanet's avatar Eric Schanet
Browse files

Fixing ROC AUC computation

parent 365ba76d
No related branches found
No related tags found
No related merge requests found
...@@ -13,7 +13,7 @@ import pandas as pd ...@@ -13,7 +13,7 @@ import pandas as pd
import h5py import h5py
from sklearn.preprocessing import StandardScaler, RobustScaler from sklearn.preprocessing import StandardScaler, RobustScaler
from sklearn.externals import joblib from sklearn.externals import joblib
from sklearn.metrics import roc_curve from sklearn.metrics import roc_curve, auc
from keras.models import Sequential from keras.models import Sequential
from keras.layers import Dense from keras.layers import Dense
...@@ -128,7 +128,7 @@ class KerasROOTClassification: ...@@ -128,7 +128,7 @@ class KerasROOTClassification:
self.b_train = tree2array(bkg_chain, self.b_train = tree2array(bkg_chain,
branches=self.branches+[self.weight_expr]+self.identifiers, branches=self.branches+[self.weight_expr]+self.identifiers,
selection=self.selection, selection=self.selection,
start=0, step=2) start=0, step=200)
self.s_test = tree2array(signal_chain, self.s_test = tree2array(signal_chain,
branches=self.branches+[self.weight_expr], branches=self.branches+[self.weight_expr],
selection=self.selection, selection=self.selection,
...@@ -136,7 +136,7 @@ class KerasROOTClassification: ...@@ -136,7 +136,7 @@ class KerasROOTClassification:
self.b_test = tree2array(bkg_chain, self.b_test = tree2array(bkg_chain,
branches=self.branches+[self.weight_expr], branches=self.branches+[self.weight_expr],
selection=self.selection, selection=self.selection,
start=1, step=2) start=1, step=200)
self._dump_training_list() self._dump_training_list()
self.s_eventlist_train = self.s_train[self.identifiers] self.s_eventlist_train = self.s_train[self.identifiers]
...@@ -404,13 +404,15 @@ class KerasROOTClassification: ...@@ -404,13 +404,15 @@ class KerasROOTClassification:
logger.info("Plot ROC curve") logger.info("Plot ROC curve")
fpr, tpr, threshold = roc_curve(self.y_test, self.scores_test, sample_weight = self.w_test) fpr, tpr, threshold = roc_curve(self.y_test, self.scores_test, sample_weight = self.w_test)
fpr = 1.0 - fpr fpr = 1.0 - fpr
roc_auc = auc(tpr, fpr)
plt.grid(color='gray', linestyle='--', linewidth=1) plt.grid(color='gray', linestyle='--', linewidth=1)
plt.plot(fpr, tpr, label='NN') plt.plot(tpr, fpr, label='NN')
plt.plot([0,1],[1,0], linestyle='--', color='black', label='Luck') plt.plot([0,1],[1,0], linestyle='--', color='black', label='Luck')
plt.xlabel("False positive rate (background rejection)") plt.xlabel("Background rejection")
plt.ylabel("True positive rate (signal efficiency)") plt.xlabel("Signal efficiency")
plt.title('Receiver operating characteristic') plt.title('Receiver operating characteristic')
plt.xlim(0,1) plt.xlim(0,1)
plt.ylim(0,1) plt.ylim(0,1)
...@@ -421,7 +423,6 @@ class KerasROOTClassification: ...@@ -421,7 +423,6 @@ class KerasROOTClassification:
plt.savefig(os.path.join(self.project_dir, "ROC.pdf")) plt.savefig(os.path.join(self.project_dir, "ROC.pdf"))
plt.clf() plt.clf()
def plot_score(self): def plot_score(self):
pass pass
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment