import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from scipy import odr


class Analysis:
    """
    Analysis class:
        This class computes mean and error on the detector output.
        It then fits the exact equations of motions for a particle
        in a magnetic field to the input data.
        The results can then be visualized with plots.
    """

    def __init__(self):
        """
        The class is initialized with the angles from the detector results
        and then automatically deletes 'None' entries and computes the minimum
        range of angles and the error.
        """
        ...

    def fill(self, bounds):
        """
        """
        self.xBounds = pd.Series({'High' : bounds[:,0,0],
                                  'Low'  : bounds[:,0,1]})
        self.yBounds = pd.Series({'High' : bounds[:,1,0],
                                  'Low'  : bounds[:,1,1]})

        self.results = pd.DataFrame({'w'    : np.zeros(2),
                                    'vx'    : np.zeros(2),
                                    'vy'    : np.zeros(2),
                                    'vz'    : np.zeros(2)})

        #Determine mean and std

        self.xPoints = pd.DataFrame({'Mean': (bounds[1:,0,0] + bounds[1:,0,1])/2.,
                                    'Error': (bounds[1:,0,0] - bounds[1:,0,1])/np.sqrt(12.)})
        self.yPoints = pd.DataFrame({'Mean': (bounds[1:,1,0] + bounds[1:,1,1])/2.,
                                    'Error': (bounds[1:,1,0] - bounds[1:,1,1])/np.sqrt(12.)})
        self.zPoints = pd.DataFrame({'Mean': (bounds[1:,2,0] + bounds[1:,2,1])/2.,
                                    'Error': (bounds[1:,2,0] - bounds[1:,2,1])/np.sqrt(12.)})
        self.tPoints = pd.DataFrame({'Mean': bounds[1:,3,0],
                                     'Error': 1e-9})
        """ 
        print(self.xPoints)
        print(self.yPoints)
        print(self.zPoints)
        print(self.tPoints)
        """

    def rmNone(self):
        """
        Fill the angle data into the class variables.
        """
        #print(self.xBounds['High'])
        self.xBounds['High'] = np.array([x for x in self.xBounds['High'] if x is not None])
        self.xBounds['Low'] = np.array([x for x in self.xBounds['Low'] if x is not None])
        self.yBounds['High'] = np.array([x for x in self.yBounds['High'] if x is not None])
        self.yBounds['Low'] = np.array([x for x in self.yBounds['Low'] if x is not None])

    """----------------------------------------------------------------------------
    FITTING

    Here we fit the data. As a fitting procedure we use ORTHOGONAL DISTANCE REGRESSION.
    We fit all three spacial dimensions with respect to time, which is the same as fitting
    path length for non-relativistic paricles.

    The fitting functions / equations of motion are xEOM, yEOM, zEOM.

    The actual fits are performed by the functions xFit, yFit, zFit.

    The fitting functions are called by the function Fit. It also collects
    all the results and fills them into the results dataframe.
    """
    
    def Fit(self):
        """

        """
        xResults = self.xFit()
        yResults = self.yFit()
        zResults = self.zFit()

        #check sign of w, as it is sometimes added to other fitting parameters
        sign = 1
        ySgn = np.sign(yResults.beta)
        zSgn = np.sign(zResults.beta)

        if np.amin(ySgn) == -1 and np.amin(zSgn) ==-1:
            sign = -1

        # do we need this?
        if np.sum(ySgn) < 1 or np.sum(zSgn) < 1:
            raise Exception('A fitting error occured. Try again.')

        self.results['vx'] = np.array([xResults.beta[0], xResults.sd_beta[0]])

        w = (np.abs(yResults.beta[0]) + np.abs(zResults.beta[0]))/2.
        wStd = np.sqrt(yResults.sd_beta[0]**2 + zResults.sd_beta[0]**2)
        self.results['w'] = np.array([ sign * w , wStd])

        vy = (np.abs(yResults.beta[1]) + np.abs(zResults.beta[1]))/2.
        vyStd = np.sqrt(yResults.sd_beta[1]**2 + zResults.sd_beta[1]**2)
        self.results['vy'] = np.array([vy,vyStd])

        vz = (np.abs(yResults.beta[2]) + np.abs(zResults.beta[2]))/2.
        vzStd = np.sqrt(yResults.sd_beta[2]**2 + zResults.sd_beta[2]**2)
        self.results['vz'] = np.array([vz,vzStd])

    def xEOM(self, B, t):
        """
        B = [Vx]
        """
        x = B[0]*t
        return x

    def yEOM(self, B, t):
        """
        B = [w, Vy, Vz]
        """
        #y = B[1] + np.sqrt(B[0]**2 - (z-B[2])**2)
        y = (1/B[0])*(B[2]*(1-np.cos(t*B[0])) + B[1]*np.sin(t*B[0]))
        return y

    def zEOM(self, B, t):
        """
            B= [w, Vy, Vz]
        """
        z = (1/B[0])*(-B[1]*(1-np.cos(t*B[0])) + B[2]*np.sin(t*B[0]))
        return z

    def pol(self, B, x):
        return B[0] + B[1]*x + B[2]*(x**2) + B[3]*(x**3)
    
    def xFit(self):
        xModel  = odr.Model(self.xEOM)
        xData   = odr.RealData(self.tPoints['Mean'],
                               self.xPoints['Mean'],
                               sx = self.tPoints['Error'],
                               sy = self.xPoints['Error'])
        xODR = odr.ODR(xData, xModel,beta0=[10])# beta0=[1, 100, 100, 100])
        xOUT = xODR.run()
        #xOUT.pprint()
        print('The x fit was performend with chi2 = ', xOUT.res_var)
        #print(xOUT.sd_beta)
        return xOUT
        """
        print('x ', xOUT.res_var)
        plt.plot(self.tPoints['Mean'],self.xEOM(xOUT.beta, self.tPoints['Mean']))
        plt.plot(self.tPoints['Mean'], self.xPoints['Mean'], 'o')
        plt.xlabel('t')
        plt.ylabel('x')
        plt.savefig('x.pdf', format='pdf')
        plt.gcf().clear()
        """
    
    def yFit(self):
        #ODR
        yModel  = odr.Model(self.yEOM)
        yData   = odr.RealData(self.tPoints['Mean'],
                               self.yPoints['Mean'],
                               sx = self.tPoints['Error'],
                               sy = self.yPoints['Error'])
        yODR = odr.ODR(yData, yModel, beta0=[1, 10, 100])
        yOUT = yODR.run()
        print('The y fit was performend with chi2 = ', yOUT.res_var)
        #yOUT.pprint()
        return yOUT
        """
        print('y :', yOUT.res_var)
        plt.plot(self.tPoints['Mean'], self.yEOM(yOUT.beta, self.tPoints['Mean']))
        plt.plot(self.tPoints['Mean'], self.yPoints['Mean'], 'o')
        plt.xlabel('t')
        plt.ylabel('y')
        plt.savefig('y.pdf', format='pdf')
        plt.gcf().clear()
        """

    def zFit(self):
        #ODR
        zModel  = odr.Model(self.zEOM)
        zData   = odr.RealData(self.tPoints['Mean'],
                               self.zPoints['Mean'],
                               sx = self.tPoints['Error'],
                               sy = self.zPoints['Error'])
        zODR = odr.ODR(zData, zModel, beta0=[1, 10, 100])
        zOUT = zODR.run()
        print('The z fit was performend with chi2 = ', zOUT.res_var)
        #zOUT.pprint()
        #print('z :', zOUT.res_var)
        return zOUT
        """
        plt.plot(self.tPoints['Mean'], self.zEOM(zOUT.beta, self.tPoints['Mean']))
        plt.plot(self.tPoints['Mean'], self.zPoints['Mean'], 'o')
        plt.xlabel('t')
        plt.ylabel('z')
        plt.savefig('z.pdf', format='pdf')
        plt.gcf().clear()
        """
    """
    def Bounds(self, high, low, call=0):
        
        Algorithm to compute the minimum possible range of angles.
        The upper and lower angles must be given as a argument and also
        the 'call' variable secifies the plane you are working in
        (x-z or y-z plane).
        The algorithim loops over the angles from the last detector layer
        to the first, and replaces the upper and lower angle bounds, if
        a certain layer restricts the range further.
        The mean is calculated as the arithmetic mean of the resulting
        angles and the error is the difference from the mean to the angle
        bounds.
        
        highRev = high[::-1]
        lowRev = low[::-1]
        highBound = highRev[0]
        lowBound = lowRev[0]
        for a,b in zip(highRev[1:], lowRev[1:]):
            if a < highBound and a > lowBound:
                highBound = a
            if b < highBound and b > lowBound:
                lowBound = b
        error = (highBound - lowBound)/2
        mean = highBound - error
        self.results['Mean'][call] = mean
        self.results['Error'][call] = error
    """
    def Plot(self):
        ...