From ffdd0c841e6c945194e6d0452262ad2ebbf4d59d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Oriol=20Tint=C3=B3?= <oriol.tinto@lmu.de> Date: Thu, 22 Jun 2023 15:40:39 +0200 Subject: [PATCH] Update streamlit example. --- .../python_scripts/compress_dummy_dataset.py | 36 ++++++++ ...ression_section.py => advanced_section.py} | 68 +-------------- examples/streamlit/component/basic_section.py | 73 ++++++++++++++++ examples/streamlit/component/data_source.py | 2 + examples/streamlit/component/plotter.py | 83 +++++++++++++++++++ examples/streamlit/playground.py | 35 ++++---- 6 files changed, 216 insertions(+), 81 deletions(-) create mode 100644 examples/python_scripts/compress_dummy_dataset.py rename examples/streamlit/component/{compression_section.py => advanced_section.py} (52%) create mode 100644 examples/streamlit/component/basic_section.py create mode 100644 examples/streamlit/component/plotter.py diff --git a/examples/python_scripts/compress_dummy_dataset.py b/examples/python_scripts/compress_dummy_dataset.py new file mode 100644 index 0000000..402aca3 --- /dev/null +++ b/examples/python_scripts/compress_dummy_dataset.py @@ -0,0 +1,36 @@ +import xarray +from enstools.compression.analyzer.analyzer import analyze_dataset +import enstools.compression.xr_accessor # noqa + +dataset_names = [ + "air_temperature", + # "air_temperature_gradient", + # "basin_mask", + # "rasm", + # "ROMS_example", + # "tiny", + # "era5-2mt-2019-03-uk.grib", + # "eraint_uvz", + # "ersstv5" +] + + +def main(): + results = {} + failed_datasets = [] + for dataset_name in dataset_names: + try: + with xarray.tutorial.open_dataset(dataset_name) as dataset: + encoding, metrics = analyze_dataset(dataset=dataset) + results[dataset_name] = (encoding, metrics) + dataset.to_netcdf(f"reference_{dataset_name}.nc") + dataset.to_compressed_netcdf(f"compressed_{dataset_name}.nc", compression=encoding) + except ValueError: + failed_datasets.append(dataset_name) + + print(results) + print(failed_datasets) + + +if __name__ == "__main__": + main() diff --git a/examples/streamlit/component/compression_section.py b/examples/streamlit/component/advanced_section.py similarity index 52% rename from examples/streamlit/component/compression_section.py rename to examples/streamlit/component/advanced_section.py index 1114e71..ba0fe5d 100644 --- a/examples/streamlit/component/compression_section.py +++ b/examples/streamlit/component/advanced_section.py @@ -23,12 +23,12 @@ default_parameters = { } -def compression_section(data, slice_selection): +def advanced_section(data, slice_selection): if data.dataset is not None: # st.markdown("# Compression") with st.expander("Compression Specifications"): - specification_mode = st.radio(label="", options=["String", "Options"], horizontal=True) + specification_mode = st.radio(label="", options=["Options", "String"], horizontal=True) # specification, options = st.tabs(["String", "Options"]) if specification_mode == "String": compression_specification = st.text_input(label="Compression", value="lossy,sz,abs,1") @@ -60,66 +60,4 @@ def compression_section(data, slice_selection): st.markdown( "Check [the compression specification format](https://enstools-encoding.readthedocs.io/en/latest/CompressionSpecificationFormat.html)") if data.compressed_da is not None: - st.markdown(f"**Compression Ratio**: {data.compressed_da.attrs['compression_ratio']}") - - -def plot_compressed(data, slice_selection): - col1, col2, *others = st.columns(2) - - new_slice = {} - - for key, values in slice_selection.items(): - if isinstance(values, tuple): - start, stop = values - if start != stop: - new_slice[key] = slice(start, stop) - else: - new_slice[key] = start - else: - new_slice[key] = values - - slice_selection = new_slice - - if data.reference_da is not None: - print(f"{slice_selection=}") - slice_selection = {key: (value if key != "lat" else slice(value.stop, value.start)) for key, value in - slice_selection.items()} - - only_slices = {key: value for key, value in slice_selection.items() if isinstance(value, slice)} - non_slices = {key: value for key, value in slice_selection.items() if not isinstance(value, slice)} - - if only_slices: - reference_slice = data.reference_da.sel(**only_slices) - else: - reference_slice = data.reference_da - - if non_slices: - reference_slice = reference_slice.sel(**non_slices, method="nearest") - print(reference_slice) - try: - reference_slice.plot() - fig1 = plt.gcf() - with col1: - st.pyplot(fig1) - except TypeError: - pass - - if data.compressed_da is not None: - plt.figure() - if only_slices: - compressed_slice = data.compressed_da.sel(**only_slices) - else: - compressed_slice = data.compressed_da - if non_slices: - compressed_slice = compressed_slice.sel(**non_slices, method="nearest") - - try: - compressed_slice.plot() - fig2 = plt.gcf() - with col2: - st.pyplot(fig2) - except TypeError: - pass - - else: - st.text("Compress the data to show the plot!") + st.markdown(f"**Compression Ratio**: {data.compressed_da.attrs['compression_ratio']}") \ No newline at end of file diff --git a/examples/streamlit/component/basic_section.py b/examples/streamlit/component/basic_section.py new file mode 100644 index 0000000..c0a44e3 --- /dev/null +++ b/examples/streamlit/component/basic_section.py @@ -0,0 +1,73 @@ +import matplotlib.pyplot as plt +import numpy as np +import streamlit as st +import xarray + +import enstools.compression.xr_accessor # noqa +from .data_source import DataContainer + + +def get_compression_ratio(data_array: xarray.DataArray, relative_tolerance: float, mode: str) -> float: + what = data_array.compression(f"lossy,sz,{mode},{relative_tolerance}", in_place=False) + return float(what.attrs["compression_ratio"]) + + +def invert_function(function): + # Define its derivative + f_prime = function.deriv() + + # Define the function for which we want to find the root + def func(x, y_val): + return function(x) - y_val + + def newtons_method(y_val, epsilon=1e-7, max_iterations=100): + x = -2 # np.log10(0.01) + print(f"{y_val=}") + for _ in range(max_iterations): + x_new = x - func(x, y_val) / f_prime(x) + if abs(x - x_new) < epsilon: + return x_new + x = x_new + print(x_new) + return None + + return newtons_method + + +def create_parameter_from_compression_ratio(data: DataContainer, mode: str): + train_x = np.logspace(-12, -.5, 15) + train_y = [get_compression_ratio(data.reference_da, parameter, mode=mode) for parameter in train_x] + + parameter_range = min(train_y), min(100., max(train_y)) + + x_log = np.log10(train_x) + y_log = np.log10(train_y) + + coeff = np.polyfit(x_log, y_log, 10) + + # Create a polynomial function from the coefficients + f = np.poly1d(coeff) + + f_inverse = invert_function(f) + + def function_to_return(compression_ratio: float) -> float: + return 10 ** f_inverse(np.log10(compression_ratio)) + + return parameter_range, function_to_return + + +def basic_section(data: DataContainer, slice_selection): + mode = st.selectbox(label="Mode", options=["rel", "pw_rel"]) + parameter_range, get_parameter = create_parameter_from_compression_ratio(data, mode=mode) + + _min, _max = parameter_range + options = [_min + (_max - _min) * _x for _x in np.logspace(-2, 0)] + options = [f"{op:.2f}" for op in options] + + compression_ratio = st.select_slider(label="Compression Ratio", options=options) + compression_ratio = float(compression_ratio) + + parameter = get_parameter(compression_ratio) + + with st.spinner(): + data.compress(f"lossy,sz,rel,{parameter}") diff --git a/examples/streamlit/component/data_source.py b/examples/streamlit/component/data_source.py index 1d52908..279cbe1 100644 --- a/examples/streamlit/component/data_source.py +++ b/examples/streamlit/component/data_source.py @@ -38,6 +38,8 @@ class DataContainer: def compress(self, compression): self.compressed_da = self.reference_da.compression(compression) + def __hash__(self): + return hash(self.reference_da.name) @st.cache_resource def create_data(): diff --git a/examples/streamlit/component/plotter.py b/examples/streamlit/component/plotter.py new file mode 100644 index 0000000..5e9ae39 --- /dev/null +++ b/examples/streamlit/component/plotter.py @@ -0,0 +1,83 @@ +import streamlit as st +import matplotlib.pyplot as plt + +def plot_comparison(data, slice_selection): + col1, col2, col3, *others = st.columns(3) + + new_slice = {} + + for key, values in slice_selection.items(): + if isinstance(values, tuple): + start, stop = values + if start != stop: + new_slice[key] = slice(start, stop) + else: + new_slice[key] = start + else: + new_slice[key] = values + + slice_selection = new_slice + + if data.reference_da is not None: + print(f"{slice_selection=}") + slice_selection = {key: (value if key != "lat" else slice(value.stop, value.start)) for key, value in + slice_selection.items()} + + only_slices = {key: value for key, value in slice_selection.items() if isinstance(value, slice)} + non_slices = {key: value for key, value in slice_selection.items() if not isinstance(value, slice)} + + if only_slices: + reference_slice = data.reference_da.sel(**only_slices) + else: + reference_slice = data.reference_da + + if non_slices: + reference_slice = reference_slice.sel(**non_slices, method="nearest") + print(reference_slice) + try: + reference_slice.plot() + fig1 = plt.gcf() + with col1: + st.pyplot(fig1) + except TypeError: + pass + + if data.compressed_da is not None: + plt.figure() + if only_slices: + compressed_slice = data.compressed_da.sel(**only_slices) + else: + compressed_slice = data.compressed_da + if non_slices: + compressed_slice = compressed_slice.sel(**non_slices, method="nearest") + + try: + compressed_slice.plot() + fig2 = plt.gcf() + with col2: + st.pyplot(fig2) + except TypeError: + pass + + diff = data.compressed_da - data.reference_da + plt.figure() + if only_slices: + diff_slice = diff.sel(**only_slices) + else: + diff_slice = data.compressed_da + if non_slices: + diff_slice = diff_slice.sel(**non_slices, method="nearest") + + try: + diff_slice.plot() + fig3 = plt.gcf() + with col3: + st.pyplot(fig3) + except TypeError: + pass + + + + + else: + st.text("Compress the data to show the plot!") diff --git a/examples/streamlit/playground.py b/examples/streamlit/playground.py index 8106016..8603b62 100644 --- a/examples/streamlit/playground.py +++ b/examples/streamlit/playground.py @@ -1,41 +1,44 @@ - import streamlit as st from component.data_source import create_data, select_dataset, select_slice -from component.compression_section import compression_section, plot_compressed +from component.basic_section import basic_section +from component.advanced_section import advanced_section from component.analysis_section import analysis_section +from component.plotter import plot_comparison - -st.set_page_config(layout="wide") +st.set_page_config(layout="wide", initial_sidebar_state="collapsed") data = create_data() -def sidebar(): - ... - - def setup_main_frame(): - st.title("Lossy Compression playground!") - # with st.expander("Data Selection"): + st.title("Welcome to the :green[enstools-compression] playground!") with st.sidebar: select_dataset(data) slice_selection = select_slice(data) st.markdown("---") - options = ["Compression", "Analysis"] - compression, analysis = st.tabs(options) + options = ["Compression", "Advanced Compression", "Analysis"] + basic, advanced, analysis = st.tabs(options) + + with basic: + basic_section(data=data, slice_selection=slice_selection) + with st.spinner(): + try: + plot_comparison(data=data, slice_selection=slice_selection) + except TypeError as err: + st.warning(err) - with compression: - compression_section(data=data, slice_selection=slice_selection) + with advanced: + advanced_section(data=data, slice_selection=slice_selection) with st.spinner(): try: - plot_compressed(data=data, slice_selection=slice_selection) + plot_comparison(data=data, slice_selection=slice_selection) except TypeError as err: st.warning(err) with analysis: analysis_section(data=data, slice_selection=slice_selection) -sidebar() + setup_main_frame() -- GitLab