Skip to content
Snippets Groups Projects
identification.py 11.2 KiB
Newer Older
from enstools.feature.identification import IdentificationTechnique
import xarray as xr
import numpy as np
import os, sys

import metpy.calc as mpcalc
from .util import calc_adv
from matplotlib import pyplot as plt
import cartopy.crs as ccrs
from .filtering import keep_wavetrough
from .processing import populate_object
from skimage.draw import line_aa
from enstools.feature.util.enstools_utils import get_vertical_dim, get_longitude_dim, get_latitude_dim
class AEWIdentification(IdentificationTechnique):

    def __init__(self, wt_traj_dir=None, cv='cv', **kwargs):
        """
        Initialize the AEW Identification.

        Parameters (experimental)
        ----------
        kwargs
        wt_out_file: output the wavetroughs as new and only out-field in 0.5x0.5
        """

        import enstools.feature.identification.african_easterly_waves.configuration as cfg
        self.config = cfg  # config
        self.config.out_wt_dir = wt_traj_dir
        self.config.cv_name = cv
        self.found_max_wt_pts = -1
        self.lock_ = threading.Lock()
        pass

    def precompute(self, dataset: xr.Dataset, **kwargs):
        print("Precompute for PV identification...")
        plt.switch_backend('agg')  # this is thread safe matplotlib but cant display.
        # --------------- CLIMATOLOGY
        lon_range = self.config.data_lon
        lat_range = self.config.data_lat
        clim_file = self.config.get_clim_file()

        level_str = get_vertical_dim(dataset)
        lat_str = get_latitude_dim(dataset)
        lon_str = get_longitude_dim(dataset)

        if os.path.isfile(clim_file):
            cv_clim = xr.open_dataset(clim_file)

            # generate: need all 40y of CV data.
            print("Climatology file not found. Computing climatology...")

            from .climatology import compute_climatology
            cv_clim = compute_climatology(self.config)
            cv_clim.to_netcdf(clim_file)
Christoph Fischer's avatar
Christoph Fischer committed

        # --------------- SUBSET DATA ACCORDING TO CFG
        start_date_dt = np.datetime64(self.config.start_date) if self.config.start_date is not None else None
        end_date_dt = np.datetime64(self.config.end_date) if self.config.end_date is not None else None
Christoph Fischer's avatar
Christoph Fischer committed
        dataset = dataset.sel(
            **{lat_str: slice(lat_range[0], lat_range[1])},
            **{lon_str: slice(lon_range[0], lon_range[1])},
            time=slice(start_date_dt, end_date_dt))
Christoph Fischer's avatar
Christoph Fischer committed
        if level_str is None:
            if not 'level' in dataset.coords:
                print("No level information given in input. Assume 700hPa.")
                exit(1)
            dataset = dataset.expand_dims('level')
            level_str = 'level'
        else:
            dataset = dataset.sel(**{level_str: self.config.levels})
Christoph Fischer's avatar
Christoph Fischer committed

        # rename cv_clim dimensions to be same as in data.
        cv_clim = cv_clim.rename({'lat': lat_str, 'lon': lon_str})
        if 'plev' in cv_clim.dims and 'plev' != level_str:
            print("plev from clim to level: div by 100.")
            cv_clim = cv_clim.rename({'plev': level_str})
            cv_clim = cv_clim.assign_coords({level_str: cv_clim[level_str] / 100})

        # make sure that lat and lon are last two dimensions
        if lat_str not in dataset[self.config.cv_name].coords.dims[-2:] or lon_str not in dataset[
                                                                                              self.config.cv_name].coords.dims[
                                                                                          -2:]:
            print("Reordering dimensions so lat and lon at back. Required for metpy.calc.")
            dataset = dataset.transpose(..., lat_str, lon_str)

        # --------------- DO NUMPY PARALLELIZED STUFF: CREATE TROUGH MASKS

        u = dataset.u if 'u' in dataset.data_vars else dataset.U
        v = dataset.v if 'v' in dataset.data_vars else dataset.V
        cv = dataset[self.config.cv_name]
        # smooth CV with kernel
        cv = mpcalc.smooth_n_point(cv, n=9, passes=2).metpy.dequantify()
Christoph Fischer's avatar
Christoph Fischer committed

        # create hourofyear to get anomalies
        cv = cv.assign_coords(hourofyear=cv.time.dt.strftime("%m-%d %H"))
        cv_anom = cv.groupby('hourofyear') - cv_clim.cv
