Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • Christoph.Fischer/enstools-feature
1 result
Show changes
Showing
with 653 additions and 88 deletions
......@@ -59,8 +59,14 @@ def get_object_data(self_, stereo_ds, level_list, dist_expand, area_map, config)
volume_km3 += lv_area[k] * height_km
volume_km2K += lv_area[k] * config.res_z # layer * dist between layers.
obj_props.volume_km3 = volume_km3
obj_props.volume_km2K = volume_km2K
obj_props.volume_km3 = volume_km3
obj_props.volume_km2K = volume_km2K
elif config.dims == 2:
lv_area = np.sum(obj_areas_s)
obj_props.area_km2 = lv_area
# ----------------------
# BOUNDING BOX (GEOREF)
......
......@@ -269,4 +269,55 @@ def plot_streamer_areas(ds: xr.Dataset, dist_field, streamer_areas, config, leve
plt.contourf(x, y, streamer_areas, 1, cmap=binary_grey, transform=ccrs.NorthPolarStereo())
plt.savefig("step5.png")
exit()
def generate_2d_plot(subset, fig_name):
fig = plt.figure(figsize=(10, 10))
ax = plt.axes(projection=ccrs.NorthPolarStereo())
ax.set_boundary(circle, transform=ax.transAxes)
ax.coastlines(linewidth=1)
ax.gridlines()
ax.set_extent(plt_extent, ccrs.NorthPolarStereo())
pv = subset.variables['pv'][:]
pvu2intvar = (pv > 2e-6).astype(dtype=int)
# CS = plt.contour(x, y, dist_field, [1000], colors=['black'], transform=ccrs.NorthPolarStereo(), linewidths=2)
plt.contour(x, y, pvu2intvar, [0.5], colors='black', transform=ccrs.NorthPolarStereo(), linewidths=2)
plt.contourf(x, y, subset.streamer, [0.5, 1000], cmap=binary_grey, transform=ccrs.NorthPolarStereo())
print("Saving " + fig_name)
plt.savefig(fig_name + '.png')
def generate_2d_plots(ds: xr.Dataset):
if 'member' in ds.coords:
members = ds.coords['member'].values
if not isinstance(members, np.ndarray):
members = [members]
else:
members = [-1]
if 'time' in ds.coords:
times = ds.coords['time'].values
if not isinstance(times, np.ndarray):
times = [times]
else:
times = [-1]
for member in members:
for ts in times:
figname = ''
if member != -1:
subset = ds.sel({'member': member, 'time': ts})
figname += 'mem_' + str(int(member)) + '_time_' + str(ts)[:13]
else:
subset = ds.sel({'time': ts})
figname += 'time_' + str(ts)[:13]
ss_sq = subset.squeeze()
generate_2d_plot(ss_sq, figname)
# package requirements for PV 3D identification
# package requirements for PV 2D and 3D identification
# do this within the conda environment
# Note that this strategy is not natively executable on Windows. It uses CDO, which does not have a Windows build.
# Instead you can use the Windows Subsystem for Linux and build enstools-feature therein.
conda install cdo=1.9.10
# also for PV identification: custom scikit-fmm:
cd enstools/feature/identification/pv_streamer/scikit-fmm-custom
python setup.py install
# Linux only
conda install cdo=1.9.10 scikit-image
# cdo wrapper (pip only)
pip install cdo
scikit-image
# https://github.com/conda-forge/cdo-feedstock
cdo
......@@ -3,16 +3,18 @@
from enstools.feature.pipeline import FeaturePipeline
from enstools.feature.identification._proto_gen import pv_streamer_pb2
from enstools.feature.identification.pv_streamer import PVIdentification, PVWernliSprenger2007
from enstools.feature.identification.pv_streamer.plotting import generate_2d_plots
import os
pipeline = FeaturePipeline(pv_streamer_pb2, processing_mode='3d')
# init PVIdentification strategy, can take different parameters
i_strat = PVIdentification(unit='pv') # , out_type='ll' , mode_2d_layer=330) # theta_range=(300, 380), extract_containing_layer=330)
# specify unit as 'pv' or 'pvu' depending on data unit
i_strat = PVIdentification(unit='pv') # ,out_type='ll' , mode_2d_layer=330, theta_range=(300, 380), extract_containing_layer=330)
# or strategy by Wernli and Sprenger (2007). Note: Fortran required and manual adaptation of scripts needed.
# change contour.f level too, also if data in PV/PVU adapt param.
# also pretty restrictive filewise. only global. need -180lon and 180lon.
# also pretty restrictive filewise.
# i_strat = PVWernliSprenger2007(unit='pv')
# t_strat = OverlapTracking()
......@@ -22,13 +24,16 @@ pipeline.set_identification_strategy(i_strat)
pipeline.set_tracking_strategy(None)
# TODO set data path
data_path = os.path.join(os.path.expanduser("~"), 'phd/data/framework_example_ds/pv/S2S_pv_avg200-500_1998.nc')
data_path = os.path.join(os.path.expanduser("~"), 'phd/data/enstools-feature/4m_12d_pvavg_200_500_1998.nc')
pipeline.set_data_path(data_path)
# execute pipeline
pipeline.execute()
if not pipeline.is_data_3d():
generate_2d_plots(pipeline.get_data())
out_netcdf_path = data_path + '_streamers.nc'
out_json_path = data_path + '_streamers_desc.json'
pipeline.save_result(description_type='json', description_path=out_json_path, dataset_path=out_netcdf_path,) # save_proj=
......
from .identification import StormIdentification
from ..identification import IdentificationStrategy
import xarray as xr
from random import randrange
from scipy import ndimage as ndi
import numpy as np
from enstools.feature.identification.storm.util import *
class StormIdentification(IdentificationStrategy):
def __init__(self, start_pressure_thr=965, step_pressure_thr=2, size_thr=10000, amplitude_thr=2, **kwargs):
# Constructor. Called from example_template.py, parameters can be passed and set here.
self.start_pressure_thr = start_pressure_thr
self.step_pressure_thr = step_pressure_thr
self.size_thr = size_thr
self.amplitude_thr = amplitude_thr
pass
def precompute(self, dataset: xr.Dataset, **kwargs):
plt.switch_backend('agg') # this is thread safe matplotlib but cant display.
dataset['storm_areas'] = xr.zeros_like(dataset['msl'], dtype=int)
return dataset
def identify(self, dataset: xr.Dataset, **kwargs):
msl_data = dataset.msl
considered_mask = xr.zeros_like(msl_data, dtype=bool)
pressure_thrs = np.arange(self.start_pressure_thr, 1000, self.step_pressure_thr)
storm_id = 1
obj_list = []
for pressure_thr in pressure_thrs:
lp_area = get_msl_areas_by_hPa_thr(msl_data, pressure_thr)
storm_areas = split_areas(lp_area)
for storm_area in storm_areas:
if np.any(np.logical_and(considered_mask.data, storm_area.data)):
continue
storm_size = get_storm_size_km2(storm_area)
if storm_size < self.size_thr:
continue
if not has_local_minimum(storm_area, msl_data):
continue
storm_amplitude = get_storm_amplitude_hPa(storm_area, msl_data)
if storm_amplitude < self.amplitude_thr:
continue
# keep storm
dataset['storm_areas'].data[storm_area] = storm_id
considered_mask.data[storm_area] = True
# get an instance of a new object, can pass an ID or set in manually afterwards
obj = self.get_new_object()
# set some ID to it
obj.id = storm_id
# get properties of object and populate them (like defined in template.proto)
properties = obj.properties
properties.min_pressure, lon, lat = get_min_pressure_hPa(storm_area, msl_data, return_pos=True)
properties.min_pressure_pos.lat = lat
properties.min_pressure_pos.lon = lon
obj_list.append(obj)
storm_id += 1
# return the dataset (can be changed here), and the list of objects
return dataset, obj_list
def postprocess(self, dataset: xr.Dataset, obj_desc, **kwargs):
# obj_desc = self.make_ids_unique(obj_desc)
return dataset, obj_desc
# Usage Example
from enstools.feature.pipeline import FeaturePipeline
from enstools.feature.identification.storm import StormIdentification
from enstools.feature.tracking.overlap_tracking import OverlapTracking
from enstools.feature.identification._proto_gen import storm_pb2
from os.path import expanduser
from enstools.feature.identification.storm.util import plot_timestep, plot_track
from enstools.feature.tracking.overlap_double_threshold_tracking import OverlapDoubleThresholdTracking
import os
from datetime import timedelta
from enstools.feature.util.graph import DataGraph
# set the pb_reference to the compiled pb2.py file (see proto_gen directory)
pipeline = FeaturePipeline(storm_pb2, processing_mode='2d')
# change this to an identification strategy that actually does something: existing one or implement your own
i_strat = StormIdentification(some_parameter='foo') # set the Identification strategy
t_strat = OverlapTracking(field_name='storm_areas', min_duration=timedelta(hours=12)) # set the tracking strategy
# t_strat = OverlapDoubleThresholdTracking(
# inner_thresh_name='storm_areas',
# outer_thresh_name='storm_areas',
# tracking_method='inner_first',
# min_duration=timedelta(hours=12)
# )
pipeline.set_identification_strategy(i_strat)
pipeline.set_tracking_strategy(t_strat) # or None as argument if no tracking
path = '/project/meteo/w2w/workshops/EMS_2022/feature/data/storm/zeynep.nc'
plot_dir = os.path.expanduser('~/plots/')
if not os.path.exists(plot_dir):
os.makedirs(plot_dir)
desc_dir = os.path.expanduser('~/feature_desc/')
if not os.path.exists(desc_dir):
os.makedirs(desc_dir)
pipeline.set_data_path(path)
# execute pipeline
pipeline.execute()
od = pipeline.get_object_desc()
reanalysis_set = od.sets[0]
g = DataGraph(reanalysis_set, t_strat)
g.generate_tracks(apply_filter=True)
# subset = get_subset_by_description(pipeline.get_data(), reanalysis_set, '2d')
for ts in reanalysis_set.timesteps:
ds_t = pipeline.get_data().sel(time=ts.valid_time)
plot_timestep(ds_t, plot_dir + str(ts.valid_time)[:13] + '.png')
for track_id, track in enumerate(reanalysis_set.tracks):
plot_track(track, plot_dir + 'track' + str(track_id) + '.png', pipeline.get_data().storm_areas)
out_json_path = desc_dir + 'desc.json'
out_ds_path = desc_dir + 'desc.nc'
pipeline.save_result(description_type='json', description_path=out_json_path, dataset_path=out_ds_path)
syntax = "proto2";
message Pos {
required float lat = 1;
required float lon = 2;
}
message Properties {
required float min_pressure = 1;
required Pos min_pressure_pos = 2;
}
from matplotlib import pyplot as plt
from cartopy import crs as ccrs
import cartopy.feature as cfeature
import numpy as np
from matplotlib import colors as mcolors
from scipy import ndimage as ndi
import xarray as xr
from pyproj import Geod
from shapely.geometry import Polygon
from shapely.ops import orient
from enstools.feature.util.data_utils import pb_str_to_datetime
import matplotlib
binary_grey = mcolors.ListedColormap(['white', 'grey'])
############### PLOTTING FUNCTIONS ###############
def plot_timestep(ds, file_name):
"""
Plots storms in the dataset ds to the given file name.
The dataset ds must only contain one timestep.
Parameters
----------
ds: xr.Dataset
file_name: str
"""
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(11, 4), subplot_kw=dict(projection=ccrs.PlateCarree()))
levels = np.linspace(97000, 104000, 14)
ds.storm_areas.plot.contourf(levels=[0.5, 1000], cmap=binary_grey, add_colorbar=False) # , transform=ccrs.NorthPolarStereo())
ds.msl.plot.contour(levels=levels, vmin=0, extend='max', cmap='Blues')
ax.coastlines()
ax.add_feature(cfeature.BORDERS.with_scale('50m'))
print("Save to " + file_name)
plt.savefig(file_name, format='png')
plt.figure().clear()
plt.close()
plt.cla()
plt.clf()
def plot_track(track, filename, storm_areas):
"""
Plots the given storm track.
Parameters
----------
track: pb2.Track
The track as generated by generate_tracks() and located as one element of set.tracks
filename: str
The file name
storm_areas: xr.DataArray
The storm areas DataArray to add the areas on top.
"""
nodes = [edge.parent for edge in track.edges]
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 9), subplot_kw=dict(projection=ccrs.PlateCarree()))
extent = [-50, 30, 40, 85]
min_time = pb_str_to_datetime(nodes[0].time).timestamp()
max_time = pb_str_to_datetime(nodes[-1].time).timestamp()
cmap = matplotlib.cm.get_cmap('rainbow')
cmap_v = matplotlib.cm.get_cmap('viridis')
color_wgts = np.linspace(0.0, 1.0, len(nodes))
colors = ['red', 'yellow', 'green', 'blue', 'purple']
lats = []
lons = []
press = []
for node_idx, node in enumerate(nodes):
obj = node.object
id = obj.id
time_d = pb_str_to_datetime(node.time).timestamp()
# current storm area to 1
storm_areas_t = storm_areas.sel(time=node.time)
binary_current_area = xr.where(storm_areas_t == id, 1, 0)
time_weight = (1.0 -(time_d - min_time) / (max_time - min_time) if max_time > min_time else 1.0)
binary_current_area.plot.contour(levels=[0.5], colors=[cmap(time_weight)])
lat = obj.properties.min_pressure_pos.lat
lon = obj.properties.min_pressure_pos.lon
pres = obj.properties.min_pressure
lats.append(lat)
lons.append(lon)
press.append(pres)
colors = [cmap_v((pres - 960.0) / (1000.0 - 960.0)) for pres in press]
ax.coastlines()
ax.add_feature(cfeature.BORDERS.with_scale('50m'))
ax.set_extent(extent, crs=ccrs.PlateCarree())
plt.scatter(x=lons, y=lats,
s=40,
c=colors,
alpha=1,
transform=ccrs.PlateCarree())
figure_name = 'track' + '.png' # .replace(':', '_')
plt.title(nodes[0].time + " - " + nodes[-1].time)
print("Plot to " + str(filename))
plt.savefig(filename, format='png')
plt.figure().clear()
plt.close()
plt.cla()
plt.clf()
############### IDENTIFICATION FUNCTIONS ###############
def get_min_pressure_hPa(storm_area, msl_data, return_pos=False):
"""
Get the minimum pressure in hPa of the given storm area.
Parameters
----------
storm_area: xr.DataArray
The DataArray of the binary storm field.
msl_data: xr.DataArray
The DataArray of mean sea level pressure.
return_pos: bool
if True also returns the position of lowest pressure.
Returns
-------
Lowest pressure in hPa, and if return_pos=True also lon and lat of that pressure.
"""
wh = msl_data.data[storm_area.data]
min_pressure_pa = np.amin(wh)
if not return_pos:
return min_pressure_pa / 100.0 # hPa
else:
# position in masked version
masked_msl_data = msl_data.where(storm_area)
wh_loc = masked_msl_data.where(masked_msl_data == min_pressure_pa, drop=True).squeeze()
# lat and lon
longitude = np.atleast_1d(wh_loc.longitude.data)[0].item()
latitude = np.atleast_1d(wh_loc.latitude.data)[0].item()
return min_pressure_pa / 100.0, longitude, latitude
def get_msl_areas_by_hPa_thr(msl_data, hPa):
"""
Get the areas with lower pressure then given threshold as a binary field.
Parameters
----------
msl_data: xr.DataArray
The DataArray of mean sea level pressure.
hPa: float
The pressure threshold.
Returns
-------
A binary DataArray based on the given maximum threshold for pressure.
"""
return dataset < hPa * 100.0
def split_areas(areas):
"""
Split disjunct areas in a binary DataArray into multiple DataArrays.
Parameters
----------
areas: xr.DataArray
The binary DataArray to be split.
Returns
-------
A list of DataArrays with the same shape as the input but each only containing True values of one single area.
"""
# split dataarray areas into list of cohesive regions (DAs)
lp_area_lab = xr.zeros_like(areas, dtype=int)
lp_area_lab.data, num_features = ndi.label(areas)
a_list = []
for i in range(1, num_features + 1):
a_list.append(lp_area_lab == i)
return a_list
def get_storm_size_km2(storm_da):
"""
Get the storm size in km².
Parameters
----------
storm_da: xr.DataArray
The binary DataArray, which is set to True for each point belonging to the storm.
Returns
-------
The "True" area of the DataArray
"""
# create contour
res = storm_da.plot.contour(levels=[0.5, 1000])
paths = res.collections[0].get_paths()
if len(paths) > 1:
print("Multi path?") # TODO
path = paths[0]
node_list = path.vertices
polygon = Polygon(node_list)
geod = Geod(ellps="WGS84")
poly_area, poly_perimeter = geod.geometry_area_perimeter(orient(polygon))
return poly_area / 1000000.0 # in km^2
def has_local_minimum(storm_area, msl_data):
"""
Check whether the storm area has a local minimum. It has a local minimum,
iff the minimum itself is not along the data boundary.
Parameters
----------
storm_area: xr.DataArray
The binary DataArray of the storm area.
msl_data: xr.DataArray
The DataArray of the mean sea level pressure data.
Returns
-------
True if storm has a local minimum in the area.
"""
all_msl = msl_data.data[storm_area.data]
min_all_data = np.amin(all_msl)
inner_area = ndi.binary_erosion(storm_area.data)
inner_msl = msl_data.data[inner_area]
if inner_msl.size == 0:
return False # too small
min_inner_data = np.amin(inner_msl)
return min_all_data == min_inner_data # true if min not on border
def get_storm_amplitude_hPa(storm_area, msl_data):
"""
Get the amplitude of the storm.
This is a measure between the minimum and average pressure within the storm area.
Parameters
----------
storm_area: xr.DataArray
The binary storm area.
msl_data: xr.DataArray
The mean sea level pressure data.
Returns
-------
The amplitude as measure mean-min in hPa
"""
all_msl = msl_data.data[storm_area.data]
min_p = np.amin(all_msl)
avg_p = np.mean(all_msl)
ampl_p = avg_p - min_p
return ampl_p / 100.0
......@@ -11,7 +11,7 @@ from enstools.feature.util.graph import DataGraph
# set the pb_reference to the compiled pb2.py file (see proto_gen directory)
pipeline = FeaturePipeline(template_pb2, processing_mode='2d')
# change this to an identification technique that actually does something: existing one or implement your own
# change this to an identification strategy that actually does something: existing one or implement your own
i_strat = IdentificationTemplate(some_parameter='foo') # set the Identification strategy
t_strat = TrackingCompareTemplate() # set the tracking strategy
......@@ -28,16 +28,17 @@ od = pipeline.get_object_desc()
for trackable_set in od.sets:
# generate graph out of tracked data
g = DataGraph(trackable_set, t_strat)
g = DataGraph(set_desc=trackable_set, tr_tech=t_strat)
# generate single tracks from tracked data
# returns list of tracks, also gets added to object description.
# Also if apply_filter, keep_track() gets called for every track, a method of TrackingTechnique, which can be overwritten.
# Also if apply_filter, keep_track() gets called for every track, a method of TrackingStrategy, which can be overwritten.
print("GENERATE TRACKS...")
g.generate_tracks(apply_filter=True)
tracks = g.set_desc.tracks
# can perform operations on tracks and nodes...
track = tracks[0]
track = DataGraph(track=tracks[0])
# get earliest and latest nodes of that track
earliest_nodes = track.get_earliest_nodes()
......@@ -51,13 +52,14 @@ for trackable_set in od.sets:
parents = track.get_parents(latest_nodes[0])
print("Parents of a latest node:")
print([p.time + " ID " + str(p.object.id) for p in parents])
childs = track.get_childs(latest_nodes[0])
print("Childs of latest nodes:")
print([c.time + " ID " + str(c.object.id) for c in childs])
children = track.get_children(latest_nodes[0])
print("Children of latest nodes:")
print([c.time + " ID " + str(c.object.id) for c in children])
for track_id, track in enumerate(tracks):
print("Track " + str(track_id) + " has " + str(len(track.graph.edges)) + " nodes.")
print("Track " + str(track_id) + " has " + str(len(DataGraph(track=track).graph.edges)) + " nodes.")
out_json_path = path[:-3] + '_desc.json'
out_graph_path = path[:-3] + '_graph.json'
......
from ..identification import IdentificationTechnique
from ..identification import IdentificationStrategy
import xarray as xr
from random import randrange
class IdentificationTemplate(IdentificationTechnique):
class IdentificationTemplate(IdentificationStrategy):
def __init__(self, some_parameter='', **kwargs):
# Constructor. Called from example_template.py, parameters can be passed and set here.
......
from ..identification import IdentificationTechnique
from ..identification import IdentificationStrategy
from .processing import mask_to_proto, detection_double_thresh
import operator
......@@ -6,7 +6,7 @@ import numpy as np
import xarray as xr
class DoubleThresholdIdentification(IdentificationTechnique):
class DoubleThresholdIdentification(IdentificationStrategy):
def __init__(self, field, outer_threshold, inner_threshold, comparison_operator,
processing_mode="2d", compress=True, **kwargs):
......
......@@ -7,9 +7,9 @@ from os.path import expanduser
# set the pb_reference to the compiled pb2.py file (see proto_gen directory)
pipeline = FeaturePipeline(threshold_pb2, processing_mode='3d')
# change this to an identification technique that actually does something: existing one or implement your own
# change this to an identification strategy that actually does something: existing one or implement your own
path = expanduser("~") + '/PhD/enstools-feature/enstools/data/foo.nc' # ERA5/vietnam/t0_glob.nc' # set data path(s) here
path = expanduser("~") + '/phd/data/enstools-feature/pv_error_soeren.nc' # ERA5/vietnam/t0_glob.nc' # set data path(s) here
pipeline.set_data_path(path)
i_strat = DoubleThresholdIdentification("E",5,10,">",processing_mode='3d') # set the Identification strategy
......
from enstools.feature.identification import IdentificationTechnique
from enstools.feature.tracking import TrackingTechnique
from enstools.feature.identification import IdentificationStrategy
from enstools.feature.tracking import TrackingStrategy
from enstools.feature.util.enstools_utils import get_vertical_dim
from datetime import datetime
import xarray as xr
class FeaturePipeline:
"""
This class encapsules the feature detection pipeline. The pipeline consists of an identification and a tracking procedure.
Feature detection pipeline (identification and tracking).
Parameters
----------
proto_ref
Protobuf template for the representation of identified features.
processing_mode : {'2d', '3d'}
Specify if identification and tracking is performed on 2D levels or in
3D, per 3D block.
"""
def __init__(self, proto_ref, processing_mode='2d'):
"""
Specify processing mode 2d or 3d: In 2d, identification and tracking will be performed on 2d levels, in 3d per 3d block.
Parameters
----------
proto_ref
processing_mode
"""
self.id_tech = None
self.tr_tech = None
......@@ -30,42 +32,51 @@ class FeaturePipeline:
self.pb_reference = proto_ref
def set_identification_strategy(self, strategy: IdentificationTechnique):
def set_identification_strategy(self, strategy: IdentificationStrategy):
"""
Set the strategy to use for the identification.
Parameters
----------
strategy : enstools.feature.identification.IdentificationTechnique
strategy : IdentificationStrategy
The identification strategy to use in the pipeline.
"""
self.id_tech = strategy
self.id_tech.pb_reference = self.pb_reference
self.id_tech.processing_mode = self.processing_mode
pass
def set_tracking_strategy(self, strategy: TrackingTechnique):
def set_tracking_strategy(self, strategy: TrackingStrategy):
"""
Set the strategy to use for the tracking.
Parameters
----------
strategy : enstools.feature.tracking.TrackingTechnique
strategy : TrackingStrategy | None
The tracking strategy to use in the pipeline. Set to `None` or
don't invoke this method at all if no tracking should be carried
out.
"""
self.tr_tech = strategy
if strategy is not None:
self.tr_tech.pb_reference = self.pb_reference
self.tr_tech.processing_mode = self.processing_mode
pass
def set_data_path(self, path):
"""
Set the path to the dataset(s) to process.
This function calls enstools.io.read and therefore can read directories using wildcards.
This function calls :py:func:`enstools.io.read` and therefore can read
directories using wildcards.
Parameters
----------
path : list of str or tuple of str
names of individual files or filename pattern
See Also
--------
:py:meth:`.set_data`
"""
if path is None:
raise Exception("None path provided.")
......@@ -78,7 +89,6 @@ class FeaturePipeline:
def set_data(self, dataset: xr.Dataset):
"""
Set the dataset to process.
The function set_data_path() can be used instead.
Parameters
----------
......@@ -92,9 +102,7 @@ class FeaturePipeline:
self.dataset_path = ""
def execute_identification(self):
"""
Execute the identification strategy.
"""
"""Execute only the identification strategy."""
return_obj_desc_id, return_ds = self.id_tech.execute(self.dataset)
self.object_desc = return_obj_desc_id
if return_ds is not None:
......@@ -104,14 +112,16 @@ class FeaturePipeline:
self.object_desc.run_time = str(datetime.now().isoformat())
def execute_tracking(self):
"""
Execute the tracking strategy.
"""
"""Execute only the tracking strategy."""
self.tr_tech.execute(self.object_desc, self.dataset)
def execute(self):
"""
Execute the feature detection based on the set data and set techniques.
Execute the entire feature detection pipeline.
See Also
--------
:py:meth:`.execute_identification`, :py:meth:`.execute_tracking`
"""
# TODO need API to check if identification output type fits to tracking input type.
......@@ -125,6 +135,21 @@ class FeaturePipeline:
def get_data(self):
return self.dataset
def is_data_3d(self):
"""
Checks if the provided dataset is spatially 3D (has a vertical dim)
Returns
-------
bool
`True` if vertical dim in dataset else `False`.
"""
if self.dataset is None:
raise Exception("None dataset provided.")
vd = get_vertical_dim(self.dataset)
return vd is not None
def get_json_object(self):
"""
Get the JSON type message of the currently saved result.
......@@ -137,8 +162,6 @@ class FeaturePipeline:
json_dataset = MessageToJson(self.object_desc)
return json_dataset
def save_result(self, description_path=None, description_type='json', dataset_path=None):
"""
Save the result of the detection process.
......@@ -178,5 +201,3 @@ class FeaturePipeline:
# TODO do bug report
# from enstools.io import write
# write(self.dataset, dataset_path)
from .tracking import TrackingTechnique
from .tracking import TrackingStrategy
from .object_compare_tracking import ObjectComparisonTracking
......@@ -67,15 +67,15 @@ class AEWTracking(ObjectComparisonTracking):
def get_nodes_after(time_delta, edges, start_index):
this_node = edges[start_index].parent
if time_delta <= zero_dt or len(edges[start_index].childs) == 0:
if time_delta <= zero_dt or len(edges[start_index].children) == 0:
dd = defaultdict(list)
dd[start_index] = [start_index]
return dd
obj_node_list = [e.parent for e in edges]
childs_ = defaultdict(list)
children_ = defaultdict(list)
for child in edges[start_index].childs:
for child in edges[start_index].children:
dt = pb_str_to_datetime(child.time) - pb_str_to_datetime(this_node.time)
remaining_time_delta = time_delta - dt
......@@ -85,9 +85,9 @@ class AEWTracking(ObjectComparisonTracking):
after_nodes = AEWTracking.get_nodes_after(remaining_time_delta, edges, c_node_idx)
# add own node to it (this_node)
for endn_, pathn_ in after_nodes.items():
childs_[endn_].extend(pathn_ + [start_index])
children_[endn_].extend(pathn_ + [start_index])
return childs_
return children_
def adjust_track(self, tracking_graph):
......
from .tracking import TrackingTechnique
from .tracking import TrackingStrategy
from abc import ABC, abstractmethod
from datetime import datetime
from dask import delayed, compute
......@@ -7,10 +7,10 @@ import xarray as xr
from enstools.feature.util.data_utils import pb_str_to_datetime, SplitDimension
class ObjectComparisonTracking(TrackingTechnique):
class ObjectComparisonTracking(TrackingStrategy):
"""
Implementation of a tracking technique which can track objects by a simple pairwise comparison of their feature
descriptions. This acts as an abstract class for comparison tracking techniques.
Implementation of a tracking strategy which can track objects by a simple pairwise comparison of their feature
descriptions. This acts as an abstract class for comparison tracking strategies.
It implements the track() method and provides an abstract correspond() method with gives a binary answer if two
objects of (consecutive) timesteps are the same. Objects do not have neccessarily be associated with consecutive
timesteps. Using set_max_delta_compare() the user can ser a maximum timedelta.
......@@ -50,8 +50,8 @@ class ObjectComparisonTracking(TrackingTechnique):
def track(self, tracking_set, subset: xr.Dataset):
"""
Implementation of track() for tracking techniques which are based on pairwise comparisons of objects. Using
this tracking technique, the correspond() method has to implement a boolean function which returns True if
Implementation of track() for tracking strategies which are based on pairwise comparisons of objects. Using
this tracking strategy, the correspond() method has to implement a boolean function which returns True if
the given object pair of consecutive time steps should be considered as the same object.
Parameters
......
from ..tracking import TrackingTechnique
from ..tracking import TrackingStrategy
from enstools.feature.util.data_utils import print_lock, get_subset_by_description, get_split_dimensions, squeeze_nodes, \
SplitDimension
SplitDimension, pb_str_to_datetime
from enstools.misc import get_time_dim
import xarray as xr
import numpy as np
class OverlapDoubleThresholdTracking(TrackingTechnique):
class OverlapDoubleThresholdTracking(TrackingStrategy):
"""
Implementation of simple overlap tracking: Objects of consecutive timestamps are considered as same if
they spatially overlap. This requires the given field which to check for overlaps.
......@@ -14,9 +14,12 @@ class OverlapDoubleThresholdTracking(TrackingTechnique):
This field should be of data type "int", where the field is
zero, if at the location is no object
i, if object with the ID i (from identification) is at the location.
Optionally, min_duration can be set as a datetime.timedelta, indicating the minimum time an object has
to be alive in filtering.
"""
def __init__(self, inner_thresh_name=None, outer_thresh_name=None,tracking_method = None):
def __init__(self, inner_thresh_name=None, outer_thresh_name=None,tracking_method=None, min_duration=None):
self.inner_thresh_name = inner_thresh_name
self.outer_thresh_name = outer_thresh_name
if tracking_method == "inner_first":
......@@ -26,6 +29,7 @@ class OverlapDoubleThresholdTracking(TrackingTechnique):
else:
raise ValueError("tracking method has to be either inner_first or outer_first")
self.min_duration = min_duration
pass
def track(self, trackable_set, subset: xr.Dataset):
......@@ -76,6 +80,18 @@ class OverlapDoubleThresholdTracking(TrackingTechnique):
print("postprocess @ overlap tracking")
return
def adjust_track(self, track):
# keep all tracks after generating them
return track
def adjust_track(self, tracking_graph):
if self.min_duration is None:
return tracking_graph
# keep track if persists longer than duration_threshold
track = tracking_graph.graph
nodes = [edge.parent for edge in track.edges]
min_time = pb_str_to_datetime(nodes[0].time)
max_time = pb_str_to_datetime(nodes[-1].time)
duration = max_time - min_time
if duration < self.min_duration:
return None
return tracking_graph
from ..tracking import TrackingTechnique
from ..tracking import TrackingStrategy
from enstools.feature.util.data_utils import print_lock, get_subset_by_description, get_split_dimensions, squeeze_nodes, \
SplitDimension
SplitDimension, pb_str_to_datetime
from enstools.misc import get_time_dim
import xarray as xr
import numpy as np
class OverlapTracking(TrackingTechnique):
class OverlapTracking(TrackingStrategy):
"""
Implementation of simple overlap tracking: Objects of consecutive timestamps are considered as same if
they spatially overlap. This requires the given field which to check for overlaps.
......@@ -14,10 +14,14 @@ class OverlapTracking(TrackingTechnique):
This field should be of data type "int", where the field is
zero, if at the location is no object
i, if object with the ID i (from identification) is at the location.
Optionally, min_duration can be set as a datetime.timedelta, indicating the minimum time an object has
to be alive in filtering.
"""
def __init__(self, field_name=None):
def __init__(self, field_name=None, min_duration=None):
self.field_name = field_name
self.min_duration = min_duration
pass
def track(self, trackable_set, subset: xr.Dataset):
......@@ -63,6 +67,18 @@ class OverlapTracking(TrackingTechnique):
print("postprocess @ ooverlap tracking")
return
def adjust_track(self, track):
# keep all tracks after generating them
return track
def adjust_track(self, tracking_graph):
if self.min_duration is None:
return tracking_graph
# keep track if persists longer than duration_threshold
track = tracking_graph.graph
nodes = [edge.parent for edge in track.edges]
min_time = pb_str_to_datetime(nodes[0].time)
max_time = pb_str_to_datetime(nodes[-1].time)
duration = max_time - min_time
if duration < self.min_duration:
return None
return tracking_graph
from ..tracking import TrackingTechnique
from ..tracking import TrackingStrategy
import xarray as xr
from enstools.misc import get_time_dim
class TrackingTemplate(TrackingTechnique):
class TrackingTemplate(TrackingStrategy):
"""
Template for a default "fall-back" tracking technique.
There is also a more specific template called template_object_compare, which provides a template for techniques
Template for a default "fall-back" tracking strategy.
There is also a more specific template called template_object_compare, which provides a template for strategies
which are solely based on comparing the features of the object descriptions of consecutive time steps.
"""
def track(self, trackable_set, subset: xr.Dataset):
"""
Implementation of track() for general tracking techniques.
Implementation of track() for general tracking strategies.
In this template, we iterate over the valid_times of this forecast/reanalysis and compare all objects of two
consecutive timestamps. If they correspond according to some metric, their IDs should be added to the list.
See also the template_object_compare for a more specific case.
......