Newer
Older
Christoph Fischer
committed
from enstools.feature.identification import IdentificationTechnique
import xarray as xr
import numpy as np
import os, sys
import metpy.calc as mpcalc
from .util import calc_adv
from matplotlib import pyplot as plt
import cartopy.crs as ccrs
from .filtering import keep_wavestate
from .._proto_gen import african_easterly_waves_pb2
Christoph Fischer
committed
class AEWIdentification(IdentificationTechnique):
def __init__(self, **kwargs):
"""
Initialize the AEW Identification.
Parameters (experimental)
----------
kwargs
"""
self.pb_reference = african_easterly_waves_pb2
Christoph Fischer
committed
import enstools.feature.identification.african_easterly_waves.configuration as cfg
self.config = cfg # config
self.processing_mode = '2d'
Christoph Fischer
committed
pass
def precompute(self, dataset: xr.Dataset, **kwargs):
print("Precompute for PV identification...")
plt.switch_backend('agg') # this is thread safe matplotlib but cant display.
Christoph Fischer
committed
# --------------- CLIMATOLOGY
lon_range = self.config.data_lon
lat_range = self.config.data_lat
clim_file = self.config.get_clim_file()
if os.path.isfile(clim_file):
cv_clim = xr.open_dataset(clim_file)
Christoph Fischer
committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# generate: need all 40y of CV data.
print("Climatology file not found. Computing climatology...")
from .climatology import compute_climatology
cv_clim = compute_climatology(self.config)
cv_clim.to_netcdf(clim_file)
# --------------- SUBSET DATA ACCORDING TO CFG
start_date_dt = np.datetime64(self.config.start_date)
end_date_dt = np.datetime64(self.config.end_date)
# get the data we want to investigate
# can also be multiple timesteps, so also multiple years -> load years we need
years = list(range(start_date_dt.astype(object).year, end_date_dt.astype(object).year + 1))
print(years)
# process_data = xr.open_mfdataset([diri + str(y) + "cv.nc" for y in years]) # open years in range of requested
dataset = dataset.sel(plev=self.config.levels,
lat=slice(lat_range[0], lat_range[1]),
lon=slice(lon_range[0], lon_range[1]),
time=slice(start_date_dt, end_date_dt))
# make sure that lat and lon are last two dimensions
if 'lat' not in dataset.cv.coords.dims[-2:] or 'lon' not in dataset.cv.coords.dims[-2:]:
print("Reordering dimensions so lat and lon at back. Required for metpy.calc.")
dataset = dataset.transpose(..., 'lat', 'lon')
# --------------- DO NUMPY PARALLELIZED STUFF: CREATE TROUGH MASKS
u = dataset.U
v = dataset.V
cv = dataset.cv
# smooth CV with kernel
cv = mpcalc.smooth_n_point(cv, n=9, passes=2).metpy.dequantify()
# create hourofyear to get anomalies
cv = cv.assign_coords(hourofyear=cv.time.dt.strftime("%m-%d %H"))
cv_anom = cv.groupby('hourofyear') - cv_clim.cv
# compute advection of cv: first and second derivative
adv1, adv2 = calc_adv(cv_anom, u, v)
# xr.where() anomaly data exceeds the percentile from the hourofyear climatology:
# replace data time with hourofyear -> compare with climatology percentile -> back to real time
cv_anom_h = cv_anom.swap_dims(dims_dict={'time': 'hourofyear'})
perc_mask_h = cv_anom_h.where(
cv_anom_h > cv_clim.cva_quantile_hoy.sel(dict(hourofyear=cv_anom.hourofyear.data)))
perc_mask = perc_mask_h.swap_dims(dims_dict={'hourofyear': 'time'})
cv_perc_thresh = np.nanpercentile(cv, self.config.cv_percentile) # 66th percentile of cv anomalies
print(cv_perc_thresh)
print('Locating wave troughs...')
# filter the advection field given our conditions:
trough_mask = adv1.where(np.logical_and(
~np.isnan(perc_mask), # percentile of anomaly over threshold from climatology
adv2.values > self.config.second_advection_min_thr,
# second time derivative > 0: dont detect local minima over the percentile
u.values < self.config.max_u_thresh)) # threshold for propagation speed -> keep only westward
dataset['trough_mask'] = trough_mask
return dataset
Christoph Fischer
committed
def identify(self, data_chunk: xr.Dataset, **kwargs):
objs = []
trough_mask_cur = data_chunk.trough_mask
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(15, 15), subplot_kw=dict(projection=ccrs.PlateCarree()))
Christoph Fischer
committed
# generate zero-contours with matplotlib core
c = trough_mask_cur.plot.contour(transform=ccrs.PlateCarree(), colors='blue', levels=[0.0])
paths = c.collections[0].get_paths()
id_ = 1
for path in paths:
o = self.get_new_object()
o.properties.num_nodes = len(path)
Christoph Fischer
committed
# fill the properties defined in the .proto file.
o.id = id_
for v in path.vertices:
line_pt = o.properties.line_pts.add()
line_pt.lon = v[0]
line_pt.lat = v[1]
# add to objects if keep
if keep_wavestate(o.properties):
# filtered_paths = spacial_filter(paths, cfg.spatial_thr, cfg.wave_lat, cfg.wave_lon)
objs.append(o)
id_ += 1
return data_chunk, objs
Christoph Fischer
committed
def postprocess(self, dataset: xr.Dataset, pb2_desc, **kwargs):
return dataset, pb2_desc