Skip to content
Snippets Groups Projects
tracking.py 10.4 KiB
Newer Older
import numpy as np
from abc import ABC, abstractmethod
import xarray as xr
from enstools.feature.util.data_utils import print_lock, get_subset_by_description, get_split_dimensions, squeeze_nodes, \
    SplitDimension
from multiprocessing.pool import ThreadPool as Pool
from functools import partial
from enstools.feature.util.data_utils import pb_str_to_datetime

class TrackingTechnique(ABC):
    """
    Base abstract class for feature tracking algorithms.
    Implementations need to override the abstract track() method.
    """

Christoph.Fischer's avatar
Christoph.Fischer committed
    def __init__(self):
        self.pb_reference = None

    def track(self, track_set, subset: xr.Dataset):
        Abstract tracking method. This gets called for each time steps list of the feature descriptions.
        This timeline can be for example a reforecast or a member of an ensemble forecast, in which detected objects should
        be tracked over multiple timestamps. This method gets called in parallel for all timelines in the dataset.

        This method should compute links of corresponding objects of two time steps.
        See template/ and especially template_object_compare/ for examples
        track_set : iterable of identification_pb2.TrackingSet
                The data to be tracked
        subset : xarray.Dataset
                Subset to be tracked. Forwarded from identification

        Returns
        -------
        connections : list of pb_reference.RefGraphConnections
                The connections forming the path of the objects in the time steps.
        """
        connections = []
        return connections

    @abstractmethod
    def postprocess(self, object_desc):
        """
        Abstract method for the postprocess of the tracking.

        Parameters
        ----------
        object_desc: The whole pb2.DatasetDescription, can be altered inplace
        """
        pass

    def track_set(self, set_idx, obj_desc=None, dataset=None):
        """
        Tracks a given TrackableSet.

        Parameters
        ----------
        set_idx: index of the set in the object description
        obj_desc: the object description
        dataset: the dataset

        Returns
        -------
        """

        # get according subset
        obj_set = obj_desc.sets[set_idx]
        dataset_sel = get_subset_by_description(dataset, obj_set, self.processing_mode)
        split_dims = get_split_dimensions(dataset, self.processing_mode)

        split_string = '; '.join([str(dim.name) + ": " + str(getattr(obj_set, dim.dim)) for dim in split_dims])
        print_lock("Start tracking data block with dimensions:   " + split_string)
        # track this set
        cons = self.track(obj_set, dataset_sel)
        if cons is None:
            cons = []

        # squeeze result
        # build empty graph nodes
        # for each object add (n1,[]) to list. then add connections.
        empty_cons = []
        for idx_ts, ts in enumerate(obj_set.timesteps):
            cur_time = ts.valid_time
            for obj in ts.objects:
                empty_cons.append(self.get_new_connection(cur_time, obj))

        cons = squeeze_nodes(list(empty_cons) + list(cons))  # TODO could be None.
        # sort them by time (string key)
        cons = sorted(cons, key=lambda c: c.parent.time)
        # remove transitive edges: a->b->c and a->c --> remove the a->c node
        cons = TrackingTechnique.remove_transitive(cons)  # TODO more efficient maybe?
        # add graph nodes to ref graph
        obj_set.graph.edges.extend(cons)
    def execute(self, object_desc, dataset_ref: xr.Dataset):
        """
        Execute the tracking procedure.
        The description is split into the different timelines which can be executed in parallel.
        object_desc : pb2.DatasetDescription
                The description of the detected features from the identificaiton technique.
        dataset_ref : xarray.Dataset
                Reference to the dataset used in the pipeline.
        """

        # parallel for all tracking sets
        pool = Pool()
        pool.map(partial(self.track_set, obj_desc=object_desc, dataset=dataset_ref),
                 range(len(object_desc.sets)))  # iterate over sets.
        self.postprocess(object_desc)  # only postprocess object desc, THEN to graph
    def adjust_track(self, track):
        Override this method to filter, change and split certain tracks.
        When generate_tracks() on a Graph is called, each track is checked via this method.
        For example, you could check here if the track holds together for longer than a certain time range.

        Parameters
        ----------
        track: the pb2.ObjectGraph track

        Returns
        -------
        Returns a list of tracks which should be used instead of the track.
        For example, an empty list [] should be returned if the track should be discarded,
        a list of the track itself [track] if it should be kept as it is or altered,
        or a list of multiple tracks if the track should be split into sub-tracks.
    @staticmethod
    def get_downstream_node_indices(graph_list, start_idx, until_time=None):
        """
        Helper method for the track generation. Searches all downstream node indices.

        Returns
        -------
        list of downstream indices in list
        """
        node_indices = [start_idx]
        obj_node_list = [con_.parent for con_ in graph_list]  # this up
        co = graph_list[start_idx]
        node, connected_nodes = co.parent, co.childs
        for c_node in connected_nodes:
            # if until_time is set: if c_node time later than until_time end recursion here
            if until_time is not None and pb_str_to_datetime(c_node.time) > until_time:
                continue

            # get index of connected node in graph
            c_node_idx = obj_node_list.index(c_node)

            # call recursively on this connected node
            c_node_downstream_indices = TrackingTechnique.get_downstream_node_indices(graph_list, c_node_idx,
                                                                                      until_time=until_time)
            node_indices.extend(c_node_downstream_indices)

        return list(set(node_indices))

    def get_new_connection(self, start_time, start_obj, end_time=None, end_obj=None):
        """
        Create a new connection for the tracking graph and populates it.
        Takes start time and object, and if existent end time and object.

        Parameters
        ----------
        start_time : start time (str)
        start_obj : start object
        end_time : end time (str)
        end_obj : end object

        Returns
        -------
        the pb2.GraphConnection
        """
        new_connection = self.pb_reference.GraphConnection()

        new_connection.parent.time = start_time
        new_connection.parent.object.CopyFrom(start_obj)
        if end_time is not None and end_obj is not None:
            n2 = new_connection.childs.add()
            n2.time = end_time
            n2.object.CopyFrom(end_obj)

        return new_connection
    def get_new_connection_from_id(self, timestep_start, obj_id_start, timestep_end=None, obj_id_end=None):
        """
        Create a new connection for the tracking graph and populates it.
        Takes start timestep and object ID, and if existent end timestep and object ID.

        Parameters
        ----------
        timestep_start : start time pb2.Timestep
        obj_id_start : start object id
        timestep_end : start time pb2.Timestep
        obj_id_end : end object id
        Returns
        -------
        the pb2.GraphConnection
        """
        new_connection = self.pb_reference.GraphConnection()
        new_connection.parent.time = timestep_start.valid_time

        # search given object ID in objects at timestep_start
        ids_start = [obj.id for obj in timestep_start.objects]
        try:
            start_obj_index = ids_start.index(obj_id_start)
        except ValueError as err:
            print(err)
            print("Could not find object ID " + str(
                obj_id_start) + " in given description for timestep at " + timestep_start.valid_time)
            exit(1)
        # add this object
        new_connection.parent.object.CopyFrom(timestep_start.objects[start_obj_index])

        # if end node specified, do the same
        if timestep_end is not None and obj_id_end is not None:
            ids_end = [obj.id for obj in timestep_end.objects]
            try:
                end_obj_index = ids_end.index(obj_id_end)
            except ValueError as err:
                print(err)
                print("Could not find object ID " + str(
                    obj_id_end) + " in given description for timestep at " + timestep_end.valid_time)
                exit(1)

            # add this object
            n2 = new_connection.childs.add()
            n2.time = timestep_end.valid_time
            n2.object.CopyFrom(timestep_end.objects[end_obj_index])

        return new_connection

    @staticmethod
    def remove_transitive(connections):
        tn = 0
        # an edge is transitive (a->c if a->b->c exists), when the downstream nodes of "a" without this edge contains the end node ("c")
        for con_id, con in enumerate(connections):

            childs = con.childs[:]
            # cant be transitive edge if only one child
            if len(childs) < 2:
                continue
            for cld in childs:
                child_time = pb_str_to_datetime(cld.time)
                con.childs.remove(cld)  # remove edge temporarily
                # and check if downstream nodes still contains end node of edge. if yes -> transitive
                downstream_idxs = TrackingTechnique.get_downstream_node_indices(connections, con_id,
                                                                                until_time=child_time)
                downstream_objs = [connections[ds_idx].parent for ds_idx in downstream_idxs]
                if cld in downstream_objs:
                    # transitive
                    tn = tn + 1
                else:
                    # add edge back
                    con.childs.append(cld)

        print("Removed " + str(tn) + " transitive edges in current set.")
        return connections

# TODO what if tracks are the identification? e.g. AEW identification in Hovmoller