diff --git a/navier_stokes_dynamics.py b/navier_stokes_dynamics.py new file mode 100644 index 0000000..5a5bff6 --- /dev/null +++ b/navier_stokes_dynamics.py @@ -0,0 +1,292 @@ +""" +This is a streamlit app. +""" +import base64 +import dataclasses +import io +import json +import random +from dataclasses import dataclass +from typing import Optional + +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +import streamlit as st +import streamlit.components.v1 as components +from IPython.display import DisplayObject +from matplotlib.colors import Colormap, LinearSegmentedColormap, ListedColormap + +import exponax as ex + +st.set_page_config(layout="wide") +jax.config.update("jax_platform_name", "cpu") + +with st.sidebar: + st.title("Exponax Dynamics Brewer") + dimension_type = st.select_slider( + "Number of Spatial Dimensions (ST=Spatio-Temporal plot)", + options=[ + "2d", + "2d ST", + ], + ) + num_points = st.slider("Number of points", 16, 256, 48) + num_steps = st.slider("Number of steps", 1, 300, 50) + num_modes_init = st.slider("Number of modes in the initial condition", 1, 40, 5) + num_substeps = st.slider("Number of substeps", 1, 100, 1) + + v_range = st.slider("Value range", 0.1, 10.0, 1.0) + + st.divider() + + # domain_extent = st.select_slider("Domain Extent", [1.0, 2 * jnp.pi]) + + dt_cols = st.columns(3) + with dt_cols[0]: + dt_mantissa = st.slider("dt mantissa", 0.0, 1.0, 0.1) + with dt_cols[1]: + dt_exponent = st.slider("dt exponent", -5, 5, 0) + dt_sign = "+" + dt = float(f"{dt_sign}{dt_mantissa}e{dt_exponent}") + + diffusivity_cols = st.columns(3) + with diffusivity_cols[0]: + diffusivity_mantissa = st.slider("diffusivity mantissa", 0.0, 1.0, 0.1) + with diffusivity_cols[1]: + diffusivity_exponent = st.slider("diffusivity exponent", -5, 5, -2) + diffusivity_sign = "+" + diffusivity = float( + f"{diffusivity_sign}{diffusivity_mantissa}e{diffusivity_exponent}" + ) + + use_kolmogorov = st.toggle("Use Kolmogorov", value=False) + + if use_kolmogorov: + domain_extent = 2 * jnp.pi + + injection_mode = st.slider("Injection Mode", 1, 20, 4) + injection_scale = st.slider("Injection Scale", 0.1, 10.0, 1.0) + + else: + domain_extent = 1.0 + + st.write(f"dt: {dt}") + st.write(f"diffusivity: {diffusivity}") + st.write(f"domain_extent: {domain_extent}") + + # st.write(f"Linear: {linear_tuple}") + # st.write(f"Nonlinear: {nonlinear_tuple}") + +if dimension_type in ["1d ST", "1d"]: + num_spatial_dims = 1 +elif dimension_type in ["2d ST", "2d"]: + num_spatial_dims = 2 +elif dimension_type == "3d": + num_spatial_dims = 3 + +if use_kolmogorov: + stepper = ex.RepeatedStepper( + ex.stepper.KolmogorovFlowVorticity( + num_spatial_dims, + domain_extent, + num_points, + dt / num_substeps, + diffusivity=diffusivity, + drag=-0.1, + injection_mode=injection_mode, + injection_scale=injection_scale, + ), + num_substeps, + ) +else: + stepper = ex.RepeatedStepper( + ex.stepper.NavierStokesVorticity( + num_spatial_dims, + domain_extent, + num_points, + dt / num_substeps, + diffusivity=diffusivity, + drag=0.0, + ), + num_substeps, + ) + +if num_spatial_dims == 1: + ic_gen = ex.ic.RandomSineWaves1d( + num_spatial_dims, cutoff=num_modes_init, max_one=True + ) +else: + ic_gen = ex.ic.RandomTruncatedFourierSeries( + num_spatial_dims, cutoff=num_modes_init, max_one=True + ) +u_0 = ic_gen(num_points, key=jax.random.PRNGKey(0)) + +trj = ex.rollout(stepper, num_steps, include_init=True)(u_0) + + +TEMPLATE_IFRAME = """ +
+ +
+ +""" + + +@dataclass(unsafe_hash=True) +class ViewerSettings: + width: int + height: int + background_color: tuple + show_colormap_editor: bool + show_volume_info: bool + vmin: Optional[float] + vmax: Optional[float] + + +def show( + data: np.ndarray, + colormap, + width: int = 800, + height: int = 600, + background_color=(0.0, 0.0, 0.0, 1.0), + show_colormap_editor=False, + show_volume_info=False, + vmin=None, + vmax=None, +): + return VolumeRenderer( + data, + colormap, + ViewerSettings( + width, + height, + background_color, + show_colormap_editor, + show_volume_info, + vmin, + vmax, + ), + ) + + +class VolumeRenderer(DisplayObject): + def __init__(self, data: np.ndarray, colormap, settings: ViewerSettings): + super(VolumeRenderer, self).__init__( + data={"volume": data, "cmap": colormap, "settings": settings} + ) + + def _repr_html_(self): + data = self.data["volume"] + colormap = self.data["cmap"] + settings = self.data["settings"] + buffer = io.BytesIO() + np.save(buffer, data.astype(np.float32)) + data_code = base64.b64encode(buffer.getvalue()) + + buffer2 = io.BytesIO() + colormap_data = colormap(np.linspace(0, 1, 256)).astype(np.float32) + np.save(buffer2, colormap_data) + cmap_code = base64.b64encode(buffer2.getvalue()) + + canvas_id = f"v4dv_canvas_{str(random.randint(0,2**32))}" + html_code = TEMPLATE_IFRAME.format( + canvas_id=canvas_id, + data_code=data_code.decode("utf-8"), + cmap_code=cmap_code.decode("utf-8"), + canvas_width=settings.width, + canvas_height=settings.height, + settings_json=json.dumps(dataclasses.asdict(settings)), + ) + return html_code + + def __html__(self): + """ + This method exists to inform other HTML-using modules (e.g. Markupsafe, + htmltag, etc) that this object is HTML and does not need things like + special characters (<>&) escaped. + """ + return self._repr_html_() + + +def felix_cmap_hack(cmap: Colormap) -> Colormap: + """changes the alpha channel of a colormap to be diverging (0->1, 0.5 > 0, 1->1) + + Args: + cmap (Colormap): colormap + + Returns: + Colormap: new colormap + """ + cmap = cmap.copy() + if isinstance(cmap, ListedColormap): + for i, a in enumerate(cmap.colors): + a.append(2 * abs(i / cmap.N - 0.5)) + elif isinstance(cmap, LinearSegmentedColormap): + cmap._segmentdata["alpha"] = np.array( + [[0.0, 1.0, 1.0], [0.5, 0.0, 0.0], [1.0, 1.0, 1.0]] + ) + else: + raise TypeError( + "cmap must be either a ListedColormap or a LinearSegmentedColormap" + ) + return cmap + + +if dimension_type == "1d ST": + ex.viz.plot_spatio_temporal(trj, vlim=(-v_range, v_range)) + fig = plt.gcf() + st.pyplot(fig) +elif dimension_type == "1d": + ani = ex.viz.animate_state_1d(trj, vlim=(-v_range, v_range)) + components.html(ani.to_jshtml(), height=800) +elif dimension_type == "2d": + ani = ex.viz.animate_state_2d(trj, vlim=(-v_range, v_range)) + components.html(ani.to_jshtml(), height=800) +elif dimension_type == "2d ST": + trj_rearranged = trj.transpose(1, 0, 2, 3)[None] + components.html( + show( + trj_rearranged, + plt.get_cmap("RdBu_r"), + width=1500, + height=800, + show_colormap_editor=True, + show_volume_info=True, + ).__html__(), + height=800, + ) +elif dimension_type == "3d": + components.html( + show( + trj, + plt.get_cmap("RdBu_r"), + width=1500, + height=800, + show_colormap_editor=True, + show_volume_info=True, + ).__html__(), + height=800, + )