diff --git a/README.md b/README.md index b9695a5..b63039a 100644 --- a/README.md +++ b/README.md @@ -19,8 +19,10 @@ Each `Challenge` has a `Component` as an attribute, and also has a target that c ## Example ```python +# Select the challenge. challenge = invrs_gym.challenges.ceviche_lightweight_waveguide_bend() +# Define loss function, which also returns auxilliary quantities. def loss_fn(params): response, aux = challenge.component.response(params) loss = challenge.loss(response) @@ -30,12 +32,15 @@ def loss_fn(params): value_and_grad_fn = jax.value_and_grad(loss_fn, has_aux=True) +# Select an optimizer. opt = invrs_opt.density_lbfgsb(beta=4) +# Generate initial parameters, and use these to initialize the optimizer state. params = challenge.component.init(jax.random.PRNGKey(0)) state = opt.init(params) -for _ in range(steps): +# Carry out the optimization. +for i in range(steps): params = opt.params(state) (value, (response, distance, aux)), grad = value_and_grad_fn(params) state = opt.update(grad=grad, value=value, params=params, state=state) diff --git a/notebooks/readme_example.ipynb b/notebooks/readme_example.ipynb index 7a5de06..8977693 100644 --- a/notebooks/readme_example.ipynb +++ b/notebooks/readme_example.ipynb @@ -28,7 +28,7 @@ "source": [ "challenge = invrs_gym.challenges.ceviche_lightweight_waveguide_bend()\n", "\n", - "\n", + "# Define loss function, which also returns auxilliary quantities.\n", "def loss_fn(params):\n", " response, aux = challenge.component.response(params)\n", " loss = challenge.loss(response)\n", @@ -36,14 +36,16 @@ " metrics = challenge.metrics(response, params, aux)\n", " return loss, (response, distance, aux)\n", "\n", - "\n", "value_and_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)\n", "\n", + "# Select an optimizer.\n", "opt = invrs_opt.density_lbfgsb(beta=4)\n", "\n", + "# Generate initial parameters, and use these to initialize the optimizer state.\n", "params = challenge.component.init(jax.random.PRNGKey(0))\n", "state = opt.init(params)\n", "\n", + "# Carry out the optimization.\n", "data = []\n", "for i in range(36):\n", " params = opt.params(state)\n", @@ -59,6 +61,7 @@ "metadata": {}, "outputs": [], "source": [ + "# Create an animated gif showing the evolution of the waveguide bend design.\n", "anim = gifcm.AnimatedFigure(figure=plt.figure(figsize=(8, 4)))\n", "\n", "for i, _, params, aux in data:\n",