# -*- coding: utf-8 -*-

import numpy as np
import xarray as xr
import math
from enstools.feature.util.enstools_utils import get_u_var, get_v_var, get_vertical_dim, get_longitude_dim, \
    get_latitude_dim


# calculates the dx's and dy's for each grid cell
# takes list of latitudes and longitudes as input and returns field with dimensions len(lats) x len(lons)
def calculate_dx_dy(lats, lons):
    lat_res = lats[1] - lats[0]
    lon_res = lons[1] - lons[0]
    nlat = len(lats)
    nlon = len(lons)
    lon_percent = (lon_res * nlon) / 360.0  # percentage of global longitudes. e.g. if 110W to 70E -> 0.5
    lat_percent = (lat_res * nlat) / 180.0  # percentage of global latitudes
    Re = 6378100

    # field filled with the latitude value at each grid point
    lat_f = np.tile(lats, (nlon, 1)).transpose()
    # field filled with the longitude value at each grid point
    lon_f = np.tile(lons, (nlat, 1))
    # dx: circumference at current latitude  *  percentage of longitudes in dataset  /  amount of lon grid points
    dx = np.cos(lat_f * math.pi / 180.0) * 2.0 * math.pi * Re * lon_percent / nlon
    # dy: constant everywhere: half circumference  *  latitude percentage /  amount of lat points
    dy = (lat_f * 0) + lat_percent * math.pi * Re / nlat

    return dx, dy


def compute_cv(dataset, u_str, v_str, cv_str):
    xr.set_options(keep_attrs=True)  # assign coords keeps attributes

    lon_str = get_longitude_dim(dataset)
    lat_str = get_latitude_dim(dataset)

    lats = dataset.coords[lat_str].data
    lons = dataset.coords[lon_str].data

    # be careful if your data is non-continuous by the chosen window (e.g. +120E..-120W)
    # reorder axis for efficient numpy broadcasting
    dataset = dataset.sortby(lat_str)
    dataset = dataset.transpose(..., lat_str, lon_str)

    dataset[cv_str] = xr.zeros_like(dataset[u_str], dtype=float)
    dataset[cv_str].attrs = {'long_name': "Curvature Vorticity", 'units': "s**-1"}

    # calculate dx and dy for each cell in grid in meters
    dx, dy = calculate_dx_dy(lats, lons)

    # Relative Vorticity = dv/dx - du/dy, use central differences
    u_arr = dataset[u_str].data
    v_arr = dataset[v_str].data
    ndims = u_arr.ndim

    # roll axes so we have a reference to each cell's neighbours
    # ax[-1] = lon, ax[-2] = lat
    # works as expected if longitude band is full 360 degrees. Rolling does exactly that.
    v_xp = np.roll(v_arr, -1, axis=ndims - 1)  # roll -1 to get +1. v_xp = v(x+1)
    v_xm = np.roll(v_arr, 1, axis=ndims - 1)
    v_yp = np.roll(v_arr, -1, axis=ndims - 2)
    v_ym = np.roll(v_arr, 1, axis=ndims - 2)

    u_xp = np.roll(u_arr, -1, axis=ndims - 1)  # roll -1 to get +1
    u_xm = np.roll(u_arr, 1, axis=ndims - 1)
    u_yp = np.roll(u_arr, -1, axis=ndims - 2)
    u_ym = np.roll(u_arr, 1, axis=ndims - 2)

    # central differences:  (V(x+1)-V(x-1)) / 2x - (U(y+1)-U(y-1)) / 2y
    RV = ((v_xp - v_xm) / (2 * dx)) - ((u_yp - u_ym) / (2 * dy))
    # results verified to ERA5 vo-data

    # split into shear + curvature: RV = -dV/dn + V/R
    # shear vorticity:
    # rate of change of wind speed in the direction of flow
    # -dV/dn

    # wind direction of this cell
    ref_angle = np.arctan2(v_arr, u_arr)

    ### METHOD 1 ### from analytical standpoint

    sin_angle = np.sin(ref_angle)
    cos_angle = np.cos(ref_angle)
    # here is where the magic happens...
    # to cartesian: dV/dn = - dV/dx * sin(phi) + dV/dy * cos(phi) ## n is normal to V, split it into x,y respecting the
    # n-direction, where phi is the rotation angle of the natural coordinate system, this also is the angle of the
    # wind vector
    #
    # we get the magnitude of shear as the projection of dV/dn onto the vector itself (direction e_t)
    # with e_t = cos(phi)*e_x + sin(phi)*e_y , so:
    # dV/dn * e_t = du/dx * (-sin*cos) + du/dy cos^2 + dv/dx *(-sin^2) + dv/dy (cos*sin)
    SV = (u_xp - u_xm) / (2 * dx) * (-sin_angle * cos_angle) + \
         (u_yp - u_ym) / (2 * dy) * (cos_angle ** 2) + \
         (v_xp - v_xm) / (2 * dx) * (-(sin_angle ** 2)) + \
         (v_yp - v_ym) / (2 * dy) * (sin_angle * cos_angle)
    SV = -SV  # -dV/dn

    # remainder is CV
    CV = RV - SV
    dataset[cv_str].values = CV

    xr.set_options(keep_attrs='default')  # assign coords keeps attributes
    return dataset


