Skip to content

Commit

Permalink
Ceviche challenges
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin Schubert authored and Martin Schubert committed Sep 22, 2023
1 parent 2ab9be5 commit 4b145ad
Show file tree
Hide file tree
Showing 11 changed files with 966 additions and 198 deletions.
76 changes: 59 additions & 17 deletions notebooks/ceviche.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"from ceviche_challenges import units as u\n",
"from ceviche_challenges import beam_splitter, mode_converter, params, waveguide_bend, wdm\n",
"\n",
"from invrs_gym.challenge.ceviche import defaults"
"from invrs_gym.challenge.ceviche import challenge, defaults"
]
},
{
Expand All @@ -28,51 +28,93 @@
"metadata": {},
"outputs": [],
"source": [
"model = wdm.model.WdmModel(defaults.LIGHTWEIGHT_SIM_PARAMS, defaults.LIGHTWEIGHT_WDM_SPEC)\n",
"density = jnp.full(model.design_variable_shape, 1.0)\n",
"plt.imshow(model.density(density)[20:-20, 20:-20])"
"for ceviche_model in (\n",
" defaults.BEAM_SPLITTER_MODEL,\n",
" defaults.LIGHTWEIGHT_BEAM_SPLITTER_MODEL,\n",
" defaults.MODE_CONVERTER_MODEL,\n",
" defaults.LIGHTWEIGHT_MODE_CONVERTER_MODEL,\n",
" defaults.WAVEGUIDE_BEND_MODEL,\n",
" defaults.LIGHTWEIGHT_WAVEGUIDE_BEND_MODEL,\n",
" defaults.WDM_MODEL,\n",
" defaults.LIGHTWEIGHT_WDM_MODEL,\n",
"):\n",
" plt.figure()\n",
" ax = plt.subplot(111)\n",
" density = jnp.full(ceviche_model.design_variable_shape, 1.0)\n",
" ax.imshow(ceviche_model.density(density)[20:-20, 20:-20])\n",
" ax.axis(\"off\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "33898107-7832-4ad1-a980-2873fb993dfb",
"id": "474e7001-99c1-4a7a-94bc-e1191628d313",
"metadata": {},
"outputs": [],
"source": [
"model = waveguide_bend.model.WaveguideBendModel(defaults.LIGHTWEIGHT_SIM_PARAMS, defaults.WAVEGUIDE_BEND_SPEC)\n",
"density = jnp.full(model.design_variable_shape, 1.0)\n",
"plt.imshow(model.density(density)[20:-20, 20:-20])"
"reload(defaults)\n",
"reload(challenge)\n",
"\n",
"bsc = challenge.lightweight_beam_splitter_challenge()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "23bf803f-5359-48de-a497-064fe6ef688a",
"id": "d37779df-9cc3-43ac-bd76-9cb82e22ab17",
"metadata": {},
"outputs": [],
"source": [
"model = beam_splitter.model.BeamSplitterModel(defaults.LIGHTWEIGHT_SIM_PARAMS, defaults.BEAM_SPLITTER_SPEC)\n",
"density = jnp.full(model.design_variable_shape, 1.0)\n",
"plt.imshow(model.density(density)[20:-20, 20:-20])"
"bsc.component"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b553c6fc-0c0f-4450-b492-fb27ba84644d",
"id": "5392e86a-e897-4dee-89a7-ddac70794048",
"metadata": {},
"outputs": [],
"source": [
"model = mode_converter.model.ModeConverterModel(defaults.LIGHTWEIGHT_SIM_PARAMS, defaults.MODE_CONVERTER_SPEC)\n",
"density = jnp.full(model.design_variable_shape, 1.0)\n",
"plt.imshow(model.density(density)[20:-20, 20:-20])"
"params = bsc.component.init(jax.random.PRNGKey(0))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "474e7001-99c1-4a7a-94bc-e1191628d313",
"id": "c686d9db-0353-4c2c-b949-adccb8e3aff6",
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"t0 = time.time()\n",
"s_params, fields = bsc.component.response(params)\n",
"print(time.time() - t0)\n",
"\n",
"\n",
"def loss_fn(params):\n",
" s_params, fields = bsc.component.response(params)\n",
" return jnp.sum(jnp.abs(s_params)**2)\n",
"\n",
"value, grad = jax.value_and_grad(loss_fn)(params)\n",
"\n",
"print(s_params.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "971eb2a6-7339-4b45-9b29-f468e25ee237",
"metadata": {},
"outputs": [],
"source": [
"print(value)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6a01ab00-8edd-4aca-adff-892925829df3",
"metadata": {},
"outputs": [],
"source": []
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ maintainers = [
]

dependencies = [
"agjax",
"ceviche_challenges",
"jax",
"jaxlib",
Expand Down
10 changes: 10 additions & 0 deletions src/invrs_gym/challenge/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from invrs_gym.challenge.ceviche.challenge import ( # noqa: F401
beam_splitter_challenge,
lightweight_beam_splitter_challenge,
lightweight_mode_converter_challenge,
lightweight_waveguide_bend_challenge,
lightweight_wdm_challenge,
mode_converter_challenge,
waveguide_bend_challenge,
wdm_challenge,
)
Loading

0 comments on commit 4b145ad

Please sign in to comment.