Skip to content
Snippets Groups Projects
tracking.py 9.78 KiB
Newer Older
import numpy as np
from abc import ABC, abstractmethod
import xarray as xr
# from enstools.feature.identification._proto_gen import identification_pb2
from enstools.feature.util.data_utils import get_subset_by_description, squeeze_nodes
from multiprocessing.pool import ThreadPool as Pool
from functools import partial


class TrackingTechnique(ABC):
    """
    Base abstract class for feature tracking algorithms.
    Implementations need to override the abstract track() method.
    """

Christoph.Fischer's avatar
Christoph.Fischer committed
    def __init__(self):
        self.pb_reference = None
        self.graph = None

    def track(self, track_set, subset: xr.Dataset):  # TODO update docstrings
        Abstract tracking method. This gets called for each timesteps list of the feature descriptions.
        This timeline can be for example a reforecast or a member of an ensemble forecast, in which detected objects should
        be tracked over multiple timestamps. This method gets called in parallel for all timelines in the dataset.

        This method should compute links of corresponding objects of two consecutive timestamps. Each object in a
        timestep has its unique ID, and a computed tuple (id1, id2) remarks that object id1 from timestamp t is the
        same object as id2 from timestamp t+1. One tuple is an edge in a tracking graph.

        Parameters
        ----------
        track_set : iterable of identification_pb2.TrackingSet
                The data to be tracked
        subset : xarray.Dataset
                Subset to be tracked. Forwarded from identification

        Returns
        -------
        connections : list of tuples of int
                The connections forming the path of the objects in the timesteps.
        """
        connections = []
        return connections

    @abstractmethod
    def postprocess(self, object_desc):  # TODO update docstrings
        pass

    def track_set(self, set_idx, obj_desc=None, dataset=None):

        obj_set = obj_desc.sets[set_idx]
        # get according subset
        dataset_sel = get_subset_by_description(dataset, obj_set)

        print("Track " + str(obj_set)[:40] + " (" + str(set_idx) + ")")
        nodes = self.track(obj_set, dataset_sel)  # TODO check subsets in here on more complex DS
        nodes = squeeze_nodes(nodes)
        obj_set.ref_graph.nodes.extend(nodes)
        # create object connections from index connections for output graph
        # not really efficient...
    def execute(self, object_desc, dataset_ref: xr.Dataset):
        """
        Execute the tracking procedure.
        The description is split into the different timelines which can be executed in parallel.
        object_desc : identification_pb2.DatasetDescription TODO?
                The description of the detected features from the identificaiton technique.
        dataset_ref : xarray.Dataset
                Reference to the dataset used in the pipeline.
        """

        # parallel for all tracking sets
        pool = Pool()
        pool.map(partial(self.track_set, obj_desc=object_desc, dataset=dataset_ref),
                 range(len(object_desc.sets)))  # iterate over sets.
        self.postprocess(object_desc)  # only postprocess object desc, THEN to graph
        # TODO what if tracks are the identification? e.g. AEW identification in Hovmoller
        pass
    def get_graph(self):
        if self.graph is not None:
            return self.graph  # use cached
        self.generate_graph()
        return self.graph
    # generate object graph from ref graph
    def generate_graph(self, object_desc):
        self.graph = self.pb_reference.DatasetDescription()
        self.graph.CopyFrom(object_desc)
Christoph.Fischer's avatar
Christoph.Fischer committed
        for set_idx, objdesc_set in enumerate(object_desc.sets):
            graph_set = self.graph.sets[set_idx]

            # empty the graph set
            del graph_set.timesteps[:]
            graph_set.ref_graph.Clear()

            # for each object add (n1,[]) to graph. then add connections.
            for idx_ts, ts in enumerate(objdesc_set.timesteps):
                cur_time = ts.valid_time
                for obj in ts.objects:
                    obj_connection = graph_set.object_graph.nodes.add()
                    obj_connection.this_node.time = cur_time
                    obj_connection.this_node.object.CopyFrom(obj)

                    # for each connection in ref graph: if n1 is current object: get n2 objects and add them
                    for ref_connection in objdesc_set.ref_graph.nodes:
                        if ref_connection.this_node.time_index == idx_ts and ref_connection.this_node.object_id == obj.id:
                            for ref_n2 in ref_connection.connected_nodes:
                                obj_n2 = obj_connection.connected_nodes.add()
                                obj_n2.time = objdesc_set.timesteps[ref_n2.time_index].valid_time
                                obj_index_of_n2_obj = [objindex_ for objindex_, obj_ in enumerate(objdesc_set.timesteps[ref_n2.time_index].objects) if obj_.id == ref_n2.object_id][0]
                                obj_n2.object.CopyFrom(objdesc_set.timesteps[ref_n2.time_index].objects[obj_index_of_n2_obj])
        return self.graph
Christoph.Fischer's avatar
Christoph.Fischer committed
    # filter the generated tracks: for each track call the keep_track() function.
    # TODO assert sorted?
    def filter_tracks(self):
Christoph.Fischer's avatar
Christoph.Fischer committed
        for set_ in self.graph.sets:
            for t_id in range(len(set_.tracks) - 1, -1, -1):
                track = set_.tracks[t_id]
                if not self.keep_track(track):
                    del set_.tracks[t_id]

    def keep_track(self, track):
        return True
Christoph.Fischer's avatar
Christoph.Fischer committed
    # TODO overlap tracking def. input just field name
    # after graph has been computed, compute "tracks", which is disjunct list of graphs of the total graph
Christoph.Fischer's avatar
Christoph.Fischer committed
    # TODO mention heuristic in docstring.
    def generate_tracks(self):
        if self.graph is None:
            print("Compute graph first.")
            exit(1)

        # for each set:
        #  order connections by time of first node
        for graph_set in self.graph.sets:
            # object_set = self.obj_ref.set_idx

            # sort nodes by time of first node (is in key), as list here
            # TODO assert comp by string yields time sorted list (yyyymmddd)
            time_sorted_nodes = list(sorted(graph_set.object_graph.nodes, key=lambda item: item.this_node.time))

            wave_id_per_node = [None] * len(time_sorted_nodes)
            cur_id = 0

            # iterate over all time sorted identified connections
            # search temporal downstream tracked troughs and group them using a set id
            for con_idx, oc in enumerate(time_sorted_nodes):

                if wave_id_per_node[con_idx] is not None:  # already part of a wave
                    continue

                # not part of wave -> get (temporal) downstream connections = wave (return indices of them)
                downstream_wave_node_indices = TrackingTechnique.get_downstream_node_indices(time_sorted_nodes, con_idx)
                print(str(con_idx) + " -> " + str(downstream_wave_node_indices))

                # any of downstream nodes already part of a wave?
                connected_wave_id = None
                for ds_node_idx in downstream_wave_node_indices:
                    if wave_id_per_node[ds_node_idx] is not None:
                        if connected_wave_id is not None:
                            print("Double ID, better resolve todo...")
                        connected_wave_id = wave_id_per_node[ds_node_idx]

                # if so set all nodes to this found id
                if connected_wave_id is not None:
                    for ds_node_idx in downstream_wave_node_indices:
                        wave_id_per_node[ds_node_idx] = connected_wave_id
                    continue

                # else new path for all wave_nodes
                cur_id_needs_update = False
                for ds_node_idx in downstream_wave_node_indices:
                    wave_id_per_node[ds_node_idx] = cur_id
                    cur_id_needs_update = True

                if cur_id_needs_update:
                    cur_id += 1

            # done, now extract every wave by id and put them into subgraphs
            waves = []  # generate list of waves according to above policy

            for wave_id in range(cur_id):
Christoph.Fischer's avatar
Christoph.Fischer committed
                track = self.pb_reference.ObjectGraph()
                wave_idxs = [i for i in range(len(wave_id_per_node)) if wave_id_per_node[i] == wave_id]

                cur_wave_nodes = [time_sorted_nodes[i] for i in wave_idxs]  # troughs of this wave
                # cur_troughs = cur_troughs.sortbytime # already sorted?
Christoph.Fischer's avatar
Christoph.Fischer committed
                track.nodes.extend(cur_wave_nodes)
                graph_set.tracks.append(track)
                # print(wave)
            # print(waves) # waves for this set only
        return None

    @staticmethod
    def get_downstream_node_indices(graph_list, start_idx):
        node_indices = [start_idx]

        co = graph_list[start_idx]
        node, connected_nodes = co.this_node, co.connected_nodes
        for c_node in connected_nodes:
            # get index of connected node in graph
            obj_node_list = [con_.this_node for con_ in graph_list]
            c_node_idx = obj_node_list.index(c_node)

            # call recursively on this connected node
            c_node_downstream_indices = TrackingTechnique.get_downstream_node_indices(graph_list, c_node_idx)
            node_indices.extend(c_node_downstream_indices)

        return list(set(node_indices))

    def get_items(self):
        return self._graph.items()

    # extract wave objects from graph: question here: what is a wave?
    #   e.g. before/after merging, what belongs to same wave
    # wave object is list of wavetroughs, so wavetroughs for consecutive timesteps
    def extract_waves(self):