Skip to content
Snippets Groups Projects
identification.py 7.01 KiB
Newer Older
from enstools.feature.identification import IdentificationTechnique
import xarray as xr
import numpy as np
import os, sys

from enstools.feature.util.pb2_properties_api import ObjectProperties
import metpy.calc as mpcalc
from .util import calc_adv
from matplotlib import pyplot as plt
import cartopy.crs as ccrs


class AEWIdentification(IdentificationTechnique):

    def __init__(self, **kwargs):
        """
        Initialize the AEW Identification.

        Parameters (experimental)
        ----------
        kwargs
        """

        import enstools.feature.identification.african_easterly_waves.configuration as cfg
        self.config = cfg  # config
        self.processing_mode = '2d' # TODO enum?
        pass

    def precompute(self, dataset: xr.Dataset, **kwargs):
        print("Precompute for PV identification...")
        
        # --------------- CLIMATOLOGY
        # TODO use config for file path
        lon_range = self.config.data_lon
        lat_range = self.config.data_lat
        clim_file = self.config.get_clim_file()

        if os.path.isfile(clim_file):
            cv_clim = xr.open_dataset(clim_file)

        else: # TODO not yet accessed in framework
            # generate: need all 40y of CV data.
            print("Climatology file not found. Computing climatology...")

            from .climatology import compute_climatology
            cv_clim = compute_climatology(self.config)
            cv_clim.to_netcdf(clim_file)
        
        # --------------- SUBSET DATA ACCORDING TO CFG
        start_date_dt = np.datetime64(self.config.start_date)
        end_date_dt = np.datetime64(self.config.end_date)
        # get the data we want to investigate
        # can also be multiple timesteps, so also multiple years -> load years we need
        years = list(range(start_date_dt.astype(object).year, end_date_dt.astype(object).year + 1))
        print(years)
        # process_data = xr.open_mfdataset([diri + str(y) + "cv.nc" for y in years])  # open years in range of requested
        dataset = dataset.sel(plev=self.config.levels,
                                        lat=slice(lat_range[0], lat_range[1]),
                                        lon=slice(lon_range[0], lon_range[1]),
                                        time=slice(start_date_dt, end_date_dt))

        # make sure that lat and lon are last two dimensions
        if 'lat' not in dataset.cv.coords.dims[-2:] or 'lon' not in dataset.cv.coords.dims[-2:]:
            print("Reordering dimensions so lat and lon at back. Required for metpy.calc.")
            dataset = dataset.transpose(..., 'lat', 'lon')

        # --------------- DO NUMPY PARALLELIZED STUFF: CREATE TROUGH MASKS

        u = dataset.U
        v = dataset.V
        cv = dataset.cv
        # smooth CV with kernel
        cv = mpcalc.smooth_n_point(cv, n=9, passes=2).metpy.dequantify()

        # create hourofyear to get anomalies
        cv = cv.assign_coords(hourofyear=cv.time.dt.strftime("%m-%d %H"))
        cv_anom = cv.groupby('hourofyear') - cv_clim.cv

        # compute advection of cv: first and second derivative
        adv1, adv2 = calc_adv(cv_anom, u, v)

        # xr.where() anomaly data exceeds the percentile from the hourofyear climatology:
        # replace data time with hourofyear -> compare with climatology percentile -> back to real time
        cv_anom_h = cv_anom.swap_dims(dims_dict={'time': 'hourofyear'})
        perc_mask_h = cv_anom_h.where(
            cv_anom_h > cv_clim.cva_quantile_hoy.sel(dict(hourofyear=cv_anom.hourofyear.data)))
        perc_mask = perc_mask_h.swap_dims(dims_dict={'hourofyear': 'time'})

        cv_perc_thresh = np.nanpercentile(cv, self.config.cv_percentile)  # 66th percentile of cv anomalies
        print(cv_perc_thresh)

        print('Locating wave troughs...')
        # filter the advection field given our conditions:
        trough_mask = adv1.where(np.logical_and(
            ~np.isnan(perc_mask),  # percentile of anomaly over threshold from climatology
            adv2.values > self.config.second_advection_min_thr,
            # second time derivative > 0: dont detect local minima over the percentile
            u.values < self.config.max_u_thresh))  # threshold for propagation speed -> keep only westward

        dataset['trough_mask'] = trough_mask
        return dataset

    # main identify, called on spatial 2d/3d subsets from enstools-feature.
    # TODO maybe rename identify_per_block
    # or separate methods identify, identify_per_block. But how to build obj if only "identify"
    def identify(self, data_chunk: xr.Dataset, **kwargs):
        print("chunk identify")
        # TODO pressure parallel? PV needs 3d input, this parallel 2d!

        from .._proto_gen import african_easterly_waves_pb2
        from enstools.feature.util.pb2_properties_api import ObjectProperties

        # Let's say you detected 5 objects:
        objs = []
        for i in range(5):
            # get an instance of a new object and its id in the descriptions.
            s_object = ObjectProperties.get_instance(african_easterly_waves_pb2)

            # fill the properties defined in the .proto file.
            s_object.set('a', 42.0 * i)
            objs.append(s_object)
            # ObjectProperties.add_to_objects(object_block, s_object)
            # can also set id manually, e.g. if labeled dataset:
            # ObjectProperties.add_to_objects(object_block, s_object, id=i)

        # fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(15, 15), subplot_kw=dict(projection=ccrs.PlateCarree()))

        # TODO can user decide ID??????!!!!
        # waves_da = xr.DataArray(coords=[data.coords['time'], data.coords['plev']]).astype(dtype=WaveTroughState)

        # TODO parallelize?
        # cart = list(itertools.product(trough_mask.time.data, trough_mask.plev.data))
        # cart_n = len(cart)

        # for i, (ctime, cplev) in enumerate(cart):
        return data_chunk, objs


        # get a new object structure
        dummy_tl = identification_pb2.Timeline()
        object_block_ref = dummy_tl.objects


        trough_mask_cur = data_chunk.trough_mask
        print(trough_mask_cur)
        # print(str(ctime)[:13] + " | " + str(int(cplev / 100)) + "hPa   (" + str(i + 1) + "/" + str(cart_n) + ")")

        exit()
        # generate zero-contours with matplotlib core
        c = trough_mask_cur.plot.contour(transform=ccrs.PlateCarree(), colors='blue', levels=[0.0])

        paths = c.collections[0].get_paths()

        from filtering import spacial_filter
        filtered_paths = spacial_filter(paths, cfg.spatial_thr, cfg.wave_lat, cfg.wave_lon)

        waves = []
        for path in filtered_paths:
            waves.append(WaveTrough(cplev, ctime, path))
        waves_da.loc[dict(time=ctime, plev=cplev)] = WaveTroughState(cplev, ctime, waves)

        print('Finished.')
        return waves_da
        # for id_, obj in object_properties_with_corres_indices:
        #     ObjectProperties.add_to_objects(object_block, obj, id=id_)

        return data_chunk

    def postprocess(self, dataset: xr.Dataset, pb2_desc, **kwargs):

        return dataset, pb2_desc