Skip to content

Commit

Permalink
Streamlit for Navier-Stokes dynamics
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed May 10, 2024
1 parent e9e9422 commit cc72a74
Showing 1 changed file with 292 additions and 0 deletions.
292 changes: 292 additions & 0 deletions navier_stokes_dynamics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
"""
This is a streamlit app.
"""
import base64
import dataclasses
import io
import json
import random
from dataclasses import dataclass
from typing import Optional

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import streamlit as st
import streamlit.components.v1 as components
from IPython.display import DisplayObject
from matplotlib.colors import Colormap, LinearSegmentedColormap, ListedColormap

import exponax as ex

st.set_page_config(layout="wide")
jax.config.update("jax_platform_name", "cpu")

with st.sidebar:
st.title("Exponax Dynamics Brewer")
dimension_type = st.select_slider(
"Number of Spatial Dimensions (ST=Spatio-Temporal plot)",
options=[
"2d",
"2d ST",
],
)
num_points = st.slider("Number of points", 16, 256, 48)
num_steps = st.slider("Number of steps", 1, 300, 50)
num_modes_init = st.slider("Number of modes in the initial condition", 1, 40, 5)
num_substeps = st.slider("Number of substeps", 1, 100, 1)

v_range = st.slider("Value range", 0.1, 10.0, 1.0)

st.divider()

# domain_extent = st.select_slider("Domain Extent", [1.0, 2 * jnp.pi])

dt_cols = st.columns(3)
with dt_cols[0]:
dt_mantissa = st.slider("dt mantissa", 0.0, 1.0, 0.1)
with dt_cols[1]:
dt_exponent = st.slider("dt exponent", -5, 5, 0)
dt_sign = "+"
dt = float(f"{dt_sign}{dt_mantissa}e{dt_exponent}")

diffusivity_cols = st.columns(3)
with diffusivity_cols[0]:
diffusivity_mantissa = st.slider("diffusivity mantissa", 0.0, 1.0, 0.1)
with diffusivity_cols[1]:
diffusivity_exponent = st.slider("diffusivity exponent", -5, 5, -2)
diffusivity_sign = "+"
diffusivity = float(
f"{diffusivity_sign}{diffusivity_mantissa}e{diffusivity_exponent}"
)

use_kolmogorov = st.toggle("Use Kolmogorov", value=False)

if use_kolmogorov:
domain_extent = 2 * jnp.pi

injection_mode = st.slider("Injection Mode", 1, 20, 4)
injection_scale = st.slider("Injection Scale", 0.1, 10.0, 1.0)

else:
domain_extent = 1.0

st.write(f"dt: {dt}")
st.write(f"diffusivity: {diffusivity}")
st.write(f"domain_extent: {domain_extent}")

# st.write(f"Linear: {linear_tuple}")
# st.write(f"Nonlinear: {nonlinear_tuple}")

if dimension_type in ["1d ST", "1d"]:
num_spatial_dims = 1
elif dimension_type in ["2d ST", "2d"]:
num_spatial_dims = 2
elif dimension_type == "3d":
num_spatial_dims = 3

if use_kolmogorov:
stepper = ex.RepeatedStepper(
ex.stepper.KolmogorovFlowVorticity(
num_spatial_dims,
domain_extent,
num_points,
dt / num_substeps,
diffusivity=diffusivity,
drag=-0.1,
injection_mode=injection_mode,
injection_scale=injection_scale,
),
num_substeps,
)
else:
stepper = ex.RepeatedStepper(
ex.stepper.NavierStokesVorticity(
num_spatial_dims,
domain_extent,
num_points,
dt / num_substeps,
diffusivity=diffusivity,
drag=0.0,
),
num_substeps,
)

if num_spatial_dims == 1:
ic_gen = ex.ic.RandomSineWaves1d(
num_spatial_dims, cutoff=num_modes_init, max_one=True
)
else:
ic_gen = ex.ic.RandomTruncatedFourierSeries(
num_spatial_dims, cutoff=num_modes_init, max_one=True
)
u_0 = ic_gen(num_points, key=jax.random.PRNGKey(0))

trj = ex.rollout(stepper, num_steps, include_init=True)(u_0)


