From 48973d3e1817aeed507f88b6c3e25866d668eb4e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Oriol=20Tint=C3=B3=20Prims?= <oriol.tinto@lmu.de>
Date: Fri, 16 Jun 2023 14:48:47 +0200
Subject: [PATCH] Small improvement to streamlit app.

---
 .../streamlit/component/analysis_section.py   |  55 +++++
 .../component/compression_section.py          | 125 ++++++++++
 examples/streamlit/component/data_source.py   | 116 +++++++++
 examples/streamlit/playground.py              | 233 ++----------------
 4 files changed, 320 insertions(+), 209 deletions(-)
 create mode 100644 examples/streamlit/component/analysis_section.py
 create mode 100644 examples/streamlit/component/compression_section.py
 create mode 100644 examples/streamlit/component/data_source.py

diff --git a/examples/streamlit/component/analysis_section.py b/examples/streamlit/component/analysis_section.py
new file mode 100644
index 0000000..b710b16
--- /dev/null
+++ b/examples/streamlit/component/analysis_section.py
@@ -0,0 +1,55 @@
+import streamlit as st
+
+import enstools.compression.xr_accessor  # noqa
+
+
+def analysis_section(data, slice_selection):
+    if data.dataset is not None:
+        # st.markdown("# Compression")
+        col1, col2 = st.columns(2)
+        with col1:
+            constrains = st.text_input(label="Constraint", value="correlation_I:5,ssim_I:3")
+
+        options = {
+            "sz": ["abs", "rel", "pw_rel"],
+            "sz3": ["abs", "rel"],
+            "zfp": ["accuracy", "rate", "precision"],
+        }
+
+        all_options = []
+        [all_options.extend([f"{compressor}-{mode}" for mode in options[compressor]]) for compressor in options]
+
+        with col2:
+            cases = st.multiselect(label="Compressor and mode", options=all_options)
+
+        if data.reference_da is None:
+            return
+
+        if not cases:
+            return
+
+        st.markdown("# Results:")
+        n_cols = 4
+        cols = st.columns(n_cols)
+
+        all_results = {}
+
+        for idx, case in enumerate(cases):
+            with cols[idx % n_cols]:
+                compressor, mode = case.split("-")
+                encoding, metrics = data.reference_da.compression.analyze(
+                    constrains=constrains,
+                    compressor=compressor,
+                    compression_mode=mode
+                )
+                parameter = encoding.split(",")[-1]
+                compression_ratio = metrics["compression_ratio"]
+                all_results[case] = compression_ratio
+                st.markdown(f"## {compressor},{mode}:\n\n"
+                            f"**Compression Ratio:** {compression_ratio:.2f}x\n\n"
+                            f"**Parameter:** {parameter}\n\n"
+                            f"**Specification String:**")
+                st.code(encoding)
+                st.markdown(f"___")
+                # st.markdown(encoding)
+                # st.markdown(metrics)
diff --git a/examples/streamlit/component/compression_section.py b/examples/streamlit/component/compression_section.py
new file mode 100644
index 0000000..1114e71
--- /dev/null
+++ b/examples/streamlit/component/compression_section.py
@@ -0,0 +1,125 @@
+import matplotlib.pyplot as plt
+import streamlit as st
+
+import enstools.compression.xr_accessor  # noqa
+from enstools.encoding.errors import InvalidCompressionSpecification
+
+default_parameters = {
+    "sz": {
+        "abs": 1,
+        "rel": 0.001,
+        "pw_rel": 0.001,
+    },
+    "sz3": {
+        "abs": 1,
+        "rel": 0.001,
+    },
+    "zfp": {
+        "accuracy": 1,
+        "rate": 3.2,
+        "precision": 14,
+    }
+
+}
+
+
+def compression_section(data, slice_selection):
+    if data.dataset is not None:
+        # st.markdown("# Compression")
+
+        with st.expander("Compression Specifications"):
+            specification_mode = st.radio(label="", options=["String", "Options"], horizontal=True)
+            # specification, options = st.tabs(["String", "Options"])
+            if specification_mode == "String":
+                compression_specification = st.text_input(label="Compression", value="lossy,sz,abs,1")
+            elif specification_mode == "Options":
+                col1_, col2_, col3_ = st.columns(3)
+                with col1_:
+                    compressor = st.selectbox(label="Compressor", options=["sz", "sz3", "zfp"])
+                if compressor == "sz":
+                    mode_options = ["abs", "rel", "pw_rel"]
+                elif compressor == "sz3":
+                    mode_options = ["abs", "rel"]
+                elif compressor == "zfp":
+                    mode_options = ["accuracy", "rate", "precision"]
+                else:
+                    mode_options = []
+                with col2_:
+                    mode = st.selectbox(label="Mode", options=mode_options)
+                with col3_:
+                    parameter = st.text_input(label="Parameter", value=default_parameters[compressor][mode])
+
+                compression_specification = f"lossy,{compressor},{mode},{parameter}"
+                st.markdown(f"**Compression Specification:** {compression_specification}")
+
+            if compression_specification:
+                try:
+                    data.compress(compression_specification)
+                except InvalidCompressionSpecification:
+                    st.warning("Invalid compression specification!")
+                    st.markdown(
+                        "Check [the compression specification format](https://enstools-encoding.readthedocs.io/en/latest/CompressionSpecificationFormat.html)")
+            if data.compressed_da is not None:
+                st.markdown(f"**Compression Ratio**: {data.compressed_da.attrs['compression_ratio']}")
+
+
+def plot_compressed(data, slice_selection):
+    col1, col2, *others = st.columns(2)
+
+    new_slice = {}
+
+    for key, values in slice_selection.items():
+        if isinstance(values, tuple):
+            start, stop = values
+            if start != stop:
+                new_slice[key] = slice(start, stop)
+            else:
+                new_slice[key] = start
+        else:
+            new_slice[key] = values
+
+    slice_selection = new_slice
+
+    if data.reference_da is not None:
+        print(f"{slice_selection=}")
+        slice_selection = {key: (value if key != "lat" else slice(value.stop, value.start)) for key, value in
+                           slice_selection.items()}
+
+        only_slices = {key: value for key, value in slice_selection.items() if isinstance(value, slice)}
+        non_slices = {key: value for key, value in slice_selection.items() if not isinstance(value, slice)}
+
+        if only_slices:
+            reference_slice = data.reference_da.sel(**only_slices)
+        else:
+            reference_slice = data.reference_da
+
+        if non_slices:
+            reference_slice = reference_slice.sel(**non_slices, method="nearest")
+        print(reference_slice)
+        try:
+            reference_slice.plot()
+            fig1 = plt.gcf()
+            with col1:
+                st.pyplot(fig1)
+        except TypeError:
+            pass
+
+    if data.compressed_da is not None:
+        plt.figure()
+        if only_slices:
+            compressed_slice = data.compressed_da.sel(**only_slices)
+        else:
+            compressed_slice = data.compressed_da
+        if non_slices:
+            compressed_slice = compressed_slice.sel(**non_slices, method="nearest")
+
+        try:
+            compressed_slice.plot()
+            fig2 = plt.gcf()
+            with col2:
+                st.pyplot(fig2)
+        except TypeError:
+            pass
+
+    else:
+        st.text("Compress the data to show the plot!")
diff --git a/examples/streamlit/component/data_source.py b/examples/streamlit/component/data_source.py
new file mode 100644
index 0000000..1d52908
--- /dev/null
+++ b/examples/streamlit/component/data_source.py
@@ -0,0 +1,116 @@
+import io
+from typing import Optional
+
+import pandas as pd
+import streamlit as st
+import xarray as xr
+
+
+class DataContainer:
+    def __init__(self, dataset: Optional[xr.Dataset] = None):
+        self.dataset = dataset
+        self.reference_da = None
+        self.compressed_da = None
+
+    def set_dataset(self, dataset):
+        self.dataset = dataset
+        self.reference_da = None
+        self.compressed_da = None
+
+    def select_variable(self, variable):
+        self.reference_da = self.dataset[variable]
+
+    @classmethod
+    def from_tutorial_data(cls, dataset_name: str = "air_temperature"):
+        return cls(dataset=xr.tutorial.open_dataset(dataset_name))
+
+    @property
+    def time_steps(self):
+        if self.reference_da is not None:
+            if "time" in self.reference_da.dims:
+                try:
+                    return pd.to_datetime(self.reference_da.time.values)
+                except TypeError:
+                    return self.reference_da.time.values
+
+            print(self.reference_da)
+
+    def compress(self, compression):
+        self.compressed_da = self.reference_da.compression(compression)
+
+
+@st.cache_resource
+def create_data():
+    return DataContainer.from_tutorial_data()
+
+
+def select_dataset(data):
+    st.title("Select Dataset")
+    data_source = st.radio(label="Data source", options=["Tutorial Dataset", "Custom Dataset"])
+    col1, col2 = st.columns(2)
+    if data_source == "Tutorial Dataset":
+        tutorial_dataset_options = [
+            "air_temperature",
+            "air_temperature_gradient",
+            # "basin_mask",  # Different coordinates
+            # "rasm",  # Has nan
+            "ROMS_example",
+            "tiny",
+            # "era5-2mt-2019-03-uk.grib",
+            "eraint_uvz",
+            "ersstv5"
+        ]
+        with col1:
+            dataset_name = st.selectbox(label="Dataset", options=tutorial_dataset_options)
+        dataset = xr.tutorial.open_dataset(dataset_name)
+
+        data.set_dataset(dataset)
+
+    elif data_source == "Custom Dataset":
+        my_file = st.file_uploader(label="Your file")
+        data.set_dataset(None)
+
+        if my_file:
+            my_virtual_file = io.BytesIO(my_file.read())
+            my_dataset = xr.open_dataset(my_virtual_file)
+            st.text("Custom dataset loaded!")
+            data.set_dataset(my_dataset)
+
+    if data.dataset is not None:
+        with col2:
+            variable = st.selectbox(label="Variable", options=data.dataset.data_vars)
+        if variable:
+            data.select_variable(variable)
+
+
+def select_slice(data):
+    st.title("Select Slice")
+    slice_selection = {}
+    if data.reference_da is not None and data.reference_da.dims and True:
+
+        tabs = st.tabs(tabs=data.reference_da.dims)
+        for idx, dimension in enumerate(data.reference_da.dims):
+            with tabs[idx]:
+                if str(dimension) == "time":
+                    if len(data.reference_da.time) > 1:
+                        slice_selection[dimension] = st.select_slider(label=dimension,
+                                                                      options=data.reference_da[dimension].values,
+                                                                      )
+                    else:
+                        slice_selection[dimension] = data.reference_da.time.values[0]
+
+                else:
+                    _min = float(data.reference_da[dimension].values[0])
+                    _max = float(data.reference_da[dimension].values[-1])
+
+                    if _max - _min < 1000:
+                        slice_selection[dimension] = st.slider(label=dimension,
+                                                               min_value=_min,
+                                                               max_value=_max,
+                                                               value=(_min, _max),
+                                                               )
+
+        # if st.button("Clear Cache"):
+        #     st.cache_resource.clear()
+
+    return slice_selection
diff --git a/examples/streamlit/playground.py b/examples/streamlit/playground.py
index a451696..8106016 100644
--- a/examples/streamlit/playground.py
+++ b/examples/streamlit/playground.py
@@ -1,226 +1,41 @@
-import io
-from typing import Optional
 
 import streamlit as st
-import xarray as xr
-import matplotlib.pyplot as plt
-import pandas as pd
 
-import enstools.compression.xr_accessor  # noqa
-from enstools.encoding.errors import InvalidCompressionSpecification
+from component.data_source import create_data, select_dataset, select_slice
+from component.compression_section import compression_section, plot_compressed
+from component.analysis_section import analysis_section
 
-options = ["Compression", "Analysis"]
 
+st.set_page_config(layout="wide")
 
-class DataContainer:
-    def __init__(self, dataset: Optional[xr.Dataset] = None):
-        self.dataset = dataset
-        self.reference_da = None
-        self.compressed_da = None
 
-    def set_dataset(self, dataset):
-        self.dataset = dataset
-        self.reference_da = None
-        self.compressed_da = None
-
-    def select_variable(self, variable):
-        self.reference_da = self.dataset[variable]
-
-    @classmethod
-    def from_tutorial_data(cls, dataset_name: str = "air_temperature"):
-        return cls(dataset=xr.tutorial.open_dataset(dataset_name))
-
-    @property
-    def time_steps(self):
-        if self.reference_da is not None:
-            if "time" in self.reference_da.dims:
-                try:
-                    return pd.to_datetime(self.reference_da.time.values)
-                except TypeError:
-                    return self.reference_da.time.values
-
-            print(self.reference_da)
-
-    def compress(self, compression):
-        self.compressed_da = self.reference_da.compression(compression)
+data = create_data()
 
 
 def sidebar():
-    with st.sidebar:
-        st.title("Data selection")
-        data_source = st.radio(label="Data source", options=["Tutorial Dataset", "Custom Dataset"])
-
-        if data_source == "Tutorial Dataset":
-            tutorial_dataset_options = [
-                "air_temperature",
-                "air_temperature_gradient",
-                # "basin_mask",  # Different coordinates
-                # "rasm",  # Has nan
-                "ROMS_example",
-                "tiny",
-                # "era5-2mt-2019-03-uk.grib",
-                "eraint_uvz",
-                "ersstv5"
-            ]
-            dataset_name = st.selectbox(label="Dataset", options=tutorial_dataset_options)
-            dataset = xr.tutorial.open_dataset(dataset_name)
-
-            data.set_dataset(dataset)
-
-        elif data_source == "Custom Dataset":
-            my_file = st.file_uploader(label="Your file")
-            data.set_dataset(None)
-
-            if my_file:
-                my_virtual_file = io.BytesIO(my_file.read())
-                my_dataset = xr.open_dataset(my_virtual_file)
-                st.text("Custom dataset loaded!")
-                data.set_dataset(my_dataset)
-
-        if data.dataset is not None:
-            variable = st.selectbox(label="Variable", options=data.dataset.data_vars)
-            if variable:
-                data.select_variable(variable)
-
-        slice_selection = {}
-        if data.reference_da.dims and True:
-            for dimension in data.reference_da.dims:
-                if str(dimension) == "time":
-                    if len(data.reference_da.time) > 1:
-                        slice_selection[dimension] = st.select_slider(label=dimension,
-                                                                      options=data.reference_da[dimension].values,
-                                                                      )
-                    else:
-                        slice_selection[dimension] = data.reference_da.time.values[0]
-                    st.markdown(f"- {dimension} -> {slice_selection[dimension]}")
-
-                else:
-                    st.markdown(f"- {dimension}")
-                    _min = float(data.reference_da[dimension].values[0])
-                    _max = float(data.reference_da[dimension].values[-1])
-
-
-                    st.markdown(f"- {_min, _max}")
-                    if _max - _min < 1000:
-                        slice_selection[dimension] = st.slider(label=dimension,
-                                                               min_value=_min,
-                                                               max_value=_max,
-                                                               value=(_min, _max),
-                                                               )
-
-            if st.button("Clear Cache"):
-                st.cache_resource.clear()
-
-        return slice_selection
-
-
-@st.cache_resource
-def create_data():
-    return DataContainer.from_tutorial_data()
-
-
-data = create_data()
+    ...
 
 
 def setup_main_frame():
