Skip to content
Snippets Groups Projects
identification.py 17.5 KiB
Newer Older
from enstools.feature.identification import IdentificationStrategy
import xarray as xr
import numpy as np
import os, sys

import metpy.calc as mpcalc
from .util import calc_adv
from matplotlib import pyplot as plt
import cartopy.crs as ccrs
from .processing import populate_object, compute_cv
from skimage.draw import line_aa
from enstools.feature.util.enstools_utils import get_u_var, get_v_var, get_vertical_dim, get_longitude_dim, get_latitude_dim, get_init_time_dim, get_valid_time_dim
Christoph.Fischer's avatar
Christoph.Fischer committed
from skimage.draw import line
class AEWIdentification(IdentificationStrategy):
Christoph.Fischer's avatar
Christoph.Fischer committed
    def __init__(self, wt_out_file=False, wt_traj_dir=None, cv='cv', year_summer=None, month=None, **kwargs):
        """
        Initialize the AEW Identification.

        Parameters (experimental)
        ----------
        kwargs
        wt_out_file: output the wavetroughs as new and only out-field in 0.5x0.5
Christoph.Fischer's avatar
Christoph.Fischer committed
        year_summer: if set, process AEW season (01.06.-31.10.) of given year
        
        """

        import enstools.feature.identification.african_easterly_waves.configuration as cfg
        self.config = cfg  # config
        self.config.out_traj_dir = wt_traj_dir
        self.config.cv_name = cv
Christoph.Fischer's avatar
Christoph.Fischer committed
        
        self.orig_dataset = None
        
Christoph.Fischer's avatar
Christoph.Fischer committed
        if year_summer is not None:
            if month is not None:
                m_str = str(month).zfill(2)
                self.config.start_date = str(year_summer) + '-' + m_str + '-01T00:00'
                self.config.end_date = str(year_summer) + '-' + m_str + '-30T00:00'
            else:
                self.config.start_date = str(year_summer) + '-06-01T00:00'
                self.config.end_date = str(year_summer) + '-10-31T00:00'
        self.config.out_wt = wt_out_file
Christoph.Fischer's avatar
Christoph.Fischer committed
        if wt_out_file:
            self.config.sum_over_all = True
        self.lock_ = threading.Lock()
        pass

    def precompute(self, dataset: xr.Dataset, **kwargs):
        print("Precompute for AEW identification...")
        plt.switch_backend('agg')  # this is thread safe matplotlib but cant display.
        # --------------- CLIMATOLOGY
        lon_range = self.config.data_lon
        lat_range = self.config.data_lat
        clim_file = self.config.get_clim_file()

        u_name = self.config.u_dim if hasattr(self.config, 'u_dim') and (self.config.u_dim is not None) else get_u_var(dataset)
        v_name = self.config.v_dim if hasattr(self.config, 'v_dim') and (self.config.v_dim is not None) else get_v_var(dataset)

        if u_name is None or v_name is None:
            print("Could not locate u and v fields in dataset. Needed to compute advection terms.")
            exit()

        # restrict dimensions only to the ones present in u/v
        all_dims = [d for d in dataset.dims]
        for d in all_dims:
            if d not in dataset[u_name].dims:
                dataset = dataset.drop_dims(d)

        level_str = get_vertical_dim(dataset)
        lat_str = get_latitude_dim(dataset)
        lon_str = get_longitude_dim(dataset)
        init_time_str = get_init_time_dim(dataset)
        valid_time_str = get_valid_time_dim(dataset)
        if os.path.isfile(clim_file):
            cv_clim = xr.open_dataset(clim_file)
            # generate: need all 40y of CV data.
            print("Climatology file not found. Computing climatology...")

            from .climatology import compute_climatology
            cv_clim = compute_climatology(self.config)
            cv_clim.to_netcdf(clim_file)
        lat_str_clim = get_latitude_dim(cv_clim)
        lon_str_clim = get_longitude_dim(cv_clim)
