Skip to content
Snippets Groups Projects
Commit 57d0d4a2 authored by Oriol Tintó's avatar Oriol Tintó
Browse files

Fix dataset chunking.

parent ca70d019
No related branches found
No related tags found
3 merge requests!10Code cleaning, better documentation and updated CI.,!9Fix CI publishing.,!8Modifying CI to make automatic releases.
Pipeline #17857 passed
import chunk
import os
from copy import deepcopy
from pathlib import Path
from typing import Union, Dict
from typing import Hashable, Union, Dict
import xarray
import yaml
import numpy as np
from . import rules
from .errors import InvalidCompressionSpecification
......@@ -12,7 +14,50 @@ from .variable_encoding import _Mapping, parse_variable_specification, Encoding,
NullEncoding
def compression_dictionary_to_string(compression_dictionary: dict) -> str:
def convert_size(size_bytes):
import math
if size_bytes < 0:
prefix = "-"
size_bytes = -size_bytes
else:
prefix = ""
if size_bytes == 0:
return "0B"
size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
i = int(math.floor(math.log(size_bytes, 1024)))
p = math.pow(1024, i)
s = round(size_bytes / p, 2)
return f"{prefix}{s}{size_name[i]}"
def convert_to_bytes(size_string):
"""
This function converts a given size string (e.g. '5MB') to the number of bytes.
Args:
size_string (str): The size string to be converted (e.g. '5MB')
Returns:
int: The number of bytes.
"""
import re
size_string = size_string.upper()
digits = re.match(r'\d+(?:\.\d+)?', size_string) # matches digits and optionally a dot followed by more digits
if digits:
digits = digits.group() # get the matched digits
else:
raise ValueError(f"Invalid size string: {size_string}")
unit = size_string.replace(digits, "")
size_name_dict = {'B': 0, 'KB': 1, 'MB': 2, 'GB': 3, 'TB': 4, 'PB': 5, 'EB': 6, 'ZB': 7, 'YB': 8}
if unit in size_name_dict:
size_bytes = float(digits) * np.power(1024, size_name_dict[unit])
else:
raise ValueError(f"Invalid size string: {size_string}")
return int(size_bytes)
def compression_dictionary_to_string(compression_dictionary: Dict[str, str]) -> str:
"""
Convert a dictionary containing multiple entries to a single line specification
"""
......@@ -20,7 +65,7 @@ def compression_dictionary_to_string(compression_dictionary: dict) -> str:
[f"{key}{rules.VARIABLE_NAME_SEPARATOR}{value}" for key, value in compression_dictionary.items()])
def parse_full_specification(spec: str) -> Dict[str, Encoding]:
def parse_full_specification(spec: Union[str, None]) -> Dict[str, Encoding]:
from enstools.encoding.rules import VARIABLE_SEPARATOR, VARIABLE_NAME_SEPARATOR, \
DATA_DEFAULT_LABEL, DATA_DEFAULT_VALUE, COORD_LABEL, COORD_DEFAULT_VALUE
result = {}
......@@ -58,8 +103,8 @@ def parse_full_specification(spec: str) -> Dict[str, Encoding]:
result[COORD_LABEL] = parse_variable_specification(COORD_DEFAULT_VALUE)
# For each specification, check that the specifications are valid.
for key, spec in result.items():
spec.check_validity()
for _, _spec in result.items():
_spec.check_validity()
return result
......@@ -70,7 +115,7 @@ class DatasetEncoding(_Mapping):
The kind of encoding that xarray expects is a mapping between the variables and their corresponding h5py encoding.
"""
def __init__(self, dataset: xarray.Dataset, compression: Union[str, dict, None]):
def __init__(self, dataset: xarray.Dataset, compression: Union[str, Dict[str, str], Path, None]):
self.dataset = dataset
# Process the compression argument to get a single string with per-variable specifications
......@@ -80,7 +125,7 @@ class DatasetEncoding(_Mapping):
self.variable_encodings = parse_full_specification(compression)
@staticmethod
def get_a_single_compression_string(compression: Union[str, dict, Path]) -> str:
def get_a_single_compression_string(compression: Union[str, Dict[str, str], Path, None]) -> Union[str, None]:
# The compression parameter can be a string or a dictionary.
# In case it is a string, it can be directly a compression specification or a yaml file.
......@@ -120,18 +165,35 @@ class DatasetEncoding(_Mapping):
var
in self.dataset.data_vars}
# Add chunking?
for variable in self.dataset.data_vars:
chunks = {k: v if k != "time" else 1 for k, v in self.dataset[variable].sizes.items()}
chunk_sizes = tuple(chunks.values())
# Ugly python magic to add chunk sizes into the encoding mapping object.
data_variable_encodings[variable]._kwargs._kwargs["chunksizes"] = chunk_sizes # noqa
# Merge
all_encodings = {**coordinate_encodings, **data_variable_encodings}
# Need to specify chunk size, otherwise it breaks down.
self.chunk(encodings=all_encodings)
return all_encodings
def chunk(self, encodings: Dict[Union[Hashable, str], Encoding], chunk_memory_size="10MB"):
"""
Add a variable "chunksizes" to each variable encoding with the corresponding
Args:
encodings (dict): Dictionary with the corresponding encoding for each variable.
chunk_memory_size (str): Desired chunk size in memory.
"""
chunk_memory_size = convert_to_bytes(chunk_memory_size)
# Loop over all the variables
for variable in self.dataset.data_vars:
da = self.dataset[variable]
type_size = da.dtype.itemsize
optimal_chunk_size = chunk_memory_size / type_size
chunk_sizes = find_chunk_sizes(data_array=da, chunk_size=optimal_chunk_size)
chunk_sizes = tuple(chunk_sizes[d] for d in da.dims)
encodings[variable].set_chunk_sizes(chunk_sizes)
@property
def _kwargs(self):
return self.encoding()
......@@ -153,3 +215,21 @@ def is_a_valid_dataset_compression_specification(specification):
return True
except InvalidCompressionSpecification:
return False
def find_chunk_sizes(data_array, chunk_size):
import math
total_points = np.prod(data_array.shape)
num_chunks = max(1, int(total_points // chunk_size))
chunk_sizes = {}
chunk_number = {}
# Sort dimensions by size
dims = sorted(data_array.dims, key=lambda x: data_array[x].shape)
pending_num_chunks = num_chunks
for dim in dims:
chunk_sizes[dim] = max(1, int(data_array[dim].size // pending_num_chunks))
chunk_number[dim] = data_array[dim].size // chunk_sizes[dim]
pending_num_chunks = math.ceil(pending_num_chunks / chunk_number[dim])
return chunk_sizes
......@@ -22,7 +22,9 @@ class _Mapping(Mapping):
"""
Subclass to implement dunder methods that are mandatory for Mapping to avoid repeating the code everywhere.
"""
_kwargs: Mapping
def __init__(self) -> None:
super().__init__()
self._kwargs = {}
def __getitem__(self, item):
return self._kwargs[item]
......@@ -52,6 +54,19 @@ class Encoding(_Mapping):
def __repr__(self):
return f"{self.__class__.__name__}({self.to_string()})"
def set_chunk_sizes(self, chunk_sizes: tuple) -> None:
"""
Method to add chunksizes into the encoding dictionary.
Parameters
----------
chunk_sizes
Returns
-------
"""
self._kwargs["chunksizes"] = chunk_sizes
class VariableEncoding(_Mapping):
......@@ -115,12 +130,14 @@ class NullEncoding(Encoding):
class LosslessEncoding(Encoding):
def __init__(self, backend: str, compression_level: int):
super().__init__()
self.backend = backend if backend is not None else rules.LOSSLESS_DEFAULT_BACKEND
self.compression_level = compression_level if compression_level is not None \
else rules.LOSSLESS_DEFAULT_COMPRESSION_LEVEL
self.check_validity()
self._kwargs = self.encoding()
# Trying to convert it to a dictionary already here.
self._kwargs = dict(self.encoding())
def check_validity(self) -> bool:
if self.backend not in definitions.lossless_backends:
......@@ -146,13 +163,16 @@ class LosslessEncoding(Encoding):
class LossyEncoding(Encoding):
def __init__(self, compressor: str, mode: str, parameter: Union[float, int]):
super().__init__()
self.compressor = compressor
self.mode = mode
self.parameter = parameter
self.check_validity()
self._kwargs = self.encoding()
# Trying to convert it to a dictionary already here.
self._kwargs = dict(self.encoding())
def check_validity(self):
# Check compressor validity
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment