from .tracking import TrackingTechnique from abc import ABC, abstractmethod import datetime from dask import delayed, compute import numpy as np import xarray as xr class ObjectComparisonTracking(TrackingTechnique): """ Implementation of a tracking technique which can track objects by a simple pairwise comparison of their feature descriptions. This acts as an abstract class for comparison tracking techniques. It implements the track() method and provides an abstract correspond() method with gives a binary answer if two objects of consecutive timesteps are the same. """ @abstractmethod def correspond(self, time1, obj1, time2, obj2): """ Abstract method. Implementations should check here if obj1 and obj2 of consecutive timestamps can be regarded as same object, creating a tracking link. Parameters ---------- obj1 : identification_pb2.Object Object to compare from timestamp t obj2 : identification_pb2.Object Object to compare from timestamp t+1 Returns ------- True if the objects can be considered the same. Otherwise False. """ return False def track(self, tracking_set, subset: xr.Dataset): """ Implementation of track() for tracking techniques which are based on pairwise comparisions of objects. Parameters ---------- tracking_set : pb2.TrackingSet The timesteps, a list of timestamps of this (re)forecast. Returns ------- List of tuples of object IDs, representing the pairwise connections between objects. """ def get_connection_if_correspond(datetimes_in_set, time_idx1, obj1, time_idx2, obj2): # if correspond... if self.correspond(datetimes_in_set[time_idx1], obj1, datetimes_in_set[time_idx2], obj2): # new object new_connection = self.pb_reference.RefGraphNode() # TODO good API? bad now. new_connection.this_node.time_index = time_idx1 # TODO index new_connection.this_node.object_id = obj1.id n2 = new_connection.connected_nodes.add() n2.time_index = time_idx2 # TODO index n2.object_id = obj2.id return new_connection else: return None delayed_connections = [] timesteps = tracking_set.timesteps datetimes = [datetime.datetime.fromisoformat(ts.valid_time) for ts in timesteps] for t in range(0, len(timesteps) - 1): t1 = timesteps[t] t2 = timesteps[t + 1] for o1 in t1.objects: for o2 in t2.objects: app = delayed(get_connection_if_correspond)(datetimes, t, o1, t+1, o2) delayed_connections.append(app) # tracking_set.graph.connections[x].n1 / n2.time_id / object_id connections = compute(*delayed_connections) # remove all None elements in list - they equal the non-corresponding elements connections = np.asarray(connections, dtype=object) connections = connections[connections != None] # tracking_set.ref_graph.connections.extend(connections) # TODO in superclass or where build what # gs = GraphStructure(connections=set.graph.connections) # print(gs) return connections