Skip to content
Snippets Groups Projects
Commit 48973d3e authored by Oriol Tintó Prims's avatar Oriol Tintó Prims
Browse files

Small improvement to streamlit app.

parent e9db02cb
No related branches found
No related tags found
1 merge request!18Several changes including chunking and expansion of examples.
Pipeline #19050 passed
import streamlit as st
import enstools.compression.xr_accessor # noqa
def analysis_section(data, slice_selection):
if data.dataset is not None:
# st.markdown("# Compression")
col1, col2 = st.columns(2)
with col1:
constrains = st.text_input(label="Constraint", value="correlation_I:5,ssim_I:3")
options = {
"sz": ["abs", "rel", "pw_rel"],
"sz3": ["abs", "rel"],
"zfp": ["accuracy", "rate", "precision"],
}
all_options = []
[all_options.extend([f"{compressor}-{mode}" for mode in options[compressor]]) for compressor in options]
with col2:
cases = st.multiselect(label="Compressor and mode", options=all_options)
if data.reference_da is None:
return
if not cases:
return
st.markdown("# Results:")
n_cols = 4
cols = st.columns(n_cols)
all_results = {}
for idx, case in enumerate(cases):
with cols[idx % n_cols]:
compressor, mode = case.split("-")
encoding, metrics = data.reference_da.compression.analyze(
constrains=constrains,
compressor=compressor,
compression_mode=mode
)
parameter = encoding.split(",")[-1]
compression_ratio = metrics["compression_ratio"]
all_results[case] = compression_ratio
st.markdown(f"## {compressor},{mode}:\n\n"
f"**Compression Ratio:** {compression_ratio:.2f}x\n\n"
f"**Parameter:** {parameter}\n\n"
f"**Specification String:**")
st.code(encoding)
st.markdown(f"___")
# st.markdown(encoding)
# st.markdown(metrics)
import matplotlib.pyplot as plt
import streamlit as st
import enstools.compression.xr_accessor # noqa
from enstools.encoding.errors import InvalidCompressionSpecification
default_parameters = {
"sz": {
"abs": 1,
"rel": 0.001,
"pw_rel": 0.001,
},
"sz3": {
"abs": 1,
"rel": 0.001,
},
"zfp": {
"accuracy": 1,
"rate": 3.2,
"precision": 14,
}
}
def compression_section(data, slice_selection):
if data.dataset is not None:
# st.markdown("# Compression")
with st.expander("Compression Specifications"):
specification_mode = st.radio(label="", options=["String", "Options"], horizontal=True)
# specification, options = st.tabs(["String", "Options"])
if specification_mode == "String":
compression_specification = st.text_input(label="Compression", value="lossy,sz,abs,1")
elif specification_mode == "Options":
col1_, col2_, col3_ = st.columns(3)
with col1_:
compressor = st.selectbox(label="Compressor", options=["sz", "sz3", "zfp"])
if compressor == "sz":
mode_options = ["abs", "rel", "pw_rel"]
elif compressor == "sz3":
mode_options = ["abs", "rel"]
elif compressor == "zfp":
mode_options = ["accuracy", "rate", "precision"]
else:
mode_options = []
with col2_:
mode = st.selectbox(label="Mode", options=mode_options)
with col3_:
parameter = st.text_input(label="Parameter", value=default_parameters[compressor][mode])
compression_specification = f"lossy,{compressor},{mode},{parameter}"
st.markdown(f"**Compression Specification:** {compression_specification}")
if compression_specification:
try:
data.compress(compression_specification)
except InvalidCompressionSpecification:
st.warning("Invalid compression specification!")
st.markdown(
"Check [the compression specification format](https://enstools-encoding.readthedocs.io/en/latest/CompressionSpecificationFormat.html)")
if data.compressed_da is not None:
st.markdown(f"**Compression Ratio**: {data.compressed_da.attrs['compression_ratio']}")
def plot_compressed(data, slice_selection):
col1, col2, *others = st.columns(2)
new_slice = {}
for key, values in slice_selection.items():
if isinstance(values, tuple):
start, stop = values
if start != stop:
new_slice[key] = slice(start, stop)
else:
new_slice[key] = start
else:
new_slice[key] = values
slice_selection = new_slice
if data.reference_da is not None:
print(f"{slice_selection=}")
slice_selection = {key: (value if key != "lat" else slice(value.stop, value.start)) for key, value in
slice_selection.items()}
only_slices = {key: value for key, value in slice_selection.items() if isinstance(value, slice)}
non_slices = {key: value for key, value in slice_selection.items() if not isinstance(value, slice)}
if only_slices:
reference_slice = data.reference_da.sel(**only_slices)
else:
reference_slice = data.reference_da
if non_slices:
reference_slice = reference_slice.sel(**non_slices, method="nearest")
print(reference_slice)
try:
reference_slice.plot()
fig1 = plt.gcf()
with col1:
st.pyplot(fig1)
except TypeError:
pass
if data.compressed_da is not None:
plt.figure()
if only_slices:
compressed_slice = data.compressed_da.sel(**only_slices)
else:
compressed_slice = data.compressed_da
if non_slices:
compressed_slice = compressed_slice.sel(**non_slices, method="nearest")
try:
compressed_slice.plot()
fig2 = plt.gcf()
with col2:
st.pyplot(fig2)
except TypeError:
pass
else:
st.text("Compress the data to show the plot!")
import io
from typing import Optional
import pandas as pd
import streamlit as st
import xarray as xr
class DataContainer:
def __init__(self, dataset: Optional[xr.Dataset] = None):
self.dataset = dataset
self.reference_da = None
self.compressed_da = None
def set_dataset(self, dataset):
self.dataset = dataset
self.reference_da = None
self.compressed_da = None
def select_variable(self, variable):
self.reference_da = self.dataset[variable]
@classmethod
def from_tutorial_data(cls, dataset_name: str = "air_temperature"):
return cls(dataset=xr.tutorial.open_dataset(dataset_name))
@property
def time_steps(self):
if self.reference_da is not None:
if "time" in self.reference_da.dims:
try:
return pd.to_datetime(self.reference_da.time.values)
except TypeError:
return self.reference_da.time.values
print(self.reference_da)
def compress(self, compression):
self.compressed_da = self.reference_da.compression(compression)
@st.cache_resource
def create_data():
return DataContainer.from_tutorial_data()
def select_dataset(data):
st.title("Select Dataset")
data_source = st.radio(label="Data source", options=["Tutorial Dataset", "Custom Dataset"])
col1, col2 = st.columns(2)
if data_source == "Tutorial Dataset":
tutorial_dataset_options = [
"air_temperature",
"air_temperature_gradient",
# "basin_mask", # Different coordinates
# "rasm", # Has nan
"ROMS_example",
"tiny",
# "era5-2mt-2019-03-uk.grib",
"eraint_uvz",
"ersstv5"
]
with col1:
dataset_name = st.selectbox(label="Dataset", options=tutorial_dataset_options)
dataset = xr.tutorial.open_dataset(dataset_name)
data.set_dataset(dataset)
elif data_source == "Custom Dataset":
my_file = st.file_uploader(label="Your file")
data.set_dataset(None)
if my_file:
my_virtual_file = io.BytesIO(my_file.read())
my_dataset = xr.open_dataset(my_virtual_file)
st.text("Custom dataset loaded!")
data.set_dataset(my_dataset)
if data.dataset is not None:
with col2:
variable = st.selectbox(label="Variable", options=data.dataset.data_vars)
if variable:
data.select_variable(variable)
def select_slice(data):
st.title("Select Slice")
slice_selection = {}
if data.reference_da is not None and data.reference_da.dims and True:
tabs = st.tabs(tabs=data.reference_da.dims)
for idx, dimension in enumerate(data.reference_da.dims):
with tabs[idx]:
if str(dimension) == "time":
if len(data.reference_da.time) > 1:
slice_selection[dimension] = st.select_slider(label=dimension,
options=data.reference_da[dimension].values,
)
else:
slice_selection[dimension] = data.reference_da.time.values[0]
else:
_min = float(data.reference_da[dimension].values[0])
_max = float(data.reference_da[dimension].values[-1])
if _max - _min < 1000:
slice_selection[dimension] = st.slider(label=dimension,
min_value=_min,
max_value=_max,
value=(_min, _max),
)
# if st.button("Clear Cache"):
# st.cache_resource.clear()
return slice_selection
import io
from typing import Optional
import streamlit as st import streamlit as st
import xarray as xr
import matplotlib.pyplot as plt
import pandas as pd
import enstools.compression.xr_accessor # noqa from component.data_source import create_data, select_dataset, select_slice
from enstools.encoding.errors import InvalidCompressionSpecification from component.compression_section import compression_section, plot_compressed
from component.analysis_section import analysis_section
options = ["Compression", "Analysis"]
st.set_page_config(layout="wide")
class DataContainer:
def __init__(self, dataset: Optional[xr.Dataset] = None):
self.dataset = dataset
self.reference_da = None
self.compressed_da = None
def set_dataset(self, dataset): data = create_data()
self.dataset = dataset
self.reference_da = None
self.compressed_da = None
def select_variable(self, variable):
self.reference_da = self.dataset[variable]
@classmethod
def from_tutorial_data(cls, dataset_name: str = "air_temperature"):
return cls(dataset=xr.tutorial.open_dataset(dataset_name))
@property
def time_steps(self):
if self.reference_da is not None:
if "time" in self.reference_da.dims:
try:
return pd.to_datetime(self.reference_da.time.values)
except TypeError:
return self.reference_da.time.values
print(self.reference_da)
def compress(self, compression):
self.compressed_da = self.reference_da.compression(compression)
def sidebar(): def sidebar():
with st.sidebar: ...
st.title("Data selection")
data_source = st.radio(label="Data source", options=["Tutorial Dataset", "Custom Dataset"])
if data_source == "Tutorial Dataset":
tutorial_dataset_options = [
"air_temperature",
"air_temperature_gradient",
# "basin_mask", # Different coordinates
# "rasm", # Has nan
"ROMS_example",
"tiny",
# "era5-2mt-2019-03-uk.grib",
"eraint_uvz",
"ersstv5"
]
dataset_name = st.selectbox(label="Dataset", options=tutorial_dataset_options)
dataset = xr.tutorial.open_dataset(dataset_name)
data.set_dataset(dataset)
elif data_source == "Custom Dataset":
my_file = st.file_uploader(label="Your file")
data.set_dataset(None)
if my_file:
my_virtual_file = io.BytesIO(my_file.read())
my_dataset = xr.open_dataset(my_virtual_file)
st.text("Custom dataset loaded!")
data.set_dataset(my_dataset)
if data.dataset is not None:
variable = st.selectbox(label="Variable", options=data.dataset.data_vars)
if variable:
data.select_variable(variable)
slice_selection = {}
if data.reference_da.dims and True:
for dimension in data.reference_da.dims:
if str(dimension) == "time":
if len(data.reference_da.time) > 1:
slice_selection[dimension] = st.select_slider(label=dimension,
options=data.reference_da[dimension].values,
)
else:
slice_selection[dimension] = data.reference_da.time.values[0]
st.markdown(f"- {dimension} -> {slice_selection[dimension]}")
else:
st.markdown(f"- {dimension}")
_min = float(data.reference_da[dimension].values[0])
_max = float(data.reference_da[dimension].values[-1])
st.markdown(f"- {_min, _max}")
if _max - _min < 1000:
slice_selection[dimension] = st.slider(label=dimension,
min_value=_min,
max_value=_max,
value=(_min, _max),
)
if st.button("Clear Cache"):
st.cache_resource.clear()
return slice_selection
@st.cache_resource
def create_data():
return DataContainer.from_tutorial_data()
data = create_data()
def setup_main_frame(): def setup_main_frame():
slice_selection = sidebar() st.title("Lossy Compression playground!")
st.title("enstools-compression playground") # with st.expander("Data Selection"):
with st.sidebar:
select_dataset(data)
slice_selection = select_slice(data)
st.markdown("---") st.markdown("---")
compression, analysis = st.tabs(["Compression", "Analysis"]) options = ["Compression", "Analysis"]
compression, analysis = st.tabs(options)
with compression: with compression:
if data.dataset is not None: compression_section(data=data, slice_selection=slice_selection)
st.markdown("# Compression") with st.spinner():
try:
with st.expander("Compression Specifications"): plot_compressed(data=data, slice_selection=slice_selection)
specification_mode = st.radio(label="", options=["String", "Options"], horizontal=True) except TypeError as err:
# specification, options = st.tabs(["String", "Options"]) st.warning(err)
if specification_mode == "String": with analysis:
compression_specification = st.text_input(label="Compression") analysis_section(data=data, slice_selection=slice_selection)
elif specification_mode == "Options":
compressor = st.selectbox(label="Compressor", options=["sz", "zfp"]) sidebar()
if compressor == "sz":
mode_options = ["abs", "rel", "pw_rel"]
elif compressor == "zfp":
mode_options = ["accuracy", "rate", "precision"]
else:
mode_options = []
mode = st.selectbox(label="Mode", options=mode_options)
parameter = st.text_input(label="Parameter")
compression_specification = f"lossy,{compressor},{mode},{parameter}"
st.markdown(f"**Compression Specification:** {compression_specification}")
if compression_specification:
try:
data.compress(compression_specification)
except InvalidCompressionSpecification:
st.warning("Invalid compression specification!")
st.markdown(
"Check [the compression specification format](https://enstools-encoding.readthedocs.io/en/latest/CompressionSpecificationFormat.html)")
if data.compressed_da is not None:
st.markdown(f"**Compression Ratio**: {data.compressed_da.attrs['compression_ratio']}")
col1, col2, *others = st.columns(2)
new_slice = {}
for key, values in slice_selection.items():
if isinstance(values, tuple):
start, stop = values
if start != stop:
new_slice[key] = slice(start, stop)
else:
new_slice[key] = start
else:
new_slice[key] = values
slice_selection = new_slice
if data.reference_da is not None:
print(f"{slice_selection=}")
slice_selection = {key: (value if key != "lat" else slice(value.stop, value.start)) for key, value in
slice_selection.items()}
only_slices = {key: value for key, value in slice_selection.items() if isinstance(value, slice)}
non_slices = {key: value for key, value in slice_selection.items() if not isinstance(value, slice)}
if only_slices:
reference_slice = data.reference_da.sel(**only_slices)
else:
reference_slice = data.reference_da
if non_slices:
reference_slice = reference_slice.sel(**non_slices, method="nearest")
print(reference_slice)
try:
reference_slice.plot()
fig1 = plt.gcf()
with col1:
st.pyplot(fig1)
except TypeError:
pass
if data.compressed_da is not None:
plt.figure()
if only_slices:
compressed_slice = data.compressed_da.sel(**only_slices)
else:
compressed_slice = data.compressed_da
if non_slices:
compressed_slice = compressed_slice.sel(**non_slices, method="nearest")
try:
compressed_slice.plot()
fig2 = plt.gcf()
with col2:
st.pyplot(fig2)
except TypeError:
pass
else:
st.text("Compress the data to show the plot!")
setup_main_frame() setup_main_frame()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment