from .tracking import TrackingTechnique
from abc import ABC, abstractmethod
from datetime import datetime
from dask import delayed, compute
import numpy as np
import xarray as xr
from enstools.feature.util.data_utils import pb_str_to_datetime, SplitDimension


class ObjectComparisonTracking(TrackingTechnique):
    """
    Implementation of a tracking technique which can track objects by a simple pairwise comparison of their feature
    descriptions. This acts as an abstract class for comparison tracking techniques.
    It implements the track() method and provides an abstract correspond() method with gives a binary answer if two
    objects of (consecutive) timesteps are the same. Objects do not have neccessarily be associated with consecutive
    timesteps. Using set_max_delta_compare() the user can ser a maximum timedelta.
    """

    def set_max_delta_compare(self, cmp_delta):
        """
        Set the maximum delta time of objects to compare. If None, only compare objects of consecutive timesteps.
        Parameters
        ----------
        cmp_delta

        Returns
        -------

        """
        self.max_compare_delta = cmp_delta

    @abstractmethod
    def correspond(self, time1, obj1, time2, obj2):
        """
        Abstract method. Implementations should check here if obj1 and obj2 of consecutive timestamps can be regarded
        as same object, creating a tracking link.

        Parameters
        ----------
        obj1 : identification_pb2.Object
                Object to compare from timestamp t
        obj2 : identification_pb2.Object
                Object to compare from timestamp t+1

        Returns
        -------
        True if the objects can be considered the same. Otherwise False.
        """
        return False

    def track(self, tracking_set, subset: xr.Dataset):
        """
        Implementation of track() for tracking techniques which are based on pairwise comparisons of objects. Using
        this tracking technique, the correspond() method has to implement a boolean function which returns True if
        the given object pair of consecutive time steps should be considered as the same object.

        Parameters
        ----------
        tracking_set : pb2.TrackingSet The timesteps, a list of timestamps of this (re)forecast.
        subset: the corresponding xarray subset

        Returns
        -------
        List of pb2.GraphConnection s.
        """

        try:
            cmp_delta = self.max_compare_delta
        except AttributeError:
            cmp_delta = None

        def get_connection_if_correspond(time1, obj1, time2, obj2):
            time1_dt = pb_str_to_datetime(time1)
            time2_dt = pb_str_to_datetime(time2)
            # if correspond...
            if self.correspond(time1_dt, obj1, time2_dt, obj2): # TODO allow skips? as param.
                # new object
                return self.get_new_connection(time1, obj1, time2, obj2)
            else:
                return None

        delayed_connections = []
        timesteps = tracking_set.timesteps

        # get pair of timesteps, check if their delta has been requested
        for t1 in range(len(timesteps)):
            for t2 in range(t1 + 1, len(timesteps)):

                t1_ts = timesteps[t1]
                t2_ts = timesteps[t2]
                t1_dt = pb_str_to_datetime(t1_ts.valid_time)
                t2_dt = pb_str_to_datetime(t2_ts.valid_time)

                if cmp_delta is None:
                    consider = (t2 - t1 == 1)
                else:
                    consider = t2_dt - t1_dt <= cmp_delta
                if consider:

                    for o1 in t1_ts.objects:
                        for o2 in t2_ts.objects:
                            app = delayed(get_connection_if_correspond)(t1_ts.valid_time, o1, t2_ts.valid_time, o2)
                            delayed_connections.append(app)


        connections = compute(*delayed_connections)
        # remove all None elements in list - they equal the non-corresponding elements
        connections = np.asarray(connections, dtype=object)
        connections = connections[connections != None]
        # TODO remove transitive edges
        return connections