Christoph.Fischer's avatar
Christoph.Fischer committed
        cv_clim = cv_clim.sel(
            **{lat_str_clim: slice(lat_range[0], lat_range[1])},
            **{lon_str_clim: slice(lon_range[0], lon_range[1])})
        # --------------- SUBSET DATA ACCORDING TO CFG
        start_date_dt = np.datetime64(self.config.start_date) if self.config.start_date is not None else None
        end_date_dt = np.datetime64(self.config.end_date) if self.config.end_date is not None else None
        # if data is lon=0..360, change it to -180..180
        dataset.coords[lon_str] = (dataset.coords[lon_str] + 180) % 360 - 180
        dataset = dataset.sortby(dataset[lon_str])
        filter_time_str = init_time_str if init_time_str is not None else valid_time_str
        # get the data we want to investigate
        dataset = dataset.sortby(lat_str) # in case of descending
Christoph Fischer's avatar
Christoph Fischer committed
        dataset = dataset.sel(
            **{lat_str: slice(lat_range[0], lat_range[1])},
            **{lon_str: slice(lon_range[0], lon_range[1])},
            **{filter_time_str: slice(start_date_dt, end_date_dt)})
        if len(dataset.time.values) == 0:
            print("Given start and end time leads to no data to process.")
            exit(1)

Christoph.Fischer's avatar
Christoph.Fischer committed
        # dataset = dataset.expand_dims('level')
        # level_str = 'level'
        if level_str in dataset[u_name].dims:
            dataset = dataset.sel(**{level_str: self.config.levels}) # 3-D wind field, select levels
        elif level_str is not None:
            dataset = dataset.sel(**{level_str: self.config.levels}) # 2-D, reduce 3-D field too
            dataset[u_name] = dataset[u_name].expand_dims('level')
            dataset[v_name] = dataset[v_name].expand_dims('level')
        else:
            dataset[u_name] = dataset[u_name].expand_dims(level=self.config.levels)
            dataset[v_name] = dataset[v_name].expand_dims(level=self.config.levels)
            level_str = 'level'
Christoph Fischer's avatar
Christoph Fischer committed
        # rename cv_clim dimensions to be same as in data.
        cv_clim = cv_clim.rename({'lat': lat_str, 'lon': lon_str})
        if 'plev' in cv_clim.dims and 'plev' != level_str:
            print("plev from clim to level: div by 100.")
            cv_clim = cv_clim.rename({'plev': level_str})
            cv_clim = cv_clim.assign_coords({level_str: cv_clim[level_str] / 100})

        # also only use levels also present in data
        cv_clim = cv_clim.sel({level_str: dataset[level_str].values})

        if self.config.cv_name not in dataset.data_vars:
            print("Curvature Vorticity not found, trying to compute it out of U and V...")
            dataset = compute_cv(dataset, u_name, v_name, self.config.cv_name)
Christoph.Fischer's avatar
Christoph.Fischer committed
        
        self.orig_dataset = dataset
        
        # make dataset to 2.5 (or same as cv_clim)
        dataset = dataset.interp({lat_str: cv_clim.coords[lat_str], lon_str: cv_clim.coords[lon_str]})

        # make sure that lat and lon are last two dimensions
        if lat_str not in dataset[self.config.cv_name].coords.dims[-2:] or lon_str not in dataset[
                                                                                              self.config.cv_name].coords.dims[
                                                                                          -2:]:
            print("Reordering dimensions so lat and lon at back. Required for metpy.calc.")
            dataset = dataset.transpose(..., lat_str, lon_str)

        # --------------- DO NUMPY PARALLELIZED STUFF: CREATE TROUGH MASKS

        u = dataset[u_name]
        v = dataset[v_name]
        cv = dataset[self.config.cv_name]
        cv = mpcalc.smooth_n_point(cv, n=9, passes=2).metpy.dequantify()
