from enstools.feature.identification import IdentificationTechnique import xarray as xr import numpy as np import os, sys import metpy.calc as mpcalc from .util import calc_adv from matplotlib import pyplot as plt import cartopy.crs as ccrs from .filtering import keep_wavestate from .._proto_gen import african_easterly_waves_pb2 class AEWIdentification(IdentificationTechnique): def __init__(self, **kwargs): """ Initialize the AEW Identification. Parameters (experimental) ---------- kwargs """ self.pb_reference = african_easterly_waves_pb2 import enstools.feature.identification.african_easterly_waves.configuration as cfg self.config = cfg # config self.processing_mode = '2d' pass def precompute(self, dataset: xr.Dataset, **kwargs): print("Precompute for PV identification...") plt.switch_backend('agg') # this is thread safe matplotlib but cant display. # --------------- CLIMATOLOGY lon_range = self.config.data_lon lat_range = self.config.data_lat clim_file = self.config.get_clim_file() if os.path.isfile(clim_file): cv_clim = xr.open_dataset(clim_file) else: # generate: need all 40y of CV data. print("Climatology file not found. Computing climatology...") from .climatology import compute_climatology cv_clim = compute_climatology(self.config) cv_clim.to_netcdf(clim_file) # --------------- SUBSET DATA ACCORDING TO CFG start_date_dt = np.datetime64(self.config.start_date) end_date_dt = np.datetime64(self.config.end_date) # get the data we want to investigate # can also be multiple timesteps, so also multiple years -> load years we need years = list(range(start_date_dt.astype(object).year, end_date_dt.astype(object).year + 1)) print(years) # process_data = xr.open_mfdataset([diri + str(y) + "cv.nc" for y in years]) # open years in range of requested dataset = dataset.sel(plev=self.config.levels, lat=slice(lat_range[0], lat_range[1]), lon=slice(lon_range[0], lon_range[1]), time=slice(start_date_dt, end_date_dt)) # make sure that lat and lon are last two dimensions if 'lat' not in dataset.cv.coords.dims[-2:] or 'lon' not in dataset.cv.coords.dims[-2:]: print("Reordering dimensions so lat and lon at back. Required for metpy.calc.") dataset = dataset.transpose(..., 'lat', 'lon') # --------------- DO NUMPY PARALLELIZED STUFF: CREATE TROUGH MASKS u = dataset.U v = dataset.V cv = dataset.cv # smooth CV with kernel cv = mpcalc.smooth_n_point(cv, n=9, passes=2).metpy.dequantify() # create hourofyear to get anomalies cv = cv.assign_coords(hourofyear=cv.time.dt.strftime("%m-%d %H")) cv_anom = cv.groupby('hourofyear') - cv_clim.cv # compute advection of cv: first and second derivative adv1, adv2 = calc_adv(cv_anom, u, v) # xr.where() anomaly data exceeds the percentile from the hourofyear climatology: # replace data time with hourofyear -> compare with climatology percentile -> back to real time cv_anom_h = cv_anom.swap_dims(dims_dict={'time': 'hourofyear'}) perc_mask_h = cv_anom_h.where( cv_anom_h > cv_clim.cva_quantile_hoy.sel(dict(hourofyear=cv_anom.hourofyear.data))) perc_mask = perc_mask_h.swap_dims(dims_dict={'hourofyear': 'time'}) cv_perc_thresh = np.nanpercentile(cv, self.config.cv_percentile) # 66th percentile of cv anomalies print(cv_perc_thresh) print('Locating wave troughs...') # filter the advection field given our conditions: trough_mask = adv1.where(np.logical_and( ~np.isnan(perc_mask), # percentile of anomaly over threshold from climatology adv2.values > self.config.second_advection_min_thr, # second time derivative > 0: dont detect local minima over the percentile u.values < self.config.max_u_thresh)) # threshold for propagation speed -> keep only westward dataset['trough_mask'] = trough_mask return dataset def identify(self, data_chunk: xr.Dataset, **kwargs): objs = [] trough_mask_cur = data_chunk.trough_mask fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(15, 15), subplot_kw=dict(projection=ccrs.PlateCarree())) # generate zero-contours with matplotlib core c = trough_mask_cur.plot.contour(transform=ccrs.PlateCarree(), colors='blue', levels=[0.0]) paths = c.collections[0].get_paths() id_ = 1 for path in paths: o = self.get_new_object() o.properties.num_nodes = len(path) # fill the properties defined in the .proto file. o.id = id_ for v in path.vertices: line_pt = o.properties.line_pts.add() line_pt.lon = v[0] line_pt.lat = v[1] # add to objects if keep if keep_wavestate(o.properties): # filtered_paths = spacial_filter(paths, cfg.spatial_thr, cfg.wave_lat, cfg.wave_lon) objs.append(o) id_ += 1 return data_chunk, objs def postprocess(self, dataset: xr.Dataset, pb2_desc, **kwargs): return dataset, pb2_desc