From d1d1cf3c25caf2d8e986a6383cdf2a937c89763e Mon Sep 17 00:00:00 2001
From: Nikolai <osterei33@gmx.de>
Date: Thu, 9 Aug 2018 20:14:57 +0200
Subject: [PATCH] adding to_DataFrame function

---
 __init__.py |  6 +++---
 toolkit.py  | 13 +++++++++++++
 2 files changed, 16 insertions(+), 3 deletions(-)

diff --git a/__init__.py b/__init__.py
index e938f9e..c2807c6 100644
--- a/__init__.py
+++ b/__init__.py
@@ -1,3 +1,3 @@
-from .toolkit import ClassificationProject
-from .compare import overlay_ROC, overlay_loss
-from .add_friend import add_friend
+from .toolkit import *
+from .compare import *
+from .add_friend import *
diff --git a/toolkit.py b/toolkit.py
index 3c2d3d9..0cea431 100755
--- a/toolkit.py
+++ b/toolkit.py
@@ -1,5 +1,7 @@
 #!/usr/bin/env python
 
+__all__ = ["ClassificationProject"]
+
 from sys import version_info
 
 if version_info[0] > 2:
@@ -69,6 +71,7 @@ def byteify(input):
 if version_info[0] > 2:
     byteify = lambda input : input
 
+
 class ClassificationProject(object):
 
     """Simple framework to load data from ROOT TTrees and train Keras
@@ -1130,6 +1133,16 @@ class ClassificationProject(object):
         # self.plot_significance()
 
 
+    def to_DataFrame(self):
+        df = pd.DataFrame(np.concatenate([self.x_train, self.x_test]), columns=self.fields)
+        df["weight"] = np.concatenate([self.w_train, self.w_test])
+        df["labels"] = pd.Categorical.from_codes(
+            np.concatenate([self.y_train, self.y_test]),
+            categories=["background", "signal"]
+        )
+        return df
+
+
 def create_getter(dataset_name):
     def getx(self):
         if getattr(self, "_"+dataset_name) is None:
-- 
GitLab