Skip to content
Snippets Groups Projects
identification.py 5.44 KiB
Newer Older
from enstools.feature.identification import IdentificationTechnique
import xarray as xr
import numpy as np
import os, sys

import metpy.calc as mpcalc
from .util import calc_adv
from matplotlib import pyplot as plt
import cartopy.crs as ccrs
from .filtering import keep_wavetrough
from .processing import populate_object

class AEWIdentification(IdentificationTechnique):

    def __init__(self, **kwargs):
        """
        Initialize the AEW Identification.

        Parameters (experimental)
        ----------
        kwargs
        """

        import enstools.feature.identification.african_easterly_waves.configuration as cfg
        self.config = cfg  # config
        self.processing_mode = '2d'
        pass

    def precompute(self, dataset: xr.Dataset, **kwargs):
        print("Precompute for PV identification...")

        plt.switch_backend('agg') # this is thread safe matplotlib but cant display.

        # --------------- CLIMATOLOGY
        lon_range = self.config.data_lon
        lat_range = self.config.data_lat
        clim_file = self.config.get_clim_file()

        if os.path.isfile(clim_file):
            cv_clim = xr.open_dataset(clim_file)

            # generate: need all 40y of CV data.
            print("Climatology file not found. Computing climatology...")

            from .climatology import compute_climatology
            cv_clim = compute_climatology(self.config)
            cv_clim.to_netcdf(clim_file)
        
        # --------------- SUBSET DATA ACCORDING TO CFG
        start_date_dt = np.datetime64(self.config.start_date)
        end_date_dt = np.datetime64(self.config.end_date)
        # get the data we want to investigate
        # can also be multiple timesteps, so also multiple years -> load years we need
        years = list(range(start_date_dt.astype(object).year, end_date_dt.astype(object).year + 1))
        print(years)
        # process_data = xr.open_mfdataset([diri + str(y) + "cv.nc" for y in years])  # open years in range of requested
        dataset = dataset.sel(plev=self.config.levels,
                                        lat=slice(lat_range[0], lat_range[1]),
                                        lon=slice(lon_range[0], lon_range[1]),
                                        time=slice(start_date_dt, end_date_dt))

        # make sure that lat and lon are last two dimensions
        if 'lat' not in dataset.cv.coords.dims[-2:] or 'lon' not in dataset.cv.coords.dims[-2:]:
            print("Reordering dimensions so lat and lon at back. Required for metpy.calc.")
            dataset = dataset.transpose(..., 'lat', 'lon')

        # --------------- DO NUMPY PARALLELIZED STUFF: CREATE TROUGH MASKS

        u = dataset.U
        v = dataset.V
        cv = dataset.cv
        # smooth CV with kernel
        cv = mpcalc.smooth_n_point(cv, n=9, passes=2).metpy.dequantify()

        # create hourofyear to get anomalies
        cv = cv.assign_coords(hourofyear=cv.time.dt.strftime("%m-%d %H"))
        cv_anom = cv.groupby('hourofyear') - cv_clim.cv

        # compute advection of cv: first and second derivative
        adv1, adv2 = calc_adv(cv_anom, u, v)

        # xr.where() anomaly data exceeds the percentile from the hourofyear climatology:
        # replace data time with hourofyear -> compare with climatology percentile -> back to real time
        cv_anom_h = cv_anom.swap_dims(dims_dict={'time': 'hourofyear'})
        perc_mask_h = cv_anom_h.where(
            cv_anom_h > cv_clim.cva_quantile_hoy.sel(dict(hourofyear=cv_anom.hourofyear.data)))
        perc_mask = perc_mask_h.swap_dims(dims_dict={'hourofyear': 'time'})

        cv_perc_thresh = np.nanpercentile(cv, self.config.cv_percentile)  # 66th percentile of cv anomalies
        print(cv_perc_thresh)

        print('Locating wave troughs...')
        # filter the advection field given our conditions:
        trough_mask = adv1.where(np.logical_and(
            ~np.isnan(perc_mask),  # percentile of anomaly over threshold from climatology
            adv2.values > self.config.second_advection_min_thr,
            # second time derivative > 0: dont detect local minima over the percentile
            u.values < self.config.max_u_thresh))  # threshold for propagation speed -> keep only westward

        dataset['trough_mask'] = trough_mask
Christoph.Fischer's avatar
Christoph.Fischer committed
        dataset['wavetroughs'] = xr.zeros_like(dataset.trough_mask, dtype=bool)
    def identify(self, data_chunk: xr.Dataset, **kwargs):
        objs = []
        trough_mask_cur = data_chunk.trough_mask

        fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(15, 15), subplot_kw = {'projection': ccrs.PlateCarree()})
        # generate zero-contours with matplotlib core
        c = trough_mask_cur.plot.contour(transform=ccrs.PlateCarree(), colors='blue', levels=[0.0], subplot_kws={'projection': ccrs.PlateCarree()})
Christoph.Fischer's avatar
Christoph.Fischer committed
        wt = data_chunk.wavetroughs
        # TODO path to data field...
        # maybe skimage draw line(), but consider lat/lons...
        id_ = 1
        for path in paths:
            o = self.get_new_object()
            # populate it
            populate_object(o.properties, path)
            # add to objects if keep
            if keep_wavetrough(o.properties, self.config):
        return data_chunk, objs


    def postprocess(self, dataset: xr.Dataset, pb2_desc, **kwargs):

        return dataset, pb2_desc