#!/usr/bin/env python3
#################################################################################################################
# I'm reusing an old script with a command line interface, so basically I'm providing the arguments here based on
# some autosubmit variables
from pathlib import Path
# Get some autosubmit variables
WORKDIR = "%HPCROOTDIR%"
STARTDATE = "%SDATE%"
RUNDIR = Path(f"{WORKDIR}/{STARTDATE}/ideal")
arguments = ["--const", f"{RUNDIR.as_posix()}/init-test-ext_DOM01_ML_0001.nc",
             "--out", f"{RUNDIR.as_posix()}/extpar_DOM01.nc"]
##################################################################################################################

"""
create an extpar file from an idealized modelrun in order to restart it as real-data case.
"""

import argparse
import logging


import xarray
import numpy
from numba import jit
from enstools.io import read, write


def uppercase_add_variables(extpar):
    """
    ICON wants to read uppercase variable names

    Parameters
    ----------
    extpar: xarray.Dataset
            the output dataset
    """
    names = list(extpar.variables)
    lowercase_names = ["topography_c"]
    mapping = {}
    for name in names:
        if name not in lowercase_names:
            mapping[name] = name.upper()
    extpar = extpar.rename(mapping)
    return extpar


def add_LU_CLASS_FRACTION(extpar, lu_class):
    """
    Add a constant land use class to all land grid points

    Parameters
    ----------
    extpar: xarray.Dataset
            the output dataset

    lu_class: int
            number of the landuse class to be used for all grid points on land

    """
    fractions = xarray.DataArray(numpy.zeros((23, extpar.dims["ncells"]), dtype=numpy.float32),
                                 dims=("nclass_lu", "ncells"),
                                 name="LU_CLASS_FRACTION",
                                 attrs={"standard_name": "Landuse class fraction",
                                        "long_name": "Fraction of land use classes in target grid element",
                                        "CDI_grid_type": "unstructured"})

    @jit(nopython=True)
    def fill_fractions(fractions, lu_class, fr_land):
        """
        JIT-Compiled function that created the fractions

        Parameters
        ----------
        fractions: numpy.ndarray
        lu_class: int
        fr_land: numpy.ndarray
        """
        for cell in range(fr_land.shape[0]):
            if fr_land[cell] > 0:
                fractions[lu_class][cell] = fr_land[cell]
            # the remainder to one is always water
            fractions[20][cell] = 1.0 - fractions[lu_class][cell]

    fill_fractions(fractions.data, lu_class, extpar["FR_LAND"].values)

    logging.info("added generated LU_CLASS_FRACTION to extpar.")
    extpar["LU_CLASS_FRACTION"] = fractions


def add_const(extpar, name, value, only_land=True, standard_name=None, long_name=None, units=None, dtype=numpy.float32,
              monthly=False):
    """
    fill in constant values over land.
    """
    # variable is already there?
    if name in extpar:
        logging.info(f"{name} already in input file, no new variable create")
        return

    # create time-dimension for monthly values
    if monthly and "time" not in extpar:
        logging.info("adding time variable to extpar.")
        time = xarray.DataArray(numpy.empty(12, dtype=numpy.float32),
                                dims="time")
        time.values[:] = numpy.arange(11110111, 11111311, 100)
        extpar["time"] = time

    # dimensions for new array
    if monthly:
        new_shape = (12, extpar.dims["ncells"])
        new_dims = ("time", "ncells")
    else:
        new_shape = (extpar.dims["ncells"],)
        new_dims = ("ncells",)

    # create new array
    new = xarray.DataArray(numpy.empty(new_shape, dtype=dtype),
                           dims=new_dims,
                           name=name,
                           attrs={"CDI_grid_type": "unstructured"})
    if standard_name is not None:
        new.attrs["standard_name"] = standard_name
    if long_name is not None:
        new.attrs["long_name"] = long_name
    if units is not None:
        new.attrs["units"] = units

    # fill in values
    if only_land:
        if monthly:
            for m in range(12):
                new.values[m, :] = numpy.where(extpar["FR_LAND"] == 0, 0, value)
        else:
            new.values[:] = numpy.where(extpar["FR_LAND"] == 0, 0, value)
    else:
        new.values[:] = value

    logging.info(f"added generated {name} to extpar.")
    extpar[name] = new


def convert_SOILTYP(extpar):
    """
    SOILTYPE is supposed to be an 32bit Integer

    Parameters
    ----------
    extpar: xarray.Dataset
            the output dataset
    """
    if extpar["SOILTYP"].dtype == numpy.int32:
        logging.info("SOILTYP has already data type integer.")

    soiltyp = xarray.DataArray(extpar["SOILTYP"].values.astype(numpy.int32),
                               dims=("ncells"),
                               name="SOILTYP",
                               attrs={"standard_name": "soil_type",
                                      "long_name": "soil type"})

    extpar["SOILTYP"] = soiltyp
    logging.info("changed datatype of soiltype to int32.")


