Skip to content

Commit

Permalink
Update for new interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed Mar 1, 2024
1 parent 50ae48e commit 472d815
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 120 deletions.
39 changes: 25 additions & 14 deletions examples/additional_features.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,16 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-03-01 14:59:05.217187: W external/xla/xla/service/gpu/nvptx_compiler.cc:679] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.52). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
Expand All @@ -64,9 +71,11 @@
"NUM_POINTS = 100\n",
"DT = 1.0\n",
"\n",
"ks_stepper = ex.KuramotoSivashinskyConservative(1, DOMAIN_EXTENT, NUM_POINTS, DT)\n",
"ks_stepper = ex.stepper.KuramotoSivashinskyConservative(\n",
" 1, DOMAIN_EXTENT, NUM_POINTS, DT\n",
")\n",
"\n",
"grid = ex.get_grid(1, DOMAIN_EXTENT, NUM_POINTS)\n",
"grid = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)\n",
"u_0 = jax.random.normal(\n",
" jax.random.PRNGKey(0),\n",
" (\n",
Expand Down Expand Up @@ -99,16 +108,16 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7f1d740f25f0>"
"<matplotlib.image.AxesImage at 0x7ff6e801c580>"
]
},
"execution_count": 6,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
Expand All @@ -124,7 +133,7 @@
}
],
"source": [
"ks_stepper_half_step = ex.KuramotoSivashinskyConservative(\n",
"ks_stepper_half_step = ex.stepper.KuramotoSivashinskyConservative(\n",
" 1, DOMAIN_EXTENT, NUM_POINTS, DT / 2\n",
")\n",
"ks_stepper_substepper = ex.RepeatedStepper(ks_stepper_half_step, 2)\n",
Expand Down Expand Up @@ -153,7 +162,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand All @@ -162,7 +171,7 @@
"Text(0.5, 1.0, 'Forced Heat Equation')"
]
},
"execution_count": 7,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
Expand All @@ -183,10 +192,12 @@
"DT = 0.01\n",
"NU = 0.01\n",
"\n",
"diffusion_stepper = ex.Diffusion(1, DOMAIN_EXTENT, NUM_POINTS, DT, diffusivity=NU)\n",
"diffusion_stepper = ex.stepper.Diffusion(\n",
" 1, DOMAIN_EXTENT, NUM_POINTS, DT, diffusivity=NU\n",
")\n",
"forced_diffusion_stepper = ex.ForcedStepper(diffusion_stepper)\n",
"\n",
"grid = ex.get_grid(1, DOMAIN_EXTENT, NUM_POINTS)\n",
"grid = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)\n",
"u_0 = jnp.sin(6 * jnp.pi * grid / DOMAIN_EXTENT)\n",
"\n",
"# Have a constant forcing term but we could also supply a time-dependent forcing\n",
Expand Down Expand Up @@ -218,7 +229,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand All @@ -227,7 +238,7 @@
"Array(4.742706e-05, dtype=float32)"
]
},
"execution_count": 8,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -236,7 +247,7 @@
"DOMAIN_EXTENT = 1.0\n",
"NUM_POINTS = 100\n",
"\n",
"grid = ex.get_grid(1, DOMAIN_EXTENT, NUM_POINTS)\n",
"grid = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS)\n",
"u = jnp.sin(6 * jnp.pi * grid / DOMAIN_EXTENT)\n",
"u_prime_exact = jnp.cos(6 * jnp.pi * grid / DOMAIN_EXTENT) * 6 * jnp.pi / DOMAIN_EXTENT\n",
"\n",
Expand Down
27 changes: 17 additions & 10 deletions examples/learning_burgers_autoregressive_neural_operator.ipynb

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions examples/simple_advection_example_1d.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"2024-02-29 10:05:44.604191: W external/xla/xla/service/gpu/nvptx_compiler.cc:679] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.52). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n"
"2024-03-01 14:57:24.421211: W external/xla/xla/service/gpu/nvptx_compiler.cc:679] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.3.52). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n"
]
},
{
Expand Down Expand Up @@ -141,7 +141,7 @@
}
],
"source": [
"grid_from_exponax = ex.get_grid(\n",
"grid_from_exponax = ex.make_grid(\n",
" 1,\n",
" DOMAIN_EXTENT,\n",
" NUM_POINTS,\n",
Expand Down Expand Up @@ -271,7 +271,7 @@
}
],
"source": [
"full_grid = ex.get_grid(1, DOMAIN_EXTENT, NUM_POINTS, full=True)\n",
"full_grid = ex.make_grid(1, DOMAIN_EXTENT, NUM_POINTS, full=True)\n",
"full_ic = ex.wrap_bc(ic)\n",
"plt.plot(full_grid[0], full_ic[0])\n",
"plt.xlim(-0.1, 1.1)\n",
Expand Down Expand Up @@ -330,7 +330,7 @@
"VELOCITY = 1.0\n",
"DT = 0.2\n",
"\n",
"advection_stepper = ex.Advection(\n",
"advection_stepper = ex.stepper.Advection(\n",
" 1,\n",
" DOMAIN_EXTENT,\n",
" NUM_POINTS,\n",
Expand Down Expand Up @@ -376,7 +376,7 @@
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7fe164079cc0>"
"<matplotlib.legend.Legend at 0x7fbc48571540>"
]
},
"execution_count": 12,
Expand Down Expand Up @@ -435,7 +435,7 @@
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7fe164230a90>"
"<matplotlib.legend.Legend at 0x7fbc4846b760>"
]
},
"execution_count": 13,
Expand Down Expand Up @@ -564,7 +564,7 @@
"outputs": [],
"source": [
"SMALLER_DT = 0.01\n",
"slower_advection_stepper = ex.Advection(\n",
"slower_advection_stepper = ex.stepper.Advection(\n",
" 1, DOMAIN_EXTENT, NUM_POINTS, SMALLER_DT, velocity=VELOCITY\n",
")\n",
"longer_rollout_advection_stepper = ex.rollout(\n",
Expand Down
113 changes: 59 additions & 54 deletions examples/solver_showcase_1d.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 472d815

Please sign in to comment.