From d1d1cf3c25caf2d8e986a6383cdf2a937c89763e Mon Sep 17 00:00:00 2001 From: Nikolai <osterei33@gmx.de> Date: Thu, 9 Aug 2018 20:14:57 +0200 Subject: [PATCH] adding to_DataFrame function --- __init__.py | 6 +++--- toolkit.py | 13 +++++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/__init__.py b/__init__.py index e938f9e..c2807c6 100644 --- a/__init__.py +++ b/__init__.py @@ -1,3 +1,3 @@ -from .toolkit import ClassificationProject -from .compare import overlay_ROC, overlay_loss -from .add_friend import add_friend +from .toolkit import * +from .compare import * +from .add_friend import * diff --git a/toolkit.py b/toolkit.py index 3c2d3d9..0cea431 100755 --- a/toolkit.py +++ b/toolkit.py @@ -1,5 +1,7 @@ #!/usr/bin/env python +__all__ = ["ClassificationProject"] + from sys import version_info if version_info[0] > 2: @@ -69,6 +71,7 @@ def byteify(input): if version_info[0] > 2: byteify = lambda input : input + class ClassificationProject(object): """Simple framework to load data from ROOT TTrees and train Keras @@ -1130,6 +1133,16 @@ class ClassificationProject(object): # self.plot_significance() + def to_DataFrame(self): + df = pd.DataFrame(np.concatenate([self.x_train, self.x_test]), columns=self.fields) + df["weight"] = np.concatenate([self.w_train, self.w_test]) + df["labels"] = pd.Categorical.from_codes( + np.concatenate([self.y_train, self.y_test]), + categories=["background", "signal"] + ) + return df + + def create_getter(dataset_name): def getx(self): if getattr(self, "_"+dataset_name) is None: -- GitLab