from enstools.feature.identification import IdentificationTechnique
import xarray as xr
import numpy as np
import os, sys

import metpy.calc as mpcalc
from .util import calc_adv
from matplotlib import pyplot as plt
import as ccrs
from .filtering import keep_wavetrough
from .processing import populate_object
from skimage.draw import line_aa
from enstools.feature.util.enstools_utils import get_vertical_dim, get_longitude_dim, get_latitude_dim
class AEWIdentification(IdentificationTechnique):

    def __init__(self, wt_traj_dir=None, cv='cv', **kwargs):
        Initialize the AEW Identification.

        Parameters (experimental)
        wt_out_file: output the wavetroughs as new and only out-field in 0.5x0.5

        import enstools.feature.identification.african_easterly_waves.configuration as cfg
        self.config = cfg  # config
        self.config.out_wt_dir = wt_traj_dir
        self.config.cv_name = cv
        self.found_max_wt_pts = -1
        self.lock_ = threading.Lock()

    def precompute(self, dataset: xr.Dataset, **kwargs):
        print("Precompute for PV identification...")
        plt.switch_backend('agg')  # this is thread safe matplotlib but cant display.
        # --------------- CLIMATOLOGY
        lon_range = self.config.data_lon
        lat_range = self.config.data_lat
        clim_file = self.config.get_clim_file()

        level_str = get_vertical_dim(dataset)
        lat_str = get_latitude_dim(dataset)
        lon_str = get_longitude_dim(dataset)

        if os.path.isfile(clim_file):
            cv_clim = xr.open_dataset(clim_file)

            # generate: need all 40y of CV data.
            print("Climatology file not found. Computing climatology...")

            from .climatology import compute_climatology
            cv_clim = compute_climatology(self.config)
        # --------------- SUBSET DATA ACCORDING TO CFG
        start_date_dt = np.datetime64(self.config.start_date) if self.config.start_date is not None else None
        end_date_dt = np.datetime64(self.config.end_date) if self.config.end_date is not None else None
        dataset = dataset.sel(
            **{lat_str: slice(lat_range[0], lat_range[1])},
            **{lon_str: slice(lon_range[0], lon_range[1])},
            time=slice(start_date_dt, end_date_dt))
        if level_str is None:
            if not 'level' in dataset.coords:
                print("No level information given in input. Assume 700hPa.")
            dataset = dataset.expand_dims('level')
            level_str = 'level'
            dataset = dataset.sel(**{level_str: self.config.levels})
        # rename cv_clim dimensions to be same as in data.
        cv_clim = cv_clim.rename({'lat': lat_str, 'lon': lon_str})
        if 'plev' in cv_clim.dims and 'plev' != level_str:
            print("plev from clim to level: div by 100.")
            cv_clim = cv_clim.rename({'plev': level_str})
            cv_clim = cv_clim.assign_coords({level_str: cv_clim[level_str] / 100})

        # make sure that lat and lon are last two dimensions
        if lat_str not in dataset[self.config.cv_name].coords.dims[-2:] or lon_str not in dataset[
            print("Reordering dimensions so lat and lon at back. Required for metpy.calc.")
            dataset = dataset.transpose(..., lat_str, lon_str)


        u = dataset.u if 'u' in dataset.data_vars else dataset.U
        v = dataset.v if 'v' in dataset.data_vars else dataset.V
        cv = dataset[self.config.cv_name]
        # smooth CV with kernel
        cv = mpcalc.smooth_n_point(cv, n=9, passes=2).metpy.dequantify()
        # create hourofyear to get anomalies
        cv = cv.assign_coords(hourofyear=cv.time.dt.strftime("%m-%d %H"))
        cv_anom = cv.groupby('hourofyear') -
        # compute advection of cv: first and second derivative
        adv1, adv2 = calc_adv(cv_anom, u, v)

        # xr.where() anomaly data exceeds the percentile from the hourofyear climatology:
        # replace data time with hourofyear -> compare with climatology percentile -> back to real time
        cv_anom_h = cv_anom.swap_dims(dims_dict={'time': 'hourofyear'})
        perc_mask_h = cv_anom_h.where(
            cv_anom_h > cv_clim.cva_quantile_hoy.sel(dict(
        perc_mask = perc_mask_h.swap_dims(dims_dict={'hourofyear': 'time'})

        cv_perc_thresh = np.nanpercentile(cv, self.config.cv_percentile)  # 66th percentile of cv anomalies

        print('Locating wave troughs...')
        # filter the advection field given our conditions:
        trough_mask = adv1.where(np.logical_and(
            ~np.isnan(perc_mask),  # percentile of anomaly over threshold from climatology
            adv2.values > self.config.second_advection_min_thr,
            # second time derivative > 0: dont detect local minima over the percentile
            u.values < self.config.max_u_thresh))  # threshold for propagation speed -> keep only westward

        dataset['trough_mask'] = trough_mask

        # create 0.5x0.5 dataarray for wavetroughs
        min_lat = dataset[lat_str].data.min()
        max_lat = dataset[lat_str].data.max()
        min_lon = dataset[lon_str].data.min()
        max_lon = dataset[lon_str].data.max()

        lat05 = np.linspace(min_lat, max_lat, int((max_lat - min_lat) * 2) + 1)
        lon05 = np.linspace(min_lon, max_lon, int((max_lon - min_lon) * 2) + 1)
        # wt = xr.DataArray(coords=[('lon05', lon05), ('lat05', lat05)],
    def identify(self, data_chunk: xr.Dataset, **kwargs):
        objs = []
        trough_mask_cur = data_chunk.trough_mask

        def clip(tup, mint, maxt):
            return np.clip(tup, mint, maxt)

        fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(15, 15), subplot_kw={'projection': ccrs.PlateCarree()})
        # generate zero-contours with matplotlib core
        c = trough_mask_cur.plot.contour(transform=ccrs.PlateCarree(), colors='blue', levels=[0.0],
                                         subplot_kws={'projection': ccrs.PlateCarree()})
        id_ = 1
        for path in paths:
            o = self.get_new_object()
            # populate it
            populate_object(, path)
            # add to objects if keep
            if keep_wavetrough(, self.config):
                id_ += 1
                num_verts = len(path.vertices)
                with self.lock_: # TODO remove this?
                    self.found_max_wt_pts = max(self.found_max_wt_pts, num_verts)
        return data_chunk, objs
    def postprocess(self, dataset: xr.Dataset, data_desc, **kwargs):
        lat_str = get_latitude_dim(dataset)
        lon_str = get_longitude_dim(dataset)

        data_desc = self.make_ids_unique(data_desc)

        # create met3d like trajectories
        if self.config.out_wt_dir:
            if not os.path.exists(self.config.out_wt_dir):

            assert(len(data_desc.sets) == 1) # TODO assert one set. maybe expand at some point
            desc_set = data_desc.sets[0]
            desc_times = desc_set.timesteps

            for idx, ts in enumerate(desc_times):
                # need to make separate dataset for each init-time
                # because number of trajs (WTs) are different from time to time
                dataset_wt = xr.Dataset()

                lon_list = []
                lat_list = []
                pres_list = []
                max_pts_in_wt = -1 # TODO what if no wts
                for o in ts.objects: # get lons and lats
                    pt_list =
                    lon_list.append(np.array([pt.lon for pt in pt_list]))
                    lat_list.append(np.array([ for pt in pt_list]))
                    pres_list.append(np.array([850.0 for pt in pt_list]))
                    max_pts_in_wt = max(max_pts_in_wt, len(lon_list[-1]))
                # go again and fill with NaNs at end
                for i in range(len(lon_list)): # get lons and lats
                    lon_list[i] = np.pad(lon_list[i], (0, max_pts_in_wt - len(lon_list[i])), mode='constant',
                    lat_list[i] = np.pad(lat_list[i], (0, max_pts_in_wt - len(lat_list[i])), mode='constant',
                    pres_list[i] = np.pad(pres_list[i], (0, max_pts_in_wt - len(pres_list[i])), mode='constant',

                dataset_wt = dataset_wt.expand_dims(time=np.arange(0, max_pts_in_wt).astype(dtype=float)) # fake traj time
                dataset_wt = dataset_wt.expand_dims(ensemble=[0])
                dataset_wt = dataset_wt.expand_dims(trajectory=np.arange(1, len(ts.objects) + 1))

                lons = xr.DataArray(np.zeros((1, len(ts.objects), max_pts_in_wt)), dims=("ensemble", "trajectory", "time"))
                lons.attrs['standard_name'] = "longitude"
                lons.attrs['long_name'] = "longitude"
                lons.attrs['units'] = "degrees_east"

                lats = xr.zeros_like(lons)
                lats.attrs['standard_name'] = "latitude"
                lats.attrs['long_name'] = "latitude"
                lats.attrs['units'] = "degrees_north"

                pres = xr.zeros_like(lons)
                pres.attrs['standard_name'] = "air_pressure"
                pres.attrs['long_name'] = "pressure"
                pres.attrs['units'] = "hPa"
                pres.attrs['positive'] = "down"
                pres.attrs['axis'] = "Z"

                dataset_wt['lon'] = lons
                dataset_wt['lat'] = lats
                dataset_wt['pressure'] = pres
                # TODO auxiliary smth?
                lon_list_np = np.array(lon_list)
                lat_list_np = np.array(lat_list)
                pres_list_np = np.array(pres_list)

                dataset_wt['lon'].data[0] = lon_list_np
                dataset_wt['lat'].data[0] = lat_list_np
                dataset_wt['pressure'].data[0] = pres_list_np

                dataset_wt['time'].attrs['standard_name'] = "time"
                dataset_wt['time'].attrs['long_name'] = "time"
                dataset_wt['time'].attrs['units'] = "hours since " + ts.valid_time.replace('T', ' ')
                dataset_wt['time'].attrs['trajectory_starttime'] = ts.valid_time.replace('T', ' ')
                dataset_wt['time'].attrs['forecast_inittime'] = ts.valid_time.replace('T', ' ') # '2006-09-01 12:00:00' # TODO ts.valid_time.replace('T', ' ')
                out_path = self.config.out_wt_dir + ts.valid_time.replace(':','_') + '.nc'
        return dataset, data_desc