from enstools.feature.identification import IdentificationTechnique import xarray as xr import numpy as np import os, sys import metpy.calc as mpcalc from .util import calc_adv from matplotlib import pyplot as plt import cartopy.crs as ccrs from .filtering import keep_wavetrough from .processing import populate_object, compute_cv from skimage.draw import line_aa from enstools.feature.util.enstools_utils import get_u_var, get_v_var, get_vertical_dim, get_longitude_dim, get_latitude_dim import threading from skimage.draw import line class AEWIdentification(IdentificationTechnique): def __init__(self, wt_out_file=False, wt_traj_dir=None, cv='cv', year_summer=None, month=None, **kwargs): """ Initialize the AEW Identification. Parameters (experimental) ---------- kwargs wt_out_file: output the wavetroughs as new and only out-field in 0.5x0.5 year_summer: if set, process AEW season (01.06.-31.10.) of given year """ import enstools.feature.identification.african_easterly_waves.configuration as cfg self.config = cfg # config self.config.out_traj_dir = wt_traj_dir self.config.cv_name = cv if year_summer is not None: if month is not None: m_str = str(month).zfill(2) self.config.start_date = str(year_summer) + '-' + m_str + '-01T00:00' self.config.end_date = str(year_summer) + '-' + m_str + '-30T00:00' else: self.config.start_date = str(year_summer) + '-06-01T00:00' self.config.end_date = str(year_summer) + '-10-31T00:00' self.config.out_wt = wt_out_file if wt_out_file: self.config.sum_over_all = True self.lock_ = threading.Lock() pass def precompute(self, dataset: xr.Dataset, **kwargs): print("Precompute for PV identification...") plt.switch_backend('agg') # this is thread safe matplotlib but cant display. # --------------- CLIMATOLOGY lon_range = self.config.data_lon lat_range = self.config.data_lat clim_file = self.config.get_clim_file() level_str = get_vertical_dim(dataset) lat_str = get_latitude_dim(dataset) lon_str = get_longitude_dim(dataset) if os.path.isfile(clim_file): cv_clim = xr.open_dataset(clim_file) else: # generate: need all 40y of CV data. print("Climatology file not found. Computing climatology...") from .climatology import compute_climatology cv_clim = compute_climatology(self.config) cv_clim.to_netcdf(clim_file) cv_clim = cv_clim.sel( **{lat_str: slice(lat_range[0], lat_range[1])}, **{lon_str: slice(lon_range[0], lon_range[1])}) # --------------- SUBSET DATA ACCORDING TO CFG start_date_dt = np.datetime64(self.config.start_date) if self.config.start_date is not None else None end_date_dt = np.datetime64(self.config.end_date) if self.config.end_date is not None else None # get the data we want to investigate dataset = dataset.sel( **{lat_str: slice(lat_range[0], lat_range[1])}, **{lon_str: slice(lon_range[0], lon_range[1])}, time=slice(start_date_dt, end_date_dt)) # dataset = dataset.expand_dims('level') # level_str = 'level' if level_str is not None: dataset = dataset.sel(**{level_str: self.config.levels}) # rename cv_clim dimensions to be same as in data. cv_clim = cv_clim.rename({'lat': lat_str, 'lon': lon_str}) if 'plev' in cv_clim.dims and 'plev' != level_str: print("plev from clim to level: div by 100.") cv_clim = cv_clim.rename({'plev': level_str}) cv_clim = cv_clim.assign_coords({level_str: cv_clim[level_str] / 100}) if self.config.cv_name not in dataset.data_vars: print("Curvature Vorticity not found, trying to compute it out of U and V...") u_name = get_u_var(dataset) v_name = get_v_var(dataset) dataset = compute_cv(dataset, u_name, v_name, self.config.cv_name) # make dataset to 2.5 (or same as cv_clim) dataset = dataset.interp({lat_str: cv_clim.coords[lat_str], lon_str: cv_clim.coords[lon_str]}) # make sure that lat and lon are last two dimensions if lat_str not in dataset[self.config.cv_name].coords.dims[-2:] or lon_str not in dataset[ self.config.cv_name].coords.dims[ -2:]: print("Reordering dimensions so lat and lon at back. Required for metpy.calc.") dataset = dataset.transpose(..., lat_str, lon_str) # --------------- DO NUMPY PARALLELIZED STUFF: CREATE TROUGH MASKS u = dataset.u if 'u' in dataset.data_vars else dataset.U v = dataset.v if 'v' in dataset.data_vars else dataset.V cv = dataset[self.config.cv_name] # smooth CV with kernel cv = mpcalc.smooth_n_point(cv, n=9, passes=2).metpy.dequantify() # create hourofyear to get anomalies cv = cv.assign_coords(hourofyear=cv.time.dt.strftime("%m-%d %H")) cv_anom = cv.groupby('hourofyear') - cv_clim.cv # compute advection of cv: first and second derivative adv1, adv2 = calc_adv(cv_anom, u, v) # xr.where() anomaly data exceeds the percentile from the hourofyear climatology: # replace data time with hourofyear -> compare with climatology percentile -> back to real time cv_anom_h = cv_anom.swap_dims(dims_dict={'time': 'hourofyear'}) perc_mask_h = cv_anom_h.where( cv_anom_h > cv_clim.cva_quantile_hoy.sel(dict(hourofyear=cv_anom.hourofyear.data))) perc_mask = perc_mask_h.swap_dims(dims_dict={'hourofyear': 'time'}) cv_perc_thresh = np.nanpercentile(cv, self.config.cv_percentile) # 66th percentile of cv anomalies # print(cv_perc_thresh) print('Locating wave troughs...') # filter the advection field given our conditions: trough_mask = adv1.where(np.logical_and( ~np.isnan(perc_mask), # percentile of anomaly over threshold from climatology adv2.values > self.config.second_advection_min_thr, # second time derivative > 0: dont detect local minima over the percentile u.values < self.config.max_u_thresh)) # threshold for propagation speed -> keep only westward dataset['trough_mask'] = trough_mask # create 0.5x0.5 dataarray for wavetroughs min_lat = dataset[lat_str].data.min() max_lat = dataset[lat_str].data.max() min_lon = dataset[lon_str].data.min() max_lon = dataset[lon_str].data.max() lat05 = np.linspace(min_lat, max_lat, int((max_lat - min_lat) * 2) + 1) lon05 = np.linspace(min_lon, max_lon, int((max_lon - min_lon) * 2) + 1) # 0.5x0.5 for wavetroughs wt = xr.zeros_like(dataset['trough_mask'], dtype=float) wt = wt.isel(**{lat_str: 0}).drop(lat_str).isel(**{lon_str: 0}).drop(lon_str) wt = wt.expand_dims(lon05=lon05).expand_dims(lat05=lat05) wt = wt.transpose(..., 'lat05', 'lon05') dataset['wavetroughs'] = wt dataset['wavetroughs'].attrs['units'] = 'prob' dataset['wavetroughs'].attrs['standard_name'] = 'wavetroughs' dataset['wavetroughs'].attrs['long_name'] = 'position_of_wavetrough' dataset['lat05'].attrs['long_name'] = 'latitude' dataset['lat05'].attrs['standard_name'] = 'latitude' dataset['lon05'].attrs['long_name'] = 'longitude' dataset['lon05'].attrs['standard_name'] = 'longitude' dataset['lat05'].attrs['units'] = 'degrees_north' dataset['lon05'].attrs['units'] = 'degrees_east' return dataset def identify(self, data_chunk: xr.Dataset, **kwargs): objs = [] trough_mask_cur = data_chunk.trough_mask def clip(tup, mint, maxt): return np.clip(tup, mint, maxt) fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(15, 15), subplot_kw={'projection': ccrs.PlateCarree()}) # generate zero-contours with matplotlib core c = trough_mask_cur.plot.contour(transform=ccrs.PlateCarree(), colors='blue', levels=[0.0], subplot_kws={'projection': ccrs.PlateCarree()}) paths = c.collections[0].get_paths() wt = data_chunk.wavetroughs min_lat = wt.lat05.data.min() max_lat = wt.lat05.data.max() min_lon = wt.lon05.data.min() max_lon = wt.lon05.data.max() lons = len(wt.lon05.data) lats = len(wt.lat05.data) id_ = 1 for path in paths: # get new object, set id o = self.get_new_object() o.id = id_ # populate it populate_object(o.properties, path) # add to objects if keep if keep_wavetrough(o.properties, self.config): objs.append(o) id_ += 1 if not self.config.out_wt: continue for v_idx in range(len(path.vertices) - 1): start_lonlat = path.vertices[v_idx][0], path.vertices[v_idx][1] end_lonlat = path.vertices[v_idx + 1][0], path.vertices[v_idx + 1][1] start_idx = ((start_lonlat[0] - min_lon) / (max_lon - min_lon) * lons, (start_lonlat[1] - min_lat) / (max_lat - min_lat) * lats) # start_idx = clip(start_idx, (0, 0), (lons, lats)) end_idx = ((end_lonlat[0] - min_lon) / (max_lon - min_lon) * lons, (end_lonlat[1] - min_lat) / (max_lat - min_lat) * lats) # end_idx = clip(end_idx, (0, 0), (lons, lats)) rr, cc, val = line_aa(int(start_idx[0]), int(start_idx[1]), int(end_idx[0]), int(end_idx[1])) rr = clip(rr, 0, lons - 1) cc = clip(cc, 0, lats - 1) wt.data[cc, rr] = np.where(np.greater(val, wt.data[cc, rr]), val, wt.data[cc, rr]) return data_chunk, objs def postprocess(self, dataset: xr.Dataset, data_desc, **kwargs): lat_str = get_latitude_dim(dataset) lon_str = get_longitude_dim(dataset) data_desc = self.make_ids_unique(data_desc) # drop everything, only keep WTs as 0.5x0.5 if self.config.out_wt: for var in dataset.data_vars: if var not in ['wavetroughs']: dataset = dataset.drop_vars([var]) # wavetroughs are 0.5x0.5 in lat05,lon05 field. remove other stuff for dim in dataset.dims: if dim in [lat_str, lon_str, 'hourofyear', 'quantile']: dataset = dataset.drop_vars([dim]) dataset = dataset.rename({'lat05': lat_str, 'lon05': lon_str}) level_str = get_vertical_dim(dataset) if level_str is not None: dataset = dataset.squeeze(drop=True) if self.config.sum_over_all: dataset['wavetroughs'] = dataset.wavetroughs.sum(dim='time') # create met3d like trajectories TODO not really working right now... if self.config.out_traj_dir: if not os.path.exists(self.config.out_traj_dir): os.makedirs(self.config.out_traj_dir) assert(len(data_desc.sets) == 1) # TODO assert one set. maybe expand at some point desc_set = data_desc.sets[0] desc_times = desc_set.timesteps for idx, ts in enumerate(desc_times): # need to make separate dataset for each init-time # because number of trajs (WTs) are different from time to time dataset_wt = xr.Dataset() lon_list = [] lat_list = [] pres_list = [] max_pts_in_wt = -1 # TODO what if no wts for o in ts.objects: # get lons and lats pt_list = o.properties.line_pts lon_list.append(np.array([pt.lon for pt in pt_list])) lat_list.append(np.array([pt.lat for pt in pt_list])) pres_list.append(np.array([850.0 for pt in pt_list])) max_pts_in_wt = max(max_pts_in_wt, len(lon_list[-1])) # go again and fill with NaNs at end for i in range(len(lon_list)): # get lons and lats lon_list[i] = np.pad(lon_list[i], (0, max_pts_in_wt - len(lon_list[i])), mode='constant', constant_values=np.nan) lat_list[i] = np.pad(lat_list[i], (0, max_pts_in_wt - len(lat_list[i])), mode='constant', constant_values=np.nan) pres_list[i] = np.pad(pres_list[i], (0, max_pts_in_wt - len(pres_list[i])), mode='constant', constant_values=np.nan) dataset_wt = dataset_wt.expand_dims(time=np.arange(0, max_pts_in_wt).astype(dtype=float)) # fake traj time dataset_wt = dataset_wt.expand_dims(ensemble=[0]) dataset_wt = dataset_wt.expand_dims(trajectory=np.arange(1, len(ts.objects) + 1)) lons = xr.DataArray(np.zeros((1, len(ts.objects), max_pts_in_wt)), dims=("ensemble", "trajectory", "time")) lons.attrs['standard_name'] = "longitude" lons.attrs['long_name'] = "longitude" lons.attrs['units'] = "degrees_east" lats = xr.zeros_like(lons) lats.attrs['standard_name'] = "latitude" lats.attrs['long_name'] = "latitude" lats.attrs['units'] = "degrees_north" pres = xr.zeros_like(lons) pres.attrs['standard_name'] = "air_pressure" pres.attrs['long_name'] = "pressure" pres.attrs['units'] = "hPa" pres.attrs['positive'] = "down" pres.attrs['axis'] = "Z" dataset_wt['lon'] = lons dataset_wt['lat'] = lats dataset_wt['pressure'] = pres # TODO auxiliary smth? lon_list_np = np.array(lon_list) lat_list_np = np.array(lat_list) pres_list_np = np.array(pres_list) dataset_wt['lon'].data[0] = lon_list_np dataset_wt['lat'].data[0] = lat_list_np dataset_wt['pressure'].data[0] = pres_list_np dataset_wt['time'].attrs['standard_name'] = "time" dataset_wt['time'].attrs['long_name'] = "time" dataset_wt['time'].attrs['units'] = "hours since " + ts.valid_time.replace('T', ' ') dataset_wt['time'].attrs['trajectory_starttime'] = ts.valid_time.replace('T', ' ') dataset_wt['time'].attrs['forecast_inittime'] = ts.valid_time.replace('T', ' ') # '2006-09-01 12:00:00' # TODO ts.valid_time.replace('T', ' ') out_path = self.config.out_traj_dir + ts.valid_time.replace(':','_') + '.nc' dataset_wt.to_netcdf(out_path) return dataset, data_desc