Skip to content
Snippets Groups Projects
object_compare_tracking.py 4.14 KiB
Newer Older
from .tracking import TrackingTechnique
from abc import ABC, abstractmethod
from datetime import datetime
from dask import delayed, compute
import numpy as np
import xarray as xr
from enstools.feature.util.data_utils import pb_str_to_datetime, SplitDimension


class ObjectComparisonTracking(TrackingTechnique):
    """
    Implementation of a tracking technique which can track objects by a simple pairwise comparison of their feature
    descriptions. This acts as an abstract class for comparison tracking techniques.
    It implements the track() method and provides an abstract correspond() method with gives a binary answer if two
    objects of (consecutive) timesteps are the same. Objects do not have neccessarily be associated with consecutive
    timesteps. Using set_max_delta_compare() the user can ser a maximum timedelta.
    def set_max_delta_compare(self, cmp_delta):
        """
        Set the maximum delta time of objects to compare. If None, only compare objects of consecutive timesteps.
        Parameters
        ----------
        cmp_delta

        Returns
        -------

        """
        self.max_compare_delta = cmp_delta

    def correspond(self, time1, obj1, time2, obj2):
        """
        Abstract method. Implementations should check here if obj1 and obj2 of consecutive timestamps can be regarded
        as same object, creating a tracking link.

        Parameters
        ----------
        obj1 : identification_pb2.Object
                Object to compare from timestamp t
        obj2 : identification_pb2.Object
                Object to compare from timestamp t+1

        Returns
        -------
        True if the objects can be considered the same. Otherwise False.
        """
        return False

    def track(self, tracking_set, subset: xr.Dataset):
        Implementation of track() for tracking techniques which are based on pairwise comparisons of objects. Using
        this tracking technique, the correspond() method has to implement a boolean function which returns True if
        the given object pair of consecutive time steps should be considered as the same object.
        tracking_set : pb2.TrackingSet The timesteps, a list of timestamps of this (re)forecast.
        subset: the corresponding xarray subset
        List of pb2.GraphConnection s.
        try:
            cmp_delta = self.max_compare_delta
        except AttributeError:
            cmp_delta = None

        def get_connection_if_correspond(time1, obj1, time2, obj2):
            time1_dt = pb_str_to_datetime(time1)
            time2_dt = pb_str_to_datetime(time2)
            if self.correspond(time1_dt, obj1, time2_dt, obj2): # TODO allow skips? as param.
                return self.get_new_connection(time1, obj1, time2, obj2)
            else:
                return None

        delayed_connections = []
        # get pair of timesteps, check if their delta has been requested
        for t1 in range(len(timesteps)):
            for t2 in range(t1 + 1, len(timesteps)):

                t1_ts = timesteps[t1]
                t2_ts = timesteps[t2]
                t1_dt = pb_str_to_datetime(t1_ts.valid_time)
                t2_dt = pb_str_to_datetime(t2_ts.valid_time)

                if cmp_delta is None:
                    consider = (t2 - t1 == 1)
                else:
                    consider = t2_dt - t1_dt <= cmp_delta
                if consider:

                    for o1 in t1_ts.objects:
                        for o2 in t2_ts.objects:
                            app = delayed(get_connection_if_correspond)(t1_ts.valid_time, o1, t2_ts.valid_time, o2)
                            delayed_connections.append(app)

        connections = compute(*delayed_connections)
        # remove all None elements in list - they equal the non-corresponding elements
        connections = np.asarray(connections, dtype=object)
        connections = connections[connections != None]
        # TODO remove transitive edges