Skip to content
Snippets Groups Projects
tracking.py 9.60 KiB
import numpy as np
from abc import ABC, abstractmethod
import xarray as xr
from enstools.feature.util.data_utils import print_lock, get_subset_by_description, get_split_dimensions, squeeze_nodes, SplitDimension
from multiprocessing.pool import ThreadPool as Pool
from functools import partial


class TrackingTechnique(ABC):
    """
    Base abstract class for feature tracking algorithms.
    Implementations need to override the abstract track() method.
    """

    def __init__(self):
        self.pb_reference = None

    @abstractmethod
    def track(self, track_set, subset: xr.Dataset):
        """
        Abstract tracking method. This gets called for each time steps list of the feature descriptions.
        This timeline can be for example a reforecast or a member of an ensemble forecast, in which detected objects should
        be tracked over multiple timestamps. This method gets called in parallel for all timelines in the dataset.

        This method should compute links of corresponding objects of two time steps.
        See template/ and especially template_object_compare/ for examples

        Parameters
        ----------
        track_set : iterable of identification_pb2.TrackingSet
                The data to be tracked
        subset : xarray.Dataset
                Subset to be tracked. Forwarded from identification

        Returns
        -------
        connections : list of pb_reference.RefGraphConnections
                The connections forming the path of the objects in the time steps.
        """
        connections = []
        return connections

    @abstractmethod
    def postprocess(self, object_desc):
        """
        Abstract method for the postprocess of the tracking.

        Parameters
        ----------
        object_desc: The whole pb2.DatasetDescription, can be altered inplace
        """
        pass

    def track_set(self, set_idx, obj_desc=None, dataset=None):
        """
        Tracks a given TrackableSet.

        Parameters
        ----------
        set_idx: index of the set in the object description
        obj_desc: the object description
        dataset: the dataset

        Returns
        -------
        """

        # get according subset
        obj_set = obj_desc.sets[set_idx]
        dataset_sel = get_subset_by_description(dataset, obj_set, self.processing_mode)
        split_dims = get_split_dimensions(dataset, self.processing_mode)

        split_string = '; '.join([str(dim.name) + ": " + str(getattr(obj_set, dim.dim)) for dim in split_dims])
        print_lock("Start tracking data block with dimensions:   " + split_string)

        # track this set
        cons = self.track(obj_set, dataset_sel)

        # squeeze result
        # build empty graph nodes
        # for each object add (n1,[]) to list. then add connections.
        empty_cons = []
        for idx_ts, ts in enumerate(obj_set.timesteps):
            cur_time = ts.valid_time
            for obj in ts.objects:
                empty_cons.append(self.get_new_connection(cur_time, obj))

        cons = squeeze_nodes(list(empty_cons) + list(cons))
        # sort them by time (string key)
        cons = sorted(cons, key=lambda c: c.this_node.time)

        # add graph nodes to ref graph
        obj_set.graph.nodes.extend(cons)


    def execute(self, object_desc, dataset_ref: xr.Dataset):
        """
        Execute the tracking procedure.
        The description is split into the different timelines which can be executed in parallel.

        Parameters
        ----------
        object_desc : pb2.DatasetDescription
                The description of the detected features from the identificaiton technique.
        dataset_ref : xarray.Dataset
                Reference to the dataset used in the pipeline.
        """

        # parallel for all tracking sets
        pool = Pool()
        pool.map(partial(self.track_set, obj_desc=object_desc, dataset=dataset_ref),
                 range(len(object_desc.sets)))  # iterate over sets.

        self.postprocess(object_desc)  # only postprocess object desc, THEN to graph

        pass


    # filter the generated tracks: for each track call the keep_track() function.
    def filter_tracks(self, object_desc):
        """
        Filter the generated tracks. For each track call the keep_track() function.

        Returns
        -------
        """
        for set_ in object_desc.sets:
            for t_id in range(len(set_.tracks) - 1, -1, -1):
                track = set_.tracks[t_id]
                if not self.keep_track(track):
                    del set_.tracks[t_id]

    def keep_track(self, track):
        """

        Parameters
        ----------
        track: the pb2.ObjectGraph track

        Returns
        -------
        True if keep, else discard
        """
        return True

    def generate_tracks(self, object_desc):
        """
        After tracking graph has been computed, here, tracks can be computed, which are a disjoint subset of graphs
        of the total graph. It is based on a simple heuristic. The nodes are ordered by time. For each non-classified
        node, all downstream nodes are searched. If any of these nodes is already classified, use the same ID.
        Otherwise give this stream a new id (new track).

        Returns
        -------
        Nothing, tracks are added to the graph_desc inplace.
        """

        # for each set:
        #  order connections by time of first node
        for graph_set in object_desc.sets:

            # sort nodes by time of first node (is in key), as list here
            time_sorted_nodes = list(sorted(graph_set.graph.nodes, key=lambda item: item.this_node.time))

            wave_id_per_node = [None] * len(time_sorted_nodes)
            cur_id = 0

            # iterate over all time sorted identified connections
            # search temporal downstream tracked troughs and group them using a set id
            for con_idx, oc in enumerate(time_sorted_nodes):

                if wave_id_per_node[con_idx] is not None:  # already part of a wave
                    continue

                # not part of wave -> get (temporal) downstream connections = wave (return indices of them)
                downstream_wave_node_indices = TrackingTechnique.get_downstream_node_indices(time_sorted_nodes, con_idx)

                # any of downstream nodes already part of a wave?
                connected_wave_id = None
                for ds_node_idx in downstream_wave_node_indices:
                    if wave_id_per_node[ds_node_idx] is not None:
                        # if connected_wave_id is not None:
                        #     print("Double ID, better resolve") # TODO
                        connected_wave_id = wave_id_per_node[ds_node_idx]

                # if so set all nodes to this found id
                if connected_wave_id is not None:
                    for ds_node_idx in downstream_wave_node_indices:
                        wave_id_per_node[ds_node_idx] = connected_wave_id
                    continue

                # else new path for all wave_nodes
                cur_id_needs_update = False
                for ds_node_idx in downstream_wave_node_indices:
                    wave_id_per_node[ds_node_idx] = cur_id
                    cur_id_needs_update = True

                if cur_id_needs_update:
                    cur_id += 1

            # done, now extract every wave by id and put them into subgraphs

            for wave_id in range(cur_id):
                track = self.pb_reference.ObjectGraph()
                wave_idxs = [i for i in range(len(wave_id_per_node)) if wave_id_per_node[i] == wave_id]

                cur_wave_nodes = [time_sorted_nodes[i] for i in wave_idxs]  # troughs of this wave
                # cur_troughs = cur_troughs.sortbytime # already sorted?
                track.nodes.extend(cur_wave_nodes)
                graph_set.tracks.append(track)

        return

    @staticmethod
    def get_downstream_node_indices(graph_list, start_idx):
        """
        Helper method for the track generation. Searches all downstream node indices.

        Returns
        -------
        list of downstream indices in list
        """
        node_indices = [start_idx]

        co = graph_list[start_idx]
        node, connected_nodes = co.this_node, co.connected_nodes

        for c_node in connected_nodes:
            # get index of connected node in graph
            obj_node_list = [con_.this_node for con_ in graph_list]#this up
            c_node_idx = obj_node_list.index(c_node)

            # call recursively on this connected node
            c_node_downstream_indices = TrackingTechnique.get_downstream_node_indices(graph_list, c_node_idx)
            node_indices.extend(c_node_downstream_indices)

        return list(set(node_indices))

    def get_new_connection(self, start_time, start_obj, end_time=None, end_obj=None):
        """
        Create a new connection for the tracking graph and populates it.
        Takes start time and object, and if existent end time and object.

        Parameters
        ----------
        start_time : start time (str)
        start_obj : start object
        end_time : end time (str)
        end_obj : end object

        Returns
        -------
        the pb2.GraphConnection
        """
        new_connection = self.pb_reference.GraphConnection()

        new_connection.this_node.time = start_time
        new_connection.this_node.object.CopyFrom(start_obj)
        if end_time is not None and end_obj is not None:
            n2 = new_connection.connected_nodes.add()
            n2.time = end_time
            n2.object.CopyFrom(end_obj)

        return new_connection

# TODO what if tracks are the identification? e.g. AEW identification in Hovmoller