from ..tracking import TrackingStrategy
from enstools.feature.util.data_utils import print_lock, get_subset_by_description, get_split_dimensions, squeeze_nodes, \
    SplitDimension, pb_str_to_datetime
from enstools.misc import get_time_dim
import xarray as xr
import numpy as np


class OverlapDoubleThresholdTracking(TrackingStrategy):
    """
    Implementation of simple overlap tracking: Objects of consecutive timestamps are considered as same if
    they spatially overlap. This requires the given field which to check for overlaps.

    This field should be of data type "int", where the field is
        zero, if at the location is no object
        i, if object with the ID i (from identification) is at the location.

    Optionally, min_duration can be set as a datetime.timedelta, indicating the minimum time an object has
    to be alive in filtering.
    """

    def __init__(self, inner_thresh_name=None, outer_thresh_name=None,tracking_method=None, min_duration=None):
        self.inner_thresh_name = inner_thresh_name
        self.outer_thresh_name = outer_thresh_name
        if tracking_method == "inner_first":
            self.tracking_method = "inner_first"
        elif tracking_method == "outer_first":
            self.tracking_method = "outer_first"
        else:
            raise ValueError("tracking method has to be either inner_first or outer_first")

        self.min_duration = min_duration
        pass

    def track(self, trackable_set, subset: xr.Dataset):
        """
        Overlap tracking implementation.
        Always fetch field_name of two consecutive timesteps, and search for overlaps in fields
        """

        time_dim = get_time_dim(subset)
        if self.tracking_method == "inner_first":
            subset_t1_field = subset[self.inner_thresh_name]
            subset_t2_field = subset[self.outer_thresh_name]
        else:
            subset_t1_field = subset[self.outer_thresh_name]
            subset_t2_field = subset[self.inner_thresh_name]
        connections = []

        timesteps = trackable_set.timesteps
        for t in range(0, len(timesteps) - 1):
            id_connections = set()
            t1 = timesteps[t]
            t2 = timesteps[t + 1]

            # data blocks 2d or 3d to check for overlap
            data_t1 = subset_t1_field.sel({time_dim: t1.valid_time})
            data_t2 = subset_t2_field.sel({time_dim: t2.valid_time})

            # search overlaps
            and_areas = np.logical_and(data_t1.data, data_t2.data)
            where_both = list(zip(*np.where(and_areas)))

            # for each overlap pixel, add start/end ID to set
            for idx in where_both:
                obj1_id = data_t1.data[idx]
                obj2_id = data_t2.data[idx]
                id_connections.add(tuple([obj1_id, obj2_id]))

            # to list so there are no duplicates
            id_connections = list(id_connections)
            # id connections to GraphConnections
            for id_con in id_connections:
                con = self.get_new_connection_from_id(t1, id_con[0], t2, id_con[1])
                connections.append(con)

        return connections

    def postprocess(self, obj_desc):
        print("postprocess @ overlap tracking")
        return

    def adjust_track(self, tracking_graph):
        if self.min_duration is None:
            return tracking_graph

        # keep track if persists longer than duration_threshold
        track = tracking_graph.graph

        nodes = [edge.parent for edge in track.edges]
        min_time = pb_str_to_datetime(nodes[0].time)
        max_time = pb_str_to_datetime(nodes[-1].time)
        duration = max_time - min_time

        if duration < self.min_duration:
            return None
        return tracking_graph