Christoph Fischer's avatar
Christoph Fischer committed

        # compute advection of cv: first and second derivative
        adv1, adv2 = calc_adv(cv_anom, u, v)

        # xr.where() anomaly data exceeds the percentile from the hourofyear climatology:
        # replace data time with hourofyear -> compare with climatology percentile -> back to real time
        cv_anom_h = cv_anom.swap_dims(dims_dict={'time': 'hourofyear'})
        perc_mask_h = cv_anom_h.where(
            cv_anom_h > cv_clim.cva_quantile_hoy.sel(dict(hourofyear=cv_anom.hourofyear.data)))
        perc_mask = perc_mask_h.swap_dims(dims_dict={'hourofyear': 'time'})

        cv_perc_thresh = np.nanpercentile(cv, self.config.cv_percentile)  # 66th percentile of cv anomalies
        print(cv_perc_thresh)

        print('Locating wave troughs...')
        # filter the advection field given our conditions:
        trough_mask = adv1.where(np.logical_and(
            ~np.isnan(perc_mask),  # percentile of anomaly over threshold from climatology
            adv2.values > self.config.second_advection_min_thr,
            # second time derivative > 0: dont detect local minima over the percentile
            u.values < self.config.max_u_thresh))  # threshold for propagation speed -> keep only westward

        dataset['trough_mask'] = trough_mask

        # create 0.5x0.5 dataarray for wavetroughs
        min_lat = dataset[lat_str].data.min()
        max_lat = dataset[lat_str].data.max()
        min_lon = dataset[lon_str].data.min()
        max_lon = dataset[lon_str].data.max()

        lat05 = np.linspace(min_lat, max_lat, int((max_lat - min_lat) * 2) + 1)
        lon05 = np.linspace(min_lon, max_lon, int((max_lon - min_lon) * 2) + 1)
        # wt = xr.DataArray(coords=[('lon05', lon05), ('lat05', lat05)],
    def identify(self, data_chunk: xr.Dataset, **kwargs):
        objs = []
        trough_mask_cur = data_chunk.trough_mask

        def clip(tup, mint, maxt):
            return np.clip(tup, mint, maxt)

        fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(15, 15), subplot_kw={'projection': ccrs.PlateCarree()})
        # generate zero-contours with matplotlib core
        c = trough_mask_cur.plot.contour(transform=ccrs.PlateCarree(), colors='blue', levels=[0.0],
                                         subplot_kws={'projection': ccrs.PlateCarree()})
        id_ = 1
        for path in paths:
            o = self.get_new_object()
            # populate it
            populate_object(o.properties, path)
            # add to objects if keep
            if keep_wavetrough(o.properties, self.config):
                objs.append(o)
                id_ += 1
                num_verts = len(path.vertices)
                with self.lock_: # TODO remove this?
                    self.found_max_wt_pts = max(self.found_max_wt_pts, num_verts)
        return data_chunk, objs

    def postprocess(self, dataset: xr.Dataset, pb2_desc, **kwargs):

        lat_str = get_latitude_dim(dataset)
        lon_str = get_longitude_dim(dataset)

        # create met3d like trajectories
        if self.config.out_wt_dir:
            if not os.path.exists(self.config.out_wt_dir):
                os.makedirs(self.config.out_wt_dir)

            assert(len(pb2_desc.sets) == 1) # TODO assert one set. maybe expand at some point
            desc_set = pb2_desc.sets[0]
            desc_times = desc_set.timesteps

            for idx, ts in enumerate(desc_times):
                # need to make separate dataset for each init-time
                # because number of trajs (WTs) are different from time to time
                dataset_wt = xr.Dataset()

                lon_list = []
                lat_list = []
                pres_list = []
                max_pts_in_wt = -1 # TODO what if no wts
                for o in ts.objects: # get lons and lats
                    pt_list = o.properties.line_pts
                    lon_list.append(np.array([pt.lon for pt in pt_list]))
                    lat_list.append(np.array([pt.lat for pt in pt_list]))
                    pres_list.append(np.array([850.0 for pt in pt_list]))
                    max_pts_in_wt = max(max_pts_in_wt, len(lon_list[-1]))
                # go again and fill with NaNs at end
                for i in range(len(lon_list)): # get lons and lats
                    lon_list[i] = np.pad(lon_list[i], (0, max_pts_in_wt - len(lon_list[i])), mode='constant',
                                         constant_values=np.nan)
                    lat_list[i] = np.pad(lat_list[i], (0, max_pts_in_wt - len(lat_list[i])), mode='constant',
                                         constant_values=np.nan)
                    pres_list[i] = np.pad(pres_list[i], (0, max_pts_in_wt - len(pres_list[i])), mode='constant',
                                        constant_values=np.nan)

                dataset_wt = dataset_wt.expand_dims(time=np.arange(0, max_pts_in_wt).astype(dtype=float)) # fake traj time
                dataset_wt = dataset_wt.expand_dims(ensemble=[0])
                dataset_wt = dataset_wt.expand_dims(trajectory=np.arange(1, len(ts.objects) + 1))

                lons = xr.DataArray(np.zeros((1, len(ts.objects), max_pts_in_wt)), dims=("ensemble", "trajectory", "time"))
                lons.attrs['standard_name'] = "longitude"
                lons.attrs['long_name'] = "longitude"
                lons.attrs['units'] = "degrees_east"

                lats = xr.zeros_like(lons)
                lats.attrs['standard_name'] = "latitude"
                lats.attrs['long_name'] = "latitude"
                lats.attrs['units'] = "degrees_north"

                pres = xr.zeros_like(lons)
                pres.attrs['standard_name'] = "air_pressure"
                pres.attrs['long_name'] = "pressure"
                pres.attrs['units'] = "hPa"
                pres.attrs['positive'] = "down"
                pres.attrs['axis'] = "Z"

                dataset_wt['lon'] = lons
                dataset_wt['lat'] = lats
                dataset_wt['pressure'] = pres
                # TODO auxiliary smth?
                lon_list_np = np.array(lon_list)
                lat_list_np = np.array(lat_list)
                pres_list_np = np.array(pres_list)

                dataset_wt['lon'].data[0] = lon_list_np
                dataset_wt['lat'].data[0] = lat_list_np
                dataset_wt['pressure'].data[0] = pres_list_np

                dataset_wt['time'].attrs['standard_name'] = "time"
                dataset_wt['time'].attrs['long_name'] = "time"
                dataset_wt['time'].attrs['units'] = "hours since " + ts.valid_time.replace('T', ' ')
                dataset_wt['time'].attrs['trajectory_starttime'] = ts.valid_time.replace('T', ' ')
                dataset_wt['time'].attrs['forecast_inittime'] = ts.valid_time.replace('T', ' ') # '2006-09-01 12:00:00' # TODO ts.valid_time.replace('T', ' ')
                
                out_path = self.config.out_wt_dir + ts.valid_time.replace(':','_') + '.nc'
                dataset_wt.to_netcdf(out_path)