Newer
Older
import numpy as np
from abc import ABC, abstractmethod
import xarray as xr
# from enstools.feature.identification._proto_gen import identification_pb2
from enstools.feature.util.data_utils import get_subset_by_description
from multiprocessing.pool import ThreadPool as Pool
from functools import partial
class TrackingTechnique(ABC):
"""
Base abstract class for feature tracking algorithms.
Implementations need to override the abstract track() method.
"""
def __init__(self):
self.pb_reference = None
self.graph = None
def track(self, track_set, subset: xr.Dataset): # TODO update docstrings
Abstract tracking method. This gets called for each timesteps list of the feature descriptions.
This timeline can be for example a reforecast or a member of an ensemble forecast, in which detected objects should
be tracked over multiple timestamps. This method gets called in parallel for all timelines in the dataset.
This method should compute links of corresponding objects of two consecutive timestamps. Each object in a
timestep has its unique ID, and a computed tuple (id1, id2) remarks that object id1 from timestamp t is the
same object as id2 from timestamp t+1. One tuple is an edge in a tracking graph.
Parameters
----------
track_set : iterable of identification_pb2.TrackingSet
The data to be tracked
subset : xarray.Dataset
Subset to be tracked. Forwarded from identification
Returns
-------
connections : list of tuples of int
The connections forming the path of the objects in the timesteps.
"""
connections = []
return connections
@abstractmethod
def postprocess(self, object_desc): # TODO update docstrings
pass
def track_set(self, set_idx, obj_desc=None, dataset=None):
"""
graph_set = self.graph.sets[set_idx]
obj_set = self.pb_reference.TrackableSet()
obj_set.CopyFrom(graph_set)
# empty the graph set
del graph_set.timesteps[:]
refg = graph_set.ref_graph
del refg
"""
obj_set = obj_desc.sets[set_idx]
# get according subset
dataset_sel = get_subset_by_description(dataset, obj_set)
print("Track " + str(obj_set)[:40] + " (" + str(set_idx) + ")")
connections = self.track(obj_set, dataset_sel) # TODO check subsets in here on more complex DS
obj_set.ref_graph.connections.extend(connections)
# create object connections from index connections for output graph
# not really efficient...
def execute(self, object_desc, dataset_ref: xr.Dataset):
"""
Execute the tracking procedure.
The description is split into the different timelines which can be executed in parallel.
Parameters
----------
object_desc : identification_pb2.DatasetDescription TODO?
The description of the detected features from the identificaiton technique.
dataset_ref : xarray.Dataset
Reference to the dataset used in the pipeline.
"""
# parallel for all tracking sets
pool = Pool()
pool.map(partial(self.track_set, obj_desc=object_desc, dataset=dataset_ref),
range(len(object_desc.sets))) # iterate over sets.
self.postprocess(object_desc) # only postprocess object desc, THEN to graph
# TODO next link pairs together to "path"
# track then has N-list of IDs
# TODO what if tracks are the identification? e.g. AEW identification in Hovmoller
pass
def get_graph(self):
if self.graph is not None:
return self.graph # use cached
self.generate_graph()
return self.graph
self.graph = self.pb_reference.DatasetDescription()
self.graph.CopyFrom(object_desc)
for c in connections: # TODO this over the weekend, prepare case studies in met3d -> Montag AF -> Di Pres.
obj_con = self.pb_reference.ObjectConnection()
start, end = c.n1, c.n2
obj_con.n1.time = obj_set.timesteps[start.time_index].valid_time
obj_with_startid = \
[objindex for objindex, obj in enumerate(obj_set.timesteps[start.time_index].objects) if
obj.id == start.object_id][0]
obj_con.n1.object.CopyFrom(obj_set.timesteps[start.time_index].objects[obj_with_startid])
obj_con.n2.time = obj_set.timesteps[end.time_index].valid_time
obj_with_endid = [objindex for objindex, obj in enumerate(obj_set.timesteps[end.time_index].objects) if
obj.id == end.object_id][0]
obj_con.n2.object.CopyFrom(obj_set.timesteps[end.time_index].objects[obj_with_endid])
graph_set.object_graph.connections.append(obj_con)