From fe1bc00de309ea7abb2fdce7aaea2324deabf7af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Oriol=20Tint=C3=B3?= <oriol.tinto@lmu.de> Date: Fri, 2 Jun 2023 15:58:41 +0200 Subject: [PATCH] Fix function find_direct_relation. Use a slice based on a chunk size instead of selecting a single time-step. Check that the constrains are fulfilled after a binary search. --- .../analyzer/analyze_data_array.py | 42 ++++++++++++------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/enstools/compression/analyzer/analyze_data_array.py b/enstools/compression/analyzer/analyze_data_array.py index 0ada7d5..308544b 100644 --- a/enstools/compression/analyzer/analyze_data_array.py +++ b/enstools/compression/analyzer/analyze_data_array.py @@ -17,8 +17,12 @@ from typing import Tuple, Callable import numpy as np import xarray +import enstools.encoding.chunk_size from enstools.compression.emulators import DefaultEmulator +from enstools.compression.errors import ConditionsNotFulfilledError +from enstools.compression.slicing import MultiDimensionalSliceCollection from enstools.encoding.api import VariableEncoding +from enstools.encoding.dataset_encoding import find_chunk_sizes, convert_to_bytes from enstools.encoding.rules import COMPRESSION_SPECIFICATION_SEPARATOR from .analysis_options import AnalysisOptions from .analyzer_utils import get_metrics, get_parameter_range, bisection_method @@ -33,23 +37,34 @@ COUNTER = 0 def find_direct_relation(parameter_range, function_to_nullify): """Return whether the nullified function has a direct relation between the parameter and the nullified value.""" min_val, max_val = parameter_range - first_q = min_val + (max_val - min_val) / 10 - third_q = min_val + 9 * (max_val - min_val) / 10 + first_percentile = min_val + (max_val - min_val) / 100 + last_percentile = min_val + 99 * (max_val - min_val) / 100 + + eval_first_percentile = function_to_nullify(first_percentile) + eval_last_percentile = function_to_nullify(last_percentile) + return eval_last_percentile > eval_first_percentile + + +def get_one_slice(data_array: xarray.DataArray, chunk_size: str = "100KB"): + chunk_memory_size = convert_to_bytes(chunk_size) + chunk_sizes = find_chunk_sizes(data_array, chunk_memory_size) + chunk_sizes = [chunk_sizes[dim] for dim in data_array.dims] + multi_dimensional_slice = MultiDimensionalSliceCollection(shape=data_array.shape, chunk_sizes=chunk_sizes) + big_chunk_size = max(set([s.size for s in multi_dimensional_slice.objects.ravel()])) + big_chunks = [s for s in multi_dimensional_slice.objects.ravel() if s.size == big_chunk_size] + + return {dim: size for dim, size in zip(data_array.dims, big_chunks[0].slices)} - eval_first_q = function_to_nullify(first_q) - eval_third_q = function_to_nullify(third_q) - return eval_third_q > eval_first_q def analyze_data_array(data_array: xarray.DataArray, options: AnalysisOptions) -> Tuple[str, dict]: """ Find the compression specification corresponding to a certain data array and a given set of compression options. """ - # In case there is a time dimension, select the last element. - # There are accumulative variables (like total precipitation) which have mostly 0 on the first time step. - # Using the last time-step can represent an advantage somehow. - if "time" in data_array.dims: - data_array = data_array.isel(time=-1) + + slices = get_one_slice(data_array, + chunk_size=enstools.encoding.chunk_size.analysis_chunk_size) + data_array = data_array.isel(**slices) # Check if the array contains any nan contains_nan = np.isnan(data_array.values).any() if contains_nan: @@ -64,11 +79,6 @@ def analyze_data_array(data_array: xarray.DataArray, options: AnalysisOptions) - # Define parameter range parameter_range = get_parameter_range(data_array, options) - # If the aim is a specific compression ratio, the parameter range needs to be reversed - # because the relation between compression ratio and quality is inverse. - # if COMPRESSION_RATIO_LABEL in options.thresholds: - # parameter_range = tuple(reversed(parameter_range)) - # Ignore warnings with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -80,6 +90,8 @@ def analyze_data_array(data_array: xarray.DataArray, options: AnalysisOptions) - fun=function_to_nullify, direct_relation=direct_relation) + if not constrain(parameter): + raise ConditionsNotFulfilledError("Condition not fulfilled!") # Compute metrics # When aiming for a compression ratio some other metrics need to be provided too. if COMPRESSION_RATIO_LABEL not in options.thresholds: -- GitLab