from .tracking import TrackingTechnique from abc import ABC, abstractmethod from datetime import datetime from dask import delayed, compute import numpy as np import xarray as xr from enstools.feature.util.data_utils import pb_str_to_datetime, SplitDimension class ObjectComparisonTracking(TrackingTechnique): """ Implementation of a tracking technique which can track objects by a simple pairwise comparison of their feature descriptions. This acts as an abstract class for comparison tracking techniques. It implements the track() method and provides an abstract correspond() method with gives a binary answer if two objects of (consecutive) timesteps are the same. Objects do not have neccessarily be associated with consecutive timesteps. Using set_max_delta_compare() the user can ser a maximum timedelta. """ def set_max_delta_compare(self, cmp_delta): """ Set the maximum delta time of objects to compare. If None, only compare objects of consecutive timesteps. Parameters ---------- cmp_delta Returns ------- """ self.max_compare_delta = cmp_delta @abstractmethod def correspond(self, time1, obj1, time2, obj2): """ Abstract method. Implementations should check here if obj1 and obj2 of consecutive timestamps can be regarded as same object, creating a tracking link. Parameters ---------- obj1 : identification_pb2.Object Object to compare from timestamp t obj2 : identification_pb2.Object Object to compare from timestamp t+1 Returns ------- True if the objects can be considered the same. Otherwise False. """ return False def track(self, tracking_set, subset: xr.Dataset): """ Implementation of track() for tracking techniques which are based on pairwise comparisons of objects. Using this tracking technique, the correspond() method has to implement a boolean function which returns True if the given object pair of consecutive time steps should be considered as the same object. Parameters ---------- tracking_set : pb2.TrackingSet The timesteps, a list of timestamps of this (re)forecast. subset: the corresponding xarray subset Returns ------- List of pb2.GraphConnection s. """ try: cmp_delta = self.max_compare_delta except AttributeError: cmp_delta = None def get_connection_if_correspond(time1, obj1, time2, obj2): time1_dt = pb_str_to_datetime(time1) time2_dt = pb_str_to_datetime(time2) # if correspond... if self.correspond(time1_dt, obj1, time2_dt, obj2): # TODO allow skips? as param. # new object return self.get_new_connection(time1, obj1, time2, obj2) else: return None delayed_connections = [] timesteps = tracking_set.timesteps # get pair of timesteps, check if their delta has been requested for t1 in range(len(timesteps)): for t2 in range(t1 + 1, len(timesteps)): t1_ts = timesteps[t1] t2_ts = timesteps[t2] t1_dt = pb_str_to_datetime(t1_ts.valid_time) t2_dt = pb_str_to_datetime(t2_ts.valid_time) if cmp_delta is None: consider = (t2 - t1 == 1) else: consider = t2_dt - t1_dt <= cmp_delta if consider: for o1 in t1_ts.objects: for o2 in t2_ts.objects: app = delayed(get_connection_if_correspond)(t1_ts.valid_time, o1, t2_ts.valid_time, o2) delayed_connections.append(app) connections = compute(*delayed_connections) # remove all None elements in list - they equal the non-corresponding elements connections = np.asarray(connections, dtype=object) connections = connections[connections != None] # TODO remove transitive edges return connections