import math
from enstools.feature.util.enstools_utils import get_u_var, get_v_var, get_vertical_dim, get_longitude_dim, \
    get_latitude_dim



# calculates the dx's and dy's for each grid cell
# takes list of latitudes and longitudes as input and returns field with dimensions len(lats) x len(lons)
def calculate_dx_dy(lats, lons):
    lat_res = lats[1] - lats[0]
    lon_res = lons[1] - lons[0]
    nlat = len(lats)
    nlon = len(lons)
    lon_percent = (lon_res * nlon) / 360.0  # percentage of global longitudes. e.g. if 110W to 70E -> 0.5
    lat_percent = (lat_res * nlat) / 180.0  # percentage of global latitudes
    Re = 6378100

    # field filled with the latitude value at each grid point
    lat_f = np.tile(lats, (nlon, 1)).transpose()
    # field filled with the longitude value at each grid point
    lon_f = np.tile(lons, (nlat, 1))
    # dx: circumference at current latitude  *  percentage of longitudes in dataset  /  amount of lon grid points
    dx = np.cos(lat_f * math.pi / 180.0) * 2.0 * math.pi * Re * lon_percent / nlon
    # dy: constant everywhere: half circumference  *  latitude percentage /  amount of lat points
    dy = (lat_f * 0) + lat_percent * math.pi * Re / nlat

    return dx, dy


