Newer
Older
Christoph Fischer
committed
from enstools.feature.identification import IdentificationTechnique
import xarray as xr
import numpy as np
import os, sys
import metpy.calc as mpcalc
from .util import calc_adv
from matplotlib import pyplot as plt
import cartopy.crs as ccrs
from .filtering import keep_wavetrough
from .processing import populate_object
from skimage.draw import line_aa
from enstools.feature.util.enstools_utils import get_vertical_dim, get_longitude_dim, get_latitude_dim
Christoph Fischer
committed
class AEWIdentification(IdentificationTechnique):
def __init__(self, wt_traj_dir=None, cv='cv', **kwargs):
Christoph Fischer
committed
"""
Initialize the AEW Identification.
Parameters (experimental)
----------
kwargs
wt_out_file: output the wavetroughs as new and only out-field in 0.5x0.5
Christoph Fischer
committed
"""
import enstools.feature.identification.african_easterly_waves.configuration as cfg
self.config = cfg # config
self.config.out_wt_dir = wt_traj_dir
self.config.cv_name = cv
self.found_max_wt_pts = -1
self.lock_ = threading.Lock()
Christoph Fischer
committed
pass
def precompute(self, dataset: xr.Dataset, **kwargs):
print("Precompute for PV identification...")
plt.switch_backend('agg') # this is thread safe matplotlib but cant display.
Christoph Fischer
committed
# --------------- CLIMATOLOGY
lon_range = self.config.data_lon
lat_range = self.config.data_lat
clim_file = self.config.get_clim_file()
level_str = get_vertical_dim(dataset)
lat_str = get_latitude_dim(dataset)
lon_str = get_longitude_dim(dataset)
Christoph Fischer
committed
if os.path.isfile(clim_file):
cv_clim = xr.open_dataset(clim_file)
Christoph Fischer
committed
# generate: need all 40y of CV data.
print("Climatology file not found. Computing climatology...")
from .climatology import compute_climatology
cv_clim = compute_climatology(self.config)
cv_clim.to_netcdf(clim_file)
Christoph Fischer
committed
# --------------- SUBSET DATA ACCORDING TO CFG
start_date_dt = np.datetime64(self.config.start_date) if self.config.start_date is not None else None
end_date_dt = np.datetime64(self.config.end_date) if self.config.end_date is not None else None
Christoph Fischer
committed
# get the data we want to investigate
**{lat_str: slice(lat_range[0], lat_range[1])},
**{lon_str: slice(lon_range[0], lon_range[1])},
time=slice(start_date_dt, end_date_dt))
if level_str is None:
if not 'level' in dataset.coords:
print("No level information given in input. Assume 700hPa.")
exit(1)
dataset = dataset.expand_dims('level')
level_str = 'level'
else:
dataset = dataset.sel(**{level_str: self.config.levels})
# rename cv_clim dimensions to be same as in data.
cv_clim = cv_clim.rename({'lat': lat_str, 'lon': lon_str})
if 'plev' in cv_clim.dims and 'plev' != level_str:
print("plev from clim to level: div by 100.")
cv_clim = cv_clim.rename({'plev': level_str})
cv_clim = cv_clim.assign_coords({level_str: cv_clim[level_str] / 100})
Christoph Fischer
committed
# make sure that lat and lon are last two dimensions
if lat_str not in dataset[self.config.cv_name].coords.dims[-2:] or lon_str not in dataset[
self.config.cv_name].coords.dims[
-2:]:
Christoph Fischer
committed
print("Reordering dimensions so lat and lon at back. Required for metpy.calc.")
dataset = dataset.transpose(..., lat_str, lon_str)
Christoph Fischer
committed
# --------------- DO NUMPY PARALLELIZED STUFF: CREATE TROUGH MASKS
u = dataset.u if 'u' in dataset.data_vars else dataset.U
v = dataset.v if 'v' in dataset.data_vars else dataset.V
cv = dataset[self.config.cv_name]
Christoph Fischer
committed
# smooth CV with kernel
cv = mpcalc.smooth_n_point(cv, n=9, passes=2).metpy.dequantify()
Christoph Fischer
committed
# create hourofyear to get anomalies
cv = cv.assign_coords(hourofyear=cv.time.dt.strftime("%m-%d %H"))
cv_anom = cv.groupby('hourofyear') - cv_clim.cv
Christoph Fischer
committed
# compute advection of cv: first and second derivative
adv1, adv2 = calc_adv(cv_anom, u, v)
# xr.where() anomaly data exceeds the percentile from the hourofyear climatology:
# replace data time with hourofyear -> compare with climatology percentile -> back to real time
cv_anom_h = cv_anom.swap_dims(dims_dict={'time': 'hourofyear'})
perc_mask_h = cv_anom_h.where(
cv_anom_h > cv_clim.cva_quantile_hoy.sel(dict(hourofyear=cv_anom.hourofyear.data)))
perc_mask = perc_mask_h.swap_dims(dims_dict={'hourofyear': 'time'})
cv_perc_thresh = np.nanpercentile(cv, self.config.cv_percentile) # 66th percentile of cv anomalies
print(cv_perc_thresh)
print('Locating wave troughs...')
# filter the advection field given our conditions:
trough_mask = adv1.where(np.logical_and(
~np.isnan(perc_mask), # percentile of anomaly over threshold from climatology
adv2.values > self.config.second_advection_min_thr,
# second time derivative > 0: dont detect local minima over the percentile
u.values < self.config.max_u_thresh)) # threshold for propagation speed -> keep only westward
dataset['trough_mask'] = trough_mask
# create 0.5x0.5 dataarray for wavetroughs
min_lat = dataset[lat_str].data.min()
max_lat = dataset[lat_str].data.max()
min_lon = dataset[lon_str].data.min()
max_lon = dataset[lon_str].data.max()
lat05 = np.linspace(min_lat, max_lat, int((max_lat - min_lat) * 2) + 1)
lon05 = np.linspace(min_lon, max_lon, int((max_lon - min_lon) * 2) + 1)
# wt = xr.DataArray(coords=[('lon05', lon05), ('lat05', lat05)],
return dataset
Christoph Fischer
committed
def identify(self, data_chunk: xr.Dataset, **kwargs):
objs = []
trough_mask_cur = data_chunk.trough_mask
def clip(tup, mint, maxt):
return np.clip(tup, mint, maxt)
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(15, 15), subplot_kw={'projection': ccrs.PlateCarree()})
Christoph Fischer
committed
# generate zero-contours with matplotlib core
c = trough_mask_cur.plot.contour(transform=ccrs.PlateCarree(), colors='blue', levels=[0.0],
subplot_kws={'projection': ccrs.PlateCarree()})
Christoph Fischer
committed
paths = c.collections[0].get_paths()
id_ = 1
for path in paths:
# get new object, set id
o = self.get_new_object()
# populate it
populate_object(o.properties, path)
# add to objects if keep
if keep_wavetrough(o.properties, self.config):
objs.append(o)
id_ += 1
num_verts = len(path.vertices)
with self.lock_: # TODO remove this?
self.found_max_wt_pts = max(self.found_max_wt_pts, num_verts)
Christoph Fischer
committed
def postprocess(self, dataset: xr.Dataset, pb2_desc, **kwargs):
lat_str = get_latitude_dim(dataset)
lon_str = get_longitude_dim(dataset)
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
# create met3d like trajectories
if self.config.out_wt_dir:
if not os.path.exists(self.config.out_wt_dir):
os.makedirs(self.config.out_wt_dir)
assert(len(pb2_desc.sets) == 1) # TODO assert one set. maybe expand at some point
desc_set = pb2_desc.sets[0]
desc_times = desc_set.timesteps
for idx, ts in enumerate(desc_times):
# need to make separate dataset for each init-time
# because number of trajs (WTs) are different from time to time
dataset_wt = xr.Dataset()
lon_list = []
lat_list = []
pres_list = []
max_pts_in_wt = -1 # TODO what if no wts
for o in ts.objects: # get lons and lats
pt_list = o.properties.line_pts
lon_list.append(np.array([pt.lon for pt in pt_list]))
lat_list.append(np.array([pt.lat for pt in pt_list]))
pres_list.append(np.array([850.0 for pt in pt_list]))
max_pts_in_wt = max(max_pts_in_wt, len(lon_list[-1]))
# go again and fill with NaNs at end
for i in range(len(lon_list)): # get lons and lats
lon_list[i] = np.pad(lon_list[i], (0, max_pts_in_wt - len(lon_list[i])), mode='constant',
constant_values=np.nan)
lat_list[i] = np.pad(lat_list[i], (0, max_pts_in_wt - len(lat_list[i])), mode='constant',
constant_values=np.nan)
pres_list[i] = np.pad(pres_list[i], (0, max_pts_in_wt - len(pres_list[i])), mode='constant',
constant_values=np.nan)
dataset_wt = dataset_wt.expand_dims(time=np.arange(0, max_pts_in_wt).astype(dtype=float)) # fake traj time
dataset_wt = dataset_wt.expand_dims(ensemble=[0])
dataset_wt = dataset_wt.expand_dims(trajectory=np.arange(1, len(ts.objects) + 1))
lons = xr.DataArray(np.zeros((1, len(ts.objects), max_pts_in_wt)), dims=("ensemble", "trajectory", "time"))
lons.attrs['standard_name'] = "longitude"
lons.attrs['long_name'] = "longitude"
lons.attrs['units'] = "degrees_east"
lats = xr.zeros_like(lons)
lats.attrs['standard_name'] = "latitude"
lats.attrs['long_name'] = "latitude"
lats.attrs['units'] = "degrees_north"
pres = xr.zeros_like(lons)
pres.attrs['standard_name'] = "air_pressure"
pres.attrs['long_name'] = "pressure"
pres.attrs['units'] = "hPa"
pres.attrs['positive'] = "down"
pres.attrs['axis'] = "Z"
dataset_wt['lon'] = lons
dataset_wt['lat'] = lats
dataset_wt['pressure'] = pres
# TODO auxiliary smth?
lon_list_np = np.array(lon_list)
lat_list_np = np.array(lat_list)
pres_list_np = np.array(pres_list)
dataset_wt['lon'].data[0] = lon_list_np
dataset_wt['lat'].data[0] = lat_list_np
dataset_wt['pressure'].data[0] = pres_list_np
dataset_wt['time'].attrs['standard_name'] = "time"
dataset_wt['time'].attrs['long_name'] = "time"
dataset_wt['time'].attrs['units'] = "hours since " + ts.valid_time.replace('T', ' ')
dataset_wt['time'].attrs['trajectory_starttime'] = ts.valid_time.replace('T', ' ')
dataset_wt['time'].attrs['forecast_inittime'] = ts.valid_time.replace('T', ' ') # '2006-09-01 12:00:00' # TODO ts.valid_time.replace('T', ' ')
out_path = self.config.out_wt_dir + ts.valid_time.replace(':','_') + '.nc'
dataset_wt.to_netcdf(out_path)
Christoph Fischer
committed
return dataset, pb2_desc