-    slice_selection = sidebar()
-    st.title("enstools-compression playground")
+    st.title("Lossy Compression playground!")
+    # with st.expander("Data Selection"):
+    with st.sidebar:
+        select_dataset(data)
+        slice_selection = select_slice(data)
+
     st.markdown("---")
-    compression, analysis = st.tabs(["Compression", "Analysis"])
+    options = ["Compression", "Analysis"]
+    compression, analysis = st.tabs(options)
 
     with compression:
-        if data.dataset is not None:
-            st.markdown("# Compression")
-
-            with st.expander("Compression Specifications"):
-                specification_mode = st.radio(label="", options=["String", "Options"], horizontal=True)
-                # specification, options = st.tabs(["String", "Options"])
-                if specification_mode == "String":
-                    compression_specification = st.text_input(label="Compression")
-                elif specification_mode == "Options":
-                    compressor = st.selectbox(label="Compressor", options=["sz", "zfp"])
-                    if compressor == "sz":
-                        mode_options = ["abs", "rel", "pw_rel"]
-                    elif compressor == "zfp":
-                        mode_options = ["accuracy", "rate", "precision"]
-                    else:
-                        mode_options = []
-                    mode = st.selectbox(label="Mode", options=mode_options)
-                    parameter = st.text_input(label="Parameter")
-
-                    compression_specification = f"lossy,{compressor},{mode},{parameter}"
-                    st.markdown(f"**Compression Specification:** {compression_specification}")
-
-                if compression_specification:
-                    try:
-                        data.compress(compression_specification)
-                    except InvalidCompressionSpecification:
-                        st.warning("Invalid compression specification!")
-                        st.markdown(
-                            "Check [the compression specification format](https://enstools-encoding.readthedocs.io/en/latest/CompressionSpecificationFormat.html)")
-                if data.compressed_da is not None:
-                    st.markdown(f"**Compression Ratio**: {data.compressed_da.attrs['compression_ratio']}")
-
-            col1, col2, *others = st.columns(2)
-
-            new_slice = {}
-
-            for key, values in slice_selection.items():
-                if isinstance(values, tuple):
-                    start, stop = values
-                    if start != stop:
-                        new_slice[key] = slice(start, stop)
-                    else:
-                        new_slice[key] = start
-                else:
-                    new_slice[key] = values
-
-            slice_selection = new_slice
-
-            if data.reference_da is not None:
-                print(f"{slice_selection=}")
-                slice_selection = {key: (value if key != "lat" else slice(value.stop, value.start)) for key, value in
-                                   slice_selection.items()}
-
-                only_slices = {key: value for key, value in slice_selection.items() if isinstance(value, slice)}
-                non_slices = {key: value for key, value in slice_selection.items() if not isinstance(value, slice)}
-
-                if only_slices:
-                    reference_slice = data.reference_da.sel(**only_slices)
-                else:
-                    reference_slice = data.reference_da
-
-                if non_slices:
-                    reference_slice = reference_slice.sel(**non_slices, method="nearest")
-                print(reference_slice)
-                try:
-                    reference_slice.plot()
-                    fig1 = plt.gcf()
-                    with col1:
-                        st.pyplot(fig1)
-                except TypeError:
-                    pass
-
-            if data.compressed_da is not None:
-                plt.figure()
-                if only_slices:
-                    compressed_slice = data.compressed_da.sel(**only_slices)
-                else:
-                    compressed_slice = data.compressed_da
-                if non_slices:
-                    compressed_slice = compressed_slice.sel(**non_slices, method="nearest")
-
-
-                try:
-                    compressed_slice.plot()
-                    fig2 = plt.gcf()
-                    with col2:
-                        st.pyplot(fig2)
-                except TypeError:
-                    pass
-
-            else:
-                st.text("Compress the data to show the plot!")
-
-
+        compression_section(data=data, slice_selection=slice_selection)
+        with st.spinner():
+            try:
+                plot_compressed(data=data, slice_selection=slice_selection)
+            except TypeError as err:
+                st.warning(err)
+    with analysis:
+        analysis_section(data=data, slice_selection=slice_selection)
+
+sidebar()
 setup_main_frame()
-- 
GitLab