Skip to content

Commit

Permalink
Add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
mfschubert committed Nov 18, 2023
1 parent 180b97a commit 893186e
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 163 deletions.
191 changes: 46 additions & 145 deletions notebooks/sorter_challenge.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,20 @@
"import time\n",
"\n",
"import jax\n",
"\n",
"# The sorter challenge appears to require 64 bit precision to enable consistent\n",
"# results across platforms.\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"\n",
"import jax.numpy as jnp\n",
"import matplotlib.pyplot as plt\n",
"import numpy as onp\n",
"from skimage import measure\n",
"\n",
"import invrs_opt"
"import invrs_opt\n",
"\n",
"from invrs_gym import challenges\n",
"from invrs_gym.utils import initializers"
]
},
{
Expand All @@ -27,12 +35,11 @@
"metadata": {},
"outputs": [],
"source": [
"from totypes import types\n",
"from importlib import reload\n",
"from invrs_gym.challenges.sorter import polarization_challenge, common\n",
"from invrs_gym.utils import initializers\n",
"reload(polarization_challenge)\n",
"reload(common)\n",
"# The polarization sorter challenge optimizes both film thicknesses and\n",
"# a metasurface density. Gradient with respect to film thicknesses have\n",
"# far larger magnitude than gradient with respect to the value of a metasurface\n",
"# pixel density. To ensure the optimizer does not only focus on film thicknesses,\n",
"# we rescale the density so that its gradient becomes larger.\n",
"\n",
"def rescale_density(density, scale):\n",
" rescaled_array = density.array - density.lower_bound\n",
Expand All @@ -54,64 +61,44 @@
" return rescale_density(density, 0.001)\n",
"\n",
"# Select the challenge to be solved.\n",
"challenge = polarization_challenge.polarization_sorter(\n",
"challenge = challenges.polarization_sorter(\n",
" density_initializer=density_initializer,\n",
" minimum_width=8,\n",
" minimum_spacing=8,\n",
")\n",
"\n",
"\n",
"def transform_density(density: types.Density2DArray, beta: float) -> types.Density2DArray:\n",
" transformed = types.symmetrize_density(density)\n",
" with jax.ensure_compile_time_eval():\n",
" transformed = transform.density_gaussian_filter_and_tanh(transformed, beta=beta)\n",
" # Scale to ensure that the full valid range of the density array is reachable.\n",
" mid_value = (density.lower_bound + density.upper_bound) / 2\n",
" transformed = tree_util.tree_map(\n",
" lambda array: mid_value + (array - mid_value) / jnp.tanh(beta), transformed\n",
" )\n",
" return transform.apply_fixed_pixels(transformed)\n",
"\n",
"\n",
"# Define the loss function; in this case we simply use the default challenge\n",
"# loss. Note that the loss function can return auxilliary quantities.\n",
"def loss_fn(params):\n",
" response, aux = challenge.component.response(params)\n",
" metrics = challenge.metrics(response, params, aux)\n",
" loss = challenge.loss(response)\n",
" return loss, (response, aux)\n",
" return loss, (response, metrics, aux)\n",
"\n",
"\n",
"# Get the initial parameters, and initialize the optimizer.\n",
"seed = 2\n",
"params = challenge.component.init(jax.random.PRNGKey(seed))\n",
"opt = invrs_opt.density_lbfgsb(beta=2)\n",
"# opt = invrs_opt.lbfgsb()\n",
"state = opt.init(params)\n",
"params = opt.params(state)\n",
"\n",
"# _ = challenge.component.response(params)\n",
"\n",
"# The metagrating challenge can be jit-compiled.\n",
"# value_and_grad_fn = jax.jit(jax.value_and_grad(loss_fn, has_aux=True))\n",
"value_and_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)\n",
"# (value, (response, aux)), grad = value_and_grad_fn(params)\n",
"# The polarization sorter challenge can be jit-compiled.\n",
"value_and_grad_fn = jax.jit(jax.value_and_grad(loss_fn, has_aux=True))\n",
"\n",
"# Carry out optimization for a fixed number of steps.\n",
"loss_values = []\n",
"distance_values = []\n",
"for i in range(150):\n",
" t0 = time.time()\n",
" params = opt.params(state)\n",
" (value, (response, aux)), grad = value_and_grad_fn(params)\n",
" (value, (response, metrics, aux)), grad = value_and_grad_fn(params)\n",
" t1 = time.time()\n",
" state = opt.update(grad=grad, value=value, params=params, state=state)\n",
"\n",
" print(\n",
" f\"{i:03} ({t1 - t0:.2f}/{time.time() - t1:.2f}s): loss={value:.3f}, \"\n",
" f\"power={response.reflection + jnp.sum(response.quadrant_transmission, axis=-1)}\"\n",
" f\"power={response.reflection + jnp.sum(response.transmission, axis=-1)}\"\n",
" )\n",
" # for p in aux:\n",
" # print(p)\n",
" loss_values.append(value)"
]
},
Expand All @@ -122,19 +109,27 @@
"metadata": {},
"outputs": [],
"source": [
"seed = 2\n",
"# Plot the initial and optimized parameters, and the loss trajectory.\n",
"initial_params = challenge.component.init(jax.random.PRNGKey(seed))\n",
"\n",
"from invrs_gym.challenges.sorter import common\n",
"\n",
"plt.figure(figsize=(10, 3))\n",
"plt.subplot(121)\n",
"plt.subplot(131)\n",
"plt.plot(loss_values)\n",
"ax = plt.subplot(132)\n",
"plt.imshow(common._density_array(initial_params[\"density_metasurface\"]))\n",
"ax.axis(False)\n",
"plt.colorbar()\n",
"plt.subplot(122)\n",
"ax = plt.subplot(133)\n",
"plt.imshow(common._density_array(params[\"density_metasurface\"]))\n",
"ax.axis(False)\n",
"plt.colorbar()\n",
"\n",
"# Print the optimized thicknesses.\n",
"print(f\" cap initial={initial_params['thickness_cap'].array}, final={params['thickness_cap'].array}\")\n",
"print(f\"metasurface initial={initial_params['thickness_metasurface'].array}, final={params['thickness_metasurface'].array}\")\n",
"print(f\" spacer initial={initial_params['thickness_spacer'].array}, final={params['thickness_spacer'].array}\")\n"
"print(f\" spacer initial={initial_params['thickness_spacer'].array}, final={params['thickness_spacer'].array}\")"
]
},
{
Expand All @@ -144,36 +139,25 @@
"metadata": {},
"outputs": [],
"source": [
"plt.plot(loss_values)\n",
"print(response.quadrant_transmission)\n",
"print(response.quadrant_target_transmission)\n",
"print(response.reflection)\n",
"# Plot the transmission into each of the four quadrants\n",
"\n",
"jnp.sum(response.quadrant_transmission, axis=0)\n",
"plt.figure(figsize=(10, 5))\n",
"plt.subplot(121)\n",
"plt.imshow(response.transmission)\n",
"plt.clim([0, 0.5])\n",
"plt.colorbar()\n",
"\n",
"# jnp.sum(response.transmission, axis=0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2018ad7b-8ec7-45cb-ae34-8598a3b0c284",
"metadata": {},
"outputs": [],
"source": [
"sz = aux[\"poynting_flux_z\"]\n",
"\n",
"plt.figure(figsize=(5, 5))\n",
"ax = plt.subplot(221)\n",
"ax = plt.subplot(243)\n",
"ax.imshow(sz[..., 0])\n",
"ax.axis(False)\n",
"ax = plt.subplot(222)\n",
"ax = plt.subplot(244)\n",
"ax.imshow(sz[..., 1])\n",
"ax.axis(False)\n",
"ax = plt.subplot(223)\n",
"ax = plt.subplot(247)\n",
"ax.imshow(sz[..., 2])\n",
"ax.axis(False)\n",
"ax = plt.subplot(224)\n",
"ax = plt.subplot(248)\n",
"ax.imshow(sz[..., 3])\n",
"ax.axis(False)\n",
"plt.tight_layout()"
Expand All @@ -182,100 +166,17 @@
{
"cell_type": "code",
"execution_count": null,
"id": "b22afb41-a151-40bc-98c2-80b572ab9bbc",
"metadata": {},
"outputs": [],
"source": [
"plt.imshow(grad[\"density_metasurface\"].array)\n",
"plt.colorbar()\n",
"print(grad[\"thickness_cap\"])\n",
"print(grad[\"thickness_metasurface\"])\n",
"print(grad[\"thickness_spacer\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d502e53a-4349-4a73-85b2-a52f70d577e6",
"id": "da4e4325-e41e-4890-82b0-c9e758628546",
"metadata": {},
"outputs": [],
"source": [
" with jax.ensure_compile_time_eval():\n",
" sz_fwd_N, sz_bwd_N = fields.amplitude_poynting_flux(\n",
" forward_amplitude=fwd_substrate_offset,\n",
" backward_amplitude=bwd_substrate_offset,\n",
" layer_solve_result=layer_solve_results[-1],\n",
" )\n",
" \n",
" sz_fwd_substrate_sum = jnp.sum(jnp.abs(sz_fwd_N), axis=-2)\n",
" sz_bwd_substrate_sum = jnp.sum(jnp.abs(sz_bwd_N), axis=-2)\n",
" printvals = [\n",
" sz_fwd_ambient_sum,\n",
" sz_bwd_ambient_sum,\n",
" sz_fwd_substrate_sum,\n",
" sz_bwd_substrate_sum,\n",
" jnp.mean(sz, axis=(-3, -2)),\n",
" sz_bwd_ambient_sum + sz_fwd_substrate_sum,\n",
" sz_bwd_ambient_sum + jnp.mean(sz, axis=(-3, -2)),\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7ac64ebe-5c83-4aaa-ac80-107198eb6274",
"metadata": {},
"outputs": [],
"source": [
"from jax import tree_util\n",
"from totypes import types\n",
"from invrs_opt.lbfgsb import transform, lbfgsb\n",
"\n",
"def transform_density(density: types.Density2DArray, beta: float) -> types.Density2DArray:\n",
" transformed = types.symmetrize_density(density)\n",
" transformed = transform.density_gaussian_filter_and_tanh(transformed, beta=beta)\n",
" # Scale to ensure that the full valid range of the density array is reachable.\n",
" mid_value = (density.lower_bound + density.upper_bound) / 2\n",
" transformed = tree_util.tree_map(\n",
" lambda array: mid_value + (array - mid_value) / jnp.tanh(beta), transformed\n",
" )\n",
" return transform.apply_fixed_pixels(transformed)\n",
"\n",
"\n",
"params, lbfgsb_state_dict = state\n",
"lbfgsb_state = lbfgsb.ScipyLbfgsbState(**lbfgsb_state_dict)\n",
"latent_params = lbfgsb._to_pytree(lbfgsb_state.x, params)\n",
"\n",
"transform_density(latent_params[\"density_metasurface\"], beta=4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ec46adbb-9583-4a6e-85bd-fac31f5e98bb",
"metadata": {},
"outputs": [],
"source": [
"params"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "25d97a2d-7ad3-4a85-a34c-6dff0718b1d2",
"metadata": {},
"outputs": [],
"source": [
"array = onp.zeros((20, 20))\n",
"array[5, :10] = 1\n",
"\n",
"plt.imshow(array)"
"print(metrics)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "56ea8893-5348-433e-ab85-6c1ad203481a",
"id": "3719cf29-5f05-442c-bb31-43f23920fe53",
"metadata": {},
"outputs": [],
"source": []
Expand Down
2 changes: 2 additions & 0 deletions src/invrs_gym/challenges/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from invrs_gym.challenges.diffract.metagrating_challenge import metagrating
from invrs_gym.challenges.diffract.splitter_challenge import diffractive_splitter
from invrs_gym.challenges.extractor.challenge import photon_extractor
from invrs_gym.challenges.sorter.polarization_challenge import polarization_sorter

BY_NAME = {
"ceviche_beam_splitter": ceviche_beam_splitter,
Expand All @@ -36,4 +37,5 @@
"metagrating": metagrating,
"diffractive_splitter": diffractive_splitter,
"photon_extractor": photon_extractor,
"polarization_sorter": polarization_sorter,
}
2 changes: 1 addition & 1 deletion src/invrs_gym/challenges/ceviche/challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from typing import Any, Optional, Sequence, Tuple

import agjax # type: ignore[import-untyped]
import ceviche_challenges as cc # type: ignore[import-untyped]
import ceviche_challenges.wdm.model as wdm_model # type: ignore[import-untyped]
import jax
import jax.numpy as jnp
import numpy as onp
import ceviche_challenges as cc # type: ignore[import-untyped]
from ceviche_challenges import units as u # type: ignore[import-untyped]
from jax import tree_util
from totypes import types
Expand Down
4 changes: 2 additions & 2 deletions src/invrs_gym/challenges/diffract/splitter_challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,9 +352,9 @@ def diffractive_splitter(
minimum_width: The minimum width target for the challenge, in pixels. The
physical minimum width is approximately 180 nm.
minimum_spacing: The minimum spacing target for the challenge, in pixels.
thickness_initializer: Callble which returns the initial thickness, given a
thickness_initializer: Callable which returns the initial thickness, given a
key and seed thickness.
density_initializer: Callble which returns the initial density, given a
density_initializer: Callable which returns the initial density, given a
key and seed density.
splitting: Defines shape of the beam array to be created by the splitter.
normalized_efficiency_lower_bound: The lower bound for normalized efficiency.
Expand Down
Loading

0 comments on commit 893186e

Please sign in to comment.