Skip to content

Commit

Permalink
Merge pull request #1 from Ceyron/add-docs
Browse files Browse the repository at this point in the history
Add docs
  • Loading branch information
Ceyron authored Jul 3, 2024
2 parents fc88c33 + 6ab2302 commit ec1726d
Show file tree
Hide file tree
Showing 16 changed files with 354 additions and 5 deletions.
34 changes: 34 additions & 0 deletions .github/workflows/build_docs.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
on:
push:
branches:
- main
permissions:
contents: write
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Configure Git Credentials
run: |
git config user.name github-actions[bot]
git config user.email 41898282+github-actions[bot]@users.noreply.github.com
- uses: actions/setup-python@v5
with:
python-version: 3.x
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV

- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install .
python -m pip install -r docs/requirements.txt
- uses: actions/cache@v4
with:
key: mkdocs-material-${{ env.cache_id }}
path: .cache
restore-keys: |
mkdocs-material-
- run: pip install mkdocs-material
- run: mkdocs gh-deploy --force
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

<h1 align="center">
<img src="img/pdequinox_logo.png" width="120">
<img src="docs/imgs/pdequinox_logo.png" width="120">
<br>
PDEquinox
<br>
Expand All @@ -21,7 +21,7 @@
</p>

<p align="center">
<img width=600 src="img/pdequinox_teaser.png">
<img width=600 src="docs/imgs/pdequinox_teaser.png">
</p>

## Installation
Expand Down Expand Up @@ -137,7 +137,7 @@ emulator. Hence, most components allow setting `boundary_mode` which can be
`"dirichlet"`, `"neumann"`, or `"periodic"`. This affects what is considered a
degree of freedom in the grid.

![](img/three_boundary_conditions.svg)
![](docs/imgs/three_boundary_conditions.svg)

Dirichlet boundaries fully eliminate degrees of freedom on the boundary.
Periodic boundaries only keep one end of the domain as a degree of freedom (This
Expand All @@ -150,7 +150,7 @@ Networks that allow for composability with the `PDEquinox` blocks.

### Sequential Constructor

![](img/sequential_net.svg)
![](docs/imgs/sequential_net.svg)

The squential network constructor is defined by:
* a lifting block $\mathcal{L}$
Expand All @@ -161,7 +161,7 @@ The squential network constructor is defined by:

### Hierarchical Constructor

![](img/hierarchical_net.svg)
![](docs/imgs/hierarchical_net.svg)

The hierarchical network constructor is defined by:
* a lifting block $\mathcal{L}$
Expand Down
20 changes: 20 additions & 0 deletions docs/examples/architecture_showcase.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# TODO"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
File renamed without changes.
File renamed without changes.
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
189 changes: 189 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Getting Started

## Installation

Clone the repository, navigate to the folder and install the package with pip:
```bash
pip install .
```

Requires Python 3.10+ and JAX 0.4.13+. 👉 [JAX install guide](https://jax.readthedocs.io/en/latest/installation.html).


## Quickstart

Train a UNet to become an emulator for the 1D Poisson equation.

```python
import jax
import jax.numpy as jnp
import equinox as eqx
import optax # `pip install optax`
import pdequinox as pdeqx
from tqdm import tqdm # `pip install tqdm`

force_fields, displacement_fields = pdeqx.sample_data.poisson_1d_dirichlet(
key=jax.random.PRNGKey(0)
)

force_fields_train = force_fields[:800]
force_fields_test = force_fields[800:]
displacement_fields_train = displacement_fields[:800]
displacement_fields_test = displacement_fields[800:]

unet = pdeqx.arch.ClassicUNet(1, 1, 1, key=jax.random.PRNGKey(1))

def loss_fn(model, x, y):
y_pref = jax.vmap(model)(x)
return jnp.mean((y_pref - y) ** 2)

opt = optax.adam(3e-4)
opt_state = opt.init(eqx.filter(unet, eqx.is_array))

@eqx.filter_jit
def update_fn(model, state, x, y):
loss, grad = eqx.filter_value_and_grad(loss_fn)(model, x, y)
updates, new_state = opt.update(grad, state, model)
new_model = eqx.apply_updates(model, updates)
return new_model, new_state, loss

loss_history = []
shuffle_key = jax.random.PRNGKey(151)
for epoch in tqdm(range(100)):
shuffle_key, subkey = jax.random.split(shuffle_key)

for batch in pdeqx.dataloader(
(force_fields_train, displacement_fields_train),
batch_size=32,
key=subkey
):
unet, opt_state, loss = update_fn(
unet,
opt_state,
*batch,
)
loss_history.append(loss)
```
## Background

Neural Emulators are networks learned to efficienty predict physical phenomena,
often associated with PDEs. In the simplest case this can be a linear advection
equation, all the way to more complicated Navier-Stokes cases. If we work on
Uniform Cartesian grids* (which this package assumes), one can borrow plenty of
architectures from image-to-image tasks in computer vision (e.g., for
segmentation). This includes:

* Standard Feedforward ConvNets
* Convolutional ResNets ([He et al.](https://arxiv.org/abs/1512.03385))
* U-Nets ([Ronneberger et al.](https://arxiv.org/abs/1505.04597))
* Dilated ResNets ([Yu et al.](https://arxiv.org/abs/1511.07122), [Stachenfeld et al.](https://arxiv.org/abs/2112.15275))
* Fourier Neural Operators ([Li et al.](https://arxiv.org/abs/2010.08895))

It is interesting to note that most of these architectures resemble classical
numerical methods or at least share similarities with them. For example,
ConvNets (or convolutions in general) are related to finite differences, while
U-Nets resemble multigrid methods. Fourier Neural Operators are related to
spectral methods. The difference is that the emulators' free parameters are
found based on a (data-driven) numerical optimization not a symbolic
manipulation of the differential equations.

(*) This means that we essentially have a pixel or voxel grid on which space is
discretized. Hence, the space can only be the scaled unit cube $\Omega = (0,
L)^D$

## Features

* Based on [JAX](https://github.com/google/jax):
* One of the best Automatic Differentiation engines (forward & reverse)
* Automatic vectorization
* Backend-agnostic code (run on CPU, GPU, and TPU)
* Based on [Equinox](https://github.com/patrick-kidger/equinox):
* Single-Batch by design
* Integration into the Equinox SciML ecosystem
* Agnostic to the spatial dimension (works for 1D, 2D, and 3D)
* Agnostic to the boundary condition (works for Dirichlet, Neumann, and periodic
BCs)
* Composability
* Tools to count parameters and assess receptive fields

## Boundary Conditions

This package assumes that the boundary condition is baked into the neural
emulator. Hence, most components allow setting `boundary_mode` which can be
`"dirichlet"`, `"neumann"`, or `"periodic"`. This affects what is considered a
degree of freedom in the grid.

![](imgs/three_boundary_conditions.svg)

Dirichlet boundaries fully eliminate degrees of freedom on the boundary.
Periodic boundaries only keep one end of the domain as a degree of freedom (This
package follows the convention that the left boundary is the degree of freedom). Neumann boundaries keep both ends as degrees of freedom.

## Constructors

There are two primary architectural constructors for Sequential and Hierarchical
Networks that allow for composability with the `PDEquinox` blocks.

### Sequential Constructor

![](imgs/sequential_net.svg)

The squential network constructor is defined by:
* a lifting block $\mathcal{L}$
* $N$ blocks $\left \{ \mathcal{B}_i \right\}_{i=1}^N$
* a projection block $\mathcal{P}$
* the hidden channels within the sequential processing
* the number of blocks $N$ (one can also supply a list of hidden channels if they shall be different between blocks)

### Hierarchical Constructor

![](imgs/hierarchical_net.svg)

The hierarchical network constructor is defined by:
* a lifting block $\mathcal{L}$
* The number of levels $D$ (i.e., the number of additional hierarchies). Setting $D = 0$ recovers the sequential processing.
* a list of $D$ blocks $\left \{ \mathcal{D}_i \right\}_{i=1}^D$ for
downsampling, i.e. mapping downwards to the lower hierarchy (oftentimes this
is that they halve the spatial axes while keeping the number of channels)
* a list of $D$ blocks $\left \{ \mathcal{B}_i^l \right\}_{i=1}^D$ for
processing in the left arc (oftentimes this changes the number of channels,
e.g. doubles it such that the combination of downsampling and left processing
halves the spatial resolution and doubles the feature count)
* a list of $D$ blocks $\left \{ \mathcal{U}_i \right\}_{i=1}^D$ for upsamping,
i.e., mapping upwards to the higher hierarchy (oftentimes this doubles the
spatial resolution; at the same time it halves the feature count such that we
can concatenate a skip connection)
* a list of $D$ blocks $\left \{ \mathcal{B}_i^r \right\}_{i=1}^D$ for
processing in the right arc (oftentimes this changes the number of channels,
e.g. halves it such that the combination of upsampling and right processing
doubles the spatial resolution and halves the feature count)
* a projection block $\mathcal{P}$
* the hidden channels within the hierarchical processing (if just an integer is
provided; this is assumed to be the number of hidden channels in the highest
hierarchy.)

### Beyond Architectural Constructors

For completion, `pdequinox.arch` also provides a `ConvNet` which is a simple
feed-forward convolutional network. It also provides `MLP` which is a dense
networks which also requires pre-defining the number of resolution points.

## Related

Similar packages that provide a collection of emulator architectures are
[PDEBench](https://github.com/pdebench/PDEBench) and
[PDEArena](https://github.com/pdearena/pdearena). With focus on Phyiscs-informed
Neural Networks and Neural Operators, there are also
[DeepXDE](https://github.com/lululxvi/deepxde) and [NVIDIA
Modulus](https://developer.nvidia.com/modulus).

## License

MIT, see [here](https://github.com/Ceyron/pdequinox/blob/main/LICENSE.txt)

---

> [fkoehler.site](https://fkoehler.site/) &nbsp;&middot;&nbsp;
> GitHub [@ceyron](https://github.com/ceyron) &nbsp;&middot;&nbsp;
> X [@felix_m_koehler](https://twitter.com/felix_m_koehler)
16 changes: 16 additions & 0 deletions docs/javascripts/mathjax.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
window.MathJax = {
tex: {
inlineMath: [["\\(", "\\)"]],
displayMath: [["\\[", "\\]"]],
processEscapes: true,
processEnvironments: true
},
options: {
ignoreHtmlClass: ".*|",
processHtmlClass: "arithmatex"
}
};

document$.subscribe(() => {
MathJax.typesetPromise()
})
6 changes: 6 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
mkdocs==1.6.0
black==24.4.2
mkdocs-material==9.5.27
mkdocstrings==0.25.1
mkdocstrings-python==1.10.5
mknotebooks==0.8.0
Empty file.
Loading

0 comments on commit ec1726d

Please sign in to comment.