From 57d0d4a2d2d5bb0274fb89af9d7d26e81658a07d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Oriol=20Tint=C3=B3?= <oriol.tinto@lmu.de> Date: Thu, 2 Feb 2023 11:46:47 +0100 Subject: [PATCH] Fix dataset chunking. --- enstools/encoding/dataset_encoding.py | 108 +++++++++++++++++++++---- enstools/encoding/variable_encoding.py | 26 +++++- 2 files changed, 117 insertions(+), 17 deletions(-) diff --git a/enstools/encoding/dataset_encoding.py b/enstools/encoding/dataset_encoding.py index 8a4ff90..6ce3336 100644 --- a/enstools/encoding/dataset_encoding.py +++ b/enstools/encoding/dataset_encoding.py @@ -1,10 +1,12 @@ +import chunk import os from copy import deepcopy from pathlib import Path -from typing import Union, Dict +from typing import Hashable, Union, Dict import xarray import yaml +import numpy as np from . import rules from .errors import InvalidCompressionSpecification @@ -12,7 +14,50 @@ from .variable_encoding import _Mapping, parse_variable_specification, Encoding, NullEncoding -def compression_dictionary_to_string(compression_dictionary: dict) -> str: +def convert_size(size_bytes): + import math + if size_bytes < 0: + prefix = "-" + size_bytes = -size_bytes + else: + prefix = "" + + if size_bytes == 0: + return "0B" + size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB") + i = int(math.floor(math.log(size_bytes, 1024))) + p = math.pow(1024, i) + s = round(size_bytes / p, 2) + return f"{prefix}{s}{size_name[i]}" + + +def convert_to_bytes(size_string): + """ + This function converts a given size string (e.g. '5MB') to the number of bytes. + + Args: + size_string (str): The size string to be converted (e.g. '5MB') + + Returns: + int: The number of bytes. + """ + import re + size_string = size_string.upper() + digits = re.match(r'\d+(?:\.\d+)?', size_string) # matches digits and optionally a dot followed by more digits + if digits: + digits = digits.group() # get the matched digits + else: + raise ValueError(f"Invalid size string: {size_string}") + unit = size_string.replace(digits, "") + size_name_dict = {'B': 0, 'KB': 1, 'MB': 2, 'GB': 3, 'TB': 4, 'PB': 5, 'EB': 6, 'ZB': 7, 'YB': 8} + if unit in size_name_dict: + size_bytes = float(digits) * np.power(1024, size_name_dict[unit]) + else: + raise ValueError(f"Invalid size string: {size_string}") + return int(size_bytes) + + +def compression_dictionary_to_string(compression_dictionary: Dict[str, str]) -> str: """ Convert a dictionary containing multiple entries to a single line specification """ @@ -20,7 +65,7 @@ def compression_dictionary_to_string(compression_dictionary: dict) -> str: [f"{key}{rules.VARIABLE_NAME_SEPARATOR}{value}" for key, value in compression_dictionary.items()]) -def parse_full_specification(spec: str) -> Dict[str, Encoding]: +def parse_full_specification(spec: Union[str, None]) -> Dict[str, Encoding]: from enstools.encoding.rules import VARIABLE_SEPARATOR, VARIABLE_NAME_SEPARATOR, \ DATA_DEFAULT_LABEL, DATA_DEFAULT_VALUE, COORD_LABEL, COORD_DEFAULT_VALUE result = {} @@ -58,8 +103,8 @@ def parse_full_specification(spec: str) -> Dict[str, Encoding]: result[COORD_LABEL] = parse_variable_specification(COORD_DEFAULT_VALUE) # For each specification, check that the specifications are valid. - for key, spec in result.items(): - spec.check_validity() + for _, _spec in result.items(): + _spec.check_validity() return result @@ -70,7 +115,7 @@ class DatasetEncoding(_Mapping): The kind of encoding that xarray expects is a mapping between the variables and their corresponding h5py encoding. """ - def __init__(self, dataset: xarray.Dataset, compression: Union[str, dict, None]): + def __init__(self, dataset: xarray.Dataset, compression: Union[str, Dict[str, str], Path, None]): self.dataset = dataset # Process the compression argument to get a single string with per-variable specifications @@ -80,7 +125,7 @@ class DatasetEncoding(_Mapping): self.variable_encodings = parse_full_specification(compression) @staticmethod - def get_a_single_compression_string(compression: Union[str, dict, Path]) -> str: + def get_a_single_compression_string(compression: Union[str, Dict[str, str], Path, None]) -> Union[str, None]: # The compression parameter can be a string or a dictionary. # In case it is a string, it can be directly a compression specification or a yaml file. @@ -120,18 +165,35 @@ class DatasetEncoding(_Mapping): var in self.dataset.data_vars} - # Add chunking? - for variable in self.dataset.data_vars: - chunks = {k: v if k != "time" else 1 for k, v in self.dataset[variable].sizes.items()} - chunk_sizes = tuple(chunks.values()) - # Ugly python magic to add chunk sizes into the encoding mapping object. - data_variable_encodings[variable]._kwargs._kwargs["chunksizes"] = chunk_sizes # noqa - # Merge all_encodings = {**coordinate_encodings, **data_variable_encodings} + # Need to specify chunk size, otherwise it breaks down. + self.chunk(encodings=all_encodings) + return all_encodings + def chunk(self, encodings: Dict[Union[Hashable, str], Encoding], chunk_memory_size="10MB"): + """ + Add a variable "chunksizes" to each variable encoding with the corresponding + + Args: + encodings (dict): Dictionary with the corresponding encoding for each variable. + chunk_memory_size (str): Desired chunk size in memory. + """ + + chunk_memory_size = convert_to_bytes(chunk_memory_size) + + # Loop over all the variables + for variable in self.dataset.data_vars: + da = self.dataset[variable] + type_size = da.dtype.itemsize + + optimal_chunk_size = chunk_memory_size / type_size + chunk_sizes = find_chunk_sizes(data_array=da, chunk_size=optimal_chunk_size) + chunk_sizes = tuple(chunk_sizes[d] for d in da.dims) + encodings[variable].set_chunk_sizes(chunk_sizes) + @property def _kwargs(self): return self.encoding() @@ -153,3 +215,21 @@ def is_a_valid_dataset_compression_specification(specification): return True except InvalidCompressionSpecification: return False + + +def find_chunk_sizes(data_array, chunk_size): + import math + total_points = np.prod(data_array.shape) + num_chunks = max(1, int(total_points // chunk_size)) + chunk_sizes = {} + chunk_number = {} + + # Sort dimensions by size + dims = sorted(data_array.dims, key=lambda x: data_array[x].shape) + pending_num_chunks = num_chunks + for dim in dims: + chunk_sizes[dim] = max(1, int(data_array[dim].size // pending_num_chunks)) + chunk_number[dim] = data_array[dim].size // chunk_sizes[dim] + + pending_num_chunks = math.ceil(pending_num_chunks / chunk_number[dim]) + return chunk_sizes diff --git a/enstools/encoding/variable_encoding.py b/enstools/encoding/variable_encoding.py index ec53fbf..51a1fc6 100644 --- a/enstools/encoding/variable_encoding.py +++ b/enstools/encoding/variable_encoding.py @@ -22,7 +22,9 @@ class _Mapping(Mapping): """ Subclass to implement dunder methods that are mandatory for Mapping to avoid repeating the code everywhere. """ - _kwargs: Mapping + def __init__(self) -> None: + super().__init__() + self._kwargs = {} def __getitem__(self, item): return self._kwargs[item] @@ -52,6 +54,19 @@ class Encoding(_Mapping): def __repr__(self): return f"{self.__class__.__name__}({self.to_string()})" + + def set_chunk_sizes(self, chunk_sizes: tuple) -> None: + """ + Method to add chunksizes into the encoding dictionary. + Parameters + ---------- + chunk_sizes + + Returns + ------- + + """ + self._kwargs["chunksizes"] = chunk_sizes class VariableEncoding(_Mapping): @@ -115,12 +130,14 @@ class NullEncoding(Encoding): class LosslessEncoding(Encoding): def __init__(self, backend: str, compression_level: int): + super().__init__() self.backend = backend if backend is not None else rules.LOSSLESS_DEFAULT_BACKEND self.compression_level = compression_level if compression_level is not None \ else rules.LOSSLESS_DEFAULT_COMPRESSION_LEVEL self.check_validity() - self._kwargs = self.encoding() + # Trying to convert it to a dictionary already here. + self._kwargs = dict(self.encoding()) def check_validity(self) -> bool: if self.backend not in definitions.lossless_backends: @@ -146,13 +163,16 @@ class LosslessEncoding(Encoding): class LossyEncoding(Encoding): def __init__(self, compressor: str, mode: str, parameter: Union[float, int]): + super().__init__() self.compressor = compressor self.mode = mode self.parameter = parameter self.check_validity() - self._kwargs = self.encoding() + + # Trying to convert it to a dictionary already here. + self._kwargs = dict(self.encoding()) def check_validity(self): # Check compressor validity -- GitLab