Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • Christoph.Fischer/enstools-feature
1 result
Show changes
Commits on Source (2)
......@@ -12,10 +12,13 @@ import numpy as np
data_lat = (3, 35)
data_lon = (-100, 45)
aew_clim_dir = '/home/ws/he7273/phd_all/data/aew/clim/cv_clim_era5.nc' # '/project/meteo/w2w/C3/fischer/belanger/aew_clim/cv_clim_era5.nc' # '/lsdf/MOD/Gruppe_Transregio/Gruppe_Knippertz/kitweather/data/era5/cv_clim_era5.nc' 'C:\\Users\\Christoph\\phd\\data\\enstools-feature\\cv_clim_era5.nc' # '/home/christoph/phd/data/aew/clim/cv_clim_era5.nc' # '/home/christoph/phd/data/framework_example_ds/aew/' # '/project/meteo/w2w/C3/fischer/belanger/aew_clim/' #
in_files = '/home/ws/he7273/phd_all/data/coll_oper/jjaso2021.nc' # '/project/meteo/w2w/C3/fischer/data/jja2022.nc' # '/home/ws/he7273/phd_all/data/coll_oper/jja2021/jja2021.nc' # 'C:\\Users\\Christoph\\phd\\data\\enstools-feature\\2008_sum_uv.nc' # '/home/christoph/phd/data/framework_example_ds/aew/cv_aug_08.nc'
out_dir = '/home/ws/he7273/phd_all/data/coll_oper/' # '/project/meteo/w2w/C3/fischer/belanger/out/' # join('/home/ws/he7273/phd_all/data/aew/out/') # '/project/meteo/w2w/C3/fischer/belanger/out/'
aew_clim_dir = '/project/meteo/w2w/C3/fischer/belanger/aew_clim/cv_clim_era5.nc' # '/home/ws/he7273/phd_all/data/aew/clim/cv_clim_era5.nc' # # '/lsdf/MOD/Gruppe_Transregio/Gruppe_Knippertz/kitweather/data/era5/cv_clim_era5.nc' 'C:\\Users\\Christoph\\phd\\data\\enstools-feature\\cv_clim_era5.nc' # '/home/christoph/phd/data/aew/clim/cv_clim_era5.nc' # '/home/christoph/phd/data/framework_example_ds/aew/' # '/project/meteo/w2w/C3/fischer/belanger/aew_clim/' #
in_files = '/project/meteo/w2w/C3/fischer/data/coll_oper/2021/jjaso2021.nc' # '/home/ws/he7273/phd_all/data/coll_oper/jja2021/jja2021.nc' # 'C:\\Users\\Christoph\\phd\\data\\enstools-feature\\2008_sum_uv.nc' # '/home/christoph/phd/data/framework_example_ds/aew/cv_aug_08.nc'
out_dir = '/project/meteo/w2w/C3/fischer/belanger/out/' # join('/home/ws/he7273/phd_all/data/aew/out/') # '/project/meteo/w2w/C3/fischer/belanger/out/'
out_json_path = out_dir + 'jjaso2021.json'
out_data_path = out_dir + 'jjaso2021_wts.nc'
generate_output = True
plot_dir = '/home/ws/he7273/phd_all/data/coll_oper/' # '/project/meteo/w2w/C3/fischer/belanger/plots/' # join('/home/ws/he7273/phd_all/data/aew/plots/') # '/project/meteo/w2w/C3/fischer/belanger/plots/'
......@@ -73,8 +76,10 @@ v_dim = 'v'
# time of interest, if None all
# june-oct is AEW season
start_date = None # '2022-08-01T00:00' # None # '2008-08-01T00:00' # # '2008-08-01T00:00'
end_date = None # '2022-08-15T00:00' # None # '2008-08-15T00:00' # None # '2008-08-03T00:00'
start_date = None # '2021-08-01T00:00' # None # '2008-08-01T00:00' # # '2008-08-01T00:00'
end_date = None # '2021-08-10T00:00' # None # '2008-08-15T00:00' # None # '2008-08-03T00:00'
add_fake_wts = True # add fake wavetroughs in object description to bridge gaps in trackspb_reference
# Algorithm parameters
# max u wind (m/s) (0 = only keep west-propagating). Belanger: 2.5; Berry: 0.0
......
......@@ -33,7 +33,9 @@ class AEWIdentification(IdentificationStrategy):
self.config = cfg # config
self.config.out_traj_dir = wt_traj_dir
self.config.cv_name = cv
self.orig_dataset = None
if year_summer is not None:
if month is not None:
m_str = str(month).zfill(2)
......@@ -143,7 +145,9 @@ class AEWIdentification(IdentificationStrategy):
if self.config.cv_name not in dataset.data_vars:
print("Curvature Vorticity not found, trying to compute it out of U and V...")
dataset = compute_cv(dataset, u_name, v_name, self.config.cv_name)
self.orig_dataset = dataset
# make dataset to 2.5 (or same as cv_clim)
dataset = dataset.interp({lat_str: cv_clim.coords[lat_str], lon_str: cv_clim.coords[lon_str]})
......@@ -163,7 +167,6 @@ class AEWIdentification(IdentificationStrategy):
cv = mpcalc.smooth_n_point(cv, n=9, passes=2).metpy.dequantify()
# create hourofyear to get anomalies
cv = cv.assign_coords(hourofyear=cv.time.dt.strftime("%m-%d %H"))
cv_anom = cv.groupby('hourofyear') - cv_clim.cv
......@@ -190,7 +193,7 @@ class AEWIdentification(IdentificationStrategy):
u.values < self.config.max_u_thresh)) # threshold for propagation speed -> keep only westward
dataset['trough_mask'] = trough_mask
"""
# create 0.5x0.5 dataarray for wavetroughs
min_lat = dataset[lat_str].data.min()
max_lat = dataset[lat_str].data.max()
......@@ -216,7 +219,7 @@ class AEWIdentification(IdentificationStrategy):
dataset['lon05'].attrs['standard_name'] = 'longitude'
dataset['lat05'].attrs['units'] = 'degrees_north'
dataset['lon05'].attrs['units'] = 'degrees_east'
"""
return dataset
def identify(self, data_chunk: xr.Dataset, **kwargs):
......@@ -235,17 +238,8 @@ class AEWIdentification(IdentificationStrategy):
subplot_kws={'projection': ccrs.PlateCarree()})
paths = c.collections[0].get_paths()
wt = data_chunk.wavetroughs
min_lat = wt.lat05.data.min()
max_lat = wt.lat05.data.max()
min_lon = wt.lon05.data.min()
max_lon = wt.lon05.data.max()
lons = len(wt.lon05.data)
lats = len(wt.lat05.data)
id_ = 1
for path in paths:
# get new object, set id
......@@ -261,28 +255,6 @@ class AEWIdentification(IdentificationStrategy):
objs.append(o)
id_ += 1
# if wavetrough out dataset, gen lines
if not self.config.out_wt:
continue
for v_idx in range(len(path.vertices) - 1):
start_lonlat = path.vertices[v_idx][0], path.vertices[v_idx][1]
end_lonlat = path.vertices[v_idx + 1][0], path.vertices[v_idx + 1][1]
start_idx = ((start_lonlat[0] - min_lon) / (max_lon - min_lon) * lons,
(start_lonlat[1] - min_lat) / (max_lat - min_lat) * lats)
# start_idx = clip(start_idx, (0, 0), (lons, lats))
end_idx = ((end_lonlat[0] - min_lon) / (max_lon - min_lon) * lons,
(end_lonlat[1] - min_lat) / (max_lat - min_lat) * lats)
# end_idx = clip(end_idx, (0, 0), (lons, lats))
rr, cc, val = line_aa(int(start_idx[0]), int(start_idx[1]), int(end_idx[0]), int(end_idx[1]))
rr = clip(rr, 0, lons - 1)
cc = clip(cc, 0, lats - 1)
wt.data[cc, rr] = np.where(np.greater(val, wt.data[cc, rr]), val, wt.data[cc, rr])
return data_chunk, objs
def postprocess(self, dataset: xr.Dataset, data_desc, **kwargs):
......@@ -291,7 +263,8 @@ class AEWIdentification(IdentificationStrategy):
lon_str = get_longitude_dim(dataset)
data_desc = self.make_ids_unique(data_desc)
"""
# drop everything, only keep WTs as 0.5x0.5
if self.config.out_wt:
for var in dataset.data_vars:
......@@ -388,7 +361,8 @@ class AEWIdentification(IdentificationStrategy):
out_path = self.config.out_traj_dir + ts.valid_time.replace(':', '_') + '.nc'
dataset_wt.to_netcdf(out_path)
"""
return dataset, data_desc
# filter: keep current wavetrough if:
......@@ -427,5 +401,11 @@ class AEWIdentification(IdentificationStrategy):
mid_lat = properties.bb.max.lat - properties.bb.min.lat
if mid_lat < 5.0 or mid_lat > 25.0:
return False
height = properties.bb.max.lat - properties.bb.min.lat
width = properties.bb.max.lon - properties.bb.min.lon
if height < 0.75 * width:
return False
return True
......@@ -5,7 +5,10 @@ import xarray as xr
import math
from enstools.feature.util.enstools_utils import get_u_var, get_v_var, get_vertical_dim, get_longitude_dim, \
get_latitude_dim
from enstools.feature.util.data_utils import pb_str_to_datetime64, simplenamespace_to_proto, datetime64_to_pb_str, proto_to_simplenamespace, clip
from google.protobuf.message import Message
from skimage.draw import line
from collections import defaultdict
# calculates the dx's and dy's for each grid cell
# takes list of latitudes and longitudes as input and returns field with dimensions len(lats) x len(lons)
......@@ -256,3 +259,244 @@ def populate_object(obj_props, path, cfg):
# identify troughs in data (should contain U,V,cv), based on the cv climatology
# def identify_troughs(data, cv_clim, cfg):
def create_fake_wt_edges(edge, child_idx, pb_ref):
# make multiple edges out of this one
parent_node = edge.parent
assert isinstance(parent_node, Message)
parent_node_pb = parent_node # simplenamespace_to_proto(parent_node, pb_ref.GraphNode())
child_node = edge.children[child_idx]
child_node_pb = child_node # simplenamespace_to_proto(child_node, pb_ref.GraphNode())
parent_node_time = pb_str_to_datetime64(parent_node.time)
child_node_time = pb_str_to_datetime64(child_node.time)
delt = child_node_time - parent_node_time
if delt != np.timedelta64(12, 'h'):
print("NO 12h abort")
print(delt)
exit(1)
fake_node_time = parent_node_time + 0.5 * delt # TODO assert only skip 1
# create fake node as copy of parent
fake_node = pb_ref.GraphNode() # json_format.Parse(json.dumps(parent_json), pb_ref.GraphNode(), ignore_unknown_fields=False)
fake_node.time = datetime64_to_pb_str(fake_node_time)
# id and flag False
fake_node.object.id = -1
fake_node.object.flag = False
# set properties: bb as mean of parent and child
parent_props = parent_node.object.properties
child_props = child_node.object.properties
fake_props = fake_node.object.properties
fake_props.bb.min.lat = (parent_props.bb.min.lat + child_props.bb.min.lat) / 2.0
fake_props.bb.max.lat = (parent_props.bb.max.lat + child_props.bb.max.lat) / 2.0
fake_props.bb.min.lon = (parent_props.bb.min.lon + child_props.bb.min.lon) / 2.0
fake_props.bb.max.lon = (parent_props.bb.max.lon + child_props.bb.max.lon) / 2.0
avg_lon = (fake_props.bb.min.lon + fake_props.bb.max.lon) / 2.0
fake_props.line_pts.add()
fake_props.line_pts[0].lat = fake_props.bb.min.lat
fake_props.line_pts[0].lon = avg_lon
fake_props.line_pts.add()
fake_props.line_pts[1].lat = fake_props.bb.max.lat
fake_props.line_pts[1].lon = avg_lon
fake_props.length_deg = math.sqrt((fake_props.bb.max.lat - fake_props.bb.min.lat) ** 2 + (fake_props.bb.max.lon - fake_props.bb.min.lon) ** 2)
# new edge:
edge1 = pb_ref.GraphConnection()
edge1.parent.CopyFrom(parent_node_pb)
edge1.children.append(fake_node)
edge2 = pb_ref.GraphConnection()
edge2.parent.CopyFrom(fake_node)
edge2.children.append(child_node_pb)
return [edge1, edge2]
# interpolate wavetroughs, create fake WTs in skipped timesteps.
def interpolate_wts(data_desc, pb_ref):
for set_ in data_desc.sets:
for track in set_.tracks:
# old_edges = []
new_edges = []
for edge in track.edges:
parent_node = edge.parent
parent_node_time = pb_str_to_datetime64(parent_node.time)
if not hasattr(edge, 'children'):
continue
for child_idx, child_node in enumerate(edge.children):
child_node_time = pb_str_to_datetime64(child_node.time)
if child_node_time - parent_node_time > np.timedelta64(6, 'h'):
print("Add fake WT at " + parent_node.time)
wt_edges = create_fake_wt_edges(edge, child_idx, pb_ref)
# old_edges.append((edge, child_idx))
new_edges.extend(wt_edges)
# remove old_edges from this track and from graph
new_edges_sn = [e for e in new_edges] # proto_to_simplenamespace(e)
track.edges.extend(new_edges_sn) # TODO sort
track.edges.sort(key=lambda item: item.parent.time)
set_.graph.edges.extend(new_edges_sn)
set_.graph.edges.sort(key=lambda item: item.parent.time)
return data_desc
def add_wts_to_ds(dataset, data_desc):
print("Create WT lines...")
lon_str = get_longitude_dim(dataset)
lat_str = get_latitude_dim(dataset)
u_str = get_u_var(dataset)
v_str = get_v_var(dataset)
dataset['wavetroughs'] = xr.zeros_like(dataset[u_str].isel(level=0).squeeze(), dtype=int) # all WTs
dataset['tracks'] = xr.zeros_like(dataset[u_str].isel(level=0).squeeze(), dtype=int) # filtered by track heuristics
min_lat = dataset.latitude.data.min()
max_lat = dataset.latitude.data.max()
min_lon = dataset.longitude.data.min()
max_lon = dataset.longitude.data.max()
lons = len(dataset.longitude.data)
lats = len(dataset.latitude.data)
wt = dataset.wavetroughs
wt_t = dataset.tracks
for wt_set in data_desc.sets:
cur_set = wt_set
# if use_fc:
# initTime = wt_set.initTime
# set_ds = dataset.sel(time=initTime) # init time
set_ds = dataset
# get nodes from all tracks in set
set_nodes = []
for track_id, track in enumerate(cur_set.tracks):
set_nodes.extend([e.parent for e in track.edges])
# put them into buckets
node_buckets = defaultdict(list)
for x in set_nodes:
node_buckets[x.time].append(x)
# iterate buckets
print("Tracks")
for time, cur_nodes in node_buckets.items():
print(time)
try:
wt_t_da = wt_t.sel(time=time)
except KeyError:
print("Skipping timestep (not in dataset) " + str(vt))
continue
for node in cur_nodes:
props = node.object.properties
if not hasattr(props, 'line_pts'):
print("?")
for v_idx in range(len(props.line_pts) - 1):
start_lonlat = props.line_pts[v_idx].lon, props.line_pts[v_idx].lat
end_lonlat = props.line_pts[v_idx + 1].lon, props.line_pts[v_idx + 1].lat
start_idx = ((start_lonlat[0] - min_lon) / (max_lon - min_lon) * lons,
(start_lonlat[1] - min_lat) / (max_lat - min_lat) * lats)
# start_idx = clip(start_idx, (0, 0), (lons, lats))
end_idx = ((end_lonlat[0] - min_lon) / (max_lon - min_lon) * lons,
(end_lonlat[1] - min_lat) / (max_lat - min_lat) * lats)
# end_idx = clip(end_idx, (0, 0), (lons, lats))
rr, cc = line(int(start_idx[0]), int(start_idx[1]), int(end_idx[0]), int(end_idx[1]))
rr = clip(rr, 0, lons - 1)
cc = clip(cc, 0, lats - 1)
wt_t_da.values[cc, rr] = node.object.id
"""
# make circle
for px_idx in range(len(rr)):
circle = circles.isel(longitude_center=rr[px_idx], latitude_center=cc[px_idx])
influence_area = circle.where(circle < d, -1)
# update influence area dataarray
infl_da = xr.where(influence_area >= 0, node.object.id, infl_da)
"""
wt_t.loc[dict(time=time)] = wt_t_da.values
### ALL NODES
graph = cur_set.graph
graph_nodes = [e.parent for e in graph.edges]
# put them into buckets
node_buckets = defaultdict(list)
for x in graph_nodes:
node_buckets[x.time].append(x)
# iterate buckets
print("Tracks")
for time, cur_nodes in node_buckets.items():
print(time)
try:
wt_da = wt.sel(time=time)
except KeyError:
print("Skipping timestep (not in dataset) " + str(vt))
continue
for node in cur_nodes:
props = node.object.properties
if not hasattr(props, 'line_pts'):
print("?")
for v_idx in range(len(props.line_pts) - 1):
start_lonlat = props.line_pts[v_idx].lon, props.line_pts[v_idx].lat
end_lonlat = props.line_pts[v_idx + 1].lon, props.line_pts[v_idx + 1].lat
start_idx = ((start_lonlat[0] - min_lon) / (max_lon - min_lon) * lons,
(start_lonlat[1] - min_lat) / (max_lat - min_lat) * lats)
# start_idx = clip(start_idx, (0, 0), (lons, lats))
end_idx = ((end_lonlat[0] - min_lon) / (max_lon - min_lon) * lons,
(end_lonlat[1] - min_lat) / (max_lat - min_lat) * lats)
# end_idx = clip(end_idx, (0, 0), (lons, lats))
rr, cc = line(int(start_idx[0]), int(start_idx[1]), int(end_idx[0]), int(end_idx[1]))
rr = clip(rr, 0, lons - 1)
cc = clip(cc, 0, lats - 1)
wt_da.values[cc, rr] = node.object.id
"""
# make circle
for px_idx in range(len(rr)):
circle = circles.isel(longitude_center=rr[px_idx], latitude_center=cc[px_idx])
influence_area = circle.where(circle < d, -1)
# update influence area dataarray
infl_da = xr.where(influence_area >= 0, node.object.id, infl_da)
"""
wt.loc[dict(time=time)] = wt_da.values
return dataset
\ No newline at end of file
......@@ -11,6 +11,7 @@ from enstools.feature.util.graph import DataGraph
from enstools.feature.identification.african_easterly_waves.plotting import plot_kw, plot_differences, plot_track, plot_track_in_ts, plot_timesteps_from_desc, plot_tracks_from_desc
import enstools.feature.identification.african_easterly_waves.configuration as cfg
import os, sys, glob, shutil
from enstools.feature.identification.african_easterly_waves.processing import interpolate_wts, add_wts_to_ds
from enstools.feature.util.data_utils import get_subset_by_description
import xarray as xr
xr.set_options(keep_attrs=True)
......@@ -22,6 +23,7 @@ pipeline = FeaturePipeline(african_easterly_waves_pb2, processing_mode='2d')
# in_files_all_cv_data = cfg.cv_data_ex
if len(sys.argv) >= 3 and sys.argv[1] == '-kw':
kw_mode = True
from kwutil import *
print("Executing in kitweather mode...")
# kitweather: use last 7 days of analysis and the ecmwf forecast
......@@ -65,12 +67,14 @@ if len(sys.argv) >= 3 and sys.argv[1] == '-kw':
pipeline.set_data(data_ds)
else:
kw_mode = False
in_file = cfg.in_files
out_dir = cfg.out_dir
pipeline.set_data_path(in_file)
# init AEWIdentification strategy, can take different parameters
i_strat = AEWIdentification(wt_out_file=False, cv='cv') # , year_summer=proc_summer_of_year, month=proc_month_of_year)
enable_out = (not kw_mode) and cfg.generate_output
i_strat = AEWIdentification(wt_out_file=enable_out, cv='cv') # , year_summer=proc_summer_of_year, month=proc_month_of_year)
t_strat = AEWTracking()
pipeline.set_identification_strategy(i_strat)
......@@ -82,6 +86,7 @@ pipeline.execute()
od = pipeline.get_object_desc()
all_tracks = []
# TODO generate tracks in tracking? why in proto then?
for set_id, trackable_set in enumerate(od.sets):
# generate graph out of tracked data
......@@ -92,24 +97,16 @@ for set_id, trackable_set in enumerate(od.sets):
g.generate_tracks(apply_filter=True) # add tracks to OD, applies filtering TODO tracks not in desc.
tracks = g.set_desc.tracks
# parents of a node: track.get_parents(track.graph.edges[0].parent)
# children of a node: track.get_children(track.graph.edges[0].parent)
# plot tracks
for track_id, track in enumerate(tracks):
plot_track(track, "track" + "{:03d}".format(set_id) + "_" + "{:03d}".format(track_id))
# for track_id, track in enumerate(tracks):
# plot_track(track, "track" + "{:03d}".format(set_id) + "_" + "{:03d}".format(track_id))
ds = pipeline.get_data()
ds_set = get_subset_by_description(ds, trackable_set, '2d')
# plot differences (passed filtering / did not pass)
# plot_differences(g, tracks) TODO need to fix if we want this
# for track_id, track in enumerate(tracks):
# plot_track(track, "track" + str(track_id))
all_tracks.extend(tracks)
# ds_set = get_subset_by_description(ds, trackable_set, '2d')
ds = pipeline.get_data()
......@@ -166,5 +163,20 @@ else:
"""
# no out dataset here.
pipeline.save_result(description_type='json', description_path=cfg.out_json_path) #, dataset_path=out_dataset_path) # dataset_path=out_dataset_path,
if enable_out:
ds = i_strat.orig_dataset
ob = pipeline.get_object_desc()
ds = ds.load()
print("Add fake WTs")
### ADD FAKE WTs
ob = interpolate_wts(ob, african_easterly_waves_pb2)
print(ds)
### add WTs to orig DS. field tracks and field WTs
ds = add_wts_to_ds(ds, ob)
print(ds)
pipeline.set_data(ds)
pipeline.save_result(description_type='json', description_path=cfg.out_json_path, dataset_path=cfg.out_data_path)
from ..object_compare_tracking import ObjectComparisonTracking
import enstools.feature.identification.african_easterly_waves.configuration as cfg
from enstools.feature.identification.african_easterly_waves.processing import interpolate_wts, add_wts_to_ds
from datetime import datetime, timedelta
from collections import defaultdict
import statistics
from shapely.geometry import Polygon, LineString
from enstools.feature.util.data_utils import pb_str_to_datetime
from enstools.feature.util.data_utils import pb_str_to_datetime, clip
from enstools.feature.util.graph import DataGraph
zero_dt = timedelta(seconds=0)
......@@ -59,8 +60,9 @@ class AEWTracking(ObjectComparisonTracking):
return True
def postprocess(self, obj_desc):
print("postprocess @ aew")
return
print("Postprocess AEW tracking.")
return obj_desc
# get dict of key=endnode, values=nodes on way) for endnodes at end if time_delta
@staticmethod
......