Skip to content

Commit

Permalink
Merge pull request #4 from invrs-io/loss
Browse files Browse the repository at this point in the history
Change loss normalization
  • Loading branch information
mfschubert authored Sep 22, 2023
2 parents cd5ed9b + 5ae08a7 commit d6e3804
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 73 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ jobs:
- name: Test pre-commit hooks
run: |
python -m pip install --upgrade pip
pip install "totypes/"
pip install pre-commit
pre-commit run -a
Expand Down
122 changes: 50 additions & 72 deletions notebooks/ceviche.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,107 +14,85 @@
"import jax.numpy as jnp\n",
"import matplotlib.pyplot as plt\n",
"import numpy as onp\n",
"import optax\n",
"\n",
"from ceviche_challenges import units as u\n",
"from ceviche_challenges import beam_splitter, mode_converter, params, waveguide_bend, wdm\n",
"\n",
"from invrs_gym.loss import transmission_loss\n",
"from invrs_gym.challenge.ceviche import challenge, defaults"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95f003fc-c137-4099-81ad-1e6cfae0f41b",
"metadata": {},
"outputs": [],
"source": [
"for ceviche_model in (\n",
" defaults.BEAM_SPLITTER_MODEL,\n",
" defaults.LIGHTWEIGHT_BEAM_SPLITTER_MODEL,\n",
" defaults.MODE_CONVERTER_MODEL,\n",
" defaults.LIGHTWEIGHT_MODE_CONVERTER_MODEL,\n",
" defaults.WAVEGUIDE_BEND_MODEL,\n",
" defaults.LIGHTWEIGHT_WAVEGUIDE_BEND_MODEL,\n",
" defaults.WDM_MODEL,\n",
" defaults.LIGHTWEIGHT_WDM_MODEL,\n",
"):\n",
" plt.figure()\n",
" ax = plt.subplot(111)\n",
" density = jnp.full(ceviche_model.design_variable_shape, 1.0)\n",
" ax.imshow(ceviche_model.density(density)[20:-20, 20:-20])\n",
" ax.axis(\"off\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "474e7001-99c1-4a7a-94bc-e1191628d313",
"metadata": {},
"outputs": [],
"source": [
"reload(defaults)\n",
"reload(challenge)\n",
"\n",
"bsc = challenge.lightweight_beam_splitter_challenge()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d37779df-9cc3-43ac-bd76-9cb82e22ab17",
"metadata": {},
"outputs": [],
"source": [
"bsc.component"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5392e86a-e897-4dee-89a7-ddac70794048",
"metadata": {},
"outputs": [],
"source": [
"params = bsc.component.init(jax.random.PRNGKey(0))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c686d9db-0353-4c2c-b949-adccb8e3aff6",
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"t0 = time.time()\n",
"s_params, fields = bsc.component.response(params)\n",
"print(time.time() - t0)\n",
"reload(transmission_loss)\n",
"\n",
"bsc = challenge.lightweight_beam_splitter_challenge()\n",
"params = bsc.component.init(jax.random.PRNGKey(0))\n",
"\n",
"def loss_fn(params):\n",
" s_params, fields = bsc.component.response(params)\n",
" return jnp.sum(jnp.abs(s_params)**2)\n",
" response, aux = bsc.component.response(params)\n",
" loss = bsc.loss(response)\n",
" return loss, (response, aux)\n",
"\n",
"opt = optax.adam(0.01)\n",
"state = opt.init(params)\n",
"\n",
"value, grad = jax.value_and_grad(loss_fn)(params)\n",
"loss_values = []\n",
"for _ in range(20):\n",
" (value, (response, aux)), grad = jax.value_and_grad(loss_fn, has_aux=True)(params)\n",
" loss_values.append(value)\n",
" updates, state = opt.update(grad, state)\n",
" params = optax.apply_updates(params, updates)\n",
"\n",
"print(s_params.shape)"
"plt.plot(loss_values)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "971eb2a6-7339-4b45-9b29-f468e25ee237",
"id": "6a01ab00-8edd-4aca-adff-892925829df3",
"metadata": {},
"outputs": [],
"source": [
"print(value)\n",
" "
"reload(transmission_loss)\n",
"\n",
"t1, t2 = jnp.meshgrid(\n",
" jnp.linspace(0, 1),\n",
" jnp.linspace(0, 1),\n",
" indexing=\"ij\",\n",
")\n",
"transmission = jnp.stack([t1, t2], axis=-1)\n",
"\n",
"lower_bound = defaults.WAVEGUIDE_BEND_TRANSMISSION_LOWER_BOUND\n",
"upper_bound = defaults.WAVEGUIDE_BEND_TRANSMISSION_UPPER_BOUND\n",
"\n",
"def loss_fn(transmission):\n",
" return transmission_loss.orthotope_smooth_transmission_loss(\n",
" transmission,\n",
" lower_bound,\n",
" upper_bound,\n",
" transmission_exponent=0.5,\n",
" scalar_exponent=2.0,\n",
" )\n",
"\n",
"loss = jax.vmap(jax.vmap(loss_fn))(transmission)\n",
"\n",
"plt.pcolor(t1, t2, jnp.log10(loss))\n",
"plt.colorbar()\n",
"plt.plot(\n",
" [lower_bound[0], upper_bound[0], upper_bound[0], lower_bound[0], lower_bound[0]],\n",
" [lower_bound[1], lower_bound[1], upper_bound[1], upper_bound[1], lower_bound[1]],\n",
" \"r\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6a01ab00-8edd-4aca-adff-892925829df3",
"id": "5423b32a-44cd-4bb5-97e8-a0c36813b2c5",
"metadata": {},
"outputs": [],
"source": []
Expand Down
3 changes: 2 additions & 1 deletion src/invrs_gym/loss/transmission_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ def orthotope_smooth_transmission_loss(
window_upper_bound, _MAX_PHYSICAL_TRANSMISSION
) - jnp.maximum(window_lower_bound, _MIN_PHYSICAL_TRANSMISSION)

# TODO: check whether dividing by max or min is appropriate.
transformed_elementwise_signed_distance = jax.nn.softplus(
elementwise_signed_distance / jnp.amin(window_size)
elementwise_signed_distance / jnp.amax(window_size)
)

return jnp.linalg.norm(transformed_elementwise_signed_distance) ** scalar_exponent
Expand Down

0 comments on commit d6e3804

Please sign in to comment.