Newer
Older
from .tracking import TrackingTechnique
from abc import ABC, abstractmethod
import datetime
from dask import delayed, compute
import numpy as np
import xarray as xr
class ObjectComparisonTracking(TrackingTechnique):
"""
Implementation of a tracking technique which can track objects by a simple pairwise comparison of their feature
descriptions. This acts as an abstract class for comparison tracking techniques.
It implements the track() method and provides an abstract correspond() method with gives a binary answer if two
objects of consecutive timesteps are the same.
"""
@abstractmethod
def correspond(self, time1, obj1, time2, obj2):
"""
Abstract method. Implementations should check here if obj1 and obj2 of consecutive timestamps can be regarded
as same object, creating a tracking link.
Parameters
----------
obj1 : identification_pb2.Object
Object to compare from timestamp t
obj2 : identification_pb2.Object
Object to compare from timestamp t+1
Returns
-------
True if the objects can be considered the same. Otherwise False.
"""
return False
def track(self, tracking_set, subset: xr.Dataset):
"""
Implementation of track() for tracking techniques which are based on pairwise comparisions of objects.
Parameters
----------
tracking_set : pb2.TrackingSet
The timesteps, a list of timestamps of this (re)forecast.
Returns
-------
List of tuples of object IDs, representing the pairwise connections between objects.
"""
def get_connection_if_correspond(datetimes_in_set, time_idx1, obj1, time_idx2, obj2):
# if correspond...
if self.correspond(datetimes_in_set[time_idx1], obj1, datetimes_in_set[time_idx2], obj2):
# new object
new_connection = self.pb_reference.RefGraphNode() # TODO good API? bad now.
new_connection.this_node.time_index = time_idx1 # TODO index
new_connection.this_node.object_id = obj1.id
n2 = new_connection.connected_nodes.add()
n2.time_index = time_idx2 # TODO index
n2.object_id = obj2.id
return new_connection
else:
return None
delayed_connections = []
timesteps = tracking_set.timesteps
datetimes = [datetime.datetime.fromisoformat(ts.valid_time) for ts in timesteps]
for t in range(0, len(timesteps) - 1):
t1 = timesteps[t]
t2 = timesteps[t + 1]
for o1 in t1.objects:
for o2 in t2.objects:
app = delayed(get_connection_if_correspond)(datetimes, t, o1, t+1, o2)
delayed_connections.append(app)
# tracking_set.graph.connections[x].n1 / n2.time_id / object_id
connections = compute(*delayed_connections)
# remove all None elements in list - they equal the non-corresponding elements
connections = np.asarray(connections, dtype=object)
connections = connections[connections != None]
# tracking_set.ref_graph.connections.extend(connections) # TODO in superclass or where build what
# gs = GraphStructure(connections=set.graph.connections)
# print(gs)