Newer
Older
from .tracking import TrackingTechnique
from abc import ABC, abstractmethod
from dask import delayed, compute
import numpy as np
import xarray as xr
from enstools.feature.util.data_utils import pb_str_to_datetime, SplitDimension
class ObjectComparisonTracking(TrackingTechnique):
"""
Implementation of a tracking technique which can track objects by a simple pairwise comparison of their feature
descriptions. This acts as an abstract class for comparison tracking techniques.
It implements the track() method and provides an abstract correspond() method with gives a binary answer if two
objects of (consecutive) timesteps are the same. Objects do not have neccessarily be associated with consecutive
timesteps. Using set_max_delta_compare() the user can ser a maximum timedelta.
def set_max_delta_compare(self, cmp_delta):
"""
Set the maximum delta time of objects to compare. If None, only compare objects of consecutive timesteps.
Parameters
----------
cmp_delta
Returns
-------
"""
self.max_compare_delta = cmp_delta
def correspond(self, time1, obj1, time2, obj2):
"""
Abstract method. Implementations should check here if obj1 and obj2 of consecutive timestamps can be regarded
as same object, creating a tracking link.
Parameters
----------
obj1 : identification_pb2.Object
Object to compare from timestamp t
obj2 : identification_pb2.Object
Object to compare from timestamp t+1
Returns
-------
True if the objects can be considered the same. Otherwise False.
"""
return False
def track(self, tracking_set, subset: xr.Dataset):
Implementation of track() for tracking techniques which are based on pairwise comparisons of objects. Using
this tracking technique, the correspond() method has to implement a boolean function which returns True if
the given object pair of consecutive time steps should be considered as the same object.
Parameters
----------
tracking_set : pb2.TrackingSet The timesteps, a list of timestamps of this (re)forecast.
subset: the corresponding xarray subset
Returns
-------
try:
cmp_delta = self.max_compare_delta
except AttributeError:
cmp_delta = None
def get_connection_if_correspond(time1, obj1, time2, obj2):
time1_dt = pb_str_to_datetime(time1)
time2_dt = pb_str_to_datetime(time2)
# if correspond...
if self.correspond(time1_dt, obj1, time2_dt, obj2): # TODO allow skips? as param.
# new object
return self.get_new_connection(time1, obj1, time2, obj2)
else:
return None
delayed_connections = []
timesteps = tracking_set.timesteps
# get pair of timesteps, check if their delta has been requested
for t1 in range(len(timesteps)):
for t2 in range(t1 + 1, len(timesteps)):
t1_ts = timesteps[t1]
t2_ts = timesteps[t2]
t1_dt = pb_str_to_datetime(t1_ts.valid_time)
t2_dt = pb_str_to_datetime(t2_ts.valid_time)
if cmp_delta is None:
consider = (t2 - t1 == 1)
else:
consider = t2_dt - t1_dt <= cmp_delta
if consider:
for o1 in t1_ts.objects:
for o2 in t2_ts.objects:
app = delayed(get_connection_if_correspond)(t1_ts.valid_time, o1, t2_ts.valid_time, o2)
delayed_connections.append(app)
connections = compute(*delayed_connections)
# remove all None elements in list - they equal the non-corresponding elements
connections = np.asarray(connections, dtype=object)
connections = connections[connections != None]