identification.py 12.60 KiB
import dask.array
import google
import numpy as np
from abc import ABC, abstractmethod
import xarray as xr
from datetime import datetime
import bisect
import logging
from enstools.feature.util.data_utils import SplitDimension, print_lock
from itertools import product
from enstools.feature.util.data_utils import get_split_dimensions, get_time_split_dimension
class IdentificationStrategy(ABC):
"""
Base abstract class for new feature identification algorithms. Need to implement abstract methods precompute,
identify and postprocess. An identification strategy gets called for a dataset. The dataset is split along
dimensions to parallelize the identification. The abstract method identify() is called for each spatial block of the
dataset (2D, 3D), for example for each timestamp and for each ensemble member, and if processing_mode is set to
'2d' also for each level. Beforehand, precompute() is called once, here one-time precomputations can be done.
"""
def __init__(self):
pass
@abstractmethod
def precompute(self, dataset: xr.Dataset, **kwargs):
"""
This is called before the identification. Here, precomputations can be done. If the user wants
to add fields to the dataset (for example to output it later), it should be added here. The parallelized
identify() needs to know the structure of the final dataset beforehand.
Parameters
----------
dataset : xarray.Dataset
The dataset to process.
Returns
-------
The dataset with additional initialized fields which will be populated in identify().
"""
# self.precomputed_stuff = self.foo()...
# dataset['bar'] = xr.zeros_like(dataset['foo'], dtype=int)
# ...
return dataset
@abstractmethod
def identify(self, process_data: xr.Dataset, **kwargs):
"""
Abstract identification method. This gets called in parallel for each spatial block.
Parameters
----------
process_data : xarray.Dataset
The subset of data which is processed in parallel,
datasets should contain same amount of dimensions as imput dataset, but parallelized ones have length 1
dataarrays ins dataset are squeezed accordingly to 2D or 3D fields.
Returns
-------
process_data : xarray.Dataset
The (altered) dataset
objects : iterable of pb2.Object
list of objects of corresponding protobuf type, see template for an example
"""
objects = []
return process_data, objects
def identify_block(self, data_block: xr.Dataset, index_accessor, **kwargs):
"""
Internal intermediate method which gets called for every data block and calls identify() on it. It also does
some preparation and object result handling. It adds the identified objects into the protobuf struct.
Parameters
----------
data_block : xarray.Dataset
The data block to be processed.
index_accessor: xarray.DataArray
DataArray of tuples specifying the set and timestep indices of this block's pb2 objects
Returns
-------
The (altered) data block.
"""
# get order of dimensions to re-order them after
dim_orders = dict()
for v in data_block.data_vars:
dim_orders[v] = data_block[v].dims
# called in parallel from map_blocks
split_dims = kwargs.keys()
split_string = ''
for dim in split_dims:
dim_val = data_block.coords[dim].data[0]
if isinstance(dim_val, np.datetime64):
dt = datetime.utcfromtimestamp(dim_val.astype('O') / 1e9)
dim_val = dt.replace(microsecond=0).isoformat()
split_string += str(dim) + ":" + str(dim_val) + " \t"
# split_string = '; '.join([str(dim) + ": " + str(data_block.coords[dim].data[0]) for dim in split_dims])
print_lock("Start processing data block with dimensions:\t" + split_string)
# get mapping which data fields' coords are getting squeezed (time, ens, ...). some fields may not have them all.
squeeze_da_dim = []
for split_dim in split_dims:
for da_str in data_block.data_vars:
if split_dim in data_block[da_str].coords:
squeeze_da_dim.append((da_str, split_dim))
data_block[da_str] = data_block[da_str].squeeze(dim=[split_dim])
access_indices = index_accessor.values.item()
data_block, object_block = self.identify(data_block)
self.pb_dataset.sets[access_indices[0]].timesteps[access_indices[1]].objects.extend(object_block)
for unsqueeze_da, unsqueeze_dim in squeeze_da_dim:
data_block[unsqueeze_da] = data_block[unsqueeze_da].expand_dims(dim=unsqueeze_dim)
# restore dim order
for v in data_block.data_vars:
data_block[v] = data_block[v].transpose(*(dim_orders[v]))
return data_block
def execute(self, dataset: xr.Dataset, **kwargs):
"""
Execute the identification of the features in the dataset. The properties of each feature are collected via the
abstract identify() in parallel and here collected.
Parameters
----------
dataset : xarray.Dataset
The dataset.
Returns
-------
pb_dataset : pb2_structure.DatasetDescription
The description struct containing all detected feature descriptions.
ds_mapped : xarray.Dataset
Dataset after being mapped via identify()
"""
# build meta data for protobuf structure
if not hasattr(self, 'pb_reference'):
print("The protobuf type has not been set in the ID strategy.")
print("Check the templates, the IdentificationStrategy.__init__() needs to set the pb2 type.")
exit(1)
self.pb_dataset = self.pb_reference.DatasetDescription()
# precomputations according to strategy - may alter the dataset
prec_ds = self.precompute(dataset)
if prec_ds is not None:
dataset = prec_ds
print("Precomputation done.")
# from enstools.misc import get_ensemble_dim, get_time_dim
# from enstools.feature.util.enstools_utils import get_vertical_dim, get_init_time_dim
# dimension names along which data is split (init_time, level, ens, ...)
split_dimensions = get_split_dimensions(dataset,
self.processing_mode) # {ENUM.init_time: "init_time", len, is_scalar}
valid_time_sd = get_time_split_dimension(dataset)
# make sure is sorted so indices work out at the end
dataset = dataset.sortby([sd.name for sd in split_dimensions])
# field of indices into descriptions
rechunk_coords = split_dimensions # [dataset.coords[sd.name] for sd in split_dimensions]
if not valid_time_sd.is_scalar:
rechunk_coords.append(valid_time_sd)
index_field = xr.DataArray(coords=[dataset.coords[sd.name] for sd in rechunk_coords]).astype(dtype=tuple)
# rechunk along parallelized dimensions
ds_rechunked = dataset.chunk(dict((sd.name, 1) for sd in rechunk_coords)) # sizes: 1
# init pb_dataset. set member + valid time for each element in lists if existent
if self.processing_mode != '2d':
level_dim_len = 1
# get the split dimensions out of the dimensions in the data set
split_dims_set = [sd for sd in split_dimensions if sd.dim != SplitDimension.SplitDimensionDim.VALID_TIME]
split_dims_set_enum = [sd.dim for sd in split_dims_set]
prod_set = [range(sd.size) for sd in split_dims_set]
iter_prod = product(*prod_set)
# for each split dimension search according dimension in enum
for elem in iter_prod:
cur_indices = dict()
s_i = self.pb_dataset.sets.add()
# init time
# TODO simplify this?
if SplitDimension.SplitDimensionDim.INIT_TIME in split_dims_set_enum:
init_time_idx = split_dims_set_enum.index(SplitDimension.SplitDimensionDim.INIT_TIME)
init_time_cur = dataset.coords[split_dims_set[init_time_idx].name].data[elem[init_time_idx]]
s_i.init_time = str(np.datetime_as_string(init_time_cur, unit='s'))
cur_indices[split_dims_set[init_time_idx].name] = s_i.init_time
if SplitDimension.SplitDimensionDim.ENSEMBLE_MEMBER in split_dims_set_enum:
ens_idx = split_dims_set_enum.index(SplitDimension.SplitDimensionDim.ENSEMBLE_MEMBER)
ens_array = dataset.coords[split_dims_set[ens_idx].name].data
if not np.issubdtype(ens_array.dtype.type, np.integer):
print("Ensemble dimension not of integer type, parsing to np.int32.")
dataset = dataset.assign_coords(coords={split_dims_set[ens_idx].name: ens_array.astype(dtype=np.int32)})
s_i.member = dataset.coords[split_dims_set[ens_idx].name].data[elem[ens_idx]]
cur_indices[split_dims_set[ens_idx].name] = s_i.member
if SplitDimension.SplitDimensionDim.LEVEL in split_dims_set_enum:
level_idx = split_dims_set_enum.index(SplitDimension.SplitDimensionDim.LEVEL)
s_i.level = dataset.coords[split_dims_set[level_idx].name].data[elem[level_idx]]
cur_indices[split_dims_set[level_idx].name] = s_i.level
for t in range(valid_time_sd.size):
timestep = s_i.timesteps.add()
valid_time_cur = dataset.coords[valid_time_sd.name].data[t] if not valid_time_sd.is_scalar else dataset[
valid_time_sd.name]
timestep.valid_time = str(np.datetime_as_string(valid_time_cur, unit='s'))
if not valid_time_sd.is_scalar:
cur_indices[valid_time_sd.name] = valid_time_cur
# ID to access this object block in protobuf set
sets_id = len(self.pb_dataset.sets) - 1
tl_id = len(s_i.timesteps) - 1
index_field.loc[cur_indices] = (sets_id, tl_id)
# dict of values of the split dimensions -> to access them within each block
dict_split_arrays = dict((dim.name, dataset.coords[dim.name]) for dim in split_dimensions)
# map the blocks, computation here
mb = xr.map_blocks(self.identify_block, ds_rechunked, args=[index_field], kwargs=dict_split_arrays,
template=ds_rechunked)
print("Start")
ds_mapped = mb.compute(num_workers=1) # num_workers=4
# postprocess
ds_mapped, self.pb_dataset = self.postprocess(ds_mapped, self.pb_dataset)
return self.pb_dataset, ds_mapped
@abstractmethod
def postprocess(self, dataset: xr.Dataset, pb2_desc, **kwargs):
"""
Abstract method for postprocessing operations.
It takes as input the dataset and the corresponding descriptions, and returns them in an arbitrary altered
state.
Parameters
----------
dataset : xarray.Dataset
The dataset.
pb2_desc : pb2_structure.DatasetDescription
The description.
kwargs
Returns
-------
As inputs, but can be altered.
"""
return dataset, pb2_desc
def make_ids_unique(self, data_desc): # TODO abstract?
"""
By default, IDs are unique per processing block.
This method changes IDs of objects to be unique across data set.
It does not ensure that IDs are kept, even if they appear only once.
Parameters
----------
data_desc: dataset description
Returns
-------
The same description, but IDs are changed to be unique across data set.
"""
next_id = 1
for desc_set in data_desc.sets:
desc_times = desc_set.timesteps
for timestep in desc_times:
for obj in timestep.objects:
obj.id = next_id
next_id += 1
# TODO also return mapping?
return data_desc
def get_new_object(self, id=None):
"""
Create a new object to be populated
Parameters
----------
id : int
id of object to set, optional
Returns
-------
The new object that can be populated
"""
# infer from type
obj = self.pb_reference.Object()
# set ID if given
if id is not None:
obj.id = id
return obj