from enstools.feature.identification import IdentificationTechnique
import xarray as xr
import numpy as np
import os, sys

import metpy.calc as mpcalc
from .util import calc_adv
from matplotlib import pyplot as plt
import cartopy.crs as ccrs
from .filtering import keep_wavetrough
from .processing import populate_object
from skimage.draw import line_aa
from enstools.feature.util.enstools_utils import get_vertical_dim, get_longitude_dim, get_latitude_dim

class AEWIdentification(IdentificationTechnique):

    def __init__(self, wt_out_file=True, cv='cv', **kwargs):
        """
        Initialize the AEW Identification.

        Parameters (experimental)
        ----------
        kwargs
        wt_out_file: output the wavetroughs as new and only out-field in 0.5x0.5
        """

        import enstools.feature.identification.african_easterly_waves.configuration as cfg
        self.config = cfg  # config
        self.config.out_wt = wt_out_file
        self.config.cv_name = cv
        self.processing_mode = '2d'
        
        pass

    def precompute(self, dataset: xr.Dataset, **kwargs):
        print("Precompute for PV identification...")

        plt.switch_backend('agg') # this is thread safe matplotlib but cant display.

        # --------------- CLIMATOLOGY
        lon_range = self.config.data_lon
        lat_range = self.config.data_lat
        clim_file = self.config.get_clim_file()

        level_str = get_vertical_dim(dataset)
        lat_str = get_latitude_dim(dataset)
        lon_str = get_longitude_dim(dataset)

        if os.path.isfile(clim_file):
            cv_clim = xr.open_dataset(clim_file)

        else:
            # generate: need all 40y of CV data.
            print("Climatology file not found. Computing climatology...")

            from .climatology import compute_climatology
            cv_clim = compute_climatology(self.config)
            cv_clim.to_netcdf(clim_file)
        
        # --------------- SUBSET DATA ACCORDING TO CFG
        start_date_dt = np.datetime64(self.config.start_date) if self.config.start_date is not None else None
        end_date_dt = np.datetime64(self.config.end_date) if self.config.end_date is not None else None
        # get the data we want to investigate

        dataset = dataset.sel(**{level_str: self.config.levels},
                                        **{lat_str: slice(lat_range[0], lat_range[1])},
                                        **{lon_str: slice(lon_range[0], lon_range[1])},
                                        time=slice(start_date_dt, end_date_dt))

        # make sure that lat and lon are last two dimensions
        if lat_str not in dataset[self.config.cv_name].coords.dims[-2:] or lon_str not in dataset[self.config.cv_name].coords.dims[-2:]:
            print("Reordering dimensions so lat and lon at back. Required for metpy.calc.")
            dataset = dataset.transpose(..., lat_str, lon_str)

        # --------------- DO NUMPY PARALLELIZED STUFF: CREATE TROUGH MASKS

        u = dataset.u if 'u' in dataset.data_vars else dataset.U
        v = dataset.v if 'v' in dataset.data_vars else dataset.V
        cv = dataset[self.config.cv_name]
        # smooth CV with kernel
        print('c')
        cv = mpcalc.smooth_n_point(cv, n=9, passes=2).metpy.dequantify()
        # create hourofyear to get anomalies
        cv = cv.assign_coords(hourofyear=cv.time.dt.strftime("%m-%d %H"))
        print('b')
        cv_anom = cv.groupby('hourofyear') - cv_clim.cv
        print('a')
        # compute advection of cv: first and second derivative
        adv1, adv2 = calc_adv(cv_anom, u, v)

        # xr.where() anomaly data exceeds the percentile from the hourofyear climatology:
        # replace data time with hourofyear -> compare with climatology percentile -> back to real time
        cv_anom_h = cv_anom.swap_dims(dims_dict={'time': 'hourofyear'})
        perc_mask_h = cv_anom_h.where(
            cv_anom_h > cv_clim.cva_quantile_hoy.sel(dict(hourofyear=cv_anom.hourofyear.data)))
        perc_mask = perc_mask_h.swap_dims(dims_dict={'hourofyear': 'time'})

        cv_perc_thresh = np.nanpercentile(cv, self.config.cv_percentile)  # 66th percentile of cv anomalies
        print(cv_perc_thresh)

        print('Locating wave troughs...')
        # filter the advection field given our conditions:
        trough_mask = adv1.where(np.logical_and(
            ~np.isnan(perc_mask),  # percentile of anomaly over threshold from climatology
            adv2.values > self.config.second_advection_min_thr,
            # second time derivative > 0: dont detect local minima over the percentile
            u.values < self.config.max_u_thresh))  # threshold for propagation speed -> keep only westward

        dataset['trough_mask'] = trough_mask

        # create 0.5x0.5 dataarray for wavetroughs
        min_lat = dataset[lat_str].data.min()
        max_lat = dataset[lat_str].data.max()
        min_lon = dataset[lon_str].data.min()
        max_lon = dataset[lon_str].data.max()

        lat05 = np.linspace(min_lat, max_lat, int((max_lat - min_lat) * 2) + 1)
        lon05 = np.linspace(min_lon, max_lon, int((max_lon - min_lon) * 2) + 1)
        # wt = xr.DataArray(coords=[('lon05', lon05), ('lat05', lat05)],

        wt = xr.zeros_like(dataset['trough_mask'], dtype=float)
        wt = wt.isel(lat=0).drop(lat_str).isel(lon=0).drop(lon_str)
        wt = wt.expand_dims(lon05=lon05).expand_dims(lat05=lat05)
        wt = wt.transpose(..., 'lat05', 'lon05')

        dataset['wavetroughs'] = wt
        dataset['lat05'].attrs['long_name'] = 'latitude'
        dataset['lon05'].attrs['long_name'] = 'longitude'
        dataset['lat05'].attrs['units'] = 'degrees_north'
        dataset['lon05'].attrs['units'] = 'degrees_east'

        return dataset

    def identify(self, data_chunk: xr.Dataset, **kwargs):
        objs = []
        trough_mask_cur = data_chunk.trough_mask

        def clip(tup, mint, maxt):
            return np.clip(tup, mint, maxt)

        fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(15, 15), subplot_kw = {'projection': ccrs.PlateCarree()})
        # generate zero-contours with matplotlib core
        c = trough_mask_cur.plot.contour(transform=ccrs.PlateCarree(), colors='blue', levels=[0.0], subplot_kws={'projection': ccrs.PlateCarree()})

        paths = c.collections[0].get_paths()
        wt = data_chunk.wavetroughs
        # TODO path to data field...
        # maybe skimage draw line(), but consider lat/lons...

        min_lat = wt.lat05.data.min()
        max_lat = wt.lat05.data.max()
        min_lon = wt.lon05.data.min()
        max_lon = wt.lon05.data.max()
        lons = len(wt.lon05.data)
        lats = len(wt.lat05.data)

        id_ = 1
        for path in paths:
            # get new object, set id
            o = self.get_new_object()
            o.id = id_
            # populate it
            populate_object(o.properties, path)

            # add to objects if keep
            if keep_wavetrough(o.properties, self.config):
                objs.append(o)

                id_ += 1
                if not self.config.out_wt:
                    continue

                for v_idx in range(len(path.vertices) - 1):
                    start_lonlat = path.vertices[v_idx][0], path.vertices[v_idx][1]
                    end_lonlat = path.vertices[v_idx + 1][0], path.vertices[v_idx + 1][1]

                    start_idx = ((start_lonlat[0] - min_lon) / (max_lon - min_lon) * lons,
                        (start_lonlat[1] - min_lat) / (max_lat - min_lat) * lats)
                    # start_idx = clip(start_idx, (0, 0), (lons, lats))

                    end_idx = ((end_lonlat[0] - min_lon) / (max_lon - min_lon) * lons,
                        (end_lonlat[1] - min_lat) / (max_lat - min_lat) * lats)
                    # end_idx = clip(end_idx, (0, 0), (lons, lats))

                    rr, cc, val = line_aa(int(start_idx[0]), int(start_idx[1]), int(end_idx[0]), int(end_idx[1]))
                    rr = clip(rr, 0, lons - 1)
                    cc = clip(cc, 0, lats - 1)

                    wt.data[cc, rr] = np.where(np.greater(val, wt.data[cc, rr]), val, wt.data[cc, rr])




        return data_chunk, objs


    def postprocess(self, dataset: xr.Dataset, pb2_desc, **kwargs):

        lat_str = get_latitude_dim(dataset)
        lon_str = get_longitude_dim(dataset)

        # drop everything, only keep WTs TODO  as config.
        if self.config.out_wt:
            for var in dataset.data_vars:
                if var not in ['wavetroughs']:
                    dataset = dataset.drop_vars([var])

        dataset = dataset.drop_vars([lat_str, lon_str, 'hourofyear', 'quantile']) # TODO only if exist
        dataset = dataset.rename({'lat05': 'lat', 'lon05': 'lon'})

        return dataset, pb2_desc