from ..tracking import TrackingStrategy from enstools.feature.util.data_utils import print_lock, get_subset_by_description, get_split_dimensions, squeeze_nodes, \ SplitDimension, pb_str_to_datetime from enstools.misc import get_time_dim import xarray as xr import numpy as np class OverlapDoubleThresholdTracking(TrackingStrategy): """ Implementation of simple overlap tracking: Objects of consecutive timestamps are considered as same if they spatially overlap. This requires the given field which to check for overlaps. This field should be of data type "int", where the field is zero, if at the location is no object i, if object with the ID i (from identification) is at the location. Optionally, min_duration can be set as a datetime.timedelta, indicating the minimum time an object has to be alive in filtering. """ def __init__(self, inner_thresh_name=None, outer_thresh_name=None,tracking_method=None, min_duration=None): self.inner_thresh_name = inner_thresh_name self.outer_thresh_name = outer_thresh_name if tracking_method == "inner_first": self.tracking_method = "inner_first" elif tracking_method == "outer_first": self.tracking_method = "outer_first" else: raise ValueError("tracking method has to be either inner_first or outer_first") self.min_duration = min_duration pass def track(self, trackable_set, subset: xr.Dataset): """ Overlap tracking implementation. Always fetch field_name of two consecutive timesteps, and search for overlaps in fields """ time_dim = get_time_dim(subset) if self.tracking_method == "inner_first": subset_t1_field = subset[self.inner_thresh_name] subset_t2_field = subset[self.outer_thresh_name] else: subset_t1_field = subset[self.outer_thresh_name] subset_t2_field = subset[self.inner_thresh_name] connections = [] timesteps = trackable_set.timesteps for t in range(0, len(timesteps) - 1): id_connections = set() t1 = timesteps[t] t2 = timesteps[t + 1] # data blocks 2d or 3d to check for overlap data_t1 = subset_t1_field.sel({time_dim: t1.valid_time}) data_t2 = subset_t2_field.sel({time_dim: t2.valid_time}) # search overlaps and_areas = np.logical_and(data_t1.data, data_t2.data) where_both = list(zip(*np.where(and_areas))) # for each overlap pixel, add start/end ID to set for idx in where_both: obj1_id = data_t1.data[idx] obj2_id = data_t2.data[idx] id_connections.add(tuple([obj1_id, obj2_id])) # to list so there are no duplicates id_connections = list(id_connections) # id connections to GraphConnections for id_con in id_connections: con = self.get_new_connection_from_id(t1, id_con[0], t2, id_con[1]) connections.append(con) return connections def postprocess(self, obj_desc): print("postprocess @ overlap tracking") return def adjust_track(self, tracking_graph): if self.min_duration is None: return tracking_graph # keep track if persists longer than duration_threshold track = tracking_graph.graph nodes = [edge.parent for edge in track.edges] min_time = pb_str_to_datetime(nodes[0].time) max_time = pb_str_to_datetime(nodes[-1].time) duration = max_time - min_time if duration < self.min_duration: return None return tracking_graph