def remove_unsed(extpar):
    """
    remove unsed variables from the file

    Parameters
    ----------
    extpar
    """
    unused = ["CLON", "CLAT", "CLON_BNDS", "CLAT_BNDS"]
    for var in unused:
        if var in extpar:
            del extpar[var]
            logging.info(f"removed unsed variable {var} from extpar")


def copy_uuid(extpar, grid_file):
    """
    read the uuid argument from the grid file

    Parameters
    ----------
    extpar: xarray.Dataset
    grid_file
    """
    attrs_to_copy = ["uuidOfHGrid", "grid_file_uri", "number_of_grid_used", "uuidOfParHGrid", "ICON_grid_file_uri"]
    grid = read(grid_file)
    for attr in attrs_to_copy:
        if attr in grid.attrs:
            extpar.attrs[attr] = grid.attrs[attr]


if __name__ == "__main__":
    # parse command line arguments
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--const", required=True, help="""file with constant variables written ba a previous icon run. 
                                                       Expected are: 'depth_lk', 'emis_rad', 'fr_lake', 'fr_land', 
                                                       'topography_c', 'soiltyp', 'sso_stdh', 'sso_theta', 'sso_gamma', 
                                                       'sso_sigma'.""")
    parser.add_argument("--lu-class", type=int, default=0, help="""Land use class used for land grid points. 
                                           If not given, all gridpoints are set to class 0""")
    parser.add_argument("--ndvi-max", type=float, default=0.5, help="""NDVI_MAX for Land-Grid-points.""")
    parser.add_argument("--grid-file", required=False, help="if given, the UUID is read from this file.")
    parser.add_argument("--out", required=True, help="Name of the output file")
    args = parser.parse_args(arguments)

    # read const input file
    extpar = read(args.const)
    extpar = uppercase_add_variables(extpar)

    # create variable LU_CLASS_FRACTION
    add_LU_CLASS_FRACTION(extpar, args.lu_class)

    # change datatype of SOILTYP
    convert_SOILTYP(extpar)

    # add constant values over land for land use and type related variables
    add_const(extpar, "NDVI_MAX", args.ndvi_max,
              standard_name="normalized_difference_vegetation_index",
              long_name="Constant NDVI over land")
    add_const(extpar, "T_CL", 287.0,
              standard_name="soil_temperature",
              long_name="Constant Values for Soil temperature")
    add_const(extpar, "PLCOV_MX", 0.6,
              standard_name="vegetation_area_fraction_vegetation_period",
              long_name="Constant Values for Plant cover maximum due to land use data")
    add_const(extpar, "LAI_MX", 2.73,
              standard_name="leaf_area_index_vegetation_period",
              long_name="Constant Values for Leaf Area Index Maximum")
    add_const(extpar, "ROOTDP", 0.73,
              standard_name="root_depth",
              long_name="Constant Values for Root depth",
              units="m")
    add_const(extpar, "RSMIN", 215.0,
              standard_name="RSMIN",
              long_name="Constant Values for Minimal stomata resistence",
              units="s/m")
    add_const(extpar, "FOR_D", 0.06,
              standard_name="fraction_of_deciduous_forest_cover",
              long_name="Constant values for Fraction of deciduous forest")
    add_const(extpar, "FOR_E", 0.15,
              standard_name="fraction_of_evergreen_forest_cover",
              long_name="Constant values for Fraction of evergreen forest")
    add_const(extpar, "ICE", 0.1,
              standard_name="Ice fraction",
              long_name="Constant values for Ice fraction due to Land Use Data")
    add_const(extpar, "NDVI_MRAT", 0.78,
              standard_name="normalized_difference_vegetation_index",
              long_name="Constant values for (monthly) proportion of actual value/maximum normalized differential vegetation index",
              monthly=True)
    # add_const(extpar, "", 0.6,
    #          standard_name="",
    #          long_name="")

    # copy uuid from grid
    if args.grid_file is not None:
        copy_uuid(extpar, args.grid_file)

    # write extpar-file
    remove_unsed(extpar)
    logging.info(f"writing {args.out}...")
    # Icon requires the attribute "rawdata" to be set to either "GLC2000" or "GLOBCOVER2009"
    # It was discussed in the following issue: https://gitlab.physik.uni-muenchen.de/w2w/icon-examples/-/issues/2
    extpar.attrs["rawdata"] = "GLC2000"
    write(extpar, args.out)
    # extpar.to_netcdf(args.out)