def compute_cv(dataset, u_str, v_str, cv_str):
    xr.set_options(keep_attrs=True)  # assign coords keeps attributes

    lon_str = get_longitude_dim(dataset)
    lat_str = get_latitude_dim(dataset)

    lats = dataset.coords[lat_str].data
    lons = dataset.coords[lon_str].data

    # be careful if your data is non-continuous by the chosen window (e.g. +120E..-120W)
    # reorder axis for efficient numpy broadcasting
    dataset = dataset.sortby(lat_str)
    dataset = dataset.transpose(..., lat_str, lon_str)

    dataset[cv_str] = xr.zeros_like(dataset[u_str], dtype=float)
    dataset[cv_str].attrs = {'long_name': "Curvature Vorticity", 'units': "s**-1"}

    # calculate dx and dy for each cell in grid in meters
    dx, dy = calculate_dx_dy(lats, lons)

    # Relative Vorticity = dv/dx - du/dy, use central differences
    u_arr = dataset[u_str].data
    v_arr = dataset[v_str].data
    ndims = u_arr.ndim

    # roll axes so we have a reference to each cell's neighbours
    # ax[-1] = lon, ax[-2] = lat
    # works as expected if longitude band is full 360 degrees. Rolling does exactly that.
    v_xp = np.roll(v_arr, -1, axis=ndims - 1)  # roll -1 to get +1. v_xp = v(x+1)
    v_xm = np.roll(v_arr, 1, axis=ndims - 1)
    v_yp = np.roll(v_arr, -1, axis=ndims - 2)
    v_ym = np.roll(v_arr, 1, axis=ndims - 2)

    u_xp = np.roll(u_arr, -1, axis=ndims - 1)  # roll -1 to get +1
    u_xm = np.roll(u_arr, 1, axis=ndims - 1)
    u_yp = np.roll(u_arr, -1, axis=ndims - 2)
    u_ym = np.roll(u_arr, 1, axis=ndims - 2)

    # central differences:  (V(x+1)-V(x-1)) / 2x - (U(y+1)-U(y-1)) / 2y
    RV = ((v_xp - v_xm) / (2 * dx)) - ((u_yp - u_ym) / (2 * dy))
    # results verified to ERA5 vo-data

    # split into shear + curvature: RV = -dV/dn + V/R
    # shear vorticity:
    # rate of change of wind speed in the direction of flow
    # -dV/dn

    # wind direction of this cell
    ref_angle = np.arctan2(v_arr, u_arr)

    ### METHOD 1 ### from analytical standpoint

    sin_angle = np.sin(ref_angle)
    cos_angle = np.cos(ref_angle)
    # here is where the magic happens...
    # to cartesian: dV/dn = - dV/dx * sin(phi) + dV/dy * cos(phi) ## n is normal to V, split it into x,y respecting the
    # n-direction, where phi is the rotation angle of the natural coordinate system, this also is the angle of the
    # wind vector
    #
    # we get the magnitude of shear as the projection of dV/dn onto the vector itself (direction e_t)
    # with e_t = cos(phi)*e_x + sin(phi)*e_y , so:
    # dV/dn * e_t = du/dx * (-sin*cos) + du/dy cos^2 + dv/dx *(-sin^2) + dv/dy (cos*sin)
    SV = (u_xp - u_xm) / (2 * dx) * (-sin_angle * cos_angle) + \
         (u_yp - u_ym) / (2 * dy) * (cos_angle ** 2) + \
         (v_xp - v_xm) / (2 * dx) * (-(sin_angle ** 2)) + \
         (v_yp - v_ym) / (2 * dy) * (sin_angle * cos_angle)
    SV = -SV  # -dV/dn

    # remainder is CV
    CV = RV - SV
    dataset[cv_str].values = CV

    xr.set_options(keep_attrs='default')  # assign coords keeps attributes
    return dataset


def populate_object(obj_props, path, cfg):
    # obj_props.num_nodes = len(path)
    # fill the properties defined in the .proto file.

    # first, remove vertices out of area
    filtered_vertices = []
    for v_idx, v in enumerate(path.vertices):
        lon = v[0]
        lat = v[1]

        # dont use vertex not fulfilling area restriction
        if cfg.is_point_in_area(lon, lat):
            filtered_vertices.append(v)

    # vertices of path
    min_lat, max_lat, min_lon, max_lon = 90.0, -90.0, 180.0, -180.0
    dist_deg = 0.0
    for v_idx, v in enumerate(filtered_vertices):
        line_pt = obj_props.line_pts.add()
        line_pt.lon = v[0]
        line_pt.lat = v[1]

        if v[0] < min_lon:
            min_lon = v[0]
        if v[0] > max_lon:
            max_lon = v[0]
        if v[1] < min_lat:
            min_lat = v[1]
        if v[1] > max_lat:
            max_lat = v[1]

        if v_idx > 0:
            dist_deg = dist_deg + (
                        ((filtered_vertices[v_idx - 1][0] - v[0]) ** 2 + (filtered_vertices[v_idx - 1][1] - v[1]) ** 2) ** 0.5)

    # bounding box
    obj_props.bb.min.lat = min_lat
    obj_props.bb.min.lon = min_lon
    obj_props.bb.max.lat = max_lat
    obj_props.bb.max.lon = max_lon

    obj_props.length_deg = dist_deg

# identify troughs in data (should contain U,V,cv), based on the cv climatology
# def identify_troughs(data, cv_clim, cfg):