import numpy as np from abc import ABC, abstractmethod import xarray as xr # from enstools.feature.identification._proto_gen import identification_pb2 from enstools.feature.util.data_utils import get_subset_by_description, squeeze_nodes from multiprocessing.pool import ThreadPool as Pool from functools import partial class TrackingTechnique(ABC): """ Base abstract class for feature tracking algorithms. Implementations need to override the abstract track() method. """ def __init__(self): self.pb_reference = None self.graph = None @abstractmethod def track(self, track_set, subset: xr.Dataset): """ Abstract tracking method. This gets called for each time steps list of the feature descriptions. This timeline can be for example a reforecast or a member of an ensemble forecast, in which detected objects should be tracked over multiple timestamps. This method gets called in parallel for all timelines in the dataset. This method should compute links of corresponding objects of two time steps. See template/ and especially template_object_compare/ for examples Parameters ---------- track_set : iterable of identification_pb2.TrackingSet The data to be tracked subset : xarray.Dataset Subset to be tracked. Forwarded from identification Returns ------- connections : list of pb_reference.RefGraphConnections The connections forming the path of the objects in the time steps. """ connections = [] return connections @abstractmethod def postprocess(self, object_desc): """ Abstract method for the postprocess of the tracking. Parameters ---------- object_desc: The whole pb2.DatasetDescription, can be altered inplace """ pass def track_set(self, set_idx, obj_desc=None, dataset=None): """ Tracks a given TrackableSet. Parameters ---------- set_idx: index of the set in the object description obj_desc: the object description dataset: the dataset Returns ------- """ # get according subset obj_set = obj_desc.sets[set_idx] dataset_sel = get_subset_by_description(dataset, obj_set) # track this set print("Track " + str(obj_set)[:40] + " (" + str(set_idx) + ")") nodes = self.track(obj_set, dataset_sel) # TODO check subsets in here on more complex DS # squeeze result # TODO order? nodes = squeeze_nodes(nodes) # add graph nodes to ref graph obj_set.ref_graph.nodes.extend(nodes) def execute(self, object_desc, dataset_ref: xr.Dataset): """ Execute the tracking procedure. The description is split into the different timelines which can be executed in parallel. Parameters ---------- object_desc : identification_pb2.DatasetDescription TODO? The description of the detected features from the identificaiton technique. dataset_ref : xarray.Dataset Reference to the dataset used in the pipeline. """ # parallel for all tracking sets pool = Pool() pool.map(partial(self.track_set, obj_desc=object_desc, dataset=dataset_ref), range(len(object_desc.sets))) # iterate over sets. self.postprocess(object_desc) # only postprocess object desc, THEN to graph # TODO what if tracks are the identification? e.g. AEW identification in Hovmoller pass def get_graph(self): """ Get graph, generate if not existent. Returns ------- """ if self.graph is not None: return self.graph # use cached else: print("Generate graph first before getting. (generate_graph()). Exit.") exit(1) # generate object graph from ref graph def generate_graph(self, object_desc): """ Generates the tacking graph from the object graph with tracking info. After executing track(), the object description contains the connections, but the graph will have a nicer structure. Parameters ---------- object_desc: object description Returns ------- Sets and returns the graph (DatasetDescription type) """ self.graph = self.pb_reference.DatasetDescription() self.graph.CopyFrom(object_desc) for set_idx, objdesc_set in enumerate(object_desc.sets): graph_set = self.graph.sets[set_idx] # empty the graph set del graph_set.timesteps[:] graph_set.ref_graph.Clear() # for each object add (n1,[]) to graph. then add connections. for idx_ts, ts in enumerate(objdesc_set.timesteps): cur_time = ts.valid_time for obj in ts.objects: obj_connection = graph_set.object_graph.nodes.add() obj_connection.this_node.time = cur_time obj_connection.this_node.object.CopyFrom(obj) # for each connection in ref graph: if n1 is current object: get n2 objects and add them for ref_connection in objdesc_set.ref_graph.nodes: if ref_connection.this_node.time_index == idx_ts and ref_connection.this_node.object_id == obj.id: for ref_n2 in ref_connection.connected_nodes: obj_n2 = obj_connection.connected_nodes.add() obj_n2.time = objdesc_set.timesteps[ref_n2.time_index].valid_time obj_index_of_n2_obj = [objindex_ for objindex_, obj_ in enumerate(objdesc_set.timesteps[ref_n2.time_index].objects) if obj_.id == ref_n2.object_id][0] obj_n2.object.CopyFrom(objdesc_set.timesteps[ref_n2.time_index].objects[obj_index_of_n2_obj]) return self.graph # filter the generated tracks: for each track call the keep_track() function. def filter_tracks(self): """ Filter the generated tracks. For each track call the keep_track() function. Returns ------- """ for set_ in self.graph.sets: for t_id in range(len(set_.tracks) - 1, -1, -1): track = set_.tracks[t_id] if not self.keep_track(track): del set_.tracks[t_id] def keep_track(self, track): """ Parameters ---------- track: the pb2.ObjectGraph track Returns ------- True if keep, else discard """ return True # TODO overlap tracking def. input just field name # # TODO mention heuristic in docstring. def generate_tracks(self): """ After tracking graph has been computed, here, tracks can be computed, which are a disjoint subset of graphs of the total graph. It is based on a simple heuristic. The nodes are ordered by time. For each non-classified node, all downstream nodes are searched. If any of these nodes is already classified, use the same ID. Otherwise give this stream a new id (new track). Returns ------- Nothing, tracks are added to the graph_desc inplace. """ if self.graph is None: print("Compute graph first.") exit(1) # for each set: # order connections by time of first node for graph_set in self.graph.sets: # object_set = self.obj_ref.set_idx # sort nodes by time of first node (is in key), as list here time_sorted_nodes = list(sorted(graph_set.object_graph.nodes, key=lambda item: item.this_node.time)) wave_id_per_node = [None] * len(time_sorted_nodes) cur_id = 0 # iterate over all time sorted identified connections # search temporal downstream tracked troughs and group them using a set id for con_idx, oc in enumerate(time_sorted_nodes): if wave_id_per_node[con_idx] is not None: # already part of a wave continue # not part of wave -> get (temporal) downstream connections = wave (return indices of them) downstream_wave_node_indices = TrackingTechnique.get_downstream_node_indices(time_sorted_nodes, con_idx) print(str(con_idx) + " -> " + str(downstream_wave_node_indices)) # any of downstream nodes already part of a wave? connected_wave_id = None for ds_node_idx in downstream_wave_node_indices: if wave_id_per_node[ds_node_idx] is not None: if connected_wave_id is not None: print("Double ID, better resolve todo...") # TODO connected_wave_id = wave_id_per_node[ds_node_idx] # if so set all nodes to this found id if connected_wave_id is not None: for ds_node_idx in downstream_wave_node_indices: wave_id_per_node[ds_node_idx] = connected_wave_id continue # else new path for all wave_nodes cur_id_needs_update = False for ds_node_idx in downstream_wave_node_indices: wave_id_per_node[ds_node_idx] = cur_id cur_id_needs_update = True if cur_id_needs_update: cur_id += 1 # done, now extract every wave by id and put them into subgraphs for wave_id in range(cur_id): track = self.pb_reference.ObjectGraph() wave_idxs = [i for i in range(len(wave_id_per_node)) if wave_id_per_node[i] == wave_id] cur_wave_nodes = [time_sorted_nodes[i] for i in wave_idxs] # troughs of this wave # cur_troughs = cur_troughs.sortbytime # already sorted? track.nodes.extend(cur_wave_nodes) graph_set.tracks.append(track) return @staticmethod def get_downstream_node_indices(graph_list, start_idx): """ Helper method for the track generation. Searches all downstream node indices. Returns ------- list of downstream indices in list """ node_indices = [start_idx] co = graph_list[start_idx] node, connected_nodes = co.this_node, co.connected_nodes for c_node in connected_nodes: # get index of connected node in graph obj_node_list = [con_.this_node for con_ in graph_list] c_node_idx = obj_node_list.index(c_node) # call recursively on this connected node c_node_downstream_indices = TrackingTechnique.get_downstream_node_indices(graph_list, c_node_idx) node_indices.extend(c_node_downstream_indices) return list(set(node_indices))