Newer
Older
from enstools.feature.identification import IdentificationStrategy
Christoph Fischer
committed
import xarray as xr
import numpy as np
import os, sys
import metpy.calc as mpcalc
from .util import calc_adv
from matplotlib import pyplot as plt
import cartopy.crs as ccrs
from .processing import populate_object, compute_cv
from skimage.draw import line_aa
from enstools.feature.util.enstools_utils import get_u_var, get_v_var, get_vertical_dim, get_longitude_dim, get_latitude_dim, get_init_time_dim, get_valid_time_dim
class AEWIdentification(IdentificationStrategy):
Christoph Fischer
committed
def __init__(self, wt_out_file=False, wt_traj_dir=None, cv='cv', year_summer=None, month=None, **kwargs):
Christoph Fischer
committed
"""
Initialize the AEW Identification.
Parameters (experimental)
----------
kwargs
wt_out_file: output the wavetroughs as new and only out-field in 0.5x0.5
year_summer: if set, process AEW season (01.06.-31.10.) of given year
Christoph Fischer
committed
"""
import enstools.feature.identification.african_easterly_waves.configuration as cfg
self.config = cfg # config
self.config.cv_name = cv
if year_summer is not None:
if month is not None:
m_str = str(month).zfill(2)
self.config.start_date = str(year_summer) + '-' + m_str + '-01T00:00'
self.config.end_date = str(year_summer) + '-' + m_str + '-30T00:00'
else:
self.config.start_date = str(year_summer) + '-06-01T00:00'
self.config.end_date = str(year_summer) + '-10-31T00:00'
self.config.sum_over_all = True
self.lock_ = threading.Lock()
Christoph Fischer
committed
pass
def precompute(self, dataset: xr.Dataset, **kwargs):
print("Precompute for AEW identification...")
plt.switch_backend('agg') # this is thread safe matplotlib but cant display.
Christoph Fischer
committed
# --------------- CLIMATOLOGY
lon_range = self.config.data_lon
lat_range = self.config.data_lat
clim_file = self.config.get_clim_file()
u_name = self.config.u_dim if self.config.u_dim is not None else get_u_var(dataset)
v_name = self.config.v_dim if self.config.v_dim is not None else get_v_var(dataset)
if u_name is None or v_name is None:
print("Could not locate u and v fields in dataset. Needed to compute advection terms.")
exit()
# restrict dimensions only to the ones present in u/v
all_dims = [d for d in dataset.dims]
for d in all_dims:
if d not in dataset[u_name].dims:
dataset = dataset.drop_dims(d)
level_str = get_vertical_dim(dataset)
lat_str = get_latitude_dim(dataset)
lon_str = get_longitude_dim(dataset)
init_time_str = get_init_time_dim(dataset)
valid_time_str = get_valid_time_dim(dataset)
Christoph Fischer
committed
if os.path.isfile(clim_file):
cv_clim = xr.open_dataset(clim_file)
Christoph Fischer
committed
# generate: need all 40y of CV data.
print("Climatology file not found. Computing climatology...")
from .climatology import compute_climatology
cv_clim = compute_climatology(self.config)
cv_clim.to_netcdf(clim_file)
lat_str_clim = get_latitude_dim(cv_clim)
lon_str_clim = get_longitude_dim(cv_clim)
**{lat_str_clim: slice(lat_range[0], lat_range[1])},
**{lon_str_clim: slice(lon_range[0], lon_range[1])})
Christoph Fischer
committed
# --------------- SUBSET DATA ACCORDING TO CFG
start_date_dt = np.datetime64(self.config.start_date) if self.config.start_date is not None else None
end_date_dt = np.datetime64(self.config.end_date) if self.config.end_date is not None else None
# if data is lon=0..360, change it to -180..180
dataset.coords[lon_str] = (dataset.coords[lon_str] + 180) % 360 - 180
dataset = dataset.sortby(dataset[lon_str])
filter_time_str = init_time_str if init_time_str is not None else valid_time_str
# get the data we want to investigate
dataset = dataset.sortby(lat_str) # in case of descending
**{lat_str: slice(lat_range[0], lat_range[1])},
**{lon_str: slice(lon_range[0], lon_range[1])},
**{filter_time_str: slice(start_date_dt, end_date_dt)})
if len(dataset.time.values) == 0:
print("Given start and end time leads to no data to process.")
exit(1)
# dataset = dataset.expand_dims('level')
# level_str = 'level'
if level_str in dataset[u_name].dims:
dataset = dataset.sel(**{level_str: self.config.levels}) # 3-D wind field, select levels
elif level_str is not None:
dataset = dataset.sel(**{level_str: self.config.levels}) # 2-D, reduce 3-D field too
dataset[u_name] = dataset[u_name].expand_dims('level')
dataset[v_name] = dataset[v_name].expand_dims('level')
else:
dataset[u_name] = dataset[u_name].expand_dims(level=self.config.levels)
dataset[v_name] = dataset[v_name].expand_dims(level=self.config.levels)
level_str = 'level'
# rename cv_clim dimensions to be same as in data.
cv_clim = cv_clim.rename({'lat': lat_str, 'lon': lon_str})
if 'plev' in cv_clim.dims and 'plev' != level_str:
print("plev from clim to level: div by 100.")
cv_clim = cv_clim.rename({'plev': level_str})
cv_clim = cv_clim.assign_coords({level_str: cv_clim[level_str] / 100})
# also only use levels also present in data
cv_clim = cv_clim.sel({level_str: dataset[level_str].values})
if self.config.cv_name not in dataset.data_vars:
print("Curvature Vorticity not found, trying to compute it out of U and V...")
dataset = compute_cv(dataset, u_name, v_name, self.config.cv_name)
# make dataset to 2.5 (or same as cv_clim)
dataset = dataset.interp({lat_str: cv_clim.coords[lat_str], lon_str: cv_clim.coords[lon_str]})
Christoph Fischer
committed
# make sure that lat and lon are last two dimensions
if lat_str not in dataset[self.config.cv_name].coords.dims[-2:] or lon_str not in dataset[
self.config.cv_name].coords.dims[
-2:]:
Christoph Fischer
committed
print("Reordering dimensions so lat and lon at back. Required for metpy.calc.")
dataset = dataset.transpose(..., lat_str, lon_str)
Christoph Fischer
committed
# --------------- DO NUMPY PARALLELIZED STUFF: CREATE TROUGH MASKS
u = dataset[u_name]
v = dataset[v_name]
cv = dataset[self.config.cv_name]
Christoph Fischer
committed
# smooth CV with kernel
Christoph Fischer
committed
cv = mpcalc.smooth_n_point(cv, n=9, passes=2).metpy.dequantify()
Christoph Fischer
committed
# create hourofyear to get anomalies
cv = cv.assign_coords(hourofyear=cv.time.dt.strftime("%m-%d %H"))
cv_anom = cv.groupby('hourofyear') - cv_clim.cv
Christoph Fischer
committed
# compute advection of cv: first and second derivative
adv1, adv2 = calc_adv(cv_anom, u, v)
# xr.where() anomaly data exceeds the percentile from the hourofyear climatology:
# replace data time with hourofyear -> compare with climatology percentile -> back to real time
cv_anom_h = cv_anom.swap_dims(dims_dict={'time': 'hourofyear'})
perc_mask_h = cv_anom_h.where(
cv_anom_h > cv_clim.cva_quantile_hoy.sel(dict(hourofyear=cv_anom.hourofyear.data)))
perc_mask = perc_mask_h.swap_dims(dims_dict={'hourofyear': 'time'})
cv_perc_thresh = np.nanpercentile(cv, self.config.cv_percentile) # 66th percentile of cv anomalies
Christoph Fischer
committed
print('Locating wave troughs...')
# filter the advection field given our conditions:
trough_mask = adv1.where(np.logical_and(
~np.isnan(perc_mask), # percentile of anomaly over threshold from climatology
adv2.values > self.config.second_advection_min_thr,
# second time derivative > 0: dont detect local minima over the percentile
u.values < self.config.max_u_thresh)) # threshold for propagation speed -> keep only westward
dataset['trough_mask'] = trough_mask
# create 0.5x0.5 dataarray for wavetroughs
min_lat = dataset[lat_str].data.min()
max_lat = dataset[lat_str].data.max()
min_lon = dataset[lon_str].data.min()
max_lon = dataset[lon_str].data.max()
lat05 = np.linspace(min_lat, max_lat, int((max_lat - min_lat) * 2) + 1)
lon05 = np.linspace(min_lon, max_lon, int((max_lon - min_lon) * 2) + 1)
# 0.5x0.5 for wavetroughs
wt = xr.zeros_like(dataset['trough_mask'], dtype=float)
wt = wt.isel(**{lat_str: 0}).drop(lat_str).isel(**{lon_str: 0}).drop(lon_str)
wt = wt.expand_dims(lon05=lon05).expand_dims(lat05=lat05)
wt = wt.transpose(..., 'lat05', 'lon05')
dataset['wavetroughs'] = wt
dataset['wavetroughs'].attrs['units'] = 'prob'
dataset['wavetroughs'].attrs['standard_name'] = 'wavetroughs'
dataset['wavetroughs'].attrs['long_name'] = 'position_of_wavetrough'
dataset['lat05'].attrs['long_name'] = 'latitude'
dataset['lat05'].attrs['standard_name'] = 'latitude'
dataset['lon05'].attrs['long_name'] = 'longitude'
dataset['lon05'].attrs['standard_name'] = 'longitude'
dataset['lat05'].attrs['units'] = 'degrees_north'
dataset['lon05'].attrs['units'] = 'degrees_east'
Christoph Fischer
committed
def identify(self, data_chunk: xr.Dataset, **kwargs):
objs = []
trough_mask_cur = data_chunk.trough_mask
if np.isnan(trough_mask_cur).all():
return data_chunk, objs
Christoph Fischer
committed
def clip(tup, mint, maxt):
return np.clip(tup, mint, maxt)
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(15, 15), subplot_kw={'projection': ccrs.PlateCarree()})
Christoph Fischer
committed
# generate zero-contours with matplotlib core
c = trough_mask_cur.plot.contour(transform=ccrs.PlateCarree(), colors='blue', levels=[0.0],
subplot_kws={'projection': ccrs.PlateCarree()})
Christoph Fischer
committed
paths = c.collections[0].get_paths()
wt = data_chunk.wavetroughs
min_lat = wt.lat05.data.min()
max_lat = wt.lat05.data.max()
min_lon = wt.lon05.data.min()
max_lon = wt.lon05.data.max()
lons = len(wt.lon05.data)
lats = len(wt.lat05.data)
id_ = 1
for path in paths:
# get new object, set id
o = self.get_new_object()
# populate it
populate_object(o.properties, path)
# add to objects if keep
if not self.keep_wavetrough(o.properties):
continue
objs.append(o)
id_ += 1
# if wavetrough out dataset, gen lines
if not self.config.out_wt:
continue
for v_idx in range(len(path.vertices) - 1):
start_lonlat = path.vertices[v_idx][0], path.vertices[v_idx][1]
end_lonlat = path.vertices[v_idx + 1][0], path.vertices[v_idx + 1][1]
start_idx = ((start_lonlat[0] - min_lon) / (max_lon - min_lon) * lons,
(start_lonlat[1] - min_lat) / (max_lat - min_lat) * lats)
# start_idx = clip(start_idx, (0, 0), (lons, lats))
end_idx = ((end_lonlat[0] - min_lon) / (max_lon - min_lon) * lons,
(end_lonlat[1] - min_lat) / (max_lat - min_lat) * lats)
# end_idx = clip(end_idx, (0, 0), (lons, lats))
rr, cc, val = line_aa(int(start_idx[0]), int(start_idx[1]), int(end_idx[0]), int(end_idx[1]))
rr = clip(rr, 0, lons - 1)
cc = clip(cc, 0, lats - 1)
wt.data[cc, rr] = np.where(np.greater(val, wt.data[cc, rr]), val, wt.data[cc, rr])
Christoph Fischer
committed
def postprocess(self, dataset: xr.Dataset, data_desc, **kwargs):
Christoph Fischer
committed
lat_str = get_latitude_dim(dataset)
lon_str = get_longitude_dim(dataset)
data_desc = self.make_ids_unique(data_desc)
# drop everything, only keep WTs as 0.5x0.5
if self.config.out_wt:
for var in dataset.data_vars:
if var not in ['wavetroughs']:
dataset = dataset.drop_vars([var])
# wavetroughs are 0.5x0.5 in lat05,lon05 field. remove other stuff
for dim in dataset.dims:
if dim in [lat_str, lon_str, 'hourofyear', 'quantile']:
dataset = dataset.drop_vars([dim])
dataset = dataset.rename({'lat05': lat_str, 'lon05': lon_str})
level_str = get_vertical_dim(dataset)
if level_str is not None:
dataset = dataset.squeeze(drop=True)
if self.config.sum_over_all:
dataset['wavetroughs'] = dataset.wavetroughs.sum(dim='time')
# create met3d like trajectories TODO not really working right now...
if self.config.out_traj_dir:
if not os.path.exists(self.config.out_traj_dir):
os.makedirs(self.config.out_traj_dir)
assert (len(data_desc.sets) == 1) # TODO assert one set. maybe expand at some point
desc_times = desc_set.timesteps
for idx, ts in enumerate(desc_times):
# need to make separate dataset for each init-time
# because number of trajs (WTs) are different from time to time
dataset_wt = xr.Dataset()
lon_list = []
lat_list = []
pres_list = []
max_pts_in_wt = -1 # TODO what if no wts
for o in ts.objects: # get lons and lats
pt_list = o.properties.line_pts
lon_list.append(np.array([pt.lon for pt in pt_list]))
lat_list.append(np.array([pt.lat for pt in pt_list]))
pres_list.append(np.array([850.0 for pt in pt_list]))
max_pts_in_wt = max(max_pts_in_wt, len(lon_list[-1]))
# go again and fill with NaNs at end
for i in range(len(lon_list)): # get lons and lats
lon_list[i] = np.pad(lon_list[i], (0, max_pts_in_wt - len(lon_list[i])), mode='constant',
constant_values=np.nan)
lat_list[i] = np.pad(lat_list[i], (0, max_pts_in_wt - len(lat_list[i])), mode='constant',
constant_values=np.nan)
pres_list[i] = np.pad(pres_list[i], (0, max_pts_in_wt - len(pres_list[i])), mode='constant',
dataset_wt = dataset_wt.expand_dims(
time=np.arange(0, max_pts_in_wt).astype(dtype=float)) # fake traj time
dataset_wt = dataset_wt.expand_dims(ensemble=[0])
dataset_wt = dataset_wt.expand_dims(trajectory=np.arange(1, len(ts.objects) + 1))
lons = xr.DataArray(np.zeros((1, len(ts.objects), max_pts_in_wt)),
dims=("ensemble", "trajectory", "time"))
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
lons.attrs['standard_name'] = "longitude"
lons.attrs['long_name'] = "longitude"
lons.attrs['units'] = "degrees_east"
lats = xr.zeros_like(lons)
lats.attrs['standard_name'] = "latitude"
lats.attrs['long_name'] = "latitude"
lats.attrs['units'] = "degrees_north"
pres = xr.zeros_like(lons)
pres.attrs['standard_name'] = "air_pressure"
pres.attrs['long_name'] = "pressure"
pres.attrs['units'] = "hPa"
pres.attrs['positive'] = "down"
pres.attrs['axis'] = "Z"
dataset_wt['lon'] = lons
dataset_wt['lat'] = lats
dataset_wt['pressure'] = pres
# TODO auxiliary smth?
lon_list_np = np.array(lon_list)
lat_list_np = np.array(lat_list)
pres_list_np = np.array(pres_list)
dataset_wt['lon'].data[0] = lon_list_np
dataset_wt['lat'].data[0] = lat_list_np
dataset_wt['pressure'].data[0] = pres_list_np
dataset_wt['time'].attrs['standard_name'] = "time"
dataset_wt['time'].attrs['long_name'] = "time"
dataset_wt['time'].attrs['units'] = "hours since " + ts.valid_time.replace('T', ' ')
dataset_wt['time'].attrs['trajectory_starttime'] = ts.valid_time.replace('T', ' ')
dataset_wt['time'].attrs['forecast_inittime'] = ts.valid_time.replace('T',
' ') # '2006-09-01 12:00:00' # TODO ts.valid_time.replace('T', ' ')
out_path = self.config.out_traj_dir + ts.valid_time.replace(':', '_') + '.nc'
dataset_wt.to_netcdf(out_path)
# filter: keep current wavetrough if:
# troughs are within a certain spatial window
# length of the trough < threshold
def keep_wavetrough(self, properties):
"""
Called for each wavetrough, check if kept based on filtering heuristics:
- WT requires any point of wavetrough in config.wave_filter range
- WT requires minimum length threshold
Parameters
----------
properties
Returns
-------
True if kept
"""
in_area = False
for line_pt in properties.line_pts:
# check if any point is outside filtering area
if (self.config.wave_filter_lon[0] < line_pt.lon < self.config.wave_filter_lon[1]
and self.config.wave_filter_lat[0] < line_pt.lat < self.config.wave_filter_lat[1]):
in_area = True
if not in_area: # no point of line segment in our area
return False
if properties.length_deg <= self.config.degree_len_thr: # too small
return False