import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from scipy import odr


class Analysis:
    """
    Analysis class:
        This class computes mean and error on the detector output angles.
        The results can then be visualized with two plots.
    """

    def __init__(self):
        """
        The class is initialized with the angles from the detector results
        and then automatically deletes 'None' entries and computes the minimum
        range of angles and the error.
        """
        ...

    def fill(self, bounds):
        self.xBounds = pd.Series({'High' : bounds[:,0,0],
                                  'Low'  : bounds[:,0,1]})
        self.yBounds = pd.Series({'High' : bounds[:,1,0],
                                  'Low'  : bounds[:,1,1]})

        self.results = pd.DataFrame({'w'    : np.zeros(2),
                                    'Vx'    : np.zeros(2),
                                    'Vy'    : np.zeros(2),
                                    'Vz'    : np.zeros(2),
                                    'R'     : np.zeros(2),
                                    'y0'    : np.zeros(2),
                                    'z0'    : np.zeros(2)})

        #Determine mean and std
        #self.rmNone()
        #self.Bounds(self.xBounds['High'], self.xBounds['Low'], call=0)
        #self.Bounds(self.yBounds['High'], self.yBounds['Low'], call=1)

        self.xPoints = pd.DataFrame({'Mean' : (bounds[:,0,0] + bounds[:,0,1])/2.,
                                    'Error' : (bounds[:,0,0] - bounds[:,0,1])/2.})
        self.yPoints = pd.DataFrame({'Mean': (bounds[:,1,0] + bounds[:,1,1])/2.,
                                    'Error': (bounds[:,1,0] - bounds[:,1,1])/2.})
        self.zPoints = pd.DataFrame({'Mean': (bounds[:,2,0] + bounds[:,2,1])/2.,
                                    'Error': (bounds[:,2,0] - bounds[:,2,1])/2.})
        
        print(self.xPoints)
        print(self.yPoints)
        print(self.zPoints)
        
        #plt.plot(self.zPoints['Mean'], self.xPoints['Mean'])
        #plt.show()
        #self.xFit()
        #self.yFit()
        #print(self.results)
        #TODO: Plot

    def rmNone(self):
        """
        Fill the angle data into the class variables.
        """
        #print(self.xBounds['High'])
        self.xBounds['High'] = np.array([x for x in self.xBounds['High'] if x is not None])
        self.xBounds['Low'] = np.array([x for x in self.xBounds['Low'] if x is not None])
        self.yBounds['High'] = np.array([x for x in self.yBounds['High'] if x is not None])
        self.yBounds['Low'] = np.array([x for x in self.yBounds['Low'] if x is not None])

    def xEOM(self, B, x):
        """
            B= [w, Vx, Vy, Vz]
        """
        z = (1/B[0])*(-B[2]*(1-np.cos(x*B[0]/B[1])) + B[3]*np.sin(x*B[0]/B[1]))
        return z

    def yEOM(self, B, z):
        """
        B = [R, y0, z0]
        """
        y = B[1] + np.sqrt(B[0]**2 - (z-B[2])**2)
        return y

    def xFit(self):
        xModel  = odr.Model(self.xEOM)
        xData   = odr.RealData(self.xPoints['Mean'],
                               self.zPoints['Mean'],
                               sx = self.xPoints['Error'],
                               sy = self.zPoints['Error'])
        xODR = odr.ODR(xData, xModel, beta0=[1, 100, 100, 100])
        xOUT = xODR.run()
        xOUT.pprint()
        """
        xPopt, xPcov = curve_fit(self.xEOM,
                                 self.xPoints['Mean'],
                                 self.zPoints['Mean']
                                 #sigma=self.xPoints['Error'],
                                 #absolute_sigma=True)
                                )
        xPerr = np.sqrt(np.diag(xPcov))
        self.results['w'][0] = xPopt[0]
        self.results['w'][1] = xPerr[0]
        self.results['Vx'][0] = xPopt[1]
        self.results['Vx'][1] = xPerr[1]
        self.results['Vy'][0] = xPopt[2]
        self.results['Vy'][1] = xPerr[2]
        self.results['Vz'][0] = xPopt[3]
        self.results['Vz'][1] = xPerr[3]
        """
    def yFit(self):
        """
        yPopt, yPcov = curve_fit(self.yEOM,
                                 self.zPoints['Mean'],
                                 self.yPoints['Mean'],
                                 #sigma=self.yPoints['Error'],
                                 p0=[10^3, 1, 1],
                                 absolute_sigma=True)
        yPerr = np.sqrt(np.diag(xPcov))
        """
        #ODR
        
        yModel  = odr.Model(self.yEOM)
        yData   = odr.RealData(self.zPoints['Mean'],
                               self.yPoints['Mean'],
                               sx = self.zPoints['Error'],
                               sy = self.yPoints['Error'])
        yODR = odr.ODR(yData, yModel, beta0=[1e3, 100, 100])
        yOUT = yODR.run()
        yOUT.pprint()
        """
        self.results['R'][0] = yPopt[0]
        self.results['R'][1] = yPerr[0]
        self.results['y0'][0] = yPopt[1]
        self.results['y0'][1] = yPerr[1]
        self.results['z0'][0] = yPopt[2]
        self.results['z0'][1] = yPerr[2]
        """
    """
    def yEOM(self, t, w, yVel, zVel):
        y = zVel + (1./w)*(-zVel*np.cos(w*t) + yVel*np.sin(w*t))
        return y

    def zEOM(self, t, w, yVel, zVel):
        z = -yVel + (1./w)*(yVel*np.cos(w*t) + zVel*np.sin(w*t))
        return z
    """
    def Bounds(self, high, low, call=0):
        """
        Algorithm to compute the minimum possible range of angles.
        The upper and lower angles must be given as a argument and also
        the 'call' variable secifies the plane you are working in
        (x-z or y-z plane).
        The algorithim loops over the angles from the last detector layer
        to the first, and replaces the upper and lower angle bounds, if
        a certain layer restricts the range further.
        The mean is calculated as the arithmetic mean of the resulting
        angles and the error is the difference from the mean to the angle
        bounds.
        """
        highRev = high[::-1]
        lowRev = low[::-1]
        highBound = highRev[0]
        lowBound = lowRev[0]
        for a,b in zip(highRev[1:], lowRev[1:]):
            if a < highBound and a > lowBound:
                highBound = a
            if b < highBound and b > lowBound:
                lowBound = b
        error = (highBound - lowBound)/2
        mean = highBound - error
        self.results['Mean'][call] = mean
        self.results['Error'][call] = error

    def Plot(self):
        ...