Skip to content
Snippets Groups Projects
Commit c524a4e7 authored by Oriol.Tinto's avatar Oriol.Tinto
Browse files

Merge branch 'dev-constant-values' into 'main'

Fix issues with the analysis of arrays with constant values

See merge request !19
parents e9593d36 28fb8557
No related branches found
No related tags found
1 merge request!19Fix issues with the analysis of arrays with constant values
Pipeline #20337 passed
2023.6 2023.11
...@@ -19,13 +19,14 @@ import xarray ...@@ -19,13 +19,14 @@ import xarray
import enstools.encoding.chunk_size import enstools.encoding.chunk_size
from enstools.compression.emulators import DefaultEmulator from enstools.compression.emulators import DefaultEmulator
from enstools.compression.errors import ConditionsNotFulfilledError from enstools.compression.errors import ConditionsNotFulfilledError, ConstantValues
from enstools.compression.slicing import MultiDimensionalSliceCollection from enstools.compression.slicing import MultiDimensionalSliceCollection
from enstools.encoding.api import VariableEncoding from enstools.encoding.api import VariableEncoding
from enstools.encoding.dataset_encoding import find_chunk_sizes, convert_to_bytes from enstools.encoding.dataset_encoding import find_chunk_sizes, convert_to_bytes
from enstools.encoding.rules import COMPRESSION_SPECIFICATION_SEPARATOR from enstools.encoding.rules import COMPRESSION_SPECIFICATION_SEPARATOR
from .analysis_options import AnalysisOptions from .analysis_options import AnalysisOptions
from .analyzer_utils import get_metrics, get_parameter_range, bisection_method from .analyzer_utils import get_metrics, get_parameter_range, bisection_method
from enstools.compression.emulation import emulate_compression_on_data_array
# These metrics will be used to select within the different encodings when aiming at a certain compression ratio. # These metrics will be used to select within the different encodings when aiming at a certain compression ratio.
ANALYSIS_DIAGNOSTIC_METRICS = ["correlation_I", "ssim_I"] ANALYSIS_DIAGNOSTIC_METRICS = ["correlation_I", "ssim_I"]
...@@ -45,7 +46,7 @@ def find_direct_relation(parameter_range, function_to_nullify): ...@@ -45,7 +46,7 @@ def find_direct_relation(parameter_range, function_to_nullify):
return eval_last_percentile > eval_first_percentile return eval_last_percentile > eval_first_percentile
def get_one_slice(data_array: xarray.DataArray, chunk_size: str = "100KB"): def get_one_slice(data_array: xarray.DataArray, chunk_size: str = "100KB"):
chunk_memory_size = convert_to_bytes(chunk_size) chunk_memory_size = convert_to_bytes(chunk_size)
chunk_sizes = find_chunk_sizes(data_array, chunk_memory_size) chunk_sizes = find_chunk_sizes(data_array, chunk_memory_size)
chunk_sizes = [chunk_sizes[dim] for dim in data_array.dims] chunk_sizes = [chunk_sizes[dim] for dim in data_array.dims]
...@@ -53,8 +54,16 @@ def get_one_slice(data_array: xarray.DataArray, chunk_size: str = "100KB"): ...@@ -53,8 +54,16 @@ def get_one_slice(data_array: xarray.DataArray, chunk_size: str = "100KB"):
big_chunk_size = max(set([s.size for s in multi_dimensional_slice.objects.ravel()])) big_chunk_size = max(set([s.size for s in multi_dimensional_slice.objects.ravel()]))
big_chunks = [s for s in multi_dimensional_slice.objects.ravel() if s.size == big_chunk_size] big_chunks = [s for s in multi_dimensional_slice.objects.ravel() if s.size == big_chunk_size]
return {dim: size for dim, size in zip(data_array.dims, big_chunks[0].slices)} for chunk_index in range(len(big_chunks)):
slices = {dim: size for dim, size in zip(data_array.dims, big_chunks[chunk_index].slices)}
data_array_slice = data_array.isel(**slices)
# Check if the range of the slice is greater than 0
if data_array_slice.size > 0 and np.ptp(data_array_slice.values) > 0:
return data_array_slice
# If all slices have a range of 0, raise an exception
raise ConstantValues("All slices have constant values or are empty.")
def analyze_data_array(data_array: xarray.DataArray, options: AnalysisOptions) -> Tuple[str, dict]: def analyze_data_array(data_array: xarray.DataArray, options: AnalysisOptions) -> Tuple[str, dict]:
...@@ -62,9 +71,32 @@ def analyze_data_array(data_array: xarray.DataArray, options: AnalysisOptions) - ...@@ -62,9 +71,32 @@ def analyze_data_array(data_array: xarray.DataArray, options: AnalysisOptions) -
Find the compression specification corresponding to a certain data array and a given set of compression options. Find the compression specification corresponding to a certain data array and a given set of compression options.
""" """
slices = get_one_slice(data_array, try:
chunk_size=enstools.encoding.chunk_size.analysis_chunk_size) data_array = get_one_slice(data_array,
data_array = data_array.isel(**slices) chunk_size=enstools.encoding.chunk_size.analysis_chunk_size,
)
except ConstantValues:
# Issue a warning that all values in the data array are constant
warning_message = f"All values in the variable {data_array.name} are constant."
warnings.warn(warning_message)
# In case all values are constant, return lossless.
# First let's find out the compression ratio
_, metrics = emulate_compression_on_data_array(data_array,
compression_specification=VariableEncoding("lossless"),
in_place=False)
return "lossless", metrics
# Compute the range of the data values in the slice
data_range = np.ptp(data_array.values) # ptp (peak-to-peak) calculates the range
# Check if the range is zero
if data_range == 0:
raise ValueError("The range of the data_array slice is zero.")
# Check that the range is not 0
# Check if the array contains any nan # Check if the array contains any nan
contains_nan = np.isnan(data_array.values).any() contains_nan = np.isnan(data_array.values).any()
if contains_nan: if contains_nan:
......
...@@ -115,8 +115,8 @@ def emulate_compression_on_numpy_array(data: numpy.ndarray, compression_specific ...@@ -115,8 +115,8 @@ def emulate_compression_on_numpy_array(data: numpy.ndarray, compression_specific
""" """
if isinstance(compression_specification, (LosslessEncoding, NullEncoding)): if isinstance(compression_specification, NullEncoding):
return data, {} return data, {"compression_ratio": 1}
emulator_backend = DefaultEmulator emulator_backend = DefaultEmulator
......
from enstools.core.errors import EnstoolsError from enstools.core.errors import EnstoolsError
class ConditionsNotFulfilledError(EnstoolsError): class ConditionsNotFulfilledError(EnstoolsError):
... ...
\ No newline at end of file
class ConstantValues(Exception):
pass
enstools>=2023.1 enstools>=2023.11
enstools-encoding>=2023.6 enstools-encoding>=2023.6
zfpy zfpy
hdf5plugin>=4.0.0 hdf5plugin>=4.0.0
......
...@@ -15,6 +15,32 @@ class TestAnalyzer(TestClass): ...@@ -15,6 +15,32 @@ class TestAnalyzer(TestClass):
input_path = input_tempdir / ds input_path = input_tempdir / ds
analyze_files(file_paths=[input_path]) analyze_files(file_paths=[input_path])
def test_analyzer_constant_array(self):
import enstools.compression.xr_accessor # noqa
import numpy as np
import xarray as xr
shape = (100, 100, 100)
data = np.zeros(shape)
data_array = xr.DataArray(data)
# Expect a warning about constant values
with pytest.warns(UserWarning, match="All values in the variable .* are constant."):
specs, metrics = data_array.compression.analyze()
data_array.compression(specs)
def test_analyzer_without_lat_lon(self):
import enstools.compression.xr_accessor # noqa
import numpy as np
import xarray as xr
shape = (100, 100, 100)
data = np.random.random(size=shape)
data_array = xr.DataArray(data)
specs, metrics = data_array.compression.analyze()
data_array.compression(specs)
def test_zfp_analyzer(self): def test_zfp_analyzer(self):
from enstools.compression.api import analyze_files from enstools.compression.api import analyze_files
input_tempdir = self.input_directory_path input_tempdir = self.input_directory_path
...@@ -60,8 +86,9 @@ class TestAnalyzer(TestClass): ...@@ -60,8 +86,9 @@ class TestAnalyzer(TestClass):
for var in metrics: for var in metrics:
if abs(metrics[var][cr_label] - thresholds[cr_label]) > TOLERANCE: if abs(metrics[var][cr_label] - thresholds[cr_label]) > TOLERANCE:
raise AssertionError(f"Case:{input_path.name}.The resulting compression ratio of {metrics[var][cr_label]:.2f}" raise AssertionError(
f"x is not close enough to the target of {thresholds[cr_label]:.2f}") f"Case:{input_path.name}.The resulting compression ratio of {metrics[var][cr_label]:.2f}"
f"x is not close enough to the target of {thresholds[cr_label]:.2f}")
def test_sz_analyzer(self): def test_sz_analyzer(self):
from enstools.compression.api import analyze_files from enstools.compression.api import analyze_files
...@@ -85,6 +112,7 @@ class TestAnalyzer(TestClass): ...@@ -85,6 +112,7 @@ class TestAnalyzer(TestClass):
compressor="zfp", compressor="zfp",
mode="rate", mode="rate",
) )
def test_rmse(self): def test_rmse(self):
from enstools.compression.api import analyze_files from enstools.compression.api import analyze_files
input_tempdir = self.input_directory_path input_tempdir = self.input_directory_path
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment