Skip to content
Snippets Groups Projects
object_compare_tracking.py 3.38 KiB
Newer Older
from .tracking import TrackingTechnique
from abc import ABC, abstractmethod
from dask import delayed, compute
import numpy as np
import xarray as xr


class ObjectComparisonTracking(TrackingTechnique):
    """
    Implementation of a tracking technique which can track objects by a simple pairwise comparison of their feature
    descriptions. This acts as an abstract class for comparison tracking techniques.
    It implements the track() method and provides an abstract correspond() method with gives a binary answer if two
    objects of consecutive timesteps are the same.
    """

    @abstractmethod
    def correspond(self, time1, obj1, time2, obj2):
        """
        Abstract method. Implementations should check here if obj1 and obj2 of consecutive timestamps can be regarded
        as same object, creating a tracking link.

        Parameters
        ----------
        obj1 : identification_pb2.Object
                Object to compare from timestamp t
        obj2 : identification_pb2.Object
                Object to compare from timestamp t+1

        Returns
        -------
        True if the objects can be considered the same. Otherwise False.
        """
        return False


    def track(self, tracking_set, subset: xr.Dataset):
        """
        Implementation of track() for tracking techniques which are based on pairwise comparisions of objects.

        Parameters
        ----------
                The timesteps, a list of timestamps of this (re)forecast.

        Returns
        -------
        List of tuples of object IDs, representing the pairwise connections between objects.
        """

        def get_connection_if_correspond(datetimes_in_set, time_idx1, obj1, time_idx2, obj2):
            # if correspond...
            if self.correspond(datetimes_in_set[time_idx1], obj1, datetimes_in_set[time_idx2], obj2):
                # new object
                new_connection = self.pb_reference.RefGraphNode() # TODO good API? bad now.
                new_connection.this_node.time_index = time_idx1 # TODO index
                new_connection.this_node.object_id = obj1.id
                n2 = new_connection.connected_nodes.add()
                n2.time_index = time_idx2 # TODO index
                n2.object_id = obj2.id
            else:
                return None

        delayed_connections = []
        datetimes = [datetime.datetime.fromisoformat(ts.valid_time) for ts in timesteps]
        for t in range(0, len(timesteps) - 1):
            t1 = timesteps[t]
            t2 = timesteps[t + 1]
            for o1 in t1.objects:
                for o2 in t2.objects:
                    app = delayed(get_connection_if_correspond)(datetimes, t, o1, t+1, o2)
                    delayed_connections.append(app)

        # tracking_set.graph.connections[x].n1 / n2.time_id / object_id

        connections = compute(*delayed_connections)
        # remove all None elements in list - they equal the non-corresponding elements
        connections = np.asarray(connections, dtype=object)
        connections = connections[connections != None]
Christoph.Fischer's avatar
Christoph.Fischer committed
        # tracking_set.ref_graph.connections.extend(connections) # TODO in superclass or where build what

        # gs = GraphStructure(connections=set.graph.connections)
        # print(gs)