Skip to content
Snippets Groups Projects
identification.py 12.60 KiB
import dask.array
import google
import numpy as np
from abc import ABC, abstractmethod
import xarray as xr
from datetime import datetime
import bisect
import logging
from enstools.feature.util.data_utils import SplitDimension, print_lock
from itertools import product
from enstools.feature.util.data_utils import get_split_dimensions, get_time_split_dimension


class IdentificationStrategy(ABC):
    """
    Base abstract class for new feature identification algorithms. Need to implement abstract methods precompute,
    identify and postprocess. An identification strategy gets called for a dataset. The dataset is split along
    dimensions to parallelize the identification. The abstract method identify() is called for each spatial block of the
    dataset (2D, 3D), for example for each timestamp and for each ensemble member, and if processing_mode is set to
    '2d' also for each level. Beforehand, precompute() is called once, here one-time precomputations can be done.
    """

    def __init__(self):
        pass

    @abstractmethod
    def precompute(self, dataset: xr.Dataset, **kwargs):
        """
        This is called before the identification. Here, precomputations can be done. If the user wants
        to add fields to the dataset (for example to output it later), it should be added here. The parallelized
        identify() needs to know the structure of the final dataset beforehand.

        Parameters
        ----------
        dataset : xarray.Dataset
                The dataset to process.

        Returns
        -------
        The dataset with additional initialized fields which will be populated in identify().
        """
        # self.precomputed_stuff = self.foo()...
        # dataset['bar'] = xr.zeros_like(dataset['foo'], dtype=int)
        # ...
        return dataset

    @abstractmethod
    def identify(self, process_data: xr.Dataset, **kwargs):
        """
        Abstract identification method. This gets called in parallel for each spatial block.

        Parameters
        ----------
        process_data : xarray.Dataset
                The subset of data which is processed in parallel,
                datasets should contain same amount of dimensions as imput dataset, but parallelized ones have length 1
                dataarrays ins dataset are squeezed accordingly to 2D or 3D fields.

        Returns
        -------
        process_data : xarray.Dataset
                The (altered) dataset
        objects : iterable of pb2.Object
            list of objects of corresponding protobuf type, see template for an example
        """
        objects = []
        return process_data, objects

    def identify_block(self, data_block: xr.Dataset, index_accessor, **kwargs):
        """
        Internal intermediate method which gets called for every data block and calls identify() on it. It also does
        some preparation and object result handling. It adds the identified objects into the protobuf struct.

        Parameters
        ----------
        data_block : xarray.Dataset
                The data block to be processed.
        index_accessor: xarray.DataArray
                DataArray of tuples specifying the set and timestep indices of this block's pb2 objects
        Returns
        -------
        The (altered) data block.
        """
        # get order of dimensions to re-order them after
        dim_orders = dict()
        for v in data_block.data_vars:
            dim_orders[v] = data_block[v].dims

        # called in parallel from map_blocks
        split_dims = kwargs.keys()
        split_string = ''
        for dim in split_dims:
            dim_val = data_block.coords[dim].data[0]
            if isinstance(dim_val, np.datetime64):
                dt = datetime.utcfromtimestamp(dim_val.astype('O') / 1e9)
                dim_val = dt.replace(microsecond=0).isoformat()
            split_string += str(dim) + ":" + str(dim_val) + " \t"
        # split_string = '; '.join([str(dim) + ": " + str(data_block.coords[dim].data[0]) for dim in split_dims])
        print_lock("Start processing data block with dimensions:\t" + split_string)

        # get mapping which data fields' coords are getting squeezed (time, ens, ...). some fields may not have them all.
        squeeze_da_dim = []
        for split_dim in split_dims:
            for da_str in data_block.data_vars:
                if split_dim in data_block[da_str].coords:
                    squeeze_da_dim.append((da_str, split_dim))
                    data_block[da_str] = data_block[da_str].squeeze(dim=[split_dim])

        access_indices = index_accessor.values.item()
    
        data_block, object_block = self.identify(data_block)
        self.pb_dataset.sets[access_indices[0]].timesteps[access_indices[1]].objects.extend(object_block)

        for unsqueeze_da, unsqueeze_dim in squeeze_da_dim:
            data_block[unsqueeze_da] = data_block[unsqueeze_da].expand_dims(dim=unsqueeze_dim)

        # restore dim order
        for v in data_block.data_vars:
            data_block[v] = data_block[v].transpose(*(dim_orders[v]))
        
        return data_block

    def execute(self, dataset: xr.Dataset, **kwargs):
        """
        Execute the identification of the features in the dataset. The properties of each feature are collected via the
        abstract identify() in parallel and here collected.

        Parameters
        ----------
        dataset : xarray.Dataset
                The dataset.

        Returns
        -------
        pb_dataset : pb2_structure.DatasetDescription
                The description struct containing all detected feature descriptions.
        ds_mapped : xarray.Dataset
                Dataset after being mapped via identify()
        """
        # build meta data for protobuf structure

        if not hasattr(self, 'pb_reference'):
            print("The protobuf type has not been set in the ID strategy.")
            print("Check the templates, the IdentificationStrategy.__init__() needs to set the pb2 type.")
            exit(1)

        self.pb_dataset = self.pb_reference.DatasetDescription()

        # precomputations according to strategy - may alter the dataset
        prec_ds = self.precompute(dataset)
        if prec_ds is not None:
            dataset = prec_ds
        print("Precomputation done.")

        # from enstools.misc import get_ensemble_dim, get_time_dim
        # from enstools.feature.util.enstools_utils import get_vertical_dim, get_init_time_dim

        # dimension names along which data is split (init_time, level, ens, ...)
        split_dimensions = get_split_dimensions(dataset,
                                                self.processing_mode)  # {ENUM.init_time: "init_time", len, is_scalar}
        valid_time_sd = get_time_split_dimension(dataset)

        # make sure is sorted so indices work out at the end
        dataset = dataset.sortby([sd.name for sd in split_dimensions])

        # field of indices into descriptions
        rechunk_coords = split_dimensions  # [dataset.coords[sd.name] for sd in split_dimensions]
        if not valid_time_sd.is_scalar:
            rechunk_coords.append(valid_time_sd)
        index_field = xr.DataArray(coords=[dataset.coords[sd.name] for sd in rechunk_coords]).astype(dtype=tuple)

        # rechunk along parallelized dimensions
        ds_rechunked = dataset.chunk(dict((sd.name, 1) for sd in rechunk_coords))  # sizes: 1

        # init pb_dataset. set member + valid time for each element in lists if existent

        if self.processing_mode != '2d':
            level_dim_len = 1

        # get the split dimensions out of the dimensions in the data set
        split_dims_set = [sd for sd in split_dimensions if sd.dim != SplitDimension.SplitDimensionDim.VALID_TIME]
        split_dims_set_enum = [sd.dim for sd in split_dims_set]
        prod_set = [range(sd.size) for sd in split_dims_set]

        iter_prod = product(*prod_set)

        # for each split dimension search according dimension in enum
        for elem in iter_prod:
            cur_indices = dict()
            s_i = self.pb_dataset.sets.add()

            # init time
            # TODO simplify this?
            if SplitDimension.SplitDimensionDim.INIT_TIME in split_dims_set_enum:
                init_time_idx = split_dims_set_enum.index(SplitDimension.SplitDimensionDim.INIT_TIME)
                init_time_cur = dataset.coords[split_dims_set[init_time_idx].name].data[elem[init_time_idx]]
                s_i.init_time = str(np.datetime_as_string(init_time_cur, unit='s'))
                cur_indices[split_dims_set[init_time_idx].name] = s_i.init_time

            if SplitDimension.SplitDimensionDim.ENSEMBLE_MEMBER in split_dims_set_enum:
                ens_idx = split_dims_set_enum.index(SplitDimension.SplitDimensionDim.ENSEMBLE_MEMBER)
                ens_array = dataset.coords[split_dims_set[ens_idx].name].data

                if not np.issubdtype(ens_array.dtype.type, np.integer):
                    print("Ensemble dimension not of integer type, parsing to np.int32.")
                    dataset = dataset.assign_coords(coords={split_dims_set[ens_idx].name: ens_array.astype(dtype=np.int32)})

                s_i.member = dataset.coords[split_dims_set[ens_idx].name].data[elem[ens_idx]]
                cur_indices[split_dims_set[ens_idx].name] = s_i.member

            if SplitDimension.SplitDimensionDim.LEVEL in split_dims_set_enum:
                level_idx = split_dims_set_enum.index(SplitDimension.SplitDimensionDim.LEVEL)
                s_i.level = dataset.coords[split_dims_set[level_idx].name].data[elem[level_idx]]
                cur_indices[split_dims_set[level_idx].name] = s_i.level

            for t in range(valid_time_sd.size):
                timestep = s_i.timesteps.add()

                valid_time_cur = dataset.coords[valid_time_sd.name].data[t] if not valid_time_sd.is_scalar else dataset[
                    valid_time_sd.name]
                timestep.valid_time = str(np.datetime_as_string(valid_time_cur, unit='s'))
                if not valid_time_sd.is_scalar:
                    cur_indices[valid_time_sd.name] = valid_time_cur

                # ID to access this object block in protobuf set
                sets_id = len(self.pb_dataset.sets) - 1
                tl_id = len(s_i.timesteps) - 1
                index_field.loc[cur_indices] = (sets_id, tl_id)

        # dict of values of the split dimensions -> to access them within each block
        dict_split_arrays = dict((dim.name, dataset.coords[dim.name]) for dim in split_dimensions)

        # map the blocks, computation here
        mb = xr.map_blocks(self.identify_block, ds_rechunked, args=[index_field], kwargs=dict_split_arrays,
                           template=ds_rechunked)
        print("Start")
        ds_mapped = mb.compute(num_workers=1)  # num_workers=4

        # postprocess
        ds_mapped, self.pb_dataset = self.postprocess(ds_mapped, self.pb_dataset)

        return self.pb_dataset, ds_mapped

    @abstractmethod
    def postprocess(self, dataset: xr.Dataset, pb2_desc, **kwargs):
        """
        Abstract method for postprocessing operations.
        It takes as input the dataset and the corresponding descriptions, and returns them in an arbitrary altered
        state.

        Parameters
        ----------
        dataset : xarray.Dataset
                The dataset.
        pb2_desc : pb2_structure.DatasetDescription
                The description.
        kwargs

        Returns
        -------
        As inputs, but can be altered.
        """

        return dataset, pb2_desc

    def make_ids_unique(self, data_desc): # TODO abstract?
        """
        By default, IDs are unique per processing block.
        This method changes IDs of objects to be unique across data set.
        It does not ensure that IDs are kept, even if they appear only once.
        Parameters
        ----------
        data_desc: dataset description

        Returns
        -------
        The same description, but IDs are changed to be unique across data set.
        """
        next_id = 1

        for desc_set in data_desc.sets:
            desc_times = desc_set.timesteps
            for timestep in desc_times:
                for obj in timestep.objects:
                    obj.id = next_id
                    next_id += 1

        # TODO also return mapping?
        return data_desc

    def get_new_object(self, id=None):
        """
        Create a new object to be populated

        Parameters
        ----------
        id : int
            id of object to set, optional

        Returns
        -------
        The new object that can be populated
        """

        # infer from type
        obj = self.pb_reference.Object()
        # set ID if given
        if id is not None:
            obj.id = id
        return obj