From e19d652b81b4851bae4d88a2d9403cc5d15f32f6 Mon Sep 17 00:00:00 2001 From: "oriol.tinto" <oriol.tinto@lmu.de> Date: Thu, 16 Nov 2023 15:05:28 +0100 Subject: [PATCH] analyze_data_array now checks if the slices are constant. --- .../analyzer/analyze_data_array.py | 44 ++++++++++++++++--- 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/enstools/compression/analyzer/analyze_data_array.py b/enstools/compression/analyzer/analyze_data_array.py index 308544b..dcf14b9 100644 --- a/enstools/compression/analyzer/analyze_data_array.py +++ b/enstools/compression/analyzer/analyze_data_array.py @@ -19,13 +19,14 @@ import xarray import enstools.encoding.chunk_size from enstools.compression.emulators import DefaultEmulator -from enstools.compression.errors import ConditionsNotFulfilledError +from enstools.compression.errors import ConditionsNotFulfilledError, ConstantValues from enstools.compression.slicing import MultiDimensionalSliceCollection from enstools.encoding.api import VariableEncoding from enstools.encoding.dataset_encoding import find_chunk_sizes, convert_to_bytes from enstools.encoding.rules import COMPRESSION_SPECIFICATION_SEPARATOR from .analysis_options import AnalysisOptions from .analyzer_utils import get_metrics, get_parameter_range, bisection_method +from enstools.compression.emulation import emulate_compression_on_data_array # These metrics will be used to select within the different encodings when aiming at a certain compression ratio. ANALYSIS_DIAGNOSTIC_METRICS = ["correlation_I", "ssim_I"] @@ -45,7 +46,7 @@ def find_direct_relation(parameter_range, function_to_nullify): return eval_last_percentile > eval_first_percentile -def get_one_slice(data_array: xarray.DataArray, chunk_size: str = "100KB"): +def get_one_slice(data_array: xarray.DataArray, chunk_size: str = "100KB"): chunk_memory_size = convert_to_bytes(chunk_size) chunk_sizes = find_chunk_sizes(data_array, chunk_memory_size) chunk_sizes = [chunk_sizes[dim] for dim in data_array.dims] @@ -53,8 +54,16 @@ def get_one_slice(data_array: xarray.DataArray, chunk_size: str = "100KB"): big_chunk_size = max(set([s.size for s in multi_dimensional_slice.objects.ravel()])) big_chunks = [s for s in multi_dimensional_slice.objects.ravel() if s.size == big_chunk_size] - return {dim: size for dim, size in zip(data_array.dims, big_chunks[0].slices)} + for chunk_index in range(len(big_chunks)): + slices = {dim: size for dim, size in zip(data_array.dims, big_chunks[chunk_index].slices)} + data_array_slice = data_array.isel(**slices) + # Check if the range of the slice is greater than 0 + if data_array_slice.size > 0 and np.ptp(data_array_slice.values) > 0: + return data_array_slice + + # If all slices have a range of 0, raise an exception + raise ConstantValues("All slices have constant values or are empty.") def analyze_data_array(data_array: xarray.DataArray, options: AnalysisOptions) -> Tuple[str, dict]: @@ -62,9 +71,32 @@ def analyze_data_array(data_array: xarray.DataArray, options: AnalysisOptions) - Find the compression specification corresponding to a certain data array and a given set of compression options. """ - slices = get_one_slice(data_array, - chunk_size=enstools.encoding.chunk_size.analysis_chunk_size) - data_array = data_array.isel(**slices) + try: + data_array = get_one_slice(data_array, + chunk_size=enstools.encoding.chunk_size.analysis_chunk_size, + ) + except ConstantValues: + # Issue a warning that all values in the data array are constant + warning_message = f"All values in the variable {data_array.name} are constant." + warnings.warn(warning_message) + + # In case all values are constant, return lossless. + # First let's find out the compression ratio + _, metrics = emulate_compression_on_data_array(data_array, + compression_specification=VariableEncoding("lossless"), + in_place=False) + + return "lossless", metrics + + # Compute the range of the data values in the slice + data_range = np.ptp(data_array.values) # ptp (peak-to-peak) calculates the range + + # Check if the range is zero + if data_range == 0: + raise ValueError("The range of the data_array slice is zero.") + + # Check that the range is not 0 + # Check if the array contains any nan contains_nan = np.isnan(data_array.values).any() if contains_nan: -- GitLab