From 48973d3e1817aeed507f88b6c3e25866d668eb4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Oriol=20Tint=C3=B3=20Prims?= <oriol.tinto@lmu.de> Date: Fri, 16 Jun 2023 14:48:47 +0200 Subject: [PATCH] Small improvement to streamlit app. --- .../streamlit/component/analysis_section.py | 55 +++++ .../component/compression_section.py | 125 ++++++++++ examples/streamlit/component/data_source.py | 116 +++++++++ examples/streamlit/playground.py | 233 ++---------------- 4 files changed, 320 insertions(+), 209 deletions(-) create mode 100644 examples/streamlit/component/analysis_section.py create mode 100644 examples/streamlit/component/compression_section.py create mode 100644 examples/streamlit/component/data_source.py diff --git a/examples/streamlit/component/analysis_section.py b/examples/streamlit/component/analysis_section.py new file mode 100644 index 0000000..b710b16 --- /dev/null +++ b/examples/streamlit/component/analysis_section.py @@ -0,0 +1,55 @@ +import streamlit as st + +import enstools.compression.xr_accessor # noqa + + +def analysis_section(data, slice_selection): + if data.dataset is not None: + # st.markdown("# Compression") + col1, col2 = st.columns(2) + with col1: + constrains = st.text_input(label="Constraint", value="correlation_I:5,ssim_I:3") + + options = { + "sz": ["abs", "rel", "pw_rel"], + "sz3": ["abs", "rel"], + "zfp": ["accuracy", "rate", "precision"], + } + + all_options = [] + [all_options.extend([f"{compressor}-{mode}" for mode in options[compressor]]) for compressor in options] + + with col2: + cases = st.multiselect(label="Compressor and mode", options=all_options) + + if data.reference_da is None: + return + + if not cases: + return + + st.markdown("# Results:") + n_cols = 4 + cols = st.columns(n_cols) + + all_results = {} + + for idx, case in enumerate(cases): + with cols[idx % n_cols]: + compressor, mode = case.split("-") + encoding, metrics = data.reference_da.compression.analyze( + constrains=constrains, + compressor=compressor, + compression_mode=mode + ) + parameter = encoding.split(",")[-1] + compression_ratio = metrics["compression_ratio"] + all_results[case] = compression_ratio + st.markdown(f"## {compressor},{mode}:\n\n" + f"**Compression Ratio:** {compression_ratio:.2f}x\n\n" + f"**Parameter:** {parameter}\n\n" + f"**Specification String:**") + st.code(encoding) + st.markdown(f"___") + # st.markdown(encoding) + # st.markdown(metrics) diff --git a/examples/streamlit/component/compression_section.py b/examples/streamlit/component/compression_section.py new file mode 100644 index 0000000..1114e71 --- /dev/null +++ b/examples/streamlit/component/compression_section.py @@ -0,0 +1,125 @@ +import matplotlib.pyplot as plt +import streamlit as st + +import enstools.compression.xr_accessor # noqa +from enstools.encoding.errors import InvalidCompressionSpecification + +default_parameters = { + "sz": { + "abs": 1, + "rel": 0.001, + "pw_rel": 0.001, + }, + "sz3": { + "abs": 1, + "rel": 0.001, + }, + "zfp": { + "accuracy": 1, + "rate": 3.2, + "precision": 14, + } + +} + + +def compression_section(data, slice_selection): + if data.dataset is not None: + # st.markdown("# Compression") + + with st.expander("Compression Specifications"): + specification_mode = st.radio(label="", options=["String", "Options"], horizontal=True) + # specification, options = st.tabs(["String", "Options"]) + if specification_mode == "String": + compression_specification = st.text_input(label="Compression", value="lossy,sz,abs,1") + elif specification_mode == "Options": + col1_, col2_, col3_ = st.columns(3) + with col1_: + compressor = st.selectbox(label="Compressor", options=["sz", "sz3", "zfp"]) + if compressor == "sz": + mode_options = ["abs", "rel", "pw_rel"] + elif compressor == "sz3": + mode_options = ["abs", "rel"] + elif compressor == "zfp": + mode_options = ["accuracy", "rate", "precision"] + else: + mode_options = [] + with col2_: + mode = st.selectbox(label="Mode", options=mode_options) + with col3_: + parameter = st.text_input(label="Parameter", value=default_parameters[compressor][mode]) + + compression_specification = f"lossy,{compressor},{mode},{parameter}" + st.markdown(f"**Compression Specification:** {compression_specification}") + + if compression_specification: + try: + data.compress(compression_specification) + except InvalidCompressionSpecification: + st.warning("Invalid compression specification!") + st.markdown( + "Check [the compression specification format](https://enstools-encoding.readthedocs.io/en/latest/CompressionSpecificationFormat.html)") + if data.compressed_da is not None: + st.markdown(f"**Compression Ratio**: {data.compressed_da.attrs['compression_ratio']}") + + +def plot_compressed(data, slice_selection): + col1, col2, *others = st.columns(2) + + new_slice = {} + + for key, values in slice_selection.items(): + if isinstance(values, tuple): + start, stop = values + if start != stop: + new_slice[key] = slice(start, stop) + else: + new_slice[key] = start + else: + new_slice[key] = values + + slice_selection = new_slice + + if data.reference_da is not None: + print(f"{slice_selection=}") + slice_selection = {key: (value if key != "lat" else slice(value.stop, value.start)) for key, value in + slice_selection.items()} + + only_slices = {key: value for key, value in slice_selection.items() if isinstance(value, slice)} + non_slices = {key: value for key, value in slice_selection.items() if not isinstance(value, slice)} + + if only_slices: + reference_slice = data.reference_da.sel(**only_slices) + else: + reference_slice = data.reference_da + + if non_slices: + reference_slice = reference_slice.sel(**non_slices, method="nearest") + print(reference_slice) + try: + reference_slice.plot() + fig1 = plt.gcf() + with col1: + st.pyplot(fig1) + except TypeError: + pass + + if data.compressed_da is not None: + plt.figure() + if only_slices: + compressed_slice = data.compressed_da.sel(**only_slices) + else: + compressed_slice = data.compressed_da + if non_slices: + compressed_slice = compressed_slice.sel(**non_slices, method="nearest") + + try: + compressed_slice.plot() + fig2 = plt.gcf() + with col2: + st.pyplot(fig2) + except TypeError: + pass + + else: + st.text("Compress the data to show the plot!") diff --git a/examples/streamlit/component/data_source.py b/examples/streamlit/component/data_source.py new file mode 100644 index 0000000..1d52908 --- /dev/null +++ b/examples/streamlit/component/data_source.py @@ -0,0 +1,116 @@ +import io +from typing import Optional + +import pandas as pd +import streamlit as st +import xarray as xr + + +class DataContainer: + def __init__(self, dataset: Optional[xr.Dataset] = None): + self.dataset = dataset + self.reference_da = None + self.compressed_da = None + + def set_dataset(self, dataset): + self.dataset = dataset + self.reference_da = None + self.compressed_da = None + + def select_variable(self, variable): + self.reference_da = self.dataset[variable] + + @classmethod + def from_tutorial_data(cls, dataset_name: str = "air_temperature"): + return cls(dataset=xr.tutorial.open_dataset(dataset_name)) + + @property + def time_steps(self): + if self.reference_da is not None: + if "time" in self.reference_da.dims: + try: + return pd.to_datetime(self.reference_da.time.values) + except TypeError: + return self.reference_da.time.values + + print(self.reference_da) + + def compress(self, compression): + self.compressed_da = self.reference_da.compression(compression) + + +@st.cache_resource +def create_data(): + return DataContainer.from_tutorial_data() + + +def select_dataset(data): + st.title("Select Dataset") + data_source = st.radio(label="Data source", options=["Tutorial Dataset", "Custom Dataset"]) + col1, col2 = st.columns(2) + if data_source == "Tutorial Dataset": + tutorial_dataset_options = [ + "air_temperature", + "air_temperature_gradient", + # "basin_mask", # Different coordinates + # "rasm", # Has nan + "ROMS_example", + "tiny", + # "era5-2mt-2019-03-uk.grib", + "eraint_uvz", + "ersstv5" + ] + with col1: + dataset_name = st.selectbox(label="Dataset", options=tutorial_dataset_options) + dataset = xr.tutorial.open_dataset(dataset_name) + + data.set_dataset(dataset) + + elif data_source == "Custom Dataset": + my_file = st.file_uploader(label="Your file") + data.set_dataset(None) + + if my_file: + my_virtual_file = io.BytesIO(my_file.read()) + my_dataset = xr.open_dataset(my_virtual_file) + st.text("Custom dataset loaded!") + data.set_dataset(my_dataset) + + if data.dataset is not None: + with col2: + variable = st.selectbox(label="Variable", options=data.dataset.data_vars) + if variable: + data.select_variable(variable) + + +def select_slice(data): + st.title("Select Slice") + slice_selection = {} + if data.reference_da is not None and data.reference_da.dims and True: + + tabs = st.tabs(tabs=data.reference_da.dims) + for idx, dimension in enumerate(data.reference_da.dims): + with tabs[idx]: + if str(dimension) == "time": + if len(data.reference_da.time) > 1: + slice_selection[dimension] = st.select_slider(label=dimension, + options=data.reference_da[dimension].values, + ) + else: + slice_selection[dimension] = data.reference_da.time.values[0] + + else: + _min = float(data.reference_da[dimension].values[0]) + _max = float(data.reference_da[dimension].values[-1]) + + if _max - _min < 1000: + slice_selection[dimension] = st.slider(label=dimension, + min_value=_min, + max_value=_max, + value=(_min, _max), + ) + + # if st.button("Clear Cache"): + # st.cache_resource.clear() + + return slice_selection diff --git a/examples/streamlit/playground.py b/examples/streamlit/playground.py index a451696..8106016 100644 --- a/examples/streamlit/playground.py +++ b/examples/streamlit/playground.py @@ -1,226 +1,41 @@ -import io -from typing import Optional import streamlit as st -import xarray as xr -import matplotlib.pyplot as plt -import pandas as pd -import enstools.compression.xr_accessor # noqa -from enstools.encoding.errors import InvalidCompressionSpecification +from component.data_source import create_data, select_dataset, select_slice +from component.compression_section import compression_section, plot_compressed +from component.analysis_section import analysis_section -options = ["Compression", "Analysis"] +st.set_page_config(layout="wide") -class DataContainer: - def __init__(self, dataset: Optional[xr.Dataset] = None): - self.dataset = dataset - self.reference_da = None - self.compressed_da = None - def set_dataset(self, dataset): - self.dataset = dataset - self.reference_da = None - self.compressed_da = None - - def select_variable(self, variable): - self.reference_da = self.dataset[variable] - - @classmethod - def from_tutorial_data(cls, dataset_name: str = "air_temperature"): - return cls(dataset=xr.tutorial.open_dataset(dataset_name)) - - @property - def time_steps(self): - if self.reference_da is not None: - if "time" in self.reference_da.dims: - try: - return pd.to_datetime(self.reference_da.time.values) - except TypeError: - return self.reference_da.time.values - - print(self.reference_da) - - def compress(self, compression): - self.compressed_da = self.reference_da.compression(compression) +data = create_data() def sidebar(): - with st.sidebar: - st.title("Data selection") - data_source = st.radio(label="Data source", options=["Tutorial Dataset", "Custom Dataset"]) - - if data_source == "Tutorial Dataset": - tutorial_dataset_options = [ - "air_temperature", - "air_temperature_gradient", - # "basin_mask", # Different coordinates - # "rasm", # Has nan - "ROMS_example", - "tiny", - # "era5-2mt-2019-03-uk.grib", - "eraint_uvz", - "ersstv5" - ] - dataset_name = st.selectbox(label="Dataset", options=tutorial_dataset_options) - dataset = xr.tutorial.open_dataset(dataset_name) - - data.set_dataset(dataset) - - elif data_source == "Custom Dataset": - my_file = st.file_uploader(label="Your file") - data.set_dataset(None) - - if my_file: - my_virtual_file = io.BytesIO(my_file.read()) - my_dataset = xr.open_dataset(my_virtual_file) - st.text("Custom dataset loaded!") - data.set_dataset(my_dataset) - - if data.dataset is not None: - variable = st.selectbox(label="Variable", options=data.dataset.data_vars) - if variable: - data.select_variable(variable) - - slice_selection = {} - if data.reference_da.dims and True: - for dimension in data.reference_da.dims: - if str(dimension) == "time": - if len(data.reference_da.time) > 1: - slice_selection[dimension] = st.select_slider(label=dimension, - options=data.reference_da[dimension].values, - ) - else: - slice_selection[dimension] = data.reference_da.time.values[0] - st.markdown(f"- {dimension} -> {slice_selection[dimension]}") - - else: - st.markdown(f"- {dimension}") - _min = float(data.reference_da[dimension].values[0]) - _max = float(data.reference_da[dimension].values[-1]) - - - st.markdown(f"- {_min, _max}") - if _max - _min < 1000: - slice_selection[dimension] = st.slider(label=dimension, - min_value=_min, - max_value=_max, - value=(_min, _max), - ) - - if st.button("Clear Cache"): - st.cache_resource.clear() - - return slice_selection - - -@st.cache_resource -def create_data(): - return DataContainer.from_tutorial_data() - - -data = create_data() + ... def setup_main_frame(): - slice_selection = sidebar() - st.title("enstools-compression playground") + st.title("Lossy Compression playground!") + # with st.expander("Data Selection"): + with st.sidebar: + select_dataset(data) + slice_selection = select_slice(data) + st.markdown("---") - compression, analysis = st.tabs(["Compression", "Analysis"]) + options = ["Compression", "Analysis"] + compression, analysis = st.tabs(options) with compression: - if data.dataset is not None: - st.markdown("# Compression") - - with st.expander("Compression Specifications"): - specification_mode = st.radio(label="", options=["String", "Options"], horizontal=True) - # specification, options = st.tabs(["String", "Options"]) - if specification_mode == "String": - compression_specification = st.text_input(label="Compression") - elif specification_mode == "Options": - compressor = st.selectbox(label="Compressor", options=["sz", "zfp"]) - if compressor == "sz": - mode_options = ["abs", "rel", "pw_rel"] - elif compressor == "zfp": - mode_options = ["accuracy", "rate", "precision"] - else: - mode_options = [] - mode = st.selectbox(label="Mode", options=mode_options) - parameter = st.text_input(label="Parameter") - - compression_specification = f"lossy,{compressor},{mode},{parameter}" - st.markdown(f"**Compression Specification:** {compression_specification}") - - if compression_specification: - try: - data.compress(compression_specification) - except InvalidCompressionSpecification: - st.warning("Invalid compression specification!") - st.markdown( - "Check [the compression specification format](https://enstools-encoding.readthedocs.io/en/latest/CompressionSpecificationFormat.html)") - if data.compressed_da is not None: - st.markdown(f"**Compression Ratio**: {data.compressed_da.attrs['compression_ratio']}") - - col1, col2, *others = st.columns(2) - - new_slice = {} - - for key, values in slice_selection.items(): - if isinstance(values, tuple): - start, stop = values - if start != stop: - new_slice[key] = slice(start, stop) - else: - new_slice[key] = start - else: - new_slice[key] = values - - slice_selection = new_slice - - if data.reference_da is not None: - print(f"{slice_selection=}") - slice_selection = {key: (value if key != "lat" else slice(value.stop, value.start)) for key, value in - slice_selection.items()} - - only_slices = {key: value for key, value in slice_selection.items() if isinstance(value, slice)} - non_slices = {key: value for key, value in slice_selection.items() if not isinstance(value, slice)} - - if only_slices: - reference_slice = data.reference_da.sel(**only_slices) - else: - reference_slice = data.reference_da - - if non_slices: - reference_slice = reference_slice.sel(**non_slices, method="nearest") - print(reference_slice) - try: - reference_slice.plot() - fig1 = plt.gcf() - with col1: - st.pyplot(fig1) - except TypeError: - pass - - if data.compressed_da is not None: - plt.figure() - if only_slices: - compressed_slice = data.compressed_da.sel(**only_slices) - else: - compressed_slice = data.compressed_da - if non_slices: - compressed_slice = compressed_slice.sel(**non_slices, method="nearest") - - - try: - compressed_slice.plot() - fig2 = plt.gcf() - with col2: - st.pyplot(fig2) - except TypeError: - pass - - else: - st.text("Compress the data to show the plot!") - - + compression_section(data=data, slice_selection=slice_selection) + with st.spinner(): + try: + plot_compressed(data=data, slice_selection=slice_selection) + except TypeError as err: + st.warning(err) + with analysis: + analysis_section(data=data, slice_selection=slice_selection) + +sidebar() setup_main_frame() -- GitLab