From fe1bc00de309ea7abb2fdce7aaea2324deabf7af Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Oriol=20Tint=C3=B3?= <oriol.tinto@lmu.de>
Date: Fri, 2 Jun 2023 15:58:41 +0200
Subject: [PATCH] Fix function find_direct_relation. Use a slice based on a
 chunk size instead of selecting a single time-step. Check that the constrains
 are fulfilled after a binary search.

---
 .../analyzer/analyze_data_array.py            | 42 ++++++++++++-------
 1 file changed, 27 insertions(+), 15 deletions(-)

diff --git a/enstools/compression/analyzer/analyze_data_array.py b/enstools/compression/analyzer/analyze_data_array.py
index 0ada7d5..308544b 100644
--- a/enstools/compression/analyzer/analyze_data_array.py
+++ b/enstools/compression/analyzer/analyze_data_array.py
@@ -17,8 +17,12 @@ from typing import Tuple, Callable
 import numpy as np
 import xarray
 
+import enstools.encoding.chunk_size
 from enstools.compression.emulators import DefaultEmulator
+from enstools.compression.errors import ConditionsNotFulfilledError
+from enstools.compression.slicing import MultiDimensionalSliceCollection
 from enstools.encoding.api import VariableEncoding
+from enstools.encoding.dataset_encoding import find_chunk_sizes, convert_to_bytes
 from enstools.encoding.rules import COMPRESSION_SPECIFICATION_SEPARATOR
 from .analysis_options import AnalysisOptions
 from .analyzer_utils import get_metrics, get_parameter_range, bisection_method
@@ -33,23 +37,34 @@ COUNTER = 0
 def find_direct_relation(parameter_range, function_to_nullify):
     """Return whether the nullified function has a direct relation between the parameter and the nullified value."""
     min_val, max_val = parameter_range
-    first_q = min_val + (max_val - min_val) / 10
-    third_q = min_val + 9 * (max_val - min_val) / 10
+    first_percentile = min_val + (max_val - min_val) / 100
+    last_percentile = min_val + 99 * (max_val - min_val) / 100
+
+    eval_first_percentile = function_to_nullify(first_percentile)
+    eval_last_percentile = function_to_nullify(last_percentile)
+    return eval_last_percentile > eval_first_percentile
+
+
+def  get_one_slice(data_array: xarray.DataArray, chunk_size: str = "100KB"):
+    chunk_memory_size = convert_to_bytes(chunk_size)
+    chunk_sizes = find_chunk_sizes(data_array, chunk_memory_size)
+    chunk_sizes = [chunk_sizes[dim] for dim in data_array.dims]
+    multi_dimensional_slice = MultiDimensionalSliceCollection(shape=data_array.shape, chunk_sizes=chunk_sizes)
+    big_chunk_size = max(set([s.size for s in multi_dimensional_slice.objects.ravel()]))
+    big_chunks = [s for s in multi_dimensional_slice.objects.ravel() if s.size == big_chunk_size]
+
+    return {dim: size for dim, size in zip(data_array.dims, big_chunks[0].slices)}
 
-    eval_first_q = function_to_nullify(first_q)
-    eval_third_q = function_to_nullify(third_q)
-    return eval_third_q > eval_first_q
 
 
 def analyze_data_array(data_array: xarray.DataArray, options: AnalysisOptions) -> Tuple[str, dict]:
     """
     Find the compression specification corresponding to a certain data array and a given set of compression options.
     """
-    # In case there is a time dimension, select the last element.
-    # There are accumulative variables (like total precipitation) which have mostly 0 on the first time step.
-    # Using the last time-step can represent an advantage somehow.
-    if "time" in data_array.dims:
-        data_array = data_array.isel(time=-1)
+
+    slices = get_one_slice(data_array,
+                           chunk_size=enstools.encoding.chunk_size.analysis_chunk_size)
+    data_array = data_array.isel(**slices)
     # Check if the array contains any nan
     contains_nan = np.isnan(data_array.values).any()
     if contains_nan:
@@ -64,11 +79,6 @@ def analyze_data_array(data_array: xarray.DataArray, options: AnalysisOptions) -
     # Define parameter range
     parameter_range = get_parameter_range(data_array, options)
 
-    # If the aim is a specific compression ratio, the parameter range needs to be reversed
-    # because the relation between compression ratio and quality is inverse.
-    # if COMPRESSION_RATIO_LABEL in options.thresholds:
-    #     parameter_range = tuple(reversed(parameter_range))
-
     #  Ignore warnings
     with warnings.catch_warnings():
         warnings.simplefilter("ignore")
@@ -80,6 +90,8 @@ def analyze_data_array(data_array: xarray.DataArray, options: AnalysisOptions) -
             fun=function_to_nullify,
             direct_relation=direct_relation)
 
+        if not constrain(parameter):
+            raise ConditionsNotFulfilledError("Condition not fulfilled!")
         # Compute metrics
         # When aiming for a compression ratio some other metrics need to be provided too.
         if COMPRESSION_RATIO_LABEL not in options.thresholds:
-- 
GitLab