identification.py 17.22 KiB
from enstools.feature.identification import IdentificationStrategy
import xarray as xr
import numpy as np
import os, sys
import metpy.calc as mpcalc
from .util import calc_adv
from matplotlib import pyplot as plt
import cartopy.crs as ccrs
from .processing import populate_object, compute_cv
from skimage.draw import line_aa
from enstools.feature.util.enstools_utils import get_u_var, get_v_var, get_vertical_dim, get_longitude_dim, get_latitude_dim
import threading
from skimage.draw import line
class AEWIdentification(IdentificationStrategy):
def __init__(self, wt_out_file=False, wt_traj_dir=None, cv='cv', year_summer=None, month=None, **kwargs):
"""
Initialize the AEW Identification.
Parameters (experimental)
----------
kwargs
wt_out_file: output the wavetroughs as new and only out-field in 0.5x0.5
year_summer: if set, process AEW season (01.06.-31.10.) of given year
"""
import enstools.feature.identification.african_easterly_waves.configuration as cfg
self.config = cfg # config
self.config.out_traj_dir = wt_traj_dir
self.config.cv_name = cv
if year_summer is not None:
if month is not None:
m_str = str(month).zfill(2)
self.config.start_date = str(year_summer) + '-' + m_str + '-01T00:00'
self.config.end_date = str(year_summer) + '-' + m_str + '-30T00:00'
else:
self.config.start_date = str(year_summer) + '-06-01T00:00'
self.config.end_date = str(year_summer) + '-10-31T00:00'
self.config.out_wt = wt_out_file
if wt_out_file:
self.config.sum_over_all = True
self.lock_ = threading.Lock()
pass
def precompute(self, dataset: xr.Dataset, **kwargs):
print("Precompute for AEW identification...")
plt.switch_backend('agg') # this is thread safe matplotlib but cant display.
# --------------- CLIMATOLOGY
lon_range = self.config.data_lon
lat_range = self.config.data_lat
clim_file = self.config.get_clim_file()
level_str = get_vertical_dim(dataset)
lat_str = get_latitude_dim(dataset)
lon_str = get_longitude_dim(dataset)
if os.path.isfile(clim_file):
cv_clim = xr.open_dataset(clim_file)
else:
# generate: need all 40y of CV data.
print("Climatology file not found. Computing climatology...")
from .climatology import compute_climatology
cv_clim = compute_climatology(self.config)
cv_clim.to_netcdf(clim_file)
lat_str_clim = get_latitude_dim(cv_clim)
lon_str_clim = get_longitude_dim(cv_clim)
cv_clim = cv_clim.sel(
**{lat_str_clim: slice(lat_range[0], lat_range[1])},
**{lon_str_clim: slice(lon_range[0], lon_range[1])})
# --------------- SUBSET DATA ACCORDING TO CFG
start_date_dt = np.datetime64(self.config.start_date) if self.config.start_date is not None else None
end_date_dt = np.datetime64(self.config.end_date) if self.config.end_date is not None else None
# get the data we want to investigate
dataset = dataset.sel(
**{lat_str: slice(lat_range[0], lat_range[1])},
**{lon_str: slice(lon_range[0], lon_range[1])},
time=slice(start_date_dt, end_date_dt))
if len(dataset.time.values) == 0:
print("Given start and end time leads to no data to process.")
exit(1)
u_name = self.config.u_dim if self.config.u_dim is not None else get_u_var(dataset)
v_name = self.config.v_dim if self.config.v_dim is not None else get_v_var(dataset)
if u_name is None or v_name is None:
print("Could not locate u and v fields in dataset. Needed to compute advection terms.")
exit()
# dataset = dataset.expand_dims('level')
# level_str = 'level'
if level_str in dataset[u_name].dims:
dataset = dataset.sel(**{level_str: self.config.levels}) # 3-D wind field, select levels
else:
dataset = dataset.sel(**{level_str: self.config.levels}) # 2-D, reduce 3-D field too
dataset[u_name] = dataset[u_name].expand_dims('level')
dataset[v_name] = dataset[v_name].expand_dims('level')
# rename cv_clim dimensions to be same as in data.
cv_clim = cv_clim.rename({'lat': lat_str, 'lon': lon_str})
if 'plev' in cv_clim.dims and 'plev' != level_str:
print("plev from clim to level: div by 100.")
cv_clim = cv_clim.rename({'plev': level_str})
cv_clim = cv_clim.assign_coords({level_str: cv_clim[level_str] / 100})
# also only use levels also present in data
cv_clim = cv_clim.sel({level_str: dataset[level_str].values})
if self.config.cv_name not in dataset.data_vars:
print("Curvature Vorticity not found, trying to compute it out of U and V...")
dataset = compute_cv(dataset, u_name, v_name, self.config.cv_name)
# make dataset to 2.5 (or same as cv_clim)
dataset = dataset.interp({lat_str: cv_clim.coords[lat_str], lon_str: cv_clim.coords[lon_str]})
# make sure that lat and lon are last two dimensions
if lat_str not in dataset[self.config.cv_name].coords.dims[-2:] or lon_str not in dataset[
self.config.cv_name].coords.dims[
-2:]:
print("Reordering dimensions so lat and lon at back. Required for metpy.calc.")
dataset = dataset.transpose(..., lat_str, lon_str)
# --------------- DO NUMPY PARALLELIZED STUFF: CREATE TROUGH MASKS
u = dataset[u_name]
v = dataset[v_name]
cv = dataset[self.config.cv_name]
# smooth CV with kernel
cv = mpcalc.smooth_n_point(cv, n=9, passes=2).metpy.dequantify()
# create hourofyear to get anomalies
cv = cv.assign_coords(hourofyear=cv.time.dt.strftime("%m-%d %H"))
cv_anom = cv.groupby('hourofyear') - cv_clim.cv
# compute advection of cv: first and second derivative
adv1, adv2 = calc_adv(cv_anom, u, v)
# xr.where() anomaly data exceeds the percentile from the hourofyear climatology:
# replace data time with hourofyear -> compare with climatology percentile -> back to real time
cv_anom_h = cv_anom.swap_dims(dims_dict={'time': 'hourofyear'})
perc_mask_h = cv_anom_h.where(
cv_anom_h > cv_clim.cva_quantile_hoy.sel(dict(hourofyear=cv_anom.hourofyear.data)))
perc_mask = perc_mask_h.swap_dims(dims_dict={'hourofyear': 'time'})
cv_perc_thresh = np.nanpercentile(cv, self.config.cv_percentile) # 66th percentile of cv anomalies
print(cv_perc_thresh)
print('Locating wave troughs...')
# filter the advection field given our conditions:
trough_mask = adv1.where(np.logical_and(
~np.isnan(perc_mask), # percentile of anomaly over threshold from climatology
adv2.values > self.config.second_advection_min_thr,
# second time derivative > 0: dont detect local minima over the percentile
u.values < self.config.max_u_thresh)) # threshold for propagation speed -> keep only westward
dataset['trough_mask'] = trough_mask
# create 0.5x0.5 dataarray for wavetroughs
min_lat = dataset[lat_str].data.min()
max_lat = dataset[lat_str].data.max()
min_lon = dataset[lon_str].data.min()
max_lon = dataset[lon_str].data.max()
lat05 = np.linspace(min_lat, max_lat, int((max_lat - min_lat) * 2) + 1)
lon05 = np.linspace(min_lon, max_lon, int((max_lon - min_lon) * 2) + 1)
# 0.5x0.5 for wavetroughs
wt = xr.zeros_like(dataset['trough_mask'], dtype=float)
wt = wt.isel(**{lat_str: 0}).drop(lat_str).isel(**{lon_str: 0}).drop(lon_str)
wt = wt.expand_dims(lon05=lon05).expand_dims(lat05=lat05)
wt = wt.transpose(..., 'lat05', 'lon05')
dataset['wavetroughs'] = wt
dataset['wavetroughs'].attrs['units'] = 'prob'
dataset['wavetroughs'].attrs['standard_name'] = 'wavetroughs'
dataset['wavetroughs'].attrs['long_name'] = 'position_of_wavetrough'
dataset['lat05'].attrs['long_name'] = 'latitude'
dataset['lat05'].attrs['standard_name'] = 'latitude'
dataset['lon05'].attrs['long_name'] = 'longitude'
dataset['lon05'].attrs['standard_name'] = 'longitude'
dataset['lat05'].attrs['units'] = 'degrees_north'
dataset['lon05'].attrs['units'] = 'degrees_east'
return dataset
def identify(self, data_chunk: xr.Dataset, **kwargs):
objs = []
trough_mask_cur = data_chunk.trough_mask
def clip(tup, mint, maxt):
return np.clip(tup, mint, maxt)
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(15, 15), subplot_kw={'projection': ccrs.PlateCarree()})
# generate zero-contours with matplotlib core
c = trough_mask_cur.plot.contour(transform=ccrs.PlateCarree(), colors='blue', levels=[0.0],
subplot_kws={'projection': ccrs.PlateCarree()})
paths = c.collections[0].get_paths()
wt = data_chunk.wavetroughs
min_lat = wt.lat05.data.min()
max_lat = wt.lat05.data.max()
min_lon = wt.lon05.data.min()
max_lon = wt.lon05.data.max()
lons = len(wt.lon05.data)
lats = len(wt.lat05.data)
id_ = 1
for path in paths:
# get new object, set id
o = self.get_new_object()
o.id = id_
# populate it
populate_object(o.properties, path)
# add to objects if keep
if not self.keep_wavetrough(o.properties):
continue
objs.append(o)
id_ += 1
# if wavetrough out dataset, gen lines
if not self.config.out_wt:
continue
for v_idx in range(len(path.vertices) - 1):
start_lonlat = path.vertices[v_idx][0], path.vertices[v_idx][1]
end_lonlat = path.vertices[v_idx + 1][0], path.vertices[v_idx + 1][1]
start_idx = ((start_lonlat[0] - min_lon) / (max_lon - min_lon) * lons,
(start_lonlat[1] - min_lat) / (max_lat - min_lat) * lats)
# start_idx = clip(start_idx, (0, 0), (lons, lats))
end_idx = ((end_lonlat[0] - min_lon) / (max_lon - min_lon) * lons,
(end_lonlat[1] - min_lat) / (max_lat - min_lat) * lats)
# end_idx = clip(end_idx, (0, 0), (lons, lats))
rr, cc, val = line_aa(int(start_idx[0]), int(start_idx[1]), int(end_idx[0]), int(end_idx[1]))
rr = clip(rr, 0, lons - 1)
cc = clip(cc, 0, lats - 1)
wt.data[cc, rr] = np.where(np.greater(val, wt.data[cc, rr]), val, wt.data[cc, rr])
return data_chunk, objs
def postprocess(self, dataset: xr.Dataset, data_desc, **kwargs):
lat_str = get_latitude_dim(dataset)
lon_str = get_longitude_dim(dataset)
data_desc = self.make_ids_unique(data_desc)
# drop everything, only keep WTs as 0.5x0.5
if self.config.out_wt:
for var in dataset.data_vars:
if var not in ['wavetroughs']:
dataset = dataset.drop_vars([var])
# wavetroughs are 0.5x0.5 in lat05,lon05 field. remove other stuff
for dim in dataset.dims:
if dim in [lat_str, lon_str, 'hourofyear', 'quantile']:
dataset = dataset.drop_vars([dim])
dataset = dataset.rename({'lat05': lat_str, 'lon05': lon_str})
level_str = get_vertical_dim(dataset)
if level_str is not None:
dataset = dataset.squeeze(drop=True)
if self.config.sum_over_all:
dataset['wavetroughs'] = dataset.wavetroughs.sum(dim='time')
# create met3d like trajectories TODO not really working right now...
if self.config.out_traj_dir:
if not os.path.exists(self.config.out_traj_dir):
os.makedirs(self.config.out_traj_dir)
assert (len(data_desc.sets) == 1) # TODO assert one set. maybe expand at some point
desc_set = data_desc.sets[0]
desc_times = desc_set.timesteps
for idx, ts in enumerate(desc_times):
# need to make separate dataset for each init-time
# because number of trajs (WTs) are different from time to time
dataset_wt = xr.Dataset()
lon_list = []
lat_list = []
pres_list = []
max_pts_in_wt = -1 # TODO what if no wts
for o in ts.objects: # get lons and lats
pt_list = o.properties.line_pts
lon_list.append(np.array([pt.lon for pt in pt_list]))
lat_list.append(np.array([pt.lat for pt in pt_list]))
pres_list.append(np.array([850.0 for pt in pt_list]))
max_pts_in_wt = max(max_pts_in_wt, len(lon_list[-1]))
# go again and fill with NaNs at end
for i in range(len(lon_list)): # get lons and lats
lon_list[i] = np.pad(lon_list[i], (0, max_pts_in_wt - len(lon_list[i])), mode='constant',
constant_values=np.nan)
lat_list[i] = np.pad(lat_list[i], (0, max_pts_in_wt - len(lat_list[i])), mode='constant',
constant_values=np.nan)
pres_list[i] = np.pad(pres_list[i], (0, max_pts_in_wt - len(pres_list[i])), mode='constant',
constant_values=np.nan)
dataset_wt = dataset_wt.expand_dims(
time=np.arange(0, max_pts_in_wt).astype(dtype=float)) # fake traj time
dataset_wt = dataset_wt.expand_dims(ensemble=[0])
dataset_wt = dataset_wt.expand_dims(trajectory=np.arange(1, len(ts.objects) + 1))
lons = xr.DataArray(np.zeros((1, len(ts.objects), max_pts_in_wt)),
dims=("ensemble", "trajectory", "time"))
lons.attrs['standard_name'] = "longitude"
lons.attrs['long_name'] = "longitude"
lons.attrs['units'] = "degrees_east"
lats = xr.zeros_like(lons)
lats.attrs['standard_name'] = "latitude"
lats.attrs['long_name'] = "latitude"
lats.attrs['units'] = "degrees_north"
pres = xr.zeros_like(lons)
pres.attrs['standard_name'] = "air_pressure"
pres.attrs['long_name'] = "pressure"
pres.attrs['units'] = "hPa"
pres.attrs['positive'] = "down"
pres.attrs['axis'] = "Z"
dataset_wt['lon'] = lons
dataset_wt['lat'] = lats
dataset_wt['pressure'] = pres
# TODO auxiliary smth?
lon_list_np = np.array(lon_list)
lat_list_np = np.array(lat_list)
pres_list_np = np.array(pres_list)
dataset_wt['lon'].data[0] = lon_list_np
dataset_wt['lat'].data[0] = lat_list_np
dataset_wt['pressure'].data[0] = pres_list_np
dataset_wt['time'].attrs['standard_name'] = "time"
dataset_wt['time'].attrs['long_name'] = "time"
dataset_wt['time'].attrs['units'] = "hours since " + ts.valid_time.replace('T', ' ')
dataset_wt['time'].attrs['trajectory_starttime'] = ts.valid_time.replace('T', ' ')
dataset_wt['time'].attrs['forecast_inittime'] = ts.valid_time.replace('T',
' ') # '2006-09-01 12:00:00' # TODO ts.valid_time.replace('T', ' ')
out_path = self.config.out_traj_dir + ts.valid_time.replace(':', '_') + '.nc'
dataset_wt.to_netcdf(out_path)
return dataset, data_desc
# filter: keep current wavetrough if:
# troughs are within a certain spatial window
# length of the trough < threshold
def keep_wavetrough(self, properties):
"""
Called for each wavetrough, check if kept based on filtering heuristics:
- WT requires any point of wavetrough in config.wave_filter range
- WT requires minimum length threshold
Parameters
----------
properties
Returns
-------
True if kept
"""
in_area = False
for line_pt in properties.line_pts:
# check if any point is outside filtering area
if (self.config.wave_filter_lon[0] < line_pt.lon < self.config.wave_filter_lon[1]
and self.config.wave_filter_lat[0] < line_pt.lat < self.config.wave_filter_lat[1]):
in_area = True
if not in_area: # no point of line segment in our area
return False
if properties.length_deg <= self.config.degree_len_thr: # too small
return False
return True