Christoph Fischer's avatar
Christoph Fischer committed

        # create hourofyear to get anomalies
        cv = cv.assign_coords(hourofyear=cv.time.dt.strftime("%m-%d %H"))
        cv_anom = cv.groupby('hourofyear') - cv_clim.cv
Christoph Fischer's avatar
Christoph Fischer committed

        # compute advection of cv: first and second derivative
        adv1, adv2 = calc_adv(cv_anom, u, v)

        # xr.where() anomaly data exceeds the percentile from the hourofyear climatology:
        # replace data time with hourofyear -> compare with climatology percentile -> back to real time
        cv_anom_h = cv_anom.swap_dims(dims_dict={'time': 'hourofyear'})
        perc_mask_h = cv_anom_h.where(
            cv_anom_h > cv_clim.cva_quantile_hoy.sel(dict(hourofyear=cv_anom.hourofyear.data)))
        perc_mask = perc_mask_h.swap_dims(dims_dict={'hourofyear': 'time'})

        cv_perc_thresh = np.nanpercentile(cv, self.config.cv_percentile)  # 66th percentile of cv anomalies
        # print(cv_perc_thresh)

        print('Locating wave troughs...')
        # filter the advection field given our conditions:
        trough_mask = adv1.where(np.logical_and(
            ~np.isnan(perc_mask),  # percentile of anomaly over threshold from climatology
            adv2.values > self.config.second_advection_min_thr,
            # second time derivative > 0: dont detect local minima over the percentile
            u.values < self.config.max_u_thresh))  # threshold for propagation speed -> keep only westward

        dataset['trough_mask'] = trough_mask
Christoph.Fischer's avatar
Christoph.Fischer committed
        """
        # create 0.5x0.5 dataarray for wavetroughs
        min_lat = dataset[lat_str].data.min()
        max_lat = dataset[lat_str].data.max()
        min_lon = dataset[lon_str].data.min()
        max_lon = dataset[lon_str].data.max()

        lat05 = np.linspace(min_lat, max_lat, int((max_lat - min_lat) * 2) + 1)
        lon05 = np.linspace(min_lon, max_lon, int((max_lon - min_lon) * 2) + 1)

        # 0.5x0.5 for wavetroughs
        wt = xr.zeros_like(dataset['trough_mask'], dtype=float)
        wt = wt.isel(**{lat_str: 0}).drop(lat_str).isel(**{lon_str: 0}).drop(lon_str)
        wt = wt.expand_dims(lon05=lon05).expand_dims(lat05=lat05)
        wt = wt.transpose(..., 'lat05', 'lon05')

        dataset['wavetroughs'] = wt
        dataset['wavetroughs'].attrs['units'] = 'prob'
        dataset['wavetroughs'].attrs['standard_name'] = 'wavetroughs'
        dataset['wavetroughs'].attrs['long_name'] = 'position_of_wavetrough'
        dataset['lat05'].attrs['long_name'] = 'latitude'
        dataset['lat05'].attrs['standard_name'] = 'latitude'
        dataset['lon05'].attrs['long_name'] = 'longitude'
        dataset['lon05'].attrs['standard_name'] = 'longitude'
        dataset['lat05'].attrs['units'] = 'degrees_north'
        dataset['lon05'].attrs['units'] = 'degrees_east'
Christoph.Fischer's avatar
Christoph.Fischer committed
        """
    def identify(self, data_chunk: xr.Dataset, **kwargs):
Christoph.Fischer's avatar
Christoph.Fischer committed

        objs = []
        trough_mask_cur = data_chunk.trough_mask
        if np.isnan(trough_mask_cur).all():
            return data_chunk, objs
        def clip(tup, mint, maxt):
            return np.clip(tup, mint, maxt)

        fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(15, 15), subplot_kw={'projection': ccrs.PlateCarree()})
        # generate zero-contours with matplotlib core
        c = trough_mask_cur.plot.contour(transform=ccrs.PlateCarree(), colors='blue', levels=[0.0],
                                         subplot_kws={'projection': ccrs.PlateCarree()})
