from enstools.feature.identification import IdentificationTechnique import xarray as xr import numpy as np import os, sys from enstools.feature.util.pb2_properties_api import ObjectProperties import metpy.calc as mpcalc from .util import calc_adv from matplotlib import pyplot as plt import cartopy.crs as ccrs class AEWIdentification(IdentificationTechnique): def __init__(self, **kwargs): """ Initialize the AEW Identification. Parameters (experimental) ---------- kwargs """ import enstools.feature.identification.african_easterly_waves.configuration as cfg self.config = cfg # config self.processing_mode = '2d' # TODO enum? pass def precompute(self, dataset: xr.Dataset, **kwargs): print("Precompute for PV identification...") # --------------- CLIMATOLOGY # TODO use config for file path lon_range = self.config.data_lon lat_range = self.config.data_lat clim_file = self.config.get_clim_file() if os.path.isfile(clim_file): cv_clim = xr.open_dataset(clim_file) else: # TODO not yet accessed in framework # generate: need all 40y of CV data. print("Climatology file not found. Computing climatology...") from .climatology import compute_climatology cv_clim = compute_climatology(self.config) cv_clim.to_netcdf(clim_file) # --------------- SUBSET DATA ACCORDING TO CFG start_date_dt = np.datetime64(self.config.start_date) end_date_dt = np.datetime64(self.config.end_date) # get the data we want to investigate # can also be multiple timesteps, so also multiple years -> load years we need years = list(range(start_date_dt.astype(object).year, end_date_dt.astype(object).year + 1)) print(years) # process_data = xr.open_mfdataset([diri + str(y) + "cv.nc" for y in years]) # open years in range of requested dataset = dataset.sel(plev=self.config.levels, lat=slice(lat_range[0], lat_range[1]), lon=slice(lon_range[0], lon_range[1]), time=slice(start_date_dt, end_date_dt)) # make sure that lat and lon are last two dimensions if 'lat' not in dataset.cv.coords.dims[-2:] or 'lon' not in dataset.cv.coords.dims[-2:]: print("Reordering dimensions so lat and lon at back. Required for metpy.calc.") dataset = dataset.transpose(..., 'lat', 'lon') # --------------- DO NUMPY PARALLELIZED STUFF: CREATE TROUGH MASKS u = dataset.U v = dataset.V cv = dataset.cv # smooth CV with kernel cv = mpcalc.smooth_n_point(cv, n=9, passes=2).metpy.dequantify() # create hourofyear to get anomalies cv = cv.assign_coords(hourofyear=cv.time.dt.strftime("%m-%d %H")) cv_anom = cv.groupby('hourofyear') - cv_clim.cv # compute advection of cv: first and second derivative adv1, adv2 = calc_adv(cv_anom, u, v) # xr.where() anomaly data exceeds the percentile from the hourofyear climatology: # replace data time with hourofyear -> compare with climatology percentile -> back to real time cv_anom_h = cv_anom.swap_dims(dims_dict={'time': 'hourofyear'}) perc_mask_h = cv_anom_h.where( cv_anom_h > cv_clim.cva_quantile_hoy.sel(dict(hourofyear=cv_anom.hourofyear.data))) perc_mask = perc_mask_h.swap_dims(dims_dict={'hourofyear': 'time'}) cv_perc_thresh = np.nanpercentile(cv, self.config.cv_percentile) # 66th percentile of cv anomalies print(cv_perc_thresh) print('Locating wave troughs...') # filter the advection field given our conditions: trough_mask = adv1.where(np.logical_and( ~np.isnan(perc_mask), # percentile of anomaly over threshold from climatology adv2.values > self.config.second_advection_min_thr, # second time derivative > 0: dont detect local minima over the percentile u.values < self.config.max_u_thresh)) # threshold for propagation speed -> keep only westward dataset['trough_mask'] = trough_mask return dataset # main identify, called on spatial 2d/3d subsets from enstools-feature. # TODO maybe rename identify_per_block # or separate methods identify, identify_per_block. But how to build obj if only "identify" def identify(self, data_chunk: xr.Dataset, **kwargs): print("chunk identify") # TODO pressure parallel? PV needs 3d input, this parallel 2d! from .._proto_gen import african_easterly_waves_pb2 from enstools.feature.util.pb2_properties_api import ObjectProperties # Let's say you detected 5 objects: objs = [] for i in range(5): # get an instance of a new object and its id in the descriptions. s_object = ObjectProperties.get_instance(african_easterly_waves_pb2) # fill the properties defined in the .proto file. s_object.set('a', 42.0 * i) objs.append(s_object) # ObjectProperties.add_to_objects(object_block, s_object) # can also set id manually, e.g. if labeled dataset: # ObjectProperties.add_to_objects(object_block, s_object, id=i) # fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(15, 15), subplot_kw=dict(projection=ccrs.PlateCarree())) # TODO can user decide ID??????!!!! # waves_da = xr.DataArray(coords=[data.coords['time'], data.coords['plev']]).astype(dtype=WaveTroughState) # TODO parallelize? # cart = list(itertools.product(trough_mask.time.data, trough_mask.plev.data)) # cart_n = len(cart) # for i, (ctime, cplev) in enumerate(cart): return data_chunk, objs # get a new object structure dummy_tl = identification_pb2.Timeline() object_block_ref = dummy_tl.objects trough_mask_cur = data_chunk.trough_mask print(trough_mask_cur) # print(str(ctime)[:13] + " | " + str(int(cplev / 100)) + "hPa (" + str(i + 1) + "/" + str(cart_n) + ")") exit() # generate zero-contours with matplotlib core c = trough_mask_cur.plot.contour(transform=ccrs.PlateCarree(), colors='blue', levels=[0.0]) paths = c.collections[0].get_paths() from filtering import spacial_filter filtered_paths = spacial_filter(paths, cfg.spatial_thr, cfg.wave_lat, cfg.wave_lon) waves = [] for path in filtered_paths: waves.append(WaveTrough(cplev, ctime, path)) waves_da.loc[dict(time=ctime, plev=cplev)] = WaveTroughState(cplev, ctime, waves) print('Finished.') return waves_da # for id_, obj in object_properties_with_corres_indices: # ObjectProperties.add_to_objects(object_block, obj, id=id_) return data_chunk def postprocess(self, dataset: xr.Dataset, pb2_desc, **kwargs): return dataset, pb2_desc