import numpy as np from abc import ABC, abstractmethod import xarray as xr # from enstools.feature.identification._proto_gen import identification_pb2 from enstools.feature.util.data_utils import get_subset_by_description, squeeze_nodes from multiprocessing.pool import ThreadPool as Pool from functools import partial class TrackingTechnique(ABC): """ Base abstract class for feature tracking algorithms. Implementations need to override the abstract track() method. """ def __init__(self): self.pb_reference = None self.graph = None @abstractmethod def track(self, track_set, subset: xr.Dataset): # TODO update docstrings """ Abstract tracking method. This gets called for each timesteps list of the feature descriptions. This timeline can be for example a reforecast or a member of an ensemble forecast, in which detected objects should be tracked over multiple timestamps. This method gets called in parallel for all timelines in the dataset. This method should compute links of corresponding objects of two consecutive timestamps. Each object in a timestep has its unique ID, and a computed tuple (id1, id2) remarks that object id1 from timestamp t is the same object as id2 from timestamp t+1. One tuple is an edge in a tracking graph. Parameters ---------- track_set : iterable of identification_pb2.TrackingSet The data to be tracked subset : xarray.Dataset Subset to be tracked. Forwarded from identification Returns ------- connections : list of tuples of int The connections forming the path of the objects in the timesteps. """ connections = [] return connections @abstractmethod def postprocess(self, object_desc): # TODO update docstrings pass def track_set(self, set_idx, obj_desc=None, dataset=None): obj_set = obj_desc.sets[set_idx] # get according subset dataset_sel = get_subset_by_description(dataset, obj_set) print("Track " + str(obj_set)[:40] + " (" + str(set_idx) + ")") nodes = self.track(obj_set, dataset_sel) # TODO check subsets in here on more complex DS print("Nodes before squeezed: " + str(len(nodes))) nodes = squeeze_nodes(nodes) print("Nodes after squeezed: " + str(len(nodes))) exit() # TODO then if squeezed generate tracks. obj_set.ref_graph.connections.extend(nodes) # create object connections from index connections for output graph # not really efficient... def execute(self, object_desc, dataset_ref: xr.Dataset): """ Execute the tracking procedure. The description is split into the different timelines which can be executed in parallel. Parameters ---------- object_desc : identification_pb2.DatasetDescription TODO? The description of the detected features from the identificaiton technique. dataset_ref : xarray.Dataset Reference to the dataset used in the pipeline. """ # parallel for all tracking sets pool = Pool() pool.map(partial(self.track_set, obj_desc=object_desc, dataset=dataset_ref), range(len(object_desc.sets))) # iterate over sets. self.postprocess(object_desc) # only postprocess object desc, THEN to graph # TODO what if tracks are the identification? e.g. AEW identification in Hovmoller pass def get_graph(self): if self.graph is not None: return self.graph # use cached self.generate_graph() return self.graph def generate_graph(self, object_desc): self.graph = self.pb_reference.DatasetDescription() self.graph.CopyFrom(object_desc) for set_idx, objdesc_set in enumerate(object_desc.sets): graph_set = self.graph.sets[set_idx] # empty the graph set del graph_set.timesteps[:] graph_set.ref_graph.Clear() # refg for c in objdesc_set.ref_graph.connections: obj_con = self.pb_reference.ObjectConnection() start, end = c.n1, c.n2 obj_con.n1.time = objdesc_set.timesteps[start.time_index].valid_time obj_with_startid = \ [objindex for objindex, obj in enumerate(objdesc_set.timesteps[start.time_index].objects) if obj.id == start.object_id][0] obj_con.n1.object.CopyFrom(objdesc_set.timesteps[start.time_index].objects[obj_with_startid]) obj_con.n2.time = objdesc_set.timesteps[end.time_index].valid_time obj_with_endid = [objindex for objindex, obj in enumerate(objdesc_set.timesteps[end.time_index].objects) if obj.id == end.object_id][0] obj_con.n2.object.CopyFrom(objdesc_set.timesteps[end.time_index].objects[obj_with_endid]) graph_set.object_graph.connections.append(obj_con) return self.graph # TODO add (n1,None) pairs for isolated nodes # TODO also for endnodes (n1,None) --> so n1 in (n1,n2) covers ALL nodes. # after graph has been computed, compute "tracks", which is disjunct list of graphs of the total graph def get_tracks(self): if self.graph is None: print("Compute graph first.") exit(1) # for each set: # order connections by time of first node print("get tracks") for graph_set in self.graph.sets: # object_set = self.obj_ref.set_idx print("new set get tracks") # sort nodes by time of first node (is in key), as list here # TODO assert comp by string yields time sorted list (yyyymmddd) time_sorted_nodes = list(sorted(graph_set.object_graph.connections, key=lambda item: item.n1.time)) wave_id_per_node = [None] * len(time_sorted_nodes) cur_id = 0 # iterate over all time sorted identified connections # search temporal downstream tracked troughs and group them using a set id for con_idx, (n1, n2) in enumerate(time_sorted_nodes): # TODO dont need n1,n2? if wave_id_per_node[con_idx] is not None: # already part of a wave continue # not part of wave -> get (temporal) downstream connections = wave (return indices of them) downstream_wave_node_indices = get_downstream_node_indices(time_sorted_nodes, con_idx) # any of downstream nodes already part of a wave? connected_wave_id = None for ds_node_idx in downstream_wave_node_indices: if wave_id_per_node[ds_node_idx] is not None: if connected_wave_id is not None: print("Double ID...") connected_wave_id = wave_id_per_node[ds_node_idx] # if so set all nodes to this found id if connected_wave_id is not None: for ds_node_idx in downstream_wave_node_indices: wave_id_per_node[ds_node_idx] = connected_wave_id continue # else new path for all wave_nodes cur_id_needs_update = False for ds_node_idx in downstream_wave_node_indices: wave_id_per_node[ds_node_idx] = cur_id cur_id_needs_update = True if cur_id_needs_update: cur_id += 1 print(wave_id_per_node) # done, now extract every wave by id and put them into subgraphs waves = [] # generate list of waves according to above policy for wave_id in range(cur_id): wave_idxs = [i for i in range(len(wave_id_per_node)) if wave_id_per_node[i] == wave_id] cur_wave_nodes = [time_sorted_nodes[i] for i in wave_idxs] # troughs of this wave # cur_troughs = cur_troughs.sortbytime # already sorted? wave = cur_wave_nodes print(wave) print(waves) # waves for this set only return None """ @staticmethod def get_downstream_node_indices(graph_list, start_idx): node_indices = [start_idx] node, connected_nodes = graph_list[start_idx] for c_node in connected_nodes: # get index of connected node in graph c_node_idx = [n_ for n_, cv_ in graph_list].index(c_node) # call recursively on this connected node c_node_downstream_indices = WaveGraph.get_downstream_node_indices(graph_list, c_node_idx) node_indices.extend(c_node_downstream_indices) return list(set(node_indices)) def get_items(self): return self._graph.items() # extract wave objects from graph: question here: what is a wave? # e.g. before/after merging, what belongs to same wave # wave object is list of wavetroughs, so wavetroughs for consecutive timesteps def extract_waves(self): """