from ..object_compare_tracking import ObjectComparisonTracking

import enstools.feature.identification.african_easterly_waves.configuration as cfg
from datetime import datetime, timedelta

from shapely.geometry import Polygon, LineString


class AEWTracking(ObjectComparisonTracking):
    """
    Tracking of AEWs.
    Based on comparisons on wavetroughs
    """

    def __init__(self, **kwargs):
        self.config = cfg
        super().__init__()  # TODO this in template? need empty graph to be generated
        pass

    def correspond(self, time1, obj1, time2, obj2):
        # returns True if objects are same wave

        # access the properties set earlier
        prop1 = obj1.properties
        prop2 = obj2.properties
        lineseg1 = prop1.line_pts
        lineseg2 = prop2.line_pts

        time_diff = time2 - time1
        time_diff_h = time_diff.seconds // 3600

        # predicted polygonal area
        predicted_polygon_list = []
        for i in range(len(lineseg1)):  # min speed
            predicted_polygon_list.append(
                [lineseg1[i].lon + self.config.speed_deg_per_h[0] * time_diff_h, lineseg1[i].lat])
        for i in range(len(lineseg1) - 1, -1, -1):  # max speed
            predicted_polygon_list.append(
                [lineseg1[i].lon + self.config.speed_deg_per_h[1] * time_diff_h, lineseg1[i].lat])
        predicted_polygon = Polygon(predicted_polygon_list)

        # actual line string
        p2_pt_list = []
        for i in range(len(lineseg2)):
            p2_pt_list.append((lineseg2[i].lon, lineseg2[i].lat))
        actual_linestring = LineString(p2_pt_list)

        if predicted_polygon.intersects(actual_linestring):
            return True

        return False

    def postprocess(self, obj_desc):
        print("postprocess @ aew")
        return

    def keep_track(self, track):
        # keep track if persists longer than duration_threshold

        nodes = [node.this_node for node in track.nodes]
        min_time = datetime.strptime(nodes[0].time, '%Y-%m-%dT%H:%M:%S')
        max_time = datetime.strptime(nodes[-1].time, '%Y-%m-%dT%H:%M:%S')
        duration = max_time - min_time

        if duration < self.config.duration_threshold:
            return False

        return True