diff --git a/examples/correction_learning__differentiable_physics.ipynb b/examples/correction_learning__differentiable_physics.ipynb new file mode 100644 index 0000000..1aa1b98 --- /dev/null +++ b/examples/correction_learning__differentiable_physics.ipynb @@ -0,0 +1,1039 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Correction Learning and Differentiable Physics\n", + "\n", + "In this notebook, a neural emulator $f_\\theta$ is trained to mimic the simulator\n", + "for the 1d advection equation $\\mathcal{P}$. However, its receptive field is not\n", + "sufficient to to propagate the state at the given CFL number or difficulty\n", + "$\\gamma_1$. To assist the neural emulator, a *corrected stepper* is built that\n", + "contains a defective numerical scheme $\\tilde{\\mathcal{P}}$ and the network's\n", + "task is to correct its output. Here, this defective scheme is only aware of\n", + "\"half the difficulty\" $\\tilde{\\gamma}_1 = \\gamma_1/2$. However, if the defective\n", + "scheme already takes care of half the difficulty, the neural network only needs\n", + "half the receptive field to correct it (in the sequential setup).\n", + "\n", + "We will train the corrected stepper both in a one-step mode which does not\n", + "require the defective/coarse physics to be differentiable, and in a supervised\n", + "rollout training requiring differentiable physics." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "import equinox as eqx\n", + "import matplotlib.pyplot as plt\n", + "from typing import Callable\n", + "import optax\n", + "from tqdm.autonotebook import tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import exponax as ex" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import HTML" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will work in 1d, and choose a relatively coarse discretization. The problem\n", + "should be more or less agnostic to the spatial resolution as long all modes of\n", + "the initial conditions are resolved properly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "NUM_POINTS = 48\n", + "GAMMA_1 = 6.5\n", + "\n", + "NUM_TRAIN_SAMPLES = 40\n", + "TRAIN_DATA_SEED = 773\n", + "TRAIN_TEMPORAL_HORIZON = 50\n", + "\n", + "NUM_TEST_SAMPLES = 30\n", + "TEST_DATA_SEED = 774\n", + "TEST_TEMPORAL_HORIZON = 200" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fine_stepper = ex.normalized.DiffultyLinearStepperSimple(\n", + " 1, NUM_POINTS, difficulty=-GAMMA_1, order=1\n", + ")\n", + "coarse_stepper = ex.normalized.DiffultyLinearStepperSimple(\n", + " 1, NUM_POINTS, difficulty=-GAMMA_1 / 2, order=1\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The distribution of initial conditions is a truncated Fourier series with up to\n", + "5 modes. We limit its amplitude to 1 to ease plotting." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ic_distribution = ex.ic.RandomTruncatedFourierSeries(1, max_one=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's create the set of initial conditions out of it" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_ic_set = ex.build_ic_set(\n", + " ic_distribution,\n", + " num_points=NUM_POINTS,\n", + " num_samples=NUM_TRAIN_SAMPLES,\n", + " key=jax.random.PRNGKey(TRAIN_DATA_SEED),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Visualizing the initial states, we indeed see that those are a combination of\n", + "Fourier modes, each with the highest absolute value of 1." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ex.viz.plot_state_1d(train_ic_set[:3, 0, :]);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Rolling them out produces the train data set." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_trj_set = jax.vmap(\n", + " ex.rollout(fine_stepper, TRAIN_TEMPORAL_HORIZON, include_init=True)\n", + ")(train_ic_set)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# (NUM_TRAIN_SAMPLES, TRAIN_TEMPORAL_HORIZON + 1, 1, NUM_POINTS)\n", + "train_trj_set.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's quickly do the same for the test set to have it at hand" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_ic_set = ex.build_ic_set(\n", + " ic_distribution,\n", + " num_points=NUM_POINTS,\n", + " num_samples=NUM_TEST_SAMPLES,\n", + " key=jax.random.PRNGKey(TEST_DATA_SEED),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_trj_set = jax.vmap(\n", + " ex.rollout(fine_stepper, TEST_TEMPORAL_HORIZON, include_init=True)\n", + ")(test_ic_set)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# (NUM_TEST_SAMPLES, TEST_TEMPORAL_HORIZON + 1, 1, NUM_POINTS)\n", + "test_trj_set.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's visualize a couple of the train trajectories. Since we have a high CFL\n", + "number and such a low spatial resolution the spatio-temporal plot looks glitchy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ex.viz.plot_spatio_temporal_facet(\n", + " train_trj_set, facet_over_channels=False, figsize=(12, 6)\n", + ");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "An animation more clearly shows that it is just advection happening. (It is just\n", + "very fast!!! ;)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "HTML(ex.viz.animate_state_1d(train_trj_set[:, :3, 0, :]).to_jshtml())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training the neural emulator naively as predictor\n", + "\n", + "Let's build a simple convolution ResNet with periodic padding" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class ResBlockPeriodic1d(eqx.Module):\n", + " conv_1: eqx.nn.Conv1d\n", + " conv_2: eqx.nn.Conv1d\n", + " activation: Callable\n", + "\n", + " def __init__(\n", + " self,\n", + " channels: int,\n", + " activation: Callable,\n", + " *,\n", + " key,\n", + " ):\n", + " c_1_key, c_2_key = jax.random.split(key)\n", + " self.conv_1 = eqx.nn.Conv1d(channels, channels, kernel_size=3, key=c_1_key)\n", + " self.conv_2 = eqx.nn.Conv1d(channels, channels, kernel_size=3, key=c_2_key)\n", + " self.activation = activation\n", + "\n", + " def periodic_padding(\n", + " self,\n", + " x,\n", + " ):\n", + " # padding over channels space\n", + " return jnp.pad(x, ((0, 0), (1, 1)), mode=\"wrap\")\n", + "\n", + " def __call__(self, x):\n", + " x_skip = x\n", + " x = self.periodic_padding(x)\n", + " x = self.conv_1(x)\n", + " x = self.activation(x)\n", + " x = self.periodic_padding(x)\n", + " x = self.conv_2(x)\n", + " x = x + x_skip\n", + " x = self.activation(x)\n", + " return x\n", + "\n", + "\n", + "class ResNetPeriodic1d(eqx.Module):\n", + " lifting: eqx.nn.Conv1d\n", + " blocks: tuple[ResBlockPeriodic1d]\n", + " projection: eqx.nn.Conv1d\n", + "\n", + " def __init__(\n", + " self,\n", + " hidden_channels: int,\n", + " num_blocks: int,\n", + " activation: Callable,\n", + " *,\n", + " key,\n", + " ):\n", + " lifting_key, *block_keys, projection_key = jax.random.split(key, 2 + num_blocks)\n", + " self.lifting = eqx.nn.Conv1d(1, hidden_channels, kernel_size=1, key=lifting_key)\n", + " self.blocks = tuple(\n", + " ResBlockPeriodic1d(hidden_channels, activation=activation, key=block_key)\n", + " for block_key in block_keys\n", + " )\n", + " self.projection = eqx.nn.Conv1d(\n", + " hidden_channels, 1, kernel_size=1, key=projection_key\n", + " )\n", + "\n", + " def __call__(self, x):\n", + " x = self.lifting(x)\n", + " for block in self.blocks:\n", + " x = block(x)\n", + " x = self.projection(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def dataloader(\n", + " data,\n", + " *,\n", + " batch_size: int,\n", + " key,\n", + "):\n", + " n_samples = data.shape[0]\n", + "\n", + " n_batches = int(jnp.ceil(n_samples / batch_size))\n", + "\n", + " permutation = jax.random.permutation(key, n_samples)\n", + "\n", + " for batch_id in range(n_batches):\n", + " start = batch_id * batch_size\n", + " end = min((batch_id + 1) * batch_size, n_samples)\n", + "\n", + " batch_indices = permutation[start:end]\n", + "\n", + " sub_data = data[batch_indices]\n", + "\n", + " yield sub_data\n", + "\n", + "\n", + "def cycling_dataloader(\n", + " data,\n", + " *,\n", + " batch_size: int,\n", + " num_steps: int,\n", + " key,\n", + " return_info: bool = False,\n", + "):\n", + " epoch_id = 0\n", + " total_step_id = 0\n", + "\n", + " while True:\n", + " key, subkey = jax.random.split(key)\n", + "\n", + " for batch_id, sub_data in enumerate(\n", + " dataloader(data, batch_size=batch_size, key=subkey)\n", + " ):\n", + " if total_step_id == num_steps:\n", + " return\n", + "\n", + " if return_info:\n", + " yield sub_data, epoch_id, batch_id\n", + " else:\n", + " yield sub_data\n", + "\n", + " total_step_id += 1\n", + "\n", + " epoch_id += 1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's train a ResNet with 2 blocks (of 2 3x convolutions each) on the train\n", + "dataset. We will train with one-step supervised learning, so we first have to\n", + "substack the data to have the input and output pairs by creating windows of size\n", + "two." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "LEARNING_RATE = 3e-4\n", + "OPTIMIZER = optax.adam(LEARNING_RATE)\n", + "NUM_STEPS = 20_000\n", + "SHUFFLE_KEY = jax.random.PRNGKey(99)\n", + "BATCH_SIZE = 16\n", + "\n", + "\n", + "one_substacked_train_trj_set = jax.vmap(\n", + " ex.stack_sub_trajectories,\n", + " in_axes=(0, None),\n", + ")(train_trj_set, 2)\n", + "# Merge the two batch axes\n", + "one_train_dataset = jnp.concatenate(one_substacked_train_trj_set)\n", + "\n", + "prediction_neural_emulator = ResNetPeriodic1d(\n", + " 32, 2, jax.nn.relu, key=jax.random.PRNGKey(0)\n", + ")\n", + "\n", + "opt_state = OPTIMIZER.init(eqx.filter(prediction_neural_emulator, eqx.is_array))\n", + "\n", + "\n", + "def one_step_loss_fn(model, batch):\n", + " x, y = batch[:, 0], batch[:, 1]\n", + " y_hat = jax.vmap(model)(x)\n", + " return jnp.mean((y - y_hat) ** 2)\n", + "\n", + "\n", + "@eqx.filter_jit\n", + "def step_fn(model, state, batch):\n", + " loss, grads = eqx.filter_value_and_grad(one_step_loss_fn)(model, batch)\n", + " updates, new_opt_state = OPTIMIZER.update(grads, state, model)\n", + " new_model = eqx.apply_updates(model, updates)\n", + " return new_model, new_opt_state, loss\n", + "\n", + "\n", + "shuffle_key = SHUFFLE_KEY\n", + "train_loss_history = []\n", + "\n", + "p_meter = tqdm(total=NUM_STEPS)\n", + "\n", + "for batch in cycling_dataloader(\n", + " one_train_dataset, batch_size=BATCH_SIZE, num_steps=NUM_STEPS, key=shuffle_key\n", + "):\n", + " prediction_neural_emulator, opt_state, loss = step_fn(\n", + " prediction_neural_emulator, opt_state, batch\n", + " )\n", + " train_loss_history.append(loss)\n", + " p_meter.update(1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.semilogy(train_loss_history)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's use the final network state to make a prediction trajectory on all the\n", + "test initial states." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prediction_trj = jax.vmap(\n", + " ex.rollout(prediction_neural_emulator, TEST_TEMPORAL_HORIZON, include_init=True)\n", + ")(test_ic_set)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And compute the mean_nRMSE rollout" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mean_nRMSE_trj = jax.vmap(ex.metrics.mean_nRMSE, in_axes=1)(\n", + " prediction_trj, test_trj_set\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We see that the rollout already diverges after two time steps. This is caused by\n", + "an insufficient receptive field." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.plot(jnp.arange(0, 200 + 1), mean_nRMSE_trj)\n", + "plt.ylim(-0.05, 1.05)\n", + "plt.xlabel(\"Time Step\")\n", + "plt.ylabel(\"Mean nRMSE\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Just as a baseline, let us train a resnet with more blocks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "LEARNING_RATE = 3e-4\n", + "OPTIMIZER = optax.adam(LEARNING_RATE)\n", + "NUM_STEPS = 20_000\n", + "SHUFFLE_KEY = jax.random.PRNGKey(99)\n", + "BATCH_SIZE = 16\n", + "\n", + "\n", + "one_substacked_train_trj_set = jax.vmap(\n", + " ex.stack_sub_trajectories,\n", + " in_axes=(0, None),\n", + ")(train_trj_set, 2)\n", + "# Merge the two batch axes\n", + "one_train_dataset = jnp.concatenate(one_substacked_train_trj_set)\n", + "\n", + "# Below we increased the number of blocks from 2 to 4\n", + "prediction_neural_emulator_more_reception = ResNetPeriodic1d(\n", + " 32, 4, jax.nn.relu, key=jax.random.PRNGKey(0)\n", + ")\n", + "\n", + "opt_state = OPTIMIZER.init(\n", + " eqx.filter(prediction_neural_emulator_more_reception, eqx.is_array)\n", + ")\n", + "\n", + "\n", + "def one_step_loss_fn(model, batch):\n", + " x, y = batch[:, 0], batch[:, 1]\n", + " y_hat = jax.vmap(model)(x)\n", + " return jnp.mean((y - y_hat) ** 2)\n", + "\n", + "\n", + "@eqx.filter_jit\n", + "def step_fn(model, state, batch):\n", + " loss, grads = eqx.filter_value_and_grad(one_step_loss_fn)(model, batch)\n", + " updates, new_opt_state = OPTIMIZER.update(grads, state, model)\n", + " new_model = eqx.apply_updates(model, updates)\n", + " return new_model, new_opt_state, loss\n", + "\n", + "\n", + "shuffle_key = SHUFFLE_KEY\n", + "train_loss_history = []\n", + "\n", + "p_meter = tqdm(total=NUM_STEPS)\n", + "\n", + "for batch in cycling_dataloader(\n", + " one_train_dataset, batch_size=BATCH_SIZE, num_steps=NUM_STEPS, key=shuffle_key\n", + "):\n", + " prediction_neural_emulator_more_reception, opt_state, loss = step_fn(\n", + " prediction_neural_emulator_more_reception, opt_state, batch\n", + " )\n", + " train_loss_history.append(loss)\n", + " p_meter.update(1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice the distinctly differently loss curve!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.semilogy(train_loss_history)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's again create the rollout and compute the mean_nRMSE" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prediction_trj = jax.vmap(\n", + " ex.rollout(\n", + " prediction_neural_emulator_more_reception,\n", + " TEST_TEMPORAL_HORIZON,\n", + " include_init=True,\n", + " )\n", + ")(test_ic_set)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mean_nRMSE_trj = jax.vmap(ex.metrics.mean_nRMSE, in_axes=1)(\n", + " prediction_trj, test_trj_set\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "On a first glance, the rollout looks equally bad (but note that the limit of the x axis is different)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.plot(jnp.arange(0, 200 + 1), mean_nRMSE_trj)\n", + "plt.ylim(-0.05, 1.05)\n", + "plt.xlabel(\"Time Step\")\n", + "plt.ylabel(\"Mean nRMSE\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Zooming in a bit, we see that the new predictor does not immediately explode.\n", + "It's performance is still not good, but at least better than before.\n", + "\n", + "Feel free to play around with the number of blocks and other parameters!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.plot(jnp.arange(0, 200 + 1), mean_nRMSE_trj)\n", + "plt.ylim(-0.05, 1.05)\n", + "plt.xlim(0, 25)\n", + "plt.xlabel(\"Time Step\")\n", + "plt.ylabel(\"Mean nRMSE\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Correction Learning" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's start by creating an `Equinox` wrapper that enables us to consider the\n", + "sequential corrector as one deep learning module." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class SequentialCorrector(eqx.Module):\n", + " coarse_predictor: eqx.Module\n", + " neural_corrector: eqx.Module\n", + "\n", + " def __call__(\n", + " self,\n", + " x,\n", + " ):\n", + " # We have to detach the coarse predictor to **not** have its parameters changed\n", + " coarse_predictor_detached = jax.lax.stop_gradient(self.coarse_predictor)\n", + "\n", + " coarse_prediction = coarse_predictor_detached(x)\n", + " corrected_prediction = self.neural_corrector(coarse_prediction)\n", + "\n", + " return corrected_prediction" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's use again the ResNet with two blocks; now as a corrector network and train\n", + "the composite module similarly to before." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "LEARNING_RATE = 3e-4\n", + "OPTIMIZER = optax.adam(LEARNING_RATE)\n", + "NUM_STEPS = 20_000\n", + "SHUFFLE_KEY = jax.random.PRNGKey(99)\n", + "BATCH_SIZE = 16\n", + "\n", + "\n", + "one_substacked_train_trj_set = jax.vmap(\n", + " ex.stack_sub_trajectories,\n", + " in_axes=(0, None),\n", + ")(train_trj_set, 2)\n", + "# Merge the two batch axes\n", + "one_train_dataset = jnp.concatenate(one_substacked_train_trj_set)\n", + "\n", + "# Again using only two blocks\n", + "correcter_network = ResNetPeriodic1d(32, 2, jax.nn.relu, key=jax.random.PRNGKey(0))\n", + "corrected_stepper = SequentialCorrector(coarse_stepper, correcter_network)\n", + "\n", + "opt_state = OPTIMIZER.init(eqx.filter(corrected_stepper, eqx.is_array))\n", + "\n", + "\n", + "def one_step_loss_fn(model, batch):\n", + " x, y = batch[:, 0], batch[:, 1]\n", + " y_hat = jax.vmap(model)(x)\n", + " return jnp.mean((y - y_hat) ** 2)\n", + "\n", + "\n", + "@eqx.filter_jit\n", + "def step_fn(model, state, batch):\n", + " loss, grads = eqx.filter_value_and_grad(one_step_loss_fn)(model, batch)\n", + " updates, new_opt_state = OPTIMIZER.update(grads, state, model)\n", + " new_model = eqx.apply_updates(model, updates)\n", + " return new_model, new_opt_state, loss\n", + "\n", + "\n", + "shuffle_key = SHUFFLE_KEY\n", + "train_loss_history = []\n", + "\n", + "p_meter = tqdm(total=NUM_STEPS)\n", + "\n", + "for batch in cycling_dataloader(\n", + " one_train_dataset, batch_size=BATCH_SIZE, num_steps=NUM_STEPS, key=shuffle_key\n", + "):\n", + " corrected_stepper, opt_state, loss = step_fn(corrected_stepper, opt_state, batch)\n", + " train_loss_history.append(loss)\n", + " p_meter.update(1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that the training loss history now looks similar to the ones we got for the\n", + "predictor with sufficient receptive field." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.semilogy(train_loss_history)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prediction_trj = jax.vmap(\n", + " ex.rollout(corrected_stepper, TEST_TEMPORAL_HORIZON, include_init=True)\n", + ")(test_ic_set)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mean_nRMSE_trj = jax.vmap(ex.metrics.mean_nRMSE, in_axes=1)(\n", + " prediction_trj, test_trj_set\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.plot(jnp.arange(0, 200 + 1), mean_nRMSE_trj)\n", + "plt.ylim(-0.05, 1.05)\n", + "plt.xlabel(\"Time Step\")\n", + "plt.ylabel(\"Mean nRMSE\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "See that we are now even better than the predictor with more blocks!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.plot(jnp.arange(0, 200 + 1), mean_nRMSE_trj)\n", + "plt.ylim(-0.05, 1.05)\n", + "plt.xlabel(\"Time Step\")\n", + "plt.ylabel(\"Mean nRMSE\")\n", + "plt.xlim(0, 25)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Rollout training\n", + "\n", + "It turns out that corrected stepper greatly benefit from rollout training! Let's\n", + "do a training with five autoregressive supervised steps. This also requires\n", + "setting **windows of length 6**. This will also slightly reduce the number of\n", + "samples available per epoch. However, we keep that number of update steps fixed,\n", + "so we automatically compensate for this.\n", + "\n", + "The training will be slightly longer because of the additional computation per\n", + "update step." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "LEARNING_RATE = 3e-4\n", + "OPTIMIZER = optax.adam(LEARNING_RATE)\n", + "NUM_STEPS = 20_000\n", + "SHUFFLE_KEY = jax.random.PRNGKey(99)\n", + "BATCH_SIZE = 16\n", + "\n", + "\n", + "one_substacked_train_trj_set = jax.vmap(\n", + " ex.stack_sub_trajectories,\n", + " in_axes=(0, None),\n", + ")(\n", + " train_trj_set, 4\n", + ") # ! HERE WE USE 4\n", + "# Merge the two batch axes\n", + "one_train_dataset = jnp.concatenate(one_substacked_train_trj_set)\n", + "\n", + "# Again using only two blocks\n", + "correcter_network = ResNetPeriodic1d(32, 2, jax.nn.relu, key=jax.random.PRNGKey(0))\n", + "corrected_stepper_rollout_trained = SequentialCorrector(\n", + " coarse_stepper, correcter_network\n", + ")\n", + "\n", + "opt_state = OPTIMIZER.init(eqx.filter(corrected_stepper_rollout_trained, eqx.is_array))\n", + "\n", + "\n", + "def one_step_loss_fn(model, batch):\n", + " ic, ref_trj = batch[:, 0], batch[:, 1:]\n", + " pred = ic\n", + " loss = 0.0\n", + " for i in range(\n", + " 3\n", + " ): # ! HERE WE USE 3 for three steps autoregressive rollout during training\n", + " pred = jax.vmap(model)(pred)\n", + " ref = ref_trj[:, i]\n", + " loss += jnp.mean((ref - pred) ** 2)\n", + "\n", + " return loss\n", + "\n", + "\n", + "@eqx.filter_jit\n", + "def step_fn(model, state, batch):\n", + " loss, grads = eqx.filter_value_and_grad(one_step_loss_fn)(model, batch)\n", + " updates, new_opt_state = OPTIMIZER.update(grads, state, model)\n", + " new_model = eqx.apply_updates(model, updates)\n", + " return new_model, new_opt_state, loss\n", + "\n", + "\n", + "shuffle_key = SHUFFLE_KEY\n", + "train_loss_history = []\n", + "\n", + "p_meter = tqdm(total=NUM_STEPS)\n", + "\n", + "for batch in cycling_dataloader(\n", + " one_train_dataset, batch_size=BATCH_SIZE, num_steps=NUM_STEPS, key=shuffle_key\n", + "):\n", + " corrected_stepper_rollout_trained, opt_state, loss = step_fn(\n", + " corrected_stepper_rollout_trained, opt_state, batch\n", + " )\n", + " train_loss_history.append(loss)\n", + " p_meter.update(1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that the rollout loss level will be different from the one-step loss level\n", + "by a factor of five because we simply added up all time-level losses." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.semilogy(train_loss_history)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prediction_trj = jax.vmap(\n", + " ex.rollout(\n", + " corrected_stepper_rollout_trained, TEST_TEMPORAL_HORIZON, include_init=True\n", + " )\n", + ")(test_ic_set)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mean_nRMSE_trj = jax.vmap(ex.metrics.mean_nRMSE, in_axes=1)(\n", + " prediction_trj, test_trj_set\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.plot(jnp.arange(0, 200 + 1), mean_nRMSE_trj)\n", + "plt.ylim(-0.05, 1.05)\n", + "plt.xlabel(\"Time Step\")\n", + "plt.ylabel(\"Mean nRMSE\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.plot(jnp.arange(0, 200 + 1), mean_nRMSE_trj)\n", + "plt.ylim(-0.05, 1.05)\n", + "plt.xlabel(\"Time Step\")\n", + "plt.ylabel(\"Mean nRMSE\")\n", + "plt.xlim(0, 25)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "jax_fresh", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}