diff --git a/notebooks/sorter_challenge.ipynb b/notebooks/sorter_challenge.ipynb index 772c4a5..2a441c8 100644 --- a/notebooks/sorter_challenge.ipynb +++ b/notebooks/sorter_challenge.ipynb @@ -12,12 +12,20 @@ "import time\n", "\n", "import jax\n", + "\n", + "# The sorter challenge appears to require 64 bit precision to enable consistent\n", + "# results across platforms.\n", + "jax.config.update(\"jax_enable_x64\", True)\n", + "\n", "import jax.numpy as jnp\n", "import matplotlib.pyplot as plt\n", "import numpy as onp\n", "from skimage import measure\n", "\n", - "import invrs_opt" + "import invrs_opt\n", + "\n", + "from invrs_gym import challenges\n", + "from invrs_gym.utils import initializers" ] }, { @@ -27,12 +35,11 @@ "metadata": {}, "outputs": [], "source": [ - "from totypes import types\n", - "from importlib import reload\n", - "from invrs_gym.challenges.sorter import polarization_challenge, common\n", - "from invrs_gym.utils import initializers\n", - "reload(polarization_challenge)\n", - "reload(common)\n", + "# The polarization sorter challenge optimizes both film thicknesses and\n", + "# a metasurface density. Gradient with respect to film thicknesses have\n", + "# far larger magnitude than gradient with respect to the value of a metasurface\n", + "# pixel density. To ensure the optimizer does not only focus on film thicknesses,\n", + "# we rescale the density so that its gradient becomes larger.\n", "\n", "def rescale_density(density, scale):\n", " rescaled_array = density.array - density.lower_bound\n", @@ -54,47 +61,29 @@ " return rescale_density(density, 0.001)\n", "\n", "# Select the challenge to be solved.\n", - "challenge = polarization_challenge.polarization_sorter(\n", + "challenge = challenges.polarization_sorter(\n", " density_initializer=density_initializer,\n", " minimum_width=8,\n", " minimum_spacing=8,\n", ")\n", "\n", - "\n", - "def transform_density(density: types.Density2DArray, beta: float) -> types.Density2DArray:\n", - " transformed = types.symmetrize_density(density)\n", - " with jax.ensure_compile_time_eval():\n", - " transformed = transform.density_gaussian_filter_and_tanh(transformed, beta=beta)\n", - " # Scale to ensure that the full valid range of the density array is reachable.\n", - " mid_value = (density.lower_bound + density.upper_bound) / 2\n", - " transformed = tree_util.tree_map(\n", - " lambda array: mid_value + (array - mid_value) / jnp.tanh(beta), transformed\n", - " )\n", - " return transform.apply_fixed_pixels(transformed)\n", - "\n", - "\n", "# Define the loss function; in this case we simply use the default challenge\n", "# loss. Note that the loss function can return auxilliary quantities.\n", "def loss_fn(params):\n", " response, aux = challenge.component.response(params)\n", + " metrics = challenge.metrics(response, params, aux)\n", " loss = challenge.loss(response)\n", - " return loss, (response, aux)\n", + " return loss, (response, metrics, aux)\n", "\n", "\n", "# Get the initial parameters, and initialize the optimizer.\n", "seed = 2\n", "params = challenge.component.init(jax.random.PRNGKey(seed))\n", "opt = invrs_opt.density_lbfgsb(beta=2)\n", - "# opt = invrs_opt.lbfgsb()\n", "state = opt.init(params)\n", - "params = opt.params(state)\n", "\n", - "# _ = challenge.component.response(params)\n", - "\n", - "# The metagrating challenge can be jit-compiled.\n", - "# value_and_grad_fn = jax.jit(jax.value_and_grad(loss_fn, has_aux=True))\n", - "value_and_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)\n", - "# (value, (response, aux)), grad = value_and_grad_fn(params)\n", + "# The polarization sorter challenge can be jit-compiled.\n", + "value_and_grad_fn = jax.jit(jax.value_and_grad(loss_fn, has_aux=True))\n", "\n", "# Carry out optimization for a fixed number of steps.\n", "loss_values = []\n", @@ -102,16 +91,14 @@ "for i in range(150):\n", " t0 = time.time()\n", " params = opt.params(state)\n", - " (value, (response, aux)), grad = value_and_grad_fn(params)\n", + " (value, (response, metrics, aux)), grad = value_and_grad_fn(params)\n", " t1 = time.time()\n", " state = opt.update(grad=grad, value=value, params=params, state=state)\n", "\n", " print(\n", " f\"{i:03} ({t1 - t0:.2f}/{time.time() - t1:.2f}s): loss={value:.3f}, \"\n", - " f\"power={response.reflection + jnp.sum(response.quadrant_transmission, axis=-1)}\"\n", + " f\"power={response.reflection + jnp.sum(response.transmission, axis=-1)}\"\n", " )\n", - " # for p in aux:\n", - " # print(p)\n", " loss_values.append(value)" ] }, @@ -122,19 +109,27 @@ "metadata": {}, "outputs": [], "source": [ - "seed = 2\n", + "# Plot the initial and optimized parameters, and the loss trajectory.\n", "initial_params = challenge.component.init(jax.random.PRNGKey(seed))\n", "\n", + "from invrs_gym.challenges.sorter import common\n", + "\n", "plt.figure(figsize=(10, 3))\n", - "plt.subplot(121)\n", + "plt.subplot(131)\n", + "plt.plot(loss_values)\n", + "ax = plt.subplot(132)\n", "plt.imshow(common._density_array(initial_params[\"density_metasurface\"]))\n", + "ax.axis(False)\n", "plt.colorbar()\n", - "plt.subplot(122)\n", + "ax = plt.subplot(133)\n", "plt.imshow(common._density_array(params[\"density_metasurface\"]))\n", + "ax.axis(False)\n", "plt.colorbar()\n", + "\n", + "# Print the optimized thicknesses.\n", "print(f\" cap initial={initial_params['thickness_cap'].array}, final={params['thickness_cap'].array}\")\n", "print(f\"metasurface initial={initial_params['thickness_metasurface'].array}, final={params['thickness_metasurface'].array}\")\n", - "print(f\" spacer initial={initial_params['thickness_spacer'].array}, final={params['thickness_spacer'].array}\")\n" + "print(f\" spacer initial={initial_params['thickness_spacer'].array}, final={params['thickness_spacer'].array}\")" ] }, { @@ -144,36 +139,25 @@ "metadata": {}, "outputs": [], "source": [ - "plt.plot(loss_values)\n", - "print(response.quadrant_transmission)\n", - "print(response.quadrant_target_transmission)\n", - "print(response.reflection)\n", + "# Plot the transmission into each of the four quadrants\n", "\n", - "jnp.sum(response.quadrant_transmission, axis=0)\n", + "plt.figure(figsize=(10, 5))\n", + "plt.subplot(121)\n", + "plt.imshow(response.transmission)\n", + "plt.clim([0, 0.5])\n", + "plt.colorbar()\n", "\n", - "# jnp.sum(response.transmission, axis=0)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2018ad7b-8ec7-45cb-ae34-8598a3b0c284", - "metadata": {}, - "outputs": [], - "source": [ "sz = aux[\"poynting_flux_z\"]\n", - "\n", - "plt.figure(figsize=(5, 5))\n", - "ax = plt.subplot(221)\n", + "ax = plt.subplot(243)\n", "ax.imshow(sz[..., 0])\n", "ax.axis(False)\n", - "ax = plt.subplot(222)\n", + "ax = plt.subplot(244)\n", "ax.imshow(sz[..., 1])\n", "ax.axis(False)\n", - "ax = plt.subplot(223)\n", + "ax = plt.subplot(247)\n", "ax.imshow(sz[..., 2])\n", "ax.axis(False)\n", - "ax = plt.subplot(224)\n", + "ax = plt.subplot(248)\n", "ax.imshow(sz[..., 3])\n", "ax.axis(False)\n", "plt.tight_layout()" @@ -182,100 +166,17 @@ { "cell_type": "code", "execution_count": null, - "id": "b22afb41-a151-40bc-98c2-80b572ab9bbc", - "metadata": {}, - "outputs": [], - "source": [ - "plt.imshow(grad[\"density_metasurface\"].array)\n", - "plt.colorbar()\n", - "print(grad[\"thickness_cap\"])\n", - "print(grad[\"thickness_metasurface\"])\n", - "print(grad[\"thickness_spacer\"])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d502e53a-4349-4a73-85b2-a52f70d577e6", + "id": "da4e4325-e41e-4890-82b0-c9e758628546", "metadata": {}, "outputs": [], "source": [ - " with jax.ensure_compile_time_eval():\n", - " sz_fwd_N, sz_bwd_N = fields.amplitude_poynting_flux(\n", - " forward_amplitude=fwd_substrate_offset,\n", - " backward_amplitude=bwd_substrate_offset,\n", - " layer_solve_result=layer_solve_results[-1],\n", - " )\n", - " \n", - " sz_fwd_substrate_sum = jnp.sum(jnp.abs(sz_fwd_N), axis=-2)\n", - " sz_bwd_substrate_sum = jnp.sum(jnp.abs(sz_bwd_N), axis=-2)\n", - " printvals = [\n", - " sz_fwd_ambient_sum,\n", - " sz_bwd_ambient_sum,\n", - " sz_fwd_substrate_sum,\n", - " sz_bwd_substrate_sum,\n", - " jnp.mean(sz, axis=(-3, -2)),\n", - " sz_bwd_ambient_sum + sz_fwd_substrate_sum,\n", - " sz_bwd_ambient_sum + jnp.mean(sz, axis=(-3, -2)),\n", - " ]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7ac64ebe-5c83-4aaa-ac80-107198eb6274", - "metadata": {}, - "outputs": [], - "source": [ - "from jax import tree_util\n", - "from totypes import types\n", - "from invrs_opt.lbfgsb import transform, lbfgsb\n", - "\n", - "def transform_density(density: types.Density2DArray, beta: float) -> types.Density2DArray:\n", - " transformed = types.symmetrize_density(density)\n", - " transformed = transform.density_gaussian_filter_and_tanh(transformed, beta=beta)\n", - " # Scale to ensure that the full valid range of the density array is reachable.\n", - " mid_value = (density.lower_bound + density.upper_bound) / 2\n", - " transformed = tree_util.tree_map(\n", - " lambda array: mid_value + (array - mid_value) / jnp.tanh(beta), transformed\n", - " )\n", - " return transform.apply_fixed_pixels(transformed)\n", - "\n", - "\n", - "params, lbfgsb_state_dict = state\n", - "lbfgsb_state = lbfgsb.ScipyLbfgsbState(**lbfgsb_state_dict)\n", - "latent_params = lbfgsb._to_pytree(lbfgsb_state.x, params)\n", - "\n", - "transform_density(latent_params[\"density_metasurface\"], beta=4)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ec46adbb-9583-4a6e-85bd-fac31f5e98bb", - "metadata": {}, - "outputs": [], - "source": [ - "params" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "25d97a2d-7ad3-4a85-a34c-6dff0718b1d2", - "metadata": {}, - "outputs": [], - "source": [ - "array = onp.zeros((20, 20))\n", - "array[5, :10] = 1\n", - "\n", - "plt.imshow(array)" + "print(metrics)" ] }, { "cell_type": "code", "execution_count": null, - "id": "56ea8893-5348-433e-ab85-6c1ad203481a", + "id": "3719cf29-5f05-442c-bb31-43f23920fe53", "metadata": {}, "outputs": [], "source": [] diff --git a/src/invrs_gym/challenges/__init__.py b/src/invrs_gym/challenges/__init__.py index 9478a78..b36ef00 100644 --- a/src/invrs_gym/challenges/__init__.py +++ b/src/invrs_gym/challenges/__init__.py @@ -23,6 +23,7 @@ from invrs_gym.challenges.diffract.metagrating_challenge import metagrating from invrs_gym.challenges.diffract.splitter_challenge import diffractive_splitter from invrs_gym.challenges.extractor.challenge import photon_extractor +from invrs_gym.challenges.sorter.polarization_challenge import polarization_sorter BY_NAME = { "ceviche_beam_splitter": ceviche_beam_splitter, @@ -36,4 +37,5 @@ "metagrating": metagrating, "diffractive_splitter": diffractive_splitter, "photon_extractor": photon_extractor, + "polarization_sorter": polarization_sorter, } diff --git a/src/invrs_gym/challenges/ceviche/challenge.py b/src/invrs_gym/challenges/ceviche/challenge.py index 807c4f7..61ee7b8 100644 --- a/src/invrs_gym/challenges/ceviche/challenge.py +++ b/src/invrs_gym/challenges/ceviche/challenge.py @@ -8,11 +8,11 @@ from typing import Any, Optional, Sequence, Tuple import agjax # type: ignore[import-untyped] +import ceviche_challenges as cc # type: ignore[import-untyped] import ceviche_challenges.wdm.model as wdm_model # type: ignore[import-untyped] import jax import jax.numpy as jnp import numpy as onp -import ceviche_challenges as cc # type: ignore[import-untyped] from ceviche_challenges import units as u # type: ignore[import-untyped] from jax import tree_util from totypes import types diff --git a/src/invrs_gym/challenges/diffract/splitter_challenge.py b/src/invrs_gym/challenges/diffract/splitter_challenge.py index 9be63ec..43a282d 100644 --- a/src/invrs_gym/challenges/diffract/splitter_challenge.py +++ b/src/invrs_gym/challenges/diffract/splitter_challenge.py @@ -352,9 +352,9 @@ def diffractive_splitter( minimum_width: The minimum width target for the challenge, in pixels. The physical minimum width is approximately 180 nm. minimum_spacing: The minimum spacing target for the challenge, in pixels. - thickness_initializer: Callble which returns the initial thickness, given a + thickness_initializer: Callable which returns the initial thickness, given a key and seed thickness. - density_initializer: Callble which returns the initial density, given a + density_initializer: Callable which returns the initial density, given a key and seed density. splitting: Defines shape of the beam array to be created by the splitter. normalized_efficiency_lower_bound: The lower bound for normalized efficiency. diff --git a/src/invrs_gym/challenges/sorter/polarization_challenge.py b/src/invrs_gym/challenges/sorter/polarization_challenge.py index 5cfd566..2525adc 100644 --- a/src/invrs_gym/challenges/sorter/polarization_challenge.py +++ b/src/invrs_gym/challenges/sorter/polarization_challenge.py @@ -32,8 +32,18 @@ class PolarizationSorterChallenge(base.Challenge): """Defines the polarization sorter challenge. - The target of the polarization sorter challenge is to achieve coupling into target - + The target of the polarization sorter challenge is to achieve coupling of incident + plane waves into four individual pixels, depending upon the polarization of the + incident wave. + + Attributes: + component: The component to be optimized. + efficiency_target: The target efficiency for the coupling of e.g. an + x-polarized plane wave into its designated pixel. The theoretical maximum + is 0.5. + polarization_ratio_target: The target ratio of power coupled for e.g. an + x-polarized plane wave into the "x-polarized pixel" and the power for + the x-polarized plane wave into the "y-polarized pixel". """ component: common.SorterComponent @@ -63,15 +73,17 @@ def loss(self, response: common.SorterResponse) -> jnp.ndarray: def distance_to_target(self, response: common.SorterResponse) -> jnp.ndarray: """Compute distance from the component `response` to the challenge target.""" - target_transmission = response.transmission[ + on_target_transmission = response.transmission[ ..., tuple(range(4)), tuple(range(4)) ] - min_efficiency = jnp.amin(target_transmission / 0.5) + min_efficiency = jnp.amin(on_target_transmission) off_target_transmission = response.transmission[ ..., tuple(range(4))[::-1], tuple(range(4)) ] - min_polarization_ratio = jnp.amin(target_transmission / off_target_transmission) + min_polarization_ratio = jnp.amin( + on_target_transmission / off_target_transmission + ) return jnp.maximum( self.polarization_ratio_target - min_polarization_ratio, 0.0 ) + jnp.maximum(self.efficiency_target - min_efficiency, 0.0) @@ -97,15 +109,15 @@ def metrics( - mean efficiency """ del params, aux - target_transmission = response.transmission[ + on_target_transmission = response.transmission[ ..., tuple(range(4)), tuple(range(4)) ] - efficiency = target_transmission / 0.5 + efficiency = on_target_transmission off_target_transmission = response.transmission[ ..., tuple(range(4))[::-1], tuple(range(4)) ] - polarization_ratio = target_transmission / off_target_transmission + polarization_ratio = on_target_transmission / off_target_transmission return { EFFICIENCY_MEAN: jnp.mean(efficiency), EFFICIENCY_MIN: jnp.amin(efficiency), @@ -121,9 +133,9 @@ def metrics( permittivity_metasurface_void=(1.5 + 0.00001j) ** 2, permittivity_spacer=(1.5 + 0.00001j) ** 2, permittivity_substrate=(4.0730 + 0.028038j) ** 2, - thickness_cap=types.BoundedArray(0.05, lower_bound=0.00, upper_bound=0.7), - thickness_metasurface=types.BoundedArray(0.15, lower_bound=0.05, upper_bound=0.3), - thickness_spacer=types.BoundedArray(1.0, lower_bound=0.5, upper_bound=1.2), + thickness_cap=types.BoundedArray(0.05, lower_bound=0.00, upper_bound=0.5), + thickness_metasurface=types.BoundedArray(0.15, lower_bound=0.1, upper_bound=0.3), + thickness_spacer=types.BoundedArray(1.0, lower_bound=0.8, upper_bound=1.2), pitch=2.0, offset_monitor_substrate=0.1, ) @@ -143,7 +155,7 @@ def metrics( MINIMUM_SPACING = 8 # Target metrics for the sorter component. -EFFICIENCY_TARGET = 0.8 +EFFICIENCY_TARGET = 0.4 POLARIZATION_RATIO_TARGET = 10 @@ -154,12 +166,33 @@ def polarization_sorter( initializers.identity_initializer ), density_initializer: base.DensityInitializer = density_initializer, - spec: common.SorterSpec = POLARIZATION_SORTER_SPEC, - sim_params: common.SorterSimParams = POLARIZATION_SORTER_SIM_PARAMS, efficiency_target: float = EFFICIENCY_TARGET, polarization_ratio_target: float = POLARIZATION_RATIO_TARGET, + spec: common.SorterSpec = POLARIZATION_SORTER_SPEC, + sim_params: common.SorterSimParams = POLARIZATION_SORTER_SIM_PARAMS, ) -> PolarizationSorterChallenge: - """Polarization sorter challenge.""" + """Polarization sorter challenge. + + Args: + minimum_width: The minimum width target for the challenge, in pixels. The + physical minimum width is approximately 80 nm. + minimum_spacing: The minimum spacing target for the challenge, in pixels. + thickness_initializer: Callable which returns the initial thickness, given a + key and seed thickness. + density_initializer: Callable which returns the initial density, given a + key and seed density. + efficiency_target: The target efficiency for the coupling of e.g. an + x-polarized plane wave into its designated pixel. The theoretical maximum + is 0.5. + polarization_ratio_target: The target ratio of power coupled for e.g. an + x-polarized plane wave into the "x-polarized pixel" and the power for + the x-polarized plane wave into the "y-polarized pixel". + spec: Defines the physical specification of the polarization sorter. + sim_params: Defines the simulation settings of the polarization sorter. + + Returns: + The `PolarizationSorterChallenge`. + """ return PolarizationSorterChallenge( component=common.SorterComponent( spec=spec,