Newer
Older
Christoph Fischer
committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from enstools.feature.identification import IdentificationTechnique
import xarray as xr
import numpy as np
import os, sys
from enstools.feature.util.pb2_properties_api import ObjectProperties
import metpy.calc as mpcalc
from .util import calc_adv
from matplotlib import pyplot as plt
import cartopy.crs as ccrs
class AEWIdentification(IdentificationTechnique):
def __init__(self, **kwargs):
"""
Initialize the AEW Identification.
Parameters (experimental)
----------
kwargs
"""
import enstools.feature.identification.african_easterly_waves.configuration as cfg
self.config = cfg # config
self.processing_mode = '2d' # TODO enum?
pass
def precompute(self, dataset: xr.Dataset, **kwargs):
print("Precompute for PV identification...")
# --------------- CLIMATOLOGY
# TODO use config for file path
lon_range = self.config.data_lon
lat_range = self.config.data_lat
clim_file = self.config.get_clim_file()
if os.path.isfile(clim_file):
cv_clim = xr.open_dataset(clim_file)
else: # TODO not yet accessed in framework
# generate: need all 40y of CV data.
print("Climatology file not found. Computing climatology...")
from .climatology import compute_climatology
cv_clim = compute_climatology(self.config)
cv_clim.to_netcdf(clim_file)
# --------------- SUBSET DATA ACCORDING TO CFG
start_date_dt = np.datetime64(self.config.start_date)
end_date_dt = np.datetime64(self.config.end_date)
# get the data we want to investigate
# can also be multiple timesteps, so also multiple years -> load years we need
years = list(range(start_date_dt.astype(object).year, end_date_dt.astype(object).year + 1))
print(years)
# process_data = xr.open_mfdataset([diri + str(y) + "cv.nc" for y in years]) # open years in range of requested
dataset = dataset.sel(plev=self.config.levels,
lat=slice(lat_range[0], lat_range[1]),
lon=slice(lon_range[0], lon_range[1]),
time=slice(start_date_dt, end_date_dt))
# make sure that lat and lon are last two dimensions
if 'lat' not in dataset.cv.coords.dims[-2:] or 'lon' not in dataset.cv.coords.dims[-2:]:
print("Reordering dimensions so lat and lon at back. Required for metpy.calc.")
dataset = dataset.transpose(..., 'lat', 'lon')
# --------------- DO NUMPY PARALLELIZED STUFF: CREATE TROUGH MASKS
u = dataset.U
v = dataset.V
cv = dataset.cv
# smooth CV with kernel
cv = mpcalc.smooth_n_point(cv, n=9, passes=2).metpy.dequantify()
# create hourofyear to get anomalies
cv = cv.assign_coords(hourofyear=cv.time.dt.strftime("%m-%d %H"))
cv_anom = cv.groupby('hourofyear') - cv_clim.cv
# compute advection of cv: first and second derivative
adv1, adv2 = calc_adv(cv_anom, u, v)
# xr.where() anomaly data exceeds the percentile from the hourofyear climatology:
# replace data time with hourofyear -> compare with climatology percentile -> back to real time
cv_anom_h = cv_anom.swap_dims(dims_dict={'time': 'hourofyear'})
perc_mask_h = cv_anom_h.where(
cv_anom_h > cv_clim.cva_quantile_hoy.sel(dict(hourofyear=cv_anom.hourofyear.data)))
perc_mask = perc_mask_h.swap_dims(dims_dict={'hourofyear': 'time'})
cv_perc_thresh = np.nanpercentile(cv, self.config.cv_percentile) # 66th percentile of cv anomalies
print(cv_perc_thresh)
print('Locating wave troughs...')
# filter the advection field given our conditions:
trough_mask = adv1.where(np.logical_and(
~np.isnan(perc_mask), # percentile of anomaly over threshold from climatology
adv2.values > self.config.second_advection_min_thr,
# second time derivative > 0: dont detect local minima over the percentile
u.values < self.config.max_u_thresh)) # threshold for propagation speed -> keep only westward
dataset['trough_mask'] = trough_mask
return dataset
# main identify, called on spatial 2d/3d subsets from enstools-feature.
# TODO maybe rename identify_per_block
# or separate methods identify, identify_per_block. But how to build obj if only "identify"
def identify(self, data_chunk: xr.Dataset, **kwargs):
print("chunk identify")
# TODO pressure parallel? PV needs 3d input, this parallel 2d!
from .._proto_gen import african_easterly_waves_pb2
from enstools.feature.util.pb2_properties_api import ObjectProperties
# Let's say you detected 5 objects:
objs = []
for i in range(5):
# get an instance of a new object and its id in the descriptions.
s_object = ObjectProperties.get_instance(african_easterly_waves_pb2)
# fill the properties defined in the .proto file.
s_object.set('a', 42.0 * i)
objs.append(s_object)
# ObjectProperties.add_to_objects(object_block, s_object)
# can also set id manually, e.g. if labeled dataset:
# ObjectProperties.add_to_objects(object_block, s_object, id=i)
# fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(15, 15), subplot_kw=dict(projection=ccrs.PlateCarree()))
# TODO can user decide ID??????!!!!
# waves_da = xr.DataArray(coords=[data.coords['time'], data.coords['plev']]).astype(dtype=WaveTroughState)
# TODO parallelize?
# cart = list(itertools.product(trough_mask.time.data, trough_mask.plev.data))
# cart_n = len(cart)
# for i, (ctime, cplev) in enumerate(cart):
return data_chunk, objs
# get a new object structure
dummy_tl = identification_pb2.Timeline()
object_block_ref = dummy_tl.objects
trough_mask_cur = data_chunk.trough_mask
print(trough_mask_cur)
# print(str(ctime)[:13] + " | " + str(int(cplev / 100)) + "hPa (" + str(i + 1) + "/" + str(cart_n) + ")")
exit()
# generate zero-contours with matplotlib core
c = trough_mask_cur.plot.contour(transform=ccrs.PlateCarree(), colors='blue', levels=[0.0])
paths = c.collections[0].get_paths()
from filtering import spacial_filter
filtered_paths = spacial_filter(paths, cfg.spatial_thr, cfg.wave_lat, cfg.wave_lon)
waves = []
for path in filtered_paths:
waves.append(WaveTrough(cplev, ctime, path))
waves_da.loc[dict(time=ctime, plev=cplev)] = WaveTroughState(cplev, ctime, waves)
print('Finished.')
return waves_da
# for id_, obj in object_properties_with_corres_indices:
# ObjectProperties.add_to_objects(object_block, obj, id=id_)
return data_chunk
def postprocess(self, dataset: xr.Dataset, pb2_desc, **kwargs):
return dataset, pb2_desc