Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • Christoph.Fischer/enstools-feature
1 result
Show changes
Commits on Source (2)
...@@ -12,10 +12,13 @@ import numpy as np ...@@ -12,10 +12,13 @@ import numpy as np
data_lat = (3, 35) data_lat = (3, 35)
data_lon = (-100, 45) data_lon = (-100, 45)
aew_clim_dir = '/home/ws/he7273/phd_all/data/aew/clim/cv_clim_era5.nc' # '/project/meteo/w2w/C3/fischer/belanger/aew_clim/cv_clim_era5.nc' # '/lsdf/MOD/Gruppe_Transregio/Gruppe_Knippertz/kitweather/data/era5/cv_clim_era5.nc' 'C:\\Users\\Christoph\\phd\\data\\enstools-feature\\cv_clim_era5.nc' # '/home/christoph/phd/data/aew/clim/cv_clim_era5.nc' # '/home/christoph/phd/data/framework_example_ds/aew/' # '/project/meteo/w2w/C3/fischer/belanger/aew_clim/' # aew_clim_dir = '/project/meteo/w2w/C3/fischer/belanger/aew_clim/cv_clim_era5.nc' # '/home/ws/he7273/phd_all/data/aew/clim/cv_clim_era5.nc' # # '/lsdf/MOD/Gruppe_Transregio/Gruppe_Knippertz/kitweather/data/era5/cv_clim_era5.nc' 'C:\\Users\\Christoph\\phd\\data\\enstools-feature\\cv_clim_era5.nc' # '/home/christoph/phd/data/aew/clim/cv_clim_era5.nc' # '/home/christoph/phd/data/framework_example_ds/aew/' # '/project/meteo/w2w/C3/fischer/belanger/aew_clim/' #
in_files = '/home/ws/he7273/phd_all/data/coll_oper/jjaso2021.nc' # '/project/meteo/w2w/C3/fischer/data/jja2022.nc' # '/home/ws/he7273/phd_all/data/coll_oper/jja2021/jja2021.nc' # 'C:\\Users\\Christoph\\phd\\data\\enstools-feature\\2008_sum_uv.nc' # '/home/christoph/phd/data/framework_example_ds/aew/cv_aug_08.nc' in_files = '/project/meteo/w2w/C3/fischer/data/coll_oper/2021/jjaso2021.nc' # '/home/ws/he7273/phd_all/data/coll_oper/jja2021/jja2021.nc' # 'C:\\Users\\Christoph\\phd\\data\\enstools-feature\\2008_sum_uv.nc' # '/home/christoph/phd/data/framework_example_ds/aew/cv_aug_08.nc'
out_dir = '/home/ws/he7273/phd_all/data/coll_oper/' # '/project/meteo/w2w/C3/fischer/belanger/out/' # join('/home/ws/he7273/phd_all/data/aew/out/') # '/project/meteo/w2w/C3/fischer/belanger/out/' out_dir = '/project/meteo/w2w/C3/fischer/belanger/out/' # join('/home/ws/he7273/phd_all/data/aew/out/') # '/project/meteo/w2w/C3/fischer/belanger/out/'
out_json_path = out_dir + 'jjaso2021.json' out_json_path = out_dir + 'jjaso2021.json'
out_data_path = out_dir + 'jjaso2021_wts.nc'
generate_output = True
plot_dir = '/home/ws/he7273/phd_all/data/coll_oper/' # '/project/meteo/w2w/C3/fischer/belanger/plots/' # join('/home/ws/he7273/phd_all/data/aew/plots/') # '/project/meteo/w2w/C3/fischer/belanger/plots/' plot_dir = '/home/ws/he7273/phd_all/data/coll_oper/' # '/project/meteo/w2w/C3/fischer/belanger/plots/' # join('/home/ws/he7273/phd_all/data/aew/plots/') # '/project/meteo/w2w/C3/fischer/belanger/plots/'
...@@ -73,8 +76,10 @@ v_dim = 'v' ...@@ -73,8 +76,10 @@ v_dim = 'v'
# time of interest, if None all # time of interest, if None all
# june-oct is AEW season # june-oct is AEW season
start_date = None # '2022-08-01T00:00' # None # '2008-08-01T00:00' # # '2008-08-01T00:00' start_date = None # '2021-08-01T00:00' # None # '2008-08-01T00:00' # # '2008-08-01T00:00'
end_date = None # '2022-08-15T00:00' # None # '2008-08-15T00:00' # None # '2008-08-03T00:00' end_date = None # '2021-08-10T00:00' # None # '2008-08-15T00:00' # None # '2008-08-03T00:00'
add_fake_wts = True # add fake wavetroughs in object description to bridge gaps in trackspb_reference
# Algorithm parameters # Algorithm parameters
# max u wind (m/s) (0 = only keep west-propagating). Belanger: 2.5; Berry: 0.0 # max u wind (m/s) (0 = only keep west-propagating). Belanger: 2.5; Berry: 0.0
......
...@@ -33,7 +33,9 @@ class AEWIdentification(IdentificationStrategy): ...@@ -33,7 +33,9 @@ class AEWIdentification(IdentificationStrategy):
self.config = cfg # config self.config = cfg # config
self.config.out_traj_dir = wt_traj_dir self.config.out_traj_dir = wt_traj_dir
self.config.cv_name = cv self.config.cv_name = cv
self.orig_dataset = None
if year_summer is not None: if year_summer is not None:
if month is not None: if month is not None:
m_str = str(month).zfill(2) m_str = str(month).zfill(2)
...@@ -143,7 +145,9 @@ class AEWIdentification(IdentificationStrategy): ...@@ -143,7 +145,9 @@ class AEWIdentification(IdentificationStrategy):
if self.config.cv_name not in dataset.data_vars: if self.config.cv_name not in dataset.data_vars:
print("Curvature Vorticity not found, trying to compute it out of U and V...") print("Curvature Vorticity not found, trying to compute it out of U and V...")
dataset = compute_cv(dataset, u_name, v_name, self.config.cv_name) dataset = compute_cv(dataset, u_name, v_name, self.config.cv_name)
self.orig_dataset = dataset
# make dataset to 2.5 (or same as cv_clim) # make dataset to 2.5 (or same as cv_clim)
dataset = dataset.interp({lat_str: cv_clim.coords[lat_str], lon_str: cv_clim.coords[lon_str]}) dataset = dataset.interp({lat_str: cv_clim.coords[lat_str], lon_str: cv_clim.coords[lon_str]})
...@@ -163,7 +167,6 @@ class AEWIdentification(IdentificationStrategy): ...@@ -163,7 +167,6 @@ class AEWIdentification(IdentificationStrategy):
cv = mpcalc.smooth_n_point(cv, n=9, passes=2).metpy.dequantify() cv = mpcalc.smooth_n_point(cv, n=9, passes=2).metpy.dequantify()
# create hourofyear to get anomalies # create hourofyear to get anomalies
cv = cv.assign_coords(hourofyear=cv.time.dt.strftime("%m-%d %H")) cv = cv.assign_coords(hourofyear=cv.time.dt.strftime("%m-%d %H"))
cv_anom = cv.groupby('hourofyear') - cv_clim.cv cv_anom = cv.groupby('hourofyear') - cv_clim.cv
...@@ -190,7 +193,7 @@ class AEWIdentification(IdentificationStrategy): ...@@ -190,7 +193,7 @@ class AEWIdentification(IdentificationStrategy):
u.values < self.config.max_u_thresh)) # threshold for propagation speed -> keep only westward u.values < self.config.max_u_thresh)) # threshold for propagation speed -> keep only westward
dataset['trough_mask'] = trough_mask dataset['trough_mask'] = trough_mask
"""
# create 0.5x0.5 dataarray for wavetroughs # create 0.5x0.5 dataarray for wavetroughs
min_lat = dataset[lat_str].data.min() min_lat = dataset[lat_str].data.min()
max_lat = dataset[lat_str].data.max() max_lat = dataset[lat_str].data.max()
...@@ -216,7 +219,7 @@ class AEWIdentification(IdentificationStrategy): ...@@ -216,7 +219,7 @@ class AEWIdentification(IdentificationStrategy):
dataset['lon05'].attrs['standard_name'] = 'longitude' dataset['lon05'].attrs['standard_name'] = 'longitude'
dataset['lat05'].attrs['units'] = 'degrees_north' dataset['lat05'].attrs['units'] = 'degrees_north'
dataset['lon05'].attrs['units'] = 'degrees_east' dataset['lon05'].attrs['units'] = 'degrees_east'
"""
return dataset return dataset
def identify(self, data_chunk: xr.Dataset, **kwargs): def identify(self, data_chunk: xr.Dataset, **kwargs):
...@@ -235,17 +238,8 @@ class AEWIdentification(IdentificationStrategy): ...@@ -235,17 +238,8 @@ class AEWIdentification(IdentificationStrategy):
subplot_kws={'projection': ccrs.PlateCarree()}) subplot_kws={'projection': ccrs.PlateCarree()})
paths = c.collections[0].get_paths() paths = c.collections[0].get_paths()
wt = data_chunk.wavetroughs
min_lat = wt.lat05.data.min()
max_lat = wt.lat05.data.max()
min_lon = wt.lon05.data.min()
max_lon = wt.lon05.data.max()
lons = len(wt.lon05.data)
lats = len(wt.lat05.data)
id_ = 1 id_ = 1
for path in paths: for path in paths:
# get new object, set id # get new object, set id
...@@ -261,28 +255,6 @@ class AEWIdentification(IdentificationStrategy): ...@@ -261,28 +255,6 @@ class AEWIdentification(IdentificationStrategy):
objs.append(o) objs.append(o)
id_ += 1 id_ += 1
# if wavetrough out dataset, gen lines
if not self.config.out_wt:
continue
for v_idx in range(len(path.vertices) - 1):
start_lonlat = path.vertices[v_idx][0], path.vertices[v_idx][1]
end_lonlat = path.vertices[v_idx + 1][0], path.vertices[v_idx + 1][1]
start_idx = ((start_lonlat[0] - min_lon) / (max_lon - min_lon) * lons,
(start_lonlat[1] - min_lat) / (max_lat - min_lat) * lats)
# start_idx = clip(start_idx, (0, 0), (lons, lats))
end_idx = ((end_lonlat[0] - min_lon) / (max_lon - min_lon) * lons,
(end_lonlat[1] - min_lat) / (max_lat - min_lat) * lats)
# end_idx = clip(end_idx, (0, 0), (lons, lats))
rr, cc, val = line_aa(int(start_idx[0]), int(start_idx[1]), int(end_idx[0]), int(end_idx[1]))
rr = clip(rr, 0, lons - 1)
cc = clip(cc, 0, lats - 1)
wt.data[cc, rr] = np.where(np.greater(val, wt.data[cc, rr]), val, wt.data[cc, rr])
return data_chunk, objs return data_chunk, objs
def postprocess(self, dataset: xr.Dataset, data_desc, **kwargs): def postprocess(self, dataset: xr.Dataset, data_desc, **kwargs):
...@@ -291,7 +263,8 @@ class AEWIdentification(IdentificationStrategy): ...@@ -291,7 +263,8 @@ class AEWIdentification(IdentificationStrategy):
lon_str = get_longitude_dim(dataset) lon_str = get_longitude_dim(dataset)
data_desc = self.make_ids_unique(data_desc) data_desc = self.make_ids_unique(data_desc)
"""
# drop everything, only keep WTs as 0.5x0.5 # drop everything, only keep WTs as 0.5x0.5
if self.config.out_wt: if self.config.out_wt:
for var in dataset.data_vars: for var in dataset.data_vars:
...@@ -388,7 +361,8 @@ class AEWIdentification(IdentificationStrategy): ...@@ -388,7 +361,8 @@ class AEWIdentification(IdentificationStrategy):
out_path = self.config.out_traj_dir + ts.valid_time.replace(':', '_') + '.nc' out_path = self.config.out_traj_dir + ts.valid_time.replace(':', '_') + '.nc'
dataset_wt.to_netcdf(out_path) dataset_wt.to_netcdf(out_path)
"""
return dataset, data_desc return dataset, data_desc
# filter: keep current wavetrough if: # filter: keep current wavetrough if:
...@@ -427,5 +401,11 @@ class AEWIdentification(IdentificationStrategy): ...@@ -427,5 +401,11 @@ class AEWIdentification(IdentificationStrategy):
mid_lat = properties.bb.max.lat - properties.bb.min.lat mid_lat = properties.bb.max.lat - properties.bb.min.lat
if mid_lat < 5.0 or mid_lat > 25.0: if mid_lat < 5.0 or mid_lat > 25.0:
return False return False
height = properties.bb.max.lat - properties.bb.min.lat
width = properties.bb.max.lon - properties.bb.min.lon
if height < 0.75 * width:
return False
return True return True
...@@ -5,7 +5,10 @@ import xarray as xr ...@@ -5,7 +5,10 @@ import xarray as xr
import math import math
from enstools.feature.util.enstools_utils import get_u_var, get_v_var, get_vertical_dim, get_longitude_dim, \ from enstools.feature.util.enstools_utils import get_u_var, get_v_var, get_vertical_dim, get_longitude_dim, \
get_latitude_dim get_latitude_dim
from enstools.feature.util.data_utils import pb_str_to_datetime64, simplenamespace_to_proto, datetime64_to_pb_str, proto_to_simplenamespace, clip
from google.protobuf.message import Message
from skimage.draw import line
from collections import defaultdict
# calculates the dx's and dy's for each grid cell # calculates the dx's and dy's for each grid cell
# takes list of latitudes and longitudes as input and returns field with dimensions len(lats) x len(lons) # takes list of latitudes and longitudes as input and returns field with dimensions len(lats) x len(lons)
...@@ -256,3 +259,244 @@ def populate_object(obj_props, path, cfg): ...@@ -256,3 +259,244 @@ def populate_object(obj_props, path, cfg):
# identify troughs in data (should contain U,V,cv), based on the cv climatology # identify troughs in data (should contain U,V,cv), based on the cv climatology
# def identify_troughs(data, cv_clim, cfg): # def identify_troughs(data, cv_clim, cfg):
def create_fake_wt_edges(edge, child_idx, pb_ref):
# make multiple edges out of this one
parent_node = edge.parent
assert isinstance(parent_node, Message)
parent_node_pb = parent_node # simplenamespace_to_proto(parent_node, pb_ref.GraphNode())
child_node = edge.children[child_idx]
child_node_pb = child_node # simplenamespace_to_proto(child_node, pb_ref.GraphNode())
parent_node_time = pb_str_to_datetime64(parent_node.time)
child_node_time = pb_str_to_datetime64(child_node.time)
delt = child_node_time - parent_node_time
if delt != np.timedelta64(12, 'h'):
print("NO 12h abort")
print(delt)
exit(1)
fake_node_time = parent_node_time + 0.5 * delt # TODO assert only skip 1
# create fake node as copy of parent
fake_node = pb_ref.GraphNode() # json_format.Parse(json.dumps(parent_json), pb_ref.GraphNode(), ignore_unknown_fields=False)
fake_node.time = datetime64_to_pb_str(fake_node_time)
# id and flag False
fake_node.object.id = -1
fake_node.object.flag = False
# set properties: bb as mean of parent and child
parent_props = parent_node.object.properties
child_props = child_node.object.properties
fake_props = fake_node.object.properties
fake_props.bb.min.lat = (parent_props.bb.min.lat + child_props.bb.min.lat) / 2.0
fake_props.bb.max.lat = (parent_props.bb.max.lat + child_props.bb.max.lat) / 2.0
fake_props.bb.min.lon = (parent_props.bb.min.lon + child_props.bb.min.lon) / 2.0
fake_props.bb.max.lon = (parent_props.bb.max.lon + child_props.bb.max.lon) / 2.0
avg_lon = (fake_props.bb.min.lon + fake_props.bb.max.lon) / 2.0
fake_props.line_pts.add()
fake_props.line_pts[0].lat = fake_props.bb.min.lat
fake_props.line_pts[0].lon = avg_lon
fake_props.line_pts.add()
fake_props.line_pts[1].lat = fake_props.bb.max.lat
fake_props.line_pts[1].lon = avg_lon
fake_props.length_deg = math.sqrt((fake_props.bb.max.lat - fake_props.bb.min.lat) ** 2 + (fake_props.bb.max.lon - fake_props.bb.min.lon) ** 2)
# new edge:
edge1 = pb_ref.GraphConnection()
edge1.parent.CopyFrom(parent_node_pb)
edge1.children.append(fake_node)
edge2 = pb_ref.GraphConnection()
edge2.parent.CopyFrom(fake_node)
edge2.children.append(child_node_pb)
return [edge1, edge2]
# interpolate wavetroughs, create fake WTs in skipped timesteps.
def interpolate_wts(data_desc, pb_ref):
for set_ in data_desc.sets:
for track in set_.tracks:
# old_edges = []
new_edges = []
for edge in track.edges:
parent_node = edge.parent
parent_node_time = pb_str_to_datetime64(parent_node.time)
if not hasattr(edge, 'children'):
continue
for child_idx, child_node in enumerate(edge.children):
child_node_time = pb_str_to_datetime64(child_node.time)
if child_node_time - parent_node_time > np.timedelta64(6, 'h'):
print("Add fake WT at " + parent_node.time)
wt_edges = create_fake_wt_edges(edge, child_idx, pb_ref)
# old_edges.append((edge, child_idx))
new_edges.extend(wt_edges)
# remove old_edges from this track and from graph
new_edges_sn = [e for e in new_edges] # proto_to_simplenamespace(e)
track.edges.extend(new_edges_sn) # TODO sort
track.edges.sort(key=lambda item: item.parent.time)
set_.graph.edges.extend(new_edges_sn)
set_.graph.edges.sort(key=lambda item: item.parent.time)
return data_desc
def add_wts_to_ds(dataset, data_desc):
print("Create WT lines...")
lon_str = get_longitude_dim(dataset)
lat_str = get_latitude_dim(dataset)
u_str = get_u_var(dataset)
v_str = get_v_var(dataset)
dataset['wavetroughs'] = xr.zeros_like(dataset[u_str].isel(level=0).squeeze(), dtype=int) # all WTs
dataset['tracks'] = xr.zeros_like(dataset[u_str].isel(level=0).squeeze(), dtype=int) # filtered by track heuristics
min_lat = dataset.latitude.data.min()
max_lat = dataset.latitude.data.max()
min_lon = dataset.longitude.data.min()
max_lon = dataset.longitude.data.max()
lons = len(dataset.longitude.data)
lats = len(dataset.latitude.data)
wt = dataset.wavetroughs
wt_t = dataset.tracks
for wt_set in data_desc.sets:
cur_set = wt_set
# if use_fc:
# initTime = wt_set.initTime
# set_ds = dataset.sel(time=initTime) # init time
set_ds = dataset
# get nodes from all tracks in set
set_nodes = []
for track_id, track in enumerate(cur_set.tracks):
set_nodes.extend([e.parent for e in track.edges])
# put them into buckets
node_buckets = defaultdict(list)
for x in set_nodes:
node_buckets[x.time].append(x)
# iterate buckets
print("Tracks")
for time, cur_nodes in node_buckets.items():
print(time)
try:
wt_t_da = wt_t.sel(time=time)
except KeyError:
print("Skipping timestep (not in dataset) " + str(vt))
continue
for node in cur_nodes:
props = node.object.properties
if not hasattr(props, 'line_pts'):
print("?")
for v_idx in range(len(props.line_pts) - 1):
start_lonlat = props.line_pts[v_idx].lon, props.line_pts[v_idx].lat
end_lonlat = props.line_pts[v_idx + 1].lon, props.line_pts[v_idx + 1].lat
start_idx = ((start_lonlat[0] - min_lon) / (max_lon - min_lon) * lons,
(start_lonlat[1] - min_lat) / (max_lat - min_lat) * lats)
# start_idx = clip(start_idx, (0, 0), (lons, lats))
end_idx = ((end_lonlat[0] - min_lon) / (max_lon - min_lon) * lons,
(end_lonlat[1] - min_lat) / (max_lat - min_lat) * lats)
# end_idx = clip(end_idx, (0, 0), (lons, lats))
rr, cc = line(int(start_idx[0]), int(start_idx[1]), int(end_idx[0]), int(end_idx[1]))
rr = clip(rr, 0, lons - 1)
cc = clip(cc, 0, lats - 1)
wt_t_da.values[cc, rr] = node.object.id
"""
# make circle
for px_idx in range(len(rr)):
circle = circles.isel(longitude_center=rr[px_idx], latitude_center=cc[px_idx])
influence_area = circle.where(circle < d, -1)
# update influence area dataarray
infl_da = xr.where(influence_area >= 0, node.object.id, infl_da)
"""
wt_t.loc[dict(time=time)] = wt_t_da.values
### ALL NODES
graph = cur_set.graph
graph_nodes = [e.parent for e in graph.edges]
# put them into buckets
node_buckets = defaultdict(list)
for x in graph_nodes:
node_buckets[x.time].append(x)
# iterate buckets
print("Tracks")
for time, cur_nodes in node_buckets.items():
print(time)
try:
wt_da = wt.sel(time=time)
except KeyError:
print("Skipping timestep (not in dataset) " + str(vt))
continue
for node in cur_nodes:
props = node.object.properties
if not hasattr(props, 'line_pts'):
print("?")
for v_idx in range(len(props.line_pts) - 1):
start_lonlat = props.line_pts[v_idx].lon, props.line_pts[v_idx].lat
end_lonlat = props.line_pts[v_idx + 1].lon, props.line_pts[v_idx + 1].lat
start_idx = ((start_lonlat[0] - min_lon) / (max_lon - min_lon) * lons,
(start_lonlat[1] - min_lat) / (max_lat - min_lat) * lats)
# start_idx = clip(start_idx, (0, 0), (lons, lats))
end_idx = ((end_lonlat[0] - min_lon) / (max_lon - min_lon) * lons,
(end_lonlat[1] - min_lat) / (max_lat - min_lat) * lats)
# end_idx = clip(end_idx, (0, 0), (lons, lats))
rr, cc = line(int(start_idx[0]), int(start_idx[1]), int(end_idx[0]), int(end_idx[1]))
rr = clip(rr, 0, lons - 1)
cc = clip(cc, 0, lats - 1)
wt_da.values[cc, rr] = node.object.id
"""
# make circle
for px_idx in range(len(rr)):
circle = circles.isel(longitude_center=rr[px_idx], latitude_center=cc[px_idx])
influence_area = circle.where(circle < d, -1)
# update influence area dataarray
infl_da = xr.where(influence_area >= 0, node.object.id, infl_da)
"""
wt.loc[dict(time=time)] = wt_da.values
return dataset
\ No newline at end of file
...@@ -11,6 +11,7 @@ from enstools.feature.util.graph import DataGraph ...@@ -11,6 +11,7 @@ from enstools.feature.util.graph import DataGraph
from enstools.feature.identification.african_easterly_waves.plotting import plot_kw, plot_differences, plot_track, plot_track_in_ts, plot_timesteps_from_desc, plot_tracks_from_desc from enstools.feature.identification.african_easterly_waves.plotting import plot_kw, plot_differences, plot_track, plot_track_in_ts, plot_timesteps_from_desc, plot_tracks_from_desc
import enstools.feature.identification.african_easterly_waves.configuration as cfg import enstools.feature.identification.african_easterly_waves.configuration as cfg
import os, sys, glob, shutil import os, sys, glob, shutil
from enstools.feature.identification.african_easterly_waves.processing import interpolate_wts, add_wts_to_ds
from enstools.feature.util.data_utils import get_subset_by_description from enstools.feature.util.data_utils import get_subset_by_description
import xarray as xr import xarray as xr
xr.set_options(keep_attrs=True) xr.set_options(keep_attrs=True)
...@@ -22,6 +23,7 @@ pipeline = FeaturePipeline(african_easterly_waves_pb2, processing_mode='2d') ...@@ -22,6 +23,7 @@ pipeline = FeaturePipeline(african_easterly_waves_pb2, processing_mode='2d')
# in_files_all_cv_data = cfg.cv_data_ex # in_files_all_cv_data = cfg.cv_data_ex
if len(sys.argv) >= 3 and sys.argv[1] == '-kw': if len(sys.argv) >= 3 and sys.argv[1] == '-kw':
kw_mode = True
from kwutil import * from kwutil import *
print("Executing in kitweather mode...") print("Executing in kitweather mode...")
# kitweather: use last 7 days of analysis and the ecmwf forecast # kitweather: use last 7 days of analysis and the ecmwf forecast
...@@ -65,12 +67,14 @@ if len(sys.argv) >= 3 and sys.argv[1] == '-kw': ...@@ -65,12 +67,14 @@ if len(sys.argv) >= 3 and sys.argv[1] == '-kw':
pipeline.set_data(data_ds) pipeline.set_data(data_ds)
else: else:
kw_mode = False
in_file = cfg.in_files in_file = cfg.in_files
out_dir = cfg.out_dir out_dir = cfg.out_dir
pipeline.set_data_path(in_file) pipeline.set_data_path(in_file)
# init AEWIdentification strategy, can take different parameters # init AEWIdentification strategy, can take different parameters
i_strat = AEWIdentification(wt_out_file=False, cv='cv') # , year_summer=proc_summer_of_year, month=proc_month_of_year) enable_out = (not kw_mode) and cfg.generate_output
i_strat = AEWIdentification(wt_out_file=enable_out, cv='cv') # , year_summer=proc_summer_of_year, month=proc_month_of_year)
t_strat = AEWTracking() t_strat = AEWTracking()
pipeline.set_identification_strategy(i_strat) pipeline.set_identification_strategy(i_strat)
...@@ -82,6 +86,7 @@ pipeline.execute() ...@@ -82,6 +86,7 @@ pipeline.execute()
od = pipeline.get_object_desc() od = pipeline.get_object_desc()
all_tracks = [] all_tracks = []
# TODO generate tracks in tracking? why in proto then?
for set_id, trackable_set in enumerate(od.sets): for set_id, trackable_set in enumerate(od.sets):
# generate graph out of tracked data # generate graph out of tracked data
...@@ -92,24 +97,16 @@ for set_id, trackable_set in enumerate(od.sets): ...@@ -92,24 +97,16 @@ for set_id, trackable_set in enumerate(od.sets):
g.generate_tracks(apply_filter=True) # add tracks to OD, applies filtering TODO tracks not in desc. g.generate_tracks(apply_filter=True) # add tracks to OD, applies filtering TODO tracks not in desc.
tracks = g.set_desc.tracks tracks = g.set_desc.tracks
# parents of a node: track.get_parents(track.graph.edges[0].parent)
# children of a node: track.get_children(track.graph.edges[0].parent)
# plot tracks # plot tracks
# for track_id, track in enumerate(tracks):
for track_id, track in enumerate(tracks): # plot_track(track, "track" + "{:03d}".format(set_id) + "_" + "{:03d}".format(track_id))
plot_track(track, "track" + "{:03d}".format(set_id) + "_" + "{:03d}".format(track_id))
ds = pipeline.get_data() ds = pipeline.get_data()
ds_set = get_subset_by_description(ds, trackable_set, '2d') ds_set = get_subset_by_description(ds, trackable_set, '2d')
# plot differences (passed filtering / did not pass) # plot differences (passed filtering / did not pass)
# plot_differences(g, tracks) TODO need to fix if we want this # plot_differences(g, tracks) TODO need to fix if we want this
# for track_id, track in enumerate(tracks):
# plot_track(track, "track" + str(track_id))
all_tracks.extend(tracks) all_tracks.extend(tracks)
# ds_set = get_subset_by_description(ds, trackable_set, '2d')
ds = pipeline.get_data() ds = pipeline.get_data()
...@@ -166,5 +163,20 @@ else: ...@@ -166,5 +163,20 @@ else:
""" """
# no out dataset here. if enable_out:
pipeline.save_result(description_type='json', description_path=cfg.out_json_path) #, dataset_path=out_dataset_path) # dataset_path=out_dataset_path, ds = i_strat.orig_dataset
ob = pipeline.get_object_desc()
ds = ds.load()
print("Add fake WTs")
### ADD FAKE WTs
ob = interpolate_wts(ob, african_easterly_waves_pb2)
print(ds)
### add WTs to orig DS. field tracks and field WTs
ds = add_wts_to_ds(ds, ob)
print(ds)
pipeline.set_data(ds)
pipeline.save_result(description_type='json', description_path=cfg.out_json_path, dataset_path=cfg.out_data_path)
from ..object_compare_tracking import ObjectComparisonTracking from ..object_compare_tracking import ObjectComparisonTracking
import enstools.feature.identification.african_easterly_waves.configuration as cfg import enstools.feature.identification.african_easterly_waves.configuration as cfg
from enstools.feature.identification.african_easterly_waves.processing import interpolate_wts, add_wts_to_ds
from datetime import datetime, timedelta from datetime import datetime, timedelta
from collections import defaultdict from collections import defaultdict
import statistics import statistics
from shapely.geometry import Polygon, LineString from shapely.geometry import Polygon, LineString
from enstools.feature.util.data_utils import pb_str_to_datetime from enstools.feature.util.data_utils import pb_str_to_datetime, clip
from enstools.feature.util.graph import DataGraph from enstools.feature.util.graph import DataGraph
zero_dt = timedelta(seconds=0) zero_dt = timedelta(seconds=0)
...@@ -59,8 +60,9 @@ class AEWTracking(ObjectComparisonTracking): ...@@ -59,8 +60,9 @@ class AEWTracking(ObjectComparisonTracking):
return True return True
def postprocess(self, obj_desc): def postprocess(self, obj_desc):
print("postprocess @ aew") print("Postprocess AEW tracking.")
return
return obj_desc
# get dict of key=endnode, values=nodes on way) for endnodes at end if time_delta # get dict of key=endnode, values=nodes on way) for endnodes at end if time_delta
@staticmethod @staticmethod
......