import os.path

from matplotlib import pyplot as plt
import numpy as np
import matplotlib
import matplotlib.patches as patches
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from datetime import datetime


# plots the wave state (all wavetroughs given specific timestep in a set) ts: pb2.Timestep
def plot_wavetroughs(ts, fig_name, cv=None):
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 4), subplot_kw=dict(projection=ccrs.PlateCarree()))

    x_ticks = [-100, -95, -85, -75, -65, -55, -45, -35, -25, -15, -5, 5, 15, 25, 35]
    y_ticks = [0, 10, 20, 30]
    extent = [-100, -45, -35, 35]

    if cv is not None:
        levelfc = np.asarray([0, 0.5, 1, 2, 3]) * 1e-5
        cv.plot.contourf(levels=levelfc, vmin=0, extend='max', cmap='Blues')

    # generate plot per pressure level, per time step

    # colors per time step
    # min_time = wave_thr_list[0].time.astype('float64')
    # max_time = wave_thr_list[-1].time.astype('float64')
    # cmap = matplotlib.cm.get_cmap('rainbow')
    # color_wgts = np.linspace(0.0, 1.0, len(wave_thr_list))
    # colors = ['red', 'yellow', 'green', 'blue', 'purple']

    vt = ts.valid_time
    for obj_idx, obj in enumerate(ts.objects):
        # time64 = wave.time.astype('float64')
        # time_weight = (time64 - min_time) / (max_time - min_time) if max_time > min_time else 1.0

        line_pts = obj.properties.line_pts
        line = patches.Path([[p.lon, p.lat] for p in line_pts])
        patch = patches.PathPatch(line, linewidth=2, facecolor='none', edgecolor='red') # cmap(time_weight)
        ax.add_patch(patch)

    ax.coastlines()
    ax.add_feature(cfeature.BORDERS.with_scale('50m'))
    ax.set_extent(extent, crs=ccrs.PlateCarree())
    yt1 = ax.set_yticks(y_ticks, crs=ccrs.PlateCarree())
    xt1 = ax.set_xticks(x_ticks, crs=ccrs.PlateCarree())

    figure_name = fig_name.replace(':', '_') + '_aew_troughs.png'
    plt.savefig(figure_name, format='png')

    plt.figure().clear()
    plt.close()
    plt.cla()
    plt.clf()

    return figure_name


def plot_timesteps_from_desc(object_desc, cv=None):
    # plot for each set for each timestep everything detected.
    from enstools.feature.util.data_utils import get_subset_by_description

    for set_idx, od_set in enumerate(object_desc.sets):
        fn = "set_" + str(set_idx)  # TODO better to_string

        cv_set = get_subset_by_description(cv, od_set)

        for ts in od_set.timesteps:
            cv_st = cv_set.sel(time=ts.valid_time).cv
            fnt = fn + "_" + ts.valid_time
            print(fnt)
            # ts.validTime / .objects
            fout_name = plot_wavetroughs(ts, fnt, cv=cv_st)
            print(fout_name)

    return None

def plot_track_from_graph(track_desc, fig_name_prefix, cv=None):

    nodes = [node.this_node for node in track_desc.track_nodes]
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 4), subplot_kw=dict(projection=ccrs.PlateCarree()))

    x_ticks = [-100, -95, -85, -75, -65, -55, -45, -35, -25, -15, -5, 5, 15, 25, 35]
    y_ticks = [0, 10, 20, 30]
    extent = [-100, -45, -35, 35]

    if cv is not None:
        cv = cv.isel(time=0)
        levelfc = np.asarray([0, 0.5, 1, 2, 3]) * 1e-5
        cv.plot.contourf(levels=levelfc, vmin=0, extend='max', cmap='Blues')

    # generate plot per pressure level, per time step

    # colors per time step
    min_time = datetime.strptime(nodes[0].time, '%Y-%m-%dT%H:%M:%S').timestamp()
    max_time = datetime.strptime(nodes[-1].time, '%Y-%m-%dT%H:%M:%S').timestamp()
    cmap = matplotlib.cm.get_cmap('rainbow')
    color_wgts = np.linspace(0.0, 1.0, len(nodes))
    colors = ['red', 'yellow', 'green', 'blue', 'purple']

    for node_idx, node in enumerate(nodes):
        obj = node.object
        time_d = datetime.strptime(node.time, '%Y-%m-%dT%H:%M:%S').timestamp()
        time_weight = (time_d - min_time) / (max_time - min_time) if max_time > min_time else 1.0

        line_pts = obj.properties.line_pts
        line = patches.Path([[p.lon, p.lat] for p in line_pts])
        patch = patches.PathPatch(line, linewidth=2, facecolor='none', edgecolor=cmap(time_weight))
        ax.add_patch(patch)

    ax.coastlines()
    ax.add_feature(cfeature.BORDERS.with_scale('50m'))
    ax.set_extent(extent, crs=ccrs.PlateCarree())
    yt1 = ax.set_yticks(y_ticks, crs=ccrs.PlateCarree())
    xt1 = ax.set_xticks(x_ticks, crs=ccrs.PlateCarree())

    figure_name = fig_name_prefix.replace(':', '_') + '_troughs.png'
    plt.title(nodes[0].time + " - " + nodes[-1].time)
    plt.savefig(figure_name, format='png')

    plt.figure().clear()
    plt.close()
    plt.cla()
    plt.clf()

    return figure_name

    return None



def plot_tracks_from_graph(graph_desc, ds=None):

    from enstools.feature.util.data_utils import get_subset_by_description

    for set_idx, od_set in enumerate(graph_desc.sets):

        cv_set = get_subset_by_description(ds, od_set)
        for set_tr, track in enumerate(od_set.tracks):
            fn = os.path.expanduser('~') + "/phd/data/aew/plots/set_" + str(set_idx) + "_track_" + str(set_tr)
            plot_track_from_graph(track, fn, cv_set.cv)

    pass