TEMPLATE_IFRAME = """
<div>
<iframe id="{canvas_id}" src="https://keksboter.github.io/v4dv/?inline" width="{canvas_width}" height="{canvas_height}" frameBorder="0" sandbox="allow-same-origin allow-scripts"></iframe>
</div>
<script>
window.addEventListener(
"message",
(event) => {{
if (event.data !== "ready") {{
return;
}}
let data_decoded = Uint8Array.from(atob("{data_code}"), c => c.charCodeAt(0));
let cmap_decoded = Uint8Array.from(atob("{cmap_code}"), c => c.charCodeAt(0));
const iframe = document.getElementById("{canvas_id}");
if (iframe === null) return;
iframe.contentWindow.postMessage({{
volume: data_decoded,
cmap: cmap_decoded,
settings: {settings_json}
}},
"*");
}},
false,
);
</script>
"""


@dataclass(unsafe_hash=True)
class ViewerSettings:
width: int
height: int
background_color: tuple
show_colormap_editor: bool
show_volume_info: bool
vmin: Optional[float]
vmax: Optional[float]


def show(
data: np.ndarray,
colormap,
width: int = 800,
height: int = 600,
background_color=(0.0, 0.0, 0.0, 1.0),
show_colormap_editor=False,
show_volume_info=False,
vmin=None,
vmax=None,
):
return VolumeRenderer(
data,
colormap,
ViewerSettings(
width,
height,
background_color,
show_colormap_editor,
show_volume_info,
vmin,
vmax,
),
)


class VolumeRenderer(DisplayObject):
def __init__(self, data: np.ndarray, colormap, settings: ViewerSettings):
super(VolumeRenderer, self).__init__(
data={"volume": data, "cmap": colormap, "settings": settings}
)

def _repr_html_(self):
data = self.data["volume"]
colormap = self.data["cmap"]
settings = self.data["settings"]
buffer = io.BytesIO()
np.save(buffer, data.astype(np.float32))
data_code = base64.b64encode(buffer.getvalue())

buffer2 = io.BytesIO()
colormap_data = colormap(np.linspace(0, 1, 256)).astype(np.float32)
np.save(buffer2, colormap_data)
cmap_code = base64.b64encode(buffer2.getvalue())

canvas_id = f"v4dv_canvas_{str(random.randint(0,2**32))}"
html_code = TEMPLATE_IFRAME.format(
canvas_id=canvas_id,
data_code=data_code.decode("utf-8"),
cmap_code=cmap_code.decode("utf-8"),
canvas_width=settings.width,
canvas_height=settings.height,
settings_json=json.dumps(dataclasses.asdict(settings)),
)
return html_code

def __html__(self):
"""
This method exists to inform other HTML-using modules (e.g. Markupsafe,
htmltag, etc) that this object is HTML and does not need things like
special characters (<>&) escaped.
"""
return self._repr_html_()


def felix_cmap_hack(cmap: Colormap) -> Colormap:
"""changes the alpha channel of a colormap to be diverging (0->1, 0.5 > 0, 1->1)
Args:
cmap (Colormap): colormap
Returns:
Colormap: new colormap
"""
cmap = cmap.copy()
if isinstance(cmap, ListedColormap):
for i, a in enumerate(cmap.colors):
a.append(2 * abs(i / cmap.N - 0.5))
elif isinstance(cmap, LinearSegmentedColormap):
cmap._segmentdata["alpha"] = np.array(
[[0.0, 1.0, 1.0], [0.5, 0.0, 0.0], [1.0, 1.0, 1.0]]
)
else:
raise TypeError(
"cmap must be either a ListedColormap or a LinearSegmentedColormap"
)
return cmap


if dimension_type == "1d ST":
ex.viz.plot_spatio_temporal(trj, vlim=(-v_range, v_range))
fig = plt.gcf()
st.pyplot(fig)
elif dimension_type == "1d":
ani = ex.viz.animate_state_1d(trj, vlim=(-v_range, v_range))
components.html(ani.to_jshtml(), height=800)
elif dimension_type == "2d":
ani = ex.viz.animate_state_2d(trj, vlim=(-v_range, v_range))
components.html(ani.to_jshtml(), height=800)
elif dimension_type == "2d ST":
trj_rearranged = trj.transpose(1, 0, 2, 3)[None]
components.html(
show(
trj_rearranged,
plt.get_cmap("RdBu_r"),
width=1500,
height=800,
show_colormap_editor=True,
show_volume_info=True,
).__html__(),
height=800,
)
elif dimension_type == "3d":
components.html(
show(
trj,
plt.get_cmap("RdBu_r"),
width=1500,
height=800,
show_colormap_editor=True,
show_volume_info=True,
).__html__(),
height=800,
)

0 comments on commit cc72a74

Please sign in to comment.