from ..object_compare_tracking import ObjectComparisonTracking

import enstools.feature.identification.african_easterly_waves.configuration as cfg
from datetime import datetime, timedelta
from collections import defaultdict
import statistics

from shapely.geometry import Polygon, LineString
from enstools.feature.util.data_utils import pb_str_to_datetime
from enstools.feature.util.graph import DataGraph

zero_dt = timedelta(seconds=0)

class AEWTracking(ObjectComparisonTracking):
    """
    Tracking of AEWs.
    Based on comparisons on wavetroughs
    """

    def __init__(self, **kwargs):
        self.config = cfg
        self.set_max_delta_compare(cfg.max_cmp_delta)
        pass

    def correspond(self, time1, obj1, time2, obj2):
        # returns True if objects are same wave

        # access the properties set earlier
        prop1 = obj1.properties
        prop2 = obj2.properties
        lineseg1 = prop1.line_pts
        lineseg2 = prop2.line_pts

        bb1_lon_center = (prop1.bb.max.lon - prop1.bb.min.lon) / 2.0
        bb2_lon_center = (prop2.bb.max.lon - prop2.bb.min.lon) / 2.0

        time_diff = time2 - time1
        time_diff_h = time_diff.seconds // 3600

        # predicted polygonal area
        predicted_polygon_list = []
        for i in range(len(lineseg1)):  # min speed
            predicted_polygon_list.append(
                [lineseg1[i].lon + self.config.speed_deg_per_h[0] * time_diff_h, lineseg1[i].lat])
        for i in range(len(lineseg1) - 1, -1, -1):  # max speed
            predicted_polygon_list.append(
                [lineseg1[i].lon + self.config.speed_deg_per_h[1] * time_diff_h, lineseg1[i].lat])
        predicted_polygon = Polygon(predicted_polygon_list)

        # actual line string
        p2_pt_list = []
        for i in range(len(lineseg2)):
            p2_pt_list.append((lineseg2[i].lon, lineseg2[i].lat))
        actual_linestring = LineString(p2_pt_list)

        if not predicted_polygon.intersects(actual_linestring):  # linesting poly not where expected?
            return False

        return True

    def postprocess(self, obj_desc):
        print("postprocess @ aew")
        return

    # get dict of key=endnode, values=nodes on way) for endnodes at end if time_delta
    @staticmethod
    def get_nodes_after(time_delta, edges, start_index):

        this_node = edges[start_index].parent
        if time_delta <= zero_dt or len(edges[start_index].childs) == 0:
            dd = defaultdict(list)
            dd[start_index] = [start_index]
            return dd

        obj_node_list = [e.parent for e in edges]
        childs_ = defaultdict(list)

        for child in edges[start_index].childs:

            dt = pb_str_to_datetime(child.time) - pb_str_to_datetime(this_node.time)
            remaining_time_delta = time_delta - dt

            # get index of connected node in graph
            c_node_idx = obj_node_list.index(child)
            after_nodes = AEWTracking.get_nodes_after(remaining_time_delta, edges, c_node_idx)
            # add own node to it (this_node)
            for endn_, pathn_ in after_nodes.items():
                childs_[endn_].extend(pathn_ + [start_index])

        return childs_

    def adjust_track(self, tracking_graph):

        # keep track if persists longer than duration_threshold
        track = tracking_graph.graph

        nodes = [edge.parent for edge in track.edges]
        min_time = pb_str_to_datetime(nodes[0].time)
        max_time = pb_str_to_datetime(nodes[-1].time)
        duration = max_time - min_time

        if duration < self.config.duration_threshold:
            return None

        times = [pb_str_to_datetime(node.time) for node in nodes]
        center_lons = [(node.object.properties.bb.max.lon + node.object.properties.bb.min.lon) / 2.0 for node in nodes]
        center_lats = [(node.object.properties.bb.max.lat + node.object.properties.bb.min.lat) / 2.0 for node in nodes]
        extent_lons = [(node.object.properties.bb.max.lon - node.object.properties.bb.min.lon) for node in nodes]
        extent_lats = [(node.object.properties.bb.max.lat - node.object.properties.bb.min.lat) for node in nodes]

        # init nodes as False = should be discarded
        keep_node = [False] * len(nodes)

        # iterate over all nodes: if fast enough from this to this+dt, keep
        for node_idx, node in enumerate(nodes):

            # downstream_nodes_with_path = self.get_downstream_nodes_after(self.config.slow_duration_window, track.edges, node_idx)
            # get list of (end_node, nodes_on_way)
            ds_nodes = AEWTracking.get_nodes_after(self.config.avg_speed_t, track.edges, node_idx)

            for end_node_idx, way_node_idxs in ds_nodes.items():
                way_node_idxs = list(set(way_node_idxs)) # make them unique. might be dup if multiple paths
                end_node = nodes[end_node_idx]

                # if dt < 48h -> didnt reach 2 days.
                if pb_str_to_datetime(end_node.time) - pb_str_to_datetime(node.time) < self.config.avg_speed_t:
                    # could not track wave that long, dont set to True
                    continue

                # if nodes along track are predominantly horizontal orientated discard them.
                lon_lat_ratio = [extent_lons[i] / extent_lats[i] for i in way_node_idxs]
                if statistics.mean(lon_lat_ratio) > 1:
                    continue

                # check avg speed in that 48h
                end_node_lon_center = (end_node.object.properties.bb.max.lon + end_node.object.properties.bb.min.lon) / 2.0

                avg_speed_window_h = self.config.avg_speed_t.total_seconds() / 3600
                deg_per_hr_avg = (end_node_lon_center - center_lons[node_idx]) / avg_speed_window_h
                # speed check. fast enough -> set to True
                if deg_per_hr_avg < self.config.avg_speed_min_deg_per_h:
                    for way_node_idx in way_node_idxs:
                        keep_node[way_node_idx] = True

                # us_nodes = get_upstream_nodes_at_time(t)
                # nodes_in_path_ab = cut(ds_nodes, us_nodes)

        # then generate_tracks()
        kept_edges = [edge for x, edge in enumerate(track.edges) if keep_node[x]]
        # regenerate tracks out of leftover edges. this also splits a track etc, and removes short ones

        # create return object
        reduced_tg = tracking_graph.tr_tech.pb_reference.TrackableSet()
        reduced_tg.CopyFrom(tracking_graph.set_desc)
        del reduced_tg.graph.edges[:]
        del reduced_tg.tracks[:]
        reduced_tg.graph.edges.extend(kept_edges)

        # clean up -> degenerate nodes (child without connections)
        dg = DataGraph(reduced_tg, self, fix=True)
        # if dg nodes same as input to this method, nothing changed, dont regenerate tracks.
        # otherwise stuck in infinite loop
        if len(dg.graph.edges) == len(tracking_graph.graph.edges):
            dg.set_desc.tracks.append(dg.set_desc.graph)
            print("Nothing changed.")
            return dg

        print("Regenerate.")
        dg.generate_tracks(apply_filter=True)

        # TODO
        #   redo clim, plot which WTs not considered

        return dg