Skip to content

Commit

Permalink
Tweak readme and notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin Schubert authored and Martin Schubert committed Oct 26, 2023
1 parent 4b17963 commit 204a670
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ Each `Challenge` has a `Component` as an attribute, and also has a target that c

## Example
```python
# Select the challenge.
challenge = invrs_gym.challenges.ceviche_lightweight_waveguide_bend()

# Define loss function, which also returns auxilliary quantities.
def loss_fn(params):
response, aux = challenge.component.response(params)
loss = challenge.loss(response)
Expand All @@ -30,12 +32,15 @@ def loss_fn(params):

value_and_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

# Select an optimizer.
opt = invrs_opt.density_lbfgsb(beta=4)

# Generate initial parameters, and use these to initialize the optimizer state.
params = challenge.component.init(jax.random.PRNGKey(0))
state = opt.init(params)

for _ in range(steps):
# Carry out the optimization.
for i in range(steps):
params = opt.params(state)
(value, (response, distance, aux)), grad = value_and_grad_fn(params)
state = opt.update(grad=grad, value=value, params=params, state=state)
Expand Down
7 changes: 5 additions & 2 deletions notebooks/readme_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,24 @@
"source": [
"challenge = invrs_gym.challenges.ceviche_lightweight_waveguide_bend()\n",
"\n",
"\n",
"# Define loss function, which also returns auxilliary quantities.\n",
"def loss_fn(params):\n",
" response, aux = challenge.component.response(params)\n",
" loss = challenge.loss(response)\n",
" distance = challenge.distance_to_target(response)\n",
" metrics = challenge.metrics(response, params, aux)\n",
" return loss, (response, distance, aux)\n",
"\n",
"\n",
"value_and_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)\n",
"\n",
"# Select an optimizer.\n",
"opt = invrs_opt.density_lbfgsb(beta=4)\n",
"\n",
"# Generate initial parameters, and use these to initialize the optimizer state.\n",
"params = challenge.component.init(jax.random.PRNGKey(0))\n",
"state = opt.init(params)\n",
"\n",
"# Carry out the optimization.\n",
"data = []\n",
"for i in range(36):\n",
" params = opt.params(state)\n",
Expand All @@ -59,6 +61,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Create an animated gif showing the evolution of the waveguide bend design.\n",
"anim = gifcm.AnimatedFigure(figure=plt.figure(figsize=(8, 4)))\n",
"\n",
"for i, _, params, aux in data:\n",
Expand Down

0 comments on commit 204a670

Please sign in to comment.