from ..tracking import TrackingTechnique
from enstools.feature.util.data_utils import print_lock, get_subset_by_description, get_split_dimensions, squeeze_nodes, \
    SplitDimension
from enstools.misc import get_time_dim
import xarray as xr
import numpy as np


class OverlapTracking(TrackingTechnique):
    """
    Implementation of simple overlap tracking: Objects of consecutive timestamps are considered as same if
    they spatially overlap. This requires the given field which to check for overlaps.

    This field should be of data type "int", where the field is
        zero, if at the location is no object
        i, if object with the ID i (from identification) is at the location.
    """

    def __init__(self, field_name=None):
        self.field_name = field_name
        pass

    def track(self, trackable_set, subset: xr.Dataset):
        # TODO docstring

        time_dim = get_time_dim(subset)
        subset_field = subset[self.field_name]
        connections = []

        timesteps = trackable_set.timesteps
        for t in range(0, len(timesteps) - 1):
            id_connections = set()
            t1 = timesteps[t]
            t2 = timesteps[t + 1]

            # data blocks 2d or 3d to check for overlap
            data_t1 = subset_field.sel({time_dim: t1.valid_time})
            data_t2 = subset_field.sel({time_dim: t2.valid_time})

            # search overlaps
            and_areas = np.logical_and(data_t1.data, data_t2.data)
            where_both = list(zip(*np.where(and_areas)))

            # for each overlap pixel, add start/end ID to set
            for idx in where_both:
                obj1_id = data_t1.data[idx]
                obj2_id = data_t2.data[idx]
                id_connections.add(tuple([obj1_id, obj2_id]))

            id_connections = list(id_connections)
            # id connections to GraphConnections
            for id_con in id_connections:
                con = self.get_new_connection_from_id(t1, id_con[0], t2, id_con[1])
                connections.append(con)

        return connections

    def postprocess(self, obj_desc):
        print("postprocess @ ooverlap tracking")
        return

    def keep_track(self, track):
        # keep all tracks after generating them
        return True