from .tracking import TrackingTechnique
from abc import ABC, abstractmethod
# from enstools.feature.identification._proto_gen import identification_pb2
import google.protobuf
from dask import delayed, compute
import numpy as np
import xarray as xr


class ObjectComparisonTracking(TrackingTechnique):
    """
    Implementation of a tracking technique which can track objects by a simple pairwise comparison of their feature
    descriptions. This acts as an abstract class for comparison tracking techniques.
    It implements the track() method and provides an abstract correspond() method with gives a binary answer if two
    objects of consecutive timesteps are the same.
    """

    @abstractmethod
    def correspond(self, obj1, obj2):
        """
        Abstract method. Implementations should check here if obj1 and obj2 of consecutive timestamps can be regarded
        as same object, creating a tracking link.

        Parameters
        ----------
        obj1 : identification_pb2.Object
                Object to compare from timestamp t
        obj2 : identification_pb2.Object
                Object to compare from timestamp t+1

        Returns
        -------
        True if the objects can be considered the same. Otherwise False.
        """
        return False

    def track(self, timesteps, subset: xr.Dataset):
        """
        Implementation of track() for tracking techniques which are based on pairwise comparisions of objects.

        Parameters
        ----------
        timesteps : iterable of pb2.Timestep
                The timesteps, a list of timestamps of this (re)forecast.

        Returns
        -------
        List of tuples of object IDs, representing the pairwise connections between objects.
        """

        def get_id_tuple_if_correspond(obj1, obj2):
            if self.correspond(obj1, obj2):
                return obj1.id, obj2.id
            else:
                return None

        delayed_connections = []

        for t in range(0, len(timesteps) - 1):
            t1 = timesteps[t]
            t2 = timesteps[t + 1]
            for o1 in t1.objects:
                for o2 in t2.objects:
                    app = delayed(get_id_tuple_if_correspond)(o1, o2)
                    delayed_connections.append(app)

        connections = compute(*delayed_connections)
        # remove all None elements in list - they equal the non-corresponding elements
        connections = np.asarray(connections, dtype=object)
        connections = connections[connections != None]
        return connections