Skip to content
Snippets Groups Projects
Commit ffdd0c84 authored by Oriol Tintó's avatar Oriol Tintó
Browse files

Update streamlit example.

parent 0b81d04d
No related branches found
No related tags found
1 merge request!18Several changes including chunking and expansion of examples.
Pipeline #19106 passed
import xarray
from enstools.compression.analyzer.analyzer import analyze_dataset
import enstools.compression.xr_accessor # noqa
dataset_names = [
"air_temperature",
# "air_temperature_gradient",
# "basin_mask",
# "rasm",
# "ROMS_example",
# "tiny",
# "era5-2mt-2019-03-uk.grib",
# "eraint_uvz",
# "ersstv5"
]
def main():
results = {}
failed_datasets = []
for dataset_name in dataset_names:
try:
with xarray.tutorial.open_dataset(dataset_name) as dataset:
encoding, metrics = analyze_dataset(dataset=dataset)
results[dataset_name] = (encoding, metrics)
dataset.to_netcdf(f"reference_{dataset_name}.nc")
dataset.to_compressed_netcdf(f"compressed_{dataset_name}.nc", compression=encoding)
except ValueError:
failed_datasets.append(dataset_name)
print(results)
print(failed_datasets)
if __name__ == "__main__":
main()
...@@ -23,12 +23,12 @@ default_parameters = { ...@@ -23,12 +23,12 @@ default_parameters = {
} }
def compression_section(data, slice_selection): def advanced_section(data, slice_selection):
if data.dataset is not None: if data.dataset is not None:
# st.markdown("# Compression") # st.markdown("# Compression")
with st.expander("Compression Specifications"): with st.expander("Compression Specifications"):
specification_mode = st.radio(label="", options=["String", "Options"], horizontal=True) specification_mode = st.radio(label="", options=["Options", "String"], horizontal=True)
# specification, options = st.tabs(["String", "Options"]) # specification, options = st.tabs(["String", "Options"])
if specification_mode == "String": if specification_mode == "String":
compression_specification = st.text_input(label="Compression", value="lossy,sz,abs,1") compression_specification = st.text_input(label="Compression", value="lossy,sz,abs,1")
...@@ -60,66 +60,4 @@ def compression_section(data, slice_selection): ...@@ -60,66 +60,4 @@ def compression_section(data, slice_selection):
st.markdown( st.markdown(
"Check [the compression specification format](https://enstools-encoding.readthedocs.io/en/latest/CompressionSpecificationFormat.html)") "Check [the compression specification format](https://enstools-encoding.readthedocs.io/en/latest/CompressionSpecificationFormat.html)")
if data.compressed_da is not None: if data.compressed_da is not None:
st.markdown(f"**Compression Ratio**: {data.compressed_da.attrs['compression_ratio']}") st.markdown(f"**Compression Ratio**: {data.compressed_da.attrs['compression_ratio']}")
\ No newline at end of file
def plot_compressed(data, slice_selection):
col1, col2, *others = st.columns(2)
new_slice = {}
for key, values in slice_selection.items():
if isinstance(values, tuple):
start, stop = values
if start != stop:
new_slice[key] = slice(start, stop)
else:
new_slice[key] = start
else:
new_slice[key] = values
slice_selection = new_slice
if data.reference_da is not None:
print(f"{slice_selection=}")
slice_selection = {key: (value if key != "lat" else slice(value.stop, value.start)) for key, value in
slice_selection.items()}
only_slices = {key: value for key, value in slice_selection.items() if isinstance(value, slice)}
non_slices = {key: value for key, value in slice_selection.items() if not isinstance(value, slice)}
if only_slices:
reference_slice = data.reference_da.sel(**only_slices)
else:
reference_slice = data.reference_da
if non_slices:
reference_slice = reference_slice.sel(**non_slices, method="nearest")
print(reference_slice)
try:
reference_slice.plot()
fig1 = plt.gcf()
with col1:
st.pyplot(fig1)
except TypeError:
pass
if data.compressed_da is not None:
plt.figure()
if only_slices:
compressed_slice = data.compressed_da.sel(**only_slices)
else:
compressed_slice = data.compressed_da
if non_slices:
compressed_slice = compressed_slice.sel(**non_slices, method="nearest")
try:
compressed_slice.plot()
fig2 = plt.gcf()
with col2:
st.pyplot(fig2)
except TypeError:
pass
else:
st.text("Compress the data to show the plot!")
import matplotlib.pyplot as plt
import numpy as np
import streamlit as st
import xarray
import enstools.compression.xr_accessor # noqa
from .data_source import DataContainer
def get_compression_ratio(data_array: xarray.DataArray, relative_tolerance: float, mode: str) -> float:
what = data_array.compression(f"lossy,sz,{mode},{relative_tolerance}", in_place=False)
return float(what.attrs["compression_ratio"])
def invert_function(function):
# Define its derivative
f_prime = function.deriv()
# Define the function for which we want to find the root
def func(x, y_val):
return function(x) - y_val
def newtons_method(y_val, epsilon=1e-7, max_iterations=100):
x = -2 # np.log10(0.01)
print(f"{y_val=}")
for _ in range(max_iterations):
x_new = x - func(x, y_val) / f_prime(x)
if abs(x - x_new) < epsilon:
return x_new
x = x_new
print(x_new)
return None
return newtons_method
def create_parameter_from_compression_ratio(data: DataContainer, mode: str):
train_x = np.logspace(-12, -.5, 15)
train_y = [get_compression_ratio(data.reference_da, parameter, mode=mode) for parameter in train_x]
parameter_range = min(train_y), min(100., max(train_y))
x_log = np.log10(train_x)
y_log = np.log10(train_y)
coeff = np.polyfit(x_log, y_log, 10)
# Create a polynomial function from the coefficients
f = np.poly1d(coeff)
f_inverse = invert_function(f)
def function_to_return(compression_ratio: float) -> float:
return 10 ** f_inverse(np.log10(compression_ratio))
return parameter_range, function_to_return
def basic_section(data: DataContainer, slice_selection):
mode = st.selectbox(label="Mode", options=["rel", "pw_rel"])
parameter_range, get_parameter = create_parameter_from_compression_ratio(data, mode=mode)
_min, _max = parameter_range
options = [_min + (_max - _min) * _x for _x in np.logspace(-2, 0)]
options = [f"{op:.2f}" for op in options]
compression_ratio = st.select_slider(label="Compression Ratio", options=options)
compression_ratio = float(compression_ratio)
parameter = get_parameter(compression_ratio)
with st.spinner():
data.compress(f"lossy,sz,rel,{parameter}")
...@@ -38,6 +38,8 @@ class DataContainer: ...@@ -38,6 +38,8 @@ class DataContainer:
def compress(self, compression): def compress(self, compression):
self.compressed_da = self.reference_da.compression(compression) self.compressed_da = self.reference_da.compression(compression)
def __hash__(self):
return hash(self.reference_da.name)
@st.cache_resource @st.cache_resource
def create_data(): def create_data():
......
import streamlit as st
import matplotlib.pyplot as plt
def plot_comparison(data, slice_selection):
col1, col2, col3, *others = st.columns(3)
new_slice = {}
for key, values in slice_selection.items():
if isinstance(values, tuple):
start, stop = values
if start != stop:
new_slice[key] = slice(start, stop)
else:
new_slice[key] = start
else:
new_slice[key] = values
slice_selection = new_slice
if data.reference_da is not None:
print(f"{slice_selection=}")
slice_selection = {key: (value if key != "lat" else slice(value.stop, value.start)) for key, value in
slice_selection.items()}
only_slices = {key: value for key, value in slice_selection.items() if isinstance(value, slice)}
non_slices = {key: value for key, value in slice_selection.items() if not isinstance(value, slice)}
if only_slices:
reference_slice = data.reference_da.sel(**only_slices)
else:
reference_slice = data.reference_da
if non_slices:
reference_slice = reference_slice.sel(**non_slices, method="nearest")
print(reference_slice)
try:
reference_slice.plot()
fig1 = plt.gcf()
with col1:
st.pyplot(fig1)
except TypeError:
pass
if data.compressed_da is not None:
plt.figure()
if only_slices:
compressed_slice = data.compressed_da.sel(**only_slices)
else:
compressed_slice = data.compressed_da
if non_slices:
compressed_slice = compressed_slice.sel(**non_slices, method="nearest")
try:
compressed_slice.plot()
fig2 = plt.gcf()
with col2:
st.pyplot(fig2)
except TypeError:
pass
diff = data.compressed_da - data.reference_da
plt.figure()
if only_slices:
diff_slice = diff.sel(**only_slices)
else:
diff_slice = data.compressed_da
if non_slices:
diff_slice = diff_slice.sel(**non_slices, method="nearest")
try:
diff_slice.plot()
fig3 = plt.gcf()
with col3:
st.pyplot(fig3)
except TypeError:
pass
else:
st.text("Compress the data to show the plot!")
import streamlit as st import streamlit as st
from component.data_source import create_data, select_dataset, select_slice from component.data_source import create_data, select_dataset, select_slice
from component.compression_section import compression_section, plot_compressed from component.basic_section import basic_section
from component.advanced_section import advanced_section
from component.analysis_section import analysis_section from component.analysis_section import analysis_section
from component.plotter import plot_comparison
st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
st.set_page_config(layout="wide")
data = create_data() data = create_data()
def sidebar():
...
def setup_main_frame(): def setup_main_frame():
st.title("Lossy Compression playground!") st.title("Welcome to the :green[enstools-compression] playground!")
# with st.expander("Data Selection"):
with st.sidebar: with st.sidebar:
select_dataset(data) select_dataset(data)
slice_selection = select_slice(data) slice_selection = select_slice(data)
st.markdown("---") st.markdown("---")
options = ["Compression", "Analysis"] options = ["Compression", "Advanced Compression", "Analysis"]
compression, analysis = st.tabs(options) basic, advanced, analysis = st.tabs(options)
with basic:
basic_section(data=data, slice_selection=slice_selection)
with st.spinner():
try:
plot_comparison(data=data, slice_selection=slice_selection)
except TypeError as err:
st.warning(err)
with compression: with advanced:
compression_section(data=data, slice_selection=slice_selection) advanced_section(data=data, slice_selection=slice_selection)
with st.spinner(): with st.spinner():
try: try:
plot_compressed(data=data, slice_selection=slice_selection) plot_comparison(data=data, slice_selection=slice_selection)
except TypeError as err: except TypeError as err:
st.warning(err) st.warning(err)
with analysis: with analysis:
analysis_section(data=data, slice_selection=slice_selection) analysis_section(data=data, slice_selection=slice_selection)
sidebar()
setup_main_frame() setup_main_frame()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment