From 378f17efbc728b54a3dbc404a8ff37a1cb87a4c7 Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Wed, 3 Jul 2024 08:35:31 +0200 Subject: [PATCH 1/3] Move examples and images into docs subfoldeR --- README.md | 10 +++++----- .../examples}/architecture_showcase.ipynb | 0 .../parameter_count_and_receptive_field.ipynb | 0 .../examples}/seed_parallel_training.ipynb | 0 .../examples}/train_unet_as_poisson_solver.ipynb | 0 {img => docs/imgs}/hierarchical_net.svg | 0 {img => docs/imgs}/pdequinox_logo.png | Bin {img => docs/imgs}/pdequinox_teaser.png | Bin {img => docs/imgs}/sequential_net.svg | 0 {img => docs/imgs}/three_boundary_conditions.svg | 0 10 files changed, 5 insertions(+), 5 deletions(-) rename {examples => docs/examples}/architecture_showcase.ipynb (100%) rename {examples => docs/examples}/parameter_count_and_receptive_field.ipynb (100%) rename {examples => docs/examples}/seed_parallel_training.ipynb (100%) rename {examples => docs/examples}/train_unet_as_poisson_solver.ipynb (100%) rename {img => docs/imgs}/hierarchical_net.svg (100%) rename {img => docs/imgs}/pdequinox_logo.png (100%) rename {img => docs/imgs}/pdequinox_teaser.png (100%) rename {img => docs/imgs}/sequential_net.svg (100%) rename {img => docs/imgs}/three_boundary_conditions.svg (100%) diff --git a/README.md b/README.md index 685294e..0ba15bc 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@

- +
PDEquinox
@@ -21,7 +21,7 @@

- +

