Skip to content

Commit

Permalink
Add simple streamlit notebook to understand difficulties
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed May 6, 2024
1 parent b7a0e90 commit 8f3d5e1
Showing 1 changed file with 154 additions and 0 deletions.
154 changes: 154 additions & 0 deletions examples/understanding_normalized_and_difficulty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""
This is a streamlit app.
"""
import jax
import matplotlib.pyplot as plt
import streamlit as st

import exponax as ex

jax.config.update("jax_platform_name", "cpu")

with st.sidebar:
num_points = st.slider("Number of points", 16, 256, 48)
num_steps = st.slider("Number of steps", 1, 300, 50)
num_modes_init = st.slider("Number of modes in the initial condition", 1, 40, 5)
num_substeps = st.slider("Number of substeps", 1, 100, 1)

use_difficulty = st.toggle("Use difficulty", value=True)

overall_scale = st.slider("Overall scale", 0.1, 50.0, 1.0)

a_0_cols = st.columns(3)
with a_0_cols[0]:
a_0_mantissa = st.slider("a_0 mantissa", 0.0, 10.0, 0.0)
with a_0_cols[1]:
a_0_exponent = st.slider("a_0 exponent", -5, 5, 0)
with a_0_cols[2]:
a_0_sign = st.select_slider("a_0 sign", options=["-", "+"])
a_0 = float(f"{a_0_sign}{a_0_mantissa}e{a_0_exponent}")

a_1_cols = st.columns(3)
with a_1_cols[0]:
a_1_mantissa = st.slider("a_1 mantissa", 0.0, 10.0, 0.1)
with a_1_cols[1]:
a_1_exponent = st.slider("a_1 exponent", -5, 5, 0)
with a_1_cols[2]:
a_1_sign = st.select_slider("a_1 sign", options=["-", "+"])
a_1 = float(f"{a_1_sign}{a_1_mantissa}e{a_1_exponent}")

a_2_cols = st.columns(3)
with a_2_cols[0]:
a_2_mantissa = st.slider("a_2 mantissa", 0.0, 10.0, 0.0)
with a_2_cols[1]:
a_2_exponent = st.slider("a_2 exponent", -5, 5, 0)
with a_2_cols[2]:
a_2_sign = st.select_slider("a_2 sign", options=["-", "+"])
a_2 = float(f"{a_2_sign}{a_2_mantissa}e{a_2_exponent}")

a_3_cols = st.columns(3)
with a_3_cols[0]:
a_3_mantissa = st.slider("a_3 mantissa", 0.0, 10.0, 0.0)
with a_3_cols[1]:
a_3_exponent = st.slider("a_3 exponent", -5, 5, 0)
with a_3_cols[2]:
a_3_sign = st.select_slider("a_3 sign", options=["-", "+"])
a_3 = float(f"{a_3_sign}{a_3_mantissa}e{a_3_exponent}")

a_4_cols = st.columns(3)
with a_4_cols[0]:
a_4_mantissa = st.slider("a_4 mantissa", 0.0, 10.0, 0.0)
with a_4_cols[1]:
a_4_exponent = st.slider("a_4 exponent", -5, 5, 0)
with a_4_cols[2]:
a_4_sign = st.select_slider("a_4 sign", options=["-", "+"])
a_4 = float(f"{a_4_sign}{a_4_mantissa}e{a_4_exponent}")

b_0_cols = st.columns(3)
with b_0_cols[0]:
b_0_mantissa = st.slider("b_0 mantissa", 0.0, 10.0, 0.0)
with b_0_cols[1]:
b_0_exponent = st.slider("b_0 exponent", -5, 5, 0)
with b_0_cols[2]:
b_0_sign = st.select_slider("b_0 sign", options=["-", "+"])
b_0 = float(f"{b_0_sign}{b_0_mantissa}e{b_0_exponent}")

b_1_cols = st.columns(3)
with b_1_cols[0]:
b_1_mantissa = st.slider("b_1 mantissa", 0.0, 10.0, 0.0)
with b_1_cols[1]:
b_1_exponent = st.slider("b_1 exponent", -5, 5, 0)
with b_1_cols[2]:
b_1_sign = st.select_slider("b_1 sign", options=["-", "+"])
b_1 = float(f"{b_1_sign}{b_1_mantissa}e{b_1_exponent}")

b_2_cols = st.columns(3)
with b_2_cols[0]:
b_2_mantissa = st.slider("b_2 mantissa", 0.0, 10.0, 0.0)
with b_2_cols[1]:
b_2_exponent = st.slider("b_2 exponent", -5, 5, 0)
with b_2_cols[2]:
b_2_sign = st.select_slider("b_2 sign", options=["-", "+"])
b_2 = float(f"{b_2_sign}{b_2_mantissa}e{b_2_exponent}")

# a_0 = st.slider("a_0", -10.0, 10.0, 0.0)
# a_1 = st.slider("a_1", -10.0, 10.0, 0.1)
# a_2 = st.slider("a_2", -10.0, 10.0, 0.0)
# a_3 = st.slider("a_3", -10.0, 10.0, 0.0)
# a_4 = st.slider("a_4", -10.0, 10.0, 0.0)
# b_0 = st.slider("b_0", -10.0, 10.0, 0.0)
# b_1 = st.slider("b_1", -10.0, 10.0, 0.0)
# b_2 = st.slider("b_2", -10.0, 10.0, 0.0)

linear_tuple = (a_0, a_1, a_2, a_3, a_4)
nonlinear_tuple = (b_0, b_1, b_2)

linear_tuple = tuple([overall_scale * x for x in linear_tuple])
nonlinear_tuple = tuple([overall_scale * x for x in nonlinear_tuple])

if use_difficulty:
stepper = ex.RepeatedStepper(
ex.normalized.DifficultyGeneralNonlinearStepper(
1,
num_points,
linear_difficulties=tuple(x / num_substeps for x in linear_tuple),
nonlinear_difficulties=tuple(x / num_substeps for x in nonlinear_tuple),
),
num_substeps,
)
else:
stepper = ex.RepeatedStepper(
ex.normalized.NormlizedGeneralNonlinearStepper(
1,
num_points,
normalized_coefficients_linear=tuple(
x / num_substeps for x in linear_tuple
),
normalized_coefficients_nonlinear=tuple(
x / num_substeps for x in nonlinear_tuple
),
),
num_substeps,
)

ic_gen = ex.ic.RandomSineWaves1d(1, cutoff=num_modes_init, max_one=True)
u_0 = ic_gen(num_points, key=jax.random.PRNGKey(0))

trj = ex.rollout(stepper, num_steps, include_init=True)(u_0)

v_range = st.slider("Colorbar range", 0.1, 10.0, 1.0)

fig, ax = plt.subplots()
ax.imshow(
trj[:, 0, :].T,
aspect="auto",
vmin=-v_range,
vmax=v_range,
cmap="RdBu_r",
origin="lower",
)

st.write(f"Linear: {linear_tuple}")
st.write(f"Nonlinear: {nonlinear_tuple}")

st.pyplot(fig)

0 comments on commit 8f3d5e1

Please sign in to comment.