Christoph.Fischer's avatar
Christoph.Fischer committed
        
        for path in paths:
            o = self.get_new_object()
            populate_object(o.properties, path, self.config)
            # add to objects if keep
Christoph.Fischer's avatar
Christoph.Fischer committed
            if not self.keep_wavetrough(o.properties):
                continue

            objs.append(o)
            id_ += 1

        return data_chunk, objs
    def postprocess(self, dataset: xr.Dataset, data_desc, **kwargs):
        lat_str = get_latitude_dim(dataset)
        lon_str = get_longitude_dim(dataset)

        data_desc = self.make_ids_unique(data_desc)
Christoph.Fischer's avatar
Christoph.Fischer committed
            
        """
        # drop everything, only keep WTs as 0.5x0.5
        if self.config.out_wt:
            for var in dataset.data_vars:
                if var not in ['wavetroughs']:
                    dataset = dataset.drop_vars([var])
            # wavetroughs are 0.5x0.5 in lat05,lon05 field. remove other stuff
            for dim in dataset.dims:
                if dim in [lat_str, lon_str, 'hourofyear', 'quantile']:
                    dataset = dataset.drop_vars([dim])

            dataset = dataset.rename({'lat05': lat_str, 'lon05': lon_str})

            level_str = get_vertical_dim(dataset)
            if level_str is not None:
                dataset = dataset.squeeze(drop=True)
Christoph.Fischer's avatar
Christoph.Fischer committed
            if self.config.sum_over_all:
                dataset['wavetroughs'] = dataset.wavetroughs.sum(dim='time')

        # create met3d like trajectories TODO not really working right now...
        if self.config.out_traj_dir:
            if not os.path.exists(self.config.out_traj_dir):
                os.makedirs(self.config.out_traj_dir)
Christoph.Fischer's avatar
Christoph.Fischer committed
            assert (len(data_desc.sets) == 1)  # TODO assert one set. maybe expand at some point
            desc_set = data_desc.sets[0]
            desc_times = desc_set.timesteps

            for idx, ts in enumerate(desc_times):
                # need to make separate dataset for each init-time
                # because number of trajs (WTs) are different from time to time
                dataset_wt = xr.Dataset()

                lon_list = []
                lat_list = []
                pres_list = []
Christoph.Fischer's avatar
Christoph.Fischer committed
                max_pts_in_wt = -1  # TODO what if no wts
                for o in ts.objects:  # get lons and lats
                    pt_list = o.properties.line_pts
                    lon_list.append(np.array([pt.lon for pt in pt_list]))
                    lat_list.append(np.array([pt.lat for pt in pt_list]))
                    pres_list.append(np.array([850.0 for pt in pt_list]))
                    max_pts_in_wt = max(max_pts_in_wt, len(lon_list[-1]))
                # go again and fill with NaNs at end
Christoph.Fischer's avatar
Christoph.Fischer committed
                for i in range(len(lon_list)):  # get lons and lats
                    lon_list[i] = np.pad(lon_list[i], (0, max_pts_in_wt - len(lon_list[i])), mode='constant',
                                         constant_values=np.nan)
                    lat_list[i] = np.pad(lat_list[i], (0, max_pts_in_wt - len(lat_list[i])), mode='constant',
                                         constant_values=np.nan)
                    pres_list[i] = np.pad(pres_list[i], (0, max_pts_in_wt - len(pres_list[i])), mode='constant',
Christoph.Fischer's avatar
Christoph.Fischer committed
                                          constant_values=np.nan)
Christoph.Fischer's avatar
Christoph.Fischer committed
                dataset_wt = dataset_wt.expand_dims(
                    time=np.arange(0, max_pts_in_wt).astype(dtype=float))  # fake traj time
                dataset_wt = dataset_wt.expand_dims(ensemble=[0])
                dataset_wt = dataset_wt.expand_dims(trajectory=np.arange(1, len(ts.objects) + 1))

Christoph.Fischer's avatar
Christoph.Fischer committed
                lons = xr.DataArray(np.zeros((1, len(ts.objects), max_pts_in_wt)),
                                    dims=("ensemble", "trajectory", "time"))
                lons.attrs['standard_name'] = "longitude"
                lons.attrs['long_name'] = "longitude"
                lons.attrs['units'] = "degrees_east"

                lats = xr.zeros_like(lons)
                lats.attrs['standard_name'] = "latitude"
                lats.attrs['long_name'] = "latitude"
                lats.attrs['units'] = "degrees_north"

                pres = xr.zeros_like(lons)
                pres.attrs['standard_name'] = "air_pressure"
                pres.attrs['long_name'] = "pressure"
                pres.attrs['units'] = "hPa"
                pres.attrs['positive'] = "down"
                pres.attrs['axis'] = "Z"

                dataset_wt['lon'] = lons
                dataset_wt['lat'] = lats
                dataset_wt['pressure'] = pres
                # TODO auxiliary smth?
                lon_list_np = np.array(lon_list)
                lat_list_np = np.array(lat_list)
                pres_list_np = np.array(pres_list)

                dataset_wt['lon'].data[0] = lon_list_np
                dataset_wt['lat'].data[0] = lat_list_np
                dataset_wt['pressure'].data[0] = pres_list_np

                dataset_wt['time'].attrs['standard_name'] = "time"
                dataset_wt['time'].attrs['long_name'] = "time"
                dataset_wt['time'].attrs['units'] = "hours since " + ts.valid_time.replace('T', ' ')
                dataset_wt['time'].attrs['trajectory_starttime'] = ts.valid_time.replace('T', ' ')
Christoph.Fischer's avatar
Christoph.Fischer committed
                dataset_wt['time'].attrs['forecast_inittime'] = ts.valid_time.replace('T',
                                                                                      ' ')  # '2006-09-01 12:00:00' # TODO ts.valid_time.replace('T', ' ')

                out_path = self.config.out_traj_dir + ts.valid_time.replace(':', '_') + '.nc'
                dataset_wt.to_netcdf(out_path)
Christoph.Fischer's avatar
Christoph.Fischer committed
        
        """
        return dataset, data_desc
Christoph.Fischer's avatar
Christoph.Fischer committed
    # filter: keep current wavetrough if:
    #   troughs are within a certain spatial window
    #   length of the trough < threshold
    def keep_wavetrough(self, properties):
        """
        Called for each wavetrough, check if kept based on filtering heuristics:
        - WT requires any point of wavetrough in config.wave_filter range
        - WT requires minimum length threshold
Christoph.Fischer's avatar
Christoph.Fischer committed
        - WT lat center not in 5..25
Christoph.Fischer's avatar
Christoph.Fischer committed

        Parameters
        ----------
        properties

        Returns
        -------
        True if kept
        """

        in_area = False
        for line_pt in properties.line_pts:

            # check if any point is outside filtering area
            if (self.config.wave_filter_lon[0] < line_pt.lon < self.config.wave_filter_lon[1]
                    and self.config.wave_filter_lat[0] < line_pt.lat < self.config.wave_filter_lat[1]):
                in_area = True
Christoph.Fischer's avatar
Christoph.Fischer committed
        if not in_area:  # no point of line segment in our area
            return False
Christoph.Fischer's avatar
Christoph.Fischer committed
        if properties.length_deg <= self.config.degree_len_thr:  # too small
            return False
Christoph.Fischer's avatar
Christoph.Fischer committed
        
        mid_lat = properties.bb.max.lat - properties.bb.min.lat
        if mid_lat < 5.0 or mid_lat > 25.0:
            return False
Christoph.Fischer's avatar
Christoph.Fischer committed
        
        height = properties.bb.max.lat - properties.bb.min.lat
        width = properties.bb.max.lon - properties.bb.min.lon
        
        if height < 0.75 * width:
            return False
        
Christoph.Fischer's avatar
Christoph.Fischer committed
        return True