## Installation @@ -137,7 +137,7 @@ emulator. Hence, most components allow setting `boundary_mode` which can be `"dirichlet"`, `"neumann"`, or `"periodic"`. This affects what is considered a degree of freedom in the grid. -![](img/three_boundary_conditions.svg) +![](docs/imgs/three_boundary_conditions.svg) Dirichlet boundaries fully eliminate degrees of freedom on the boundary. Periodic boundaries only keep one end of the domain as a degree of freedom (This @@ -150,7 +150,7 @@ Networks that allow for composability with the `PDEquinox` blocks. ### Sequential Constructor -![](img/sequential_net.svg) +![](docs/imgs/sequential_net.svg) The squential network constructor is defined by: * a lifting block $\mathcal{L}$ @@ -161,7 +161,7 @@ The squential network constructor is defined by: ### Hierarchical Constructor -![](img/hierarchical_net.svg) +![](docs/imgs/hierarchical_net.svg) The hierarchical network constructor is defined by: * a lifting block $\mathcal{L}$ diff --git a/examples/architecture_showcase.ipynb b/docs/examples/architecture_showcase.ipynb similarity index 100% rename from examples/architecture_showcase.ipynb rename to docs/examples/architecture_showcase.ipynb diff --git a/examples/parameter_count_and_receptive_field.ipynb b/docs/examples/parameter_count_and_receptive_field.ipynb similarity index 100% rename from examples/parameter_count_and_receptive_field.ipynb rename to docs/examples/parameter_count_and_receptive_field.ipynb diff --git a/examples/seed_parallel_training.ipynb b/docs/examples/seed_parallel_training.ipynb similarity index 100% rename from examples/seed_parallel_training.ipynb rename to docs/examples/seed_parallel_training.ipynb diff --git a/examples/train_unet_as_poisson_solver.ipynb b/docs/examples/train_unet_as_poisson_solver.ipynb similarity index 100% rename from examples/train_unet_as_poisson_solver.ipynb rename to docs/examples/train_unet_as_poisson_solver.ipynb diff --git a/img/hierarchical_net.svg b/docs/imgs/hierarchical_net.svg similarity index 100% rename from img/hierarchical_net.svg rename to docs/imgs/hierarchical_net.svg diff --git a/img/pdequinox_logo.png b/docs/imgs/pdequinox_logo.png similarity index 100% rename from img/pdequinox_logo.png rename to docs/imgs/pdequinox_logo.png diff --git a/img/pdequinox_teaser.png b/docs/imgs/pdequinox_teaser.png similarity index 100% rename from img/pdequinox_teaser.png rename to docs/imgs/pdequinox_teaser.png diff --git a/img/sequential_net.svg b/docs/imgs/sequential_net.svg similarity index 100% rename from img/sequential_net.svg rename to docs/imgs/sequential_net.svg diff --git a/img/three_boundary_conditions.svg b/docs/imgs/three_boundary_conditions.svg similarity index 100% rename from img/three_boundary_conditions.svg rename to docs/imgs/three_boundary_conditions.svg From 780eca71767839e053e0a8e9d3550eeafee41c11 Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Wed, 3 Jul 2024 08:40:13 +0200 Subject: [PATCH 2/3] Setup barebone mkdocs --- docs/examples/architecture_showcase.ipynb | 20 +++ docs/index.md | 189 ++++++++++++++++++++++ docs/javascripts/mathjax.js | 16 ++ docs/requirements.txt | 6 + mkdocs.yml | 84 ++++++++++ 5 files changed, 315 insertions(+) create mode 100644 docs/index.md create mode 100644 docs/javascripts/mathjax.js create mode 100644 docs/requirements.txt create mode 100644 mkdocs.yml diff --git a/docs/examples/architecture_showcase.ipynb b/docs/examples/architecture_showcase.ipynb index e69de29..48822f3 100644 --- a/docs/examples/architecture_showcase.ipynb +++ b/docs/examples/architecture_showcase.ipynb @@ -0,0 +1,20 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000..401e4f5 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,189 @@ +# Getting Started + +## Installation + +Clone the repository, navigate to the folder and install the package with pip: +```bash +pip install . +``` + +Requires Python 3.10+ and JAX 0.4.13+. ๐Ÿ‘‰ [JAX install guide](https://jax.readthedocs.io/en/latest/installation.html). + + +## Quickstart + +Train a UNet to become an emulator for the 1D Poisson equation. + +```python +import jax +import jax.numpy as jnp +import equinox as eqx +import optax # `pip install optax` +import pdequinox as pdeqx +from tqdm import tqdm # `pip install tqdm` + +force_fields, displacement_fields = pdeqx.sample_data.poisson_1d_dirichlet( + key=jax.random.PRNGKey(0) +) + +force_fields_train = force_fields[:800] +force_fields_test = force_fields[800:] +displacement_fields_train = displacement_fields[:800] +displacement_fields_test = displacement_fields[800:] + +unet = pdeqx.arch.ClassicUNet(1, 1, 1, key=jax.random.PRNGKey(1)) + +def loss_fn(model, x, y): + y_pref = jax.vmap(model)(x) + return jnp.mean((y_pref - y) ** 2) + +opt = optax.adam(3e-4) +opt_state = opt.init(eqx.filter(unet, eqx.is_array)) + +@eqx.filter_jit +def update_fn(model, state, x, y): + loss, grad = eqx.filter_value_and_grad(loss_fn)(model, x, y) + updates, new_state = opt.update(grad, state, model) + new_model = eqx.apply_updates(model, updates) + return new_model, new_state, loss + +loss_history = [] +shuffle_key = jax.random.PRNGKey(151) +for epoch in tqdm(range(100)): + shuffle_key, subkey = jax.random.split(shuffle_key) + + for batch in pdeqx.dataloader( + (force_fields_train, displacement_fields_train), + batch_size=32, + key=subkey + ): + unet, opt_state, loss = update_fn( + unet, + opt_state, + *batch, + ) + loss_history.append(loss) +``` +## Background + +Neural Emulators are networks learned to efficienty predict physical phenomena, +often associated with PDEs. In the simplest case this can be a linear advection +equation, all the way to more complicated Navier-Stokes cases. If we work on +Uniform Cartesian grids* (which this package assumes), one can borrow plenty of +architectures from image-to-image tasks in computer vision (e.g., for +segmentation). This includes: + +* Standard Feedforward ConvNets +* Convolutional ResNets ([He et al.](https://arxiv.org/abs/1512.03385)) +* U-Nets ([Ronneberger et al.](https://arxiv.org/abs/1505.04597)) +* Dilated ResNets ([Yu et al.](https://arxiv.org/abs/1511.07122), [Stachenfeld et al.](https://arxiv.org/abs/2112.15275)) +* Fourier Neural Operators ([Li et al.](https://arxiv.org/abs/2010.08895)) + +It is interesting to note that most of these architectures resemble classical +numerical methods or at least share similarities with them. For example, +ConvNets (or convolutions in general) are related to finite differences, while +U-Nets resemble multigrid methods. Fourier Neural Operators are related to +spectral methods. The difference is that the emulators' free parameters are +found based on a (data-driven) numerical optimization not a symbolic +manipulation of the differential equations. + +(*) This means that we essentially have a pixel or voxel grid on which space is +discretized. Hence, the space can only be the scaled unit cube $\Omega = (0, +L)^D$ + +## Features + +* Based on [JAX](https://github.com/google/jax): + * One of the best Automatic Differentiation engines (forward & reverse) + * Automatic vectorization + * Backend-agnostic code (run on CPU, GPU, and TPU) +* Based on [Equinox](https://github.com/patrick-kidger/equinox): + * Single-Batch by design + * Integration into the Equinox SciML ecosystem +* Agnostic to the spatial dimension (works for 1D, 2D, and 3D) +* Agnostic to the boundary condition (works for Dirichlet, Neumann, and periodic + BCs) +* Composability +* Tools to count parameters and assess receptive fields + +## Boundary Conditions + +This package assumes that the boundary condition is baked into the neural +emulator. Hence, most components allow setting `boundary_mode` which can be +`"dirichlet"`, `"neumann"`, or `"periodic"`. This affects what is considered a +degree of freedom in the grid. + +![](imgs/three_boundary_conditions.svg) + +Dirichlet boundaries fully eliminate degrees of freedom on the boundary. +Periodic boundaries only keep one end of the domain as a degree of freedom (This +package follows the convention that the left boundary is the degree of freedom). Neumann boundaries keep both ends as degrees of freedom. + +## Constructors + +There are two primary architectural constructors for Sequential and Hierarchical +Networks that allow for composability with the `PDEquinox` blocks. + +### Sequential Constructor + +![](imgs/sequential_net.svg) + +The squential network constructor is defined by: +* a lifting block $\mathcal{L}$ +* $N$ blocks $\left \{ \mathcal{B}_i \right\}_{i=1}^N$ +* a projection block $\mathcal{P}$ +* the hidden channels within the sequential processing +* the number of blocks $N$ (one can also supply a list of hidden channels if they shall be different between blocks) + +### Hierarchical Constructor + +![](imgs/hierarchical_net.svg) + +The hierarchical network constructor is defined by: +* a lifting block $\mathcal{L}$ +* The number of levels $D$ (i.e., the number of additional hierarchies). Setting $D = 0$ recovers the sequential processing. +* a list of $D$ blocks $\left \{ \mathcal{D}_i \right\}_{i=1}^D$ for + downsampling, i.e. mapping downwards to the lower hierarchy (oftentimes this + is that they halve the spatial axes while keeping the number of channels) +* a list of $D$ blocks $\left \{ \mathcal{B}_i^l \right\}_{i=1}^D$ for + processing in the left arc (oftentimes this changes the number of channels, + e.g. doubles it such that the combination of downsampling and left processing + halves the spatial resolution and doubles the feature count) +* a list of $D$ blocks $\left \{ \mathcal{U}_i \right\}_{i=1}^D$ for upsamping, + i.e., mapping upwards to the higher hierarchy (oftentimes this doubles the + spatial resolution; at the same time it halves the feature count such that we + can concatenate a skip connection) +* a list of $D$ blocks $\left \{ \mathcal{B}_i^r \right\}_{i=1}^D$ for + processing in the right arc (oftentimes this changes the number of channels, + e.g. halves it such that the combination of upsampling and right processing + doubles the spatial resolution and halves the feature count) +* a projection block $\mathcal{P}$ +* the hidden channels within the hierarchical processing (if just an integer is + provided; this is assumed to be the number of hidden channels in the highest + hierarchy.) + +### Beyond Architectural Constructors + +For completion, `pdequinox.arch` also provides a `ConvNet` which is a simple +feed-forward convolutional network. It also provides `MLP` which is a dense +networks which also requires pre-defining the number of resolution points. + +## Related + +Similar packages that provide a collection of emulator architectures are +[PDEBench](https://github.com/pdebench/PDEBench) and +[PDEArena](https://github.com/pdearena/pdearena). With focus on Phyiscs-informed +Neural Networks and Neural Operators, there are also +[DeepXDE](https://github.com/lululxvi/deepxde) and [NVIDIA +Modulus](https://developer.nvidia.com/modulus). + +## License + +MIT, see [here](https://github.com/Ceyron/pdequinox/blob/main/LICENSE.txt) + +--- + +> [fkoehler.site](https://fkoehler.site/)  ·  +> GitHub [@ceyron](https://github.com/ceyron)  ·  +> X [@felix_m_koehler](https://twitter.com/felix_m_koehler) + diff --git a/docs/javascripts/mathjax.js b/docs/javascripts/mathjax.js new file mode 100644 index 0000000..080801e --- /dev/null +++ b/docs/javascripts/mathjax.js @@ -0,0 +1,16 @@ +window.MathJax = { + tex: { + inlineMath: [["\\(", "\\)"]], + displayMath: [["\\[", "\\]"]], + processEscapes: true, + processEnvironments: true + }, + options: { + ignoreHtmlClass: ".*|", + processHtmlClass: "arithmatex" + } +}; + +document$.subscribe(() => { + MathJax.typesetPromise() +}) diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000..09fe41f --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,6 @@ +mkdocs==1.6.0 +black==24.4.2 +mkdocs-material==9.5.27 +mkdocstrings==0.25.1 +mkdocstrings-python==1.10.5 +mknotebooks==0.8.0 \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 0000000..795c7d8 --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,84 @@ +site_name: PDEquinox +site_description: Neural PDE Emulator Architectures in JAX & Equinox. +site_author: Felix Koehler +site_url: https://fkoehler.site/pdequinox + +repo_url: https://github.com/Ceyron/pdequinox +repo_name: Ceyron/pdequinox +edit_uri: "" # No edit button, as some of our pages are in /docs and some in /examples via symlink, so it's impossible for them all to be accurate + +theme: + name: material + features: + - navigation.sections # Sections are included in the navigation on the left. + - toc.integrate # Table of contents is integrated on the left; does not appear separately on the right. + - header.autohide # header disappears as you scroll + palette: + - scheme: default + primary: indigo + accent: amber + toggle: + icon: material/weather-night + name: Switch to dark mode + - scheme: slate + primary: indigo + accent: amber + toggle: + icon: material/weather-sunny + name: Switch to light mode + icon: + repo: fontawesome/brands/github # GitHub logo in top right + +extra: + social: + - icon: fontawesome/brands/twitter + link: https://twitter.com/felix_m_koehler + - icon: fontawesome/brands/github + link: https://github.com/ceyron + - icon: fontawesome/brands/youtube + link: https://youtube.com/@MachineLearningSimulation + + +strict: true # Don't allow warnings during the build process + +markdown_extensions: + - pymdownx.arithmatex: # Render LaTeX via MathJax + generic: true + - pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme. + - pymdownx.details # Allowing hidden expandable regions denoted by ??? + - pymdownx.snippets: # Include one Markdown file into another + base_path: docs + - admonition + - toc: + permalink: "ยค" # Adds a clickable permalink to each section heading + toc_depth: 4 + - pymdownx.arithmatex: + generic: true + +extra_javascript: + - javascripts/mathjax.js + - https://polyfill.io/v3/polyfill.min.js?features=es6 + - https://unpkg.com/mathjax@3/es5/tex-mml-chtml.js + +plugins: + - search # default search plugin; needs manually re-enabling when using any other plugins + - autorefs # Cross-links to headings + - mknotebooks # Jupyter notebooks + - mkdocstrings: + handlers: + python: + options: + inherited_members: true # Allow looking up inherited methods + show_root_heading: true # actually display anything at all... + show_root_full_path: true # display "diffrax.asdf" not just "asdf" + show_if_no_docstring: true + show_signature_annotations: true + separate_signature: true + show_source: true # don't include source code + members_order: source # order methods according to their order of definition in the source code, not alphabetical order + heading_level: 4 + show_symbol_type_heading: true + docstring_style: null + +nav: + - 'index.md' \ No newline at end of file From 6ab2302cd30290158a5cbb90daa28a10802a42e8 Mon Sep 17 00:00:00 2001 From: Felix Koehler Date: Wed, 3 Jul 2024 08:41:11 +0200 Subject: [PATCH 3/3] Add doc building workflow --- .github/workflows/build_docs.yml | 34 ++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 .github/workflows/build_docs.yml diff --git a/.github/workflows/build_docs.yml b/.github/workflows/build_docs.yml new file mode 100644 index 0000000..3b5dd29 --- /dev/null +++ b/.github/workflows/build_docs.yml @@ -0,0 +1,34 @@ +on: + push: + branches: + - main +permissions: + contents: write +jobs: + deploy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Configure Git Credentials + run: | + git config user.name github-actions[bot] + git config user.email 41898282+github-actions[bot]@users.noreply.github.com + - uses: actions/setup-python@v5 + with: + python-version: 3.x + - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install . + python -m pip install -r docs/requirements.txt + + - uses: actions/cache@v4 + with: + key: mkdocs-material-${{ env.cache_id }} + path: .cache + restore-keys: | + mkdocs-material- + - run: pip install mkdocs-material + - run: mkdocs gh-deploy --force