From a6895a6d78521beeadb3d4c12913caf26d989f82 Mon Sep 17 00:00:00 2001 From: Martin Schubert Date: Wed, 8 Nov 2023 15:39:51 -0800 Subject: [PATCH] Add sorter component --- scripts/experiment.py | 5 +- src/invrs_gym/challenges/ceviche/challenge.py | 14 +- src/invrs_gym/challenges/sorter/common.py | 173 +++++++++++++++++- 3 files changed, 174 insertions(+), 18 deletions(-) diff --git a/scripts/experiment.py b/scripts/experiment.py index ed8c11b..6c25001 100644 --- a/scripts/experiment.py +++ b/scripts/experiment.py @@ -15,6 +15,7 @@ import multiprocessing as mp import os import random +import time from typing import Any, Dict, List, Tuple from invrs_utils.experiment import sweep @@ -103,11 +104,9 @@ def run_work_unit( # The use of multiprocessing requires that some modules be imported here, as they # cannot be imported in the main process which is forked. - import time - import invrs_opt - from invrs_utils.experiment import checkpoint import jax + from invrs_utils.experiment import checkpoint from jax import numpy as jnp from totypes import json_utils diff --git a/src/invrs_gym/challenges/ceviche/challenge.py b/src/invrs_gym/challenges/ceviche/challenge.py index 5a88f9c..886c616 100644 --- a/src/invrs_gym/challenges/ceviche/challenge.py +++ b/src/invrs_gym/challenges/ceviche/challenge.py @@ -8,21 +8,21 @@ from typing import Any, Optional, Sequence, Tuple import agjax # type: ignore[import-untyped] +import ceviche_challenges.wdm.model as wdm_model # type: ignore[import-untyped] import jax import jax.numpy as jnp import numpy as onp from ceviche_challenges import params # type: ignore[import-untyped] from ceviche_challenges import units as u -from ceviche_challenges.beam_splitter import ( # type: ignore[import-untyped] - model as beam_splitter_model, +from ceviche_challenges.beam_splitter import ( + model as beam_splitter_model, # type: ignore[import-untyped] ) -from ceviche_challenges.mode_converter import ( # type: ignore[import-untyped] - model as mode_converter_model, +from ceviche_challenges.mode_converter import ( + model as mode_converter_model, # type: ignore[import-untyped] ) -from ceviche_challenges.waveguide_bend import ( # type: ignore[import-untyped] - model as waveguide_bend_model, +from ceviche_challenges.waveguide_bend import ( + model as waveguide_bend_model, # type: ignore[import-untyped] ) -import ceviche_challenges.wdm.model as wdm_model # type: ignore[import-untyped] from jax import tree_util from totypes import types diff --git a/src/invrs_gym/challenges/sorter/common.py b/src/invrs_gym/challenges/sorter/common.py index 6a3081c..6be9367 100644 --- a/src/invrs_gym/challenges/sorter/common.py +++ b/src/invrs_gym/challenges/sorter/common.py @@ -4,16 +4,24 @@ """ import dataclasses -from typing import Any, Dict, Tuple +from typing import Any, Callable, Dict, Optional, Tuple import jax import jax.numpy as jnp -import numpy as onp from fmmax import basis, fields, fmm, scattering, utils # type: ignore[import-untyped] from jax import tree_util from totypes import types -Params = Any +from invrs_gym.challenges import base + +Params = Dict[str, types.BoundedArray | types.Density2DArray] +ThicknessInitializer = Callable[[jax.Array, types.BoundedArray], types.BoundedArray] + + +DENSITY_METASURFACE = "density_metasurface" +THICKNESS_CAP = "thickness_cap" +THICKNESS_METASURFACE = "thickness_metasurface" +THICKNESS_SPACER = "thickness_spacer" DENSITY_LOWER_BOUND = 0.0 DENSITY_UPPER_BOUND = 1.0 @@ -22,7 +30,7 @@ @dataclasses.dataclass class SorterSpec: """Defines the physical specification of a sorter. - + Attributes: permittivity_ambient: Permittivity of the ambient material. permittivity_cap: Permittivity of the cap layer. @@ -37,6 +45,7 @@ class SorterSpec: offset_monitor_substrate: Offset of the monitor plane from the interface between spacer and substrate. """ + permittivity_ambient: complex permittivity_cap: complex permittivity_metasurface_solid: complex @@ -56,7 +65,7 @@ class SorterSpec: @dataclasses.dataclass class SorterSimParams: """Parameters that configure the simulation of a sorter. - + Attributes: grid_spacing: The spacing of points on the real-space grid. wavelength: The wavelength of the excitation. @@ -66,6 +75,7 @@ class SorterSimParams: approximate_num_terms: Defines the number of terms in the Fourier expansion. truncation: Determines how the Fourier basis is truncated. """ + grid_spacing: float wavelength: float | jnp.ndarray polar_angle: float | jnp.ndarray @@ -78,7 +88,7 @@ class SorterSimParams: @dataclasses.dataclass class SorterResponse: """Contains the response of the sorter. - + Attributes: wavelength: The wavelength for the sorter response. polar_angle: The polar angle for the sorter response. @@ -87,6 +97,7 @@ class SorterResponse: polarizations (i.e. x, y, x + y, x - y). reflection: The reflection back to the ambient for the four polarizations. """ + wavelength: jnp.ndarray polar_angle: jnp.ndarray azimuthal_angle: jnp.ndarray @@ -110,6 +121,152 @@ class SorterResponse: ) +class SorterComponent(base.Component): + """Defines a photon extractor component.""" + + def __init__( + self, + spec: SorterSpec, + sim_params: SorterSimParams, + thickness_initializer: ThicknessInitializer, + density_initializer: base.DensityInitializer, + **seed_density_kwargs: Any, + ) -> None: + """Initializes the sorter component. + + Args: + spec: Defines the physical specification of the sorter. + sim_params: Defines simulation parameters for the sorter. + thickness_initializer: Callable which returns the initial thickness for + a layer from a random key and a bounded array with value equal the + thickness from `spec`. + density_initializer: Callable which generates the initial density from + a random key and the seed density. + **seed_density_kwargs: Keyword arguments which set the attributes of + the seed density used to generate the inital parameters. + """ + + self.spec = spec + self.sim_params = sim_params + self.thickness_initializer = thickness_initializer + self.density_initializer = density_initializer + self.grid_shape = (divide_and_round(spec.pitch, sim_params.grid_spacing),) * 2 + + self.seed_density = seed_density( + grid_shape=self.grid_shape, **seed_density_kwargs + ) + self.expansion = basis.generate_expansion( + primitive_lattice_vectors=basis.LatticeVectors( + u=self.spec.pitch * basis.X, + v=self.spec.pitch * basis.Y, + ), + approximate_num_terms=self.sim_params.approximate_num_terms, + truncation=self.sim_params.truncation, + ) + + def init(self, key: jax.Array) -> Params: + """Return the initial parameters for the sorter component.""" + ( + key_thickness_cap, + key_thickness_metasurface, + key_density_metasurface, + key_thickness_spacer, + ) = jax.random.split(key, 4) + params = { + THICKNESS_CAP: self.thickness_initializer( + key_thickness_cap, + types.BoundedArray( + self.spec.thickness_cap, + lower_bound=0.0, + upper_bound=None, + ), + ), + THICKNESS_METASURFACE: self.thickness_initializer( + key_thickness_metasurface, + types.BoundedArray( + self.spec.thickness_metasurface, + lower_bound=0.0, + upper_bound=None, + ), + ), + DENSITY_METASURFACE: self.density_initializer( + key_density_metasurface, self.seed_density + ), + THICKNESS_SPACER: self.thickness_initializer( + key_thickness_spacer, + types.BoundedArray( + self.spec.thickness_spacer, + lower_bound=0.0, + upper_bound=None, + ), + ), + } + # Ensure that there are no weak types in the initial parameters. + return tree_util.tree_map( + lambda x: jnp.asarray(x, jnp.asarray(x).dtype), params + ) + + def response( + self, + params: Params, + *, + wavelength: Optional[float | jnp.ndarray] = None, + polar_angle: Optional[float | jnp.ndarray] = None, + azimuthal_angle: Optional[float | jnp.ndarray] = None, + expansion: Optional[basis.Expansion] = None, + ) -> Tuple[SorterResponse, base.AuxDict]: + """Computes the response of the sorter. + + Args: + params: The parameters defining the sorter, with structure matching that + of the parameters returned by the `init` method. + wavelength: Optional wavelength to override the default in `sim_params`. + polar_angle: Optional polar angle to override the default. + azimuthal_angle: Optional azimuthal angle to override the default. + expansion: Optional expansion to override the default `expansion`. + + Returns: + The `(response, aux)` tuple. + """ + if expansion is None: + expansion = self.expansion + if wavelength is None: + wavelength = self.sim_params.wavelength + if polar_angle is None: + polar_angle = self.sim_params.polar_angle + if azimuthal_angle is None: + azimuthal_angle = self.sim_params.azimuthal_angle + + spec = dataclasses.replace( + self.spec, + thickness_cap=params[THICKNESS_CAP].array, # type: ignore[arg-type] + thickness_metasurface=( + params[THICKNESS_METASURFACE].array # type: ignore[arg-type] + ), + thickness_spacer=params[THICKNESS_SPACER].array, # type: ignore[arg-type] + ) + return simulate_sorter( + density_array=params[DENSITY_METASURFACE].array, # type: ignore[arg-type] + spec=spec, + wavelength=jnp.asarray(wavelength), + polar_angle=jnp.asarray(polar_angle), + azimuthal_angle=jnp.asarray(azimuthal_angle), + expansion=expansion, + formulation=self.sim_params.formulation, + ) + + +def divide_and_round(a: float, b: float) -> int: + """Checks that `a` is nearly evenly divisible by `b`, and returns `a / b`.""" + result = int(jnp.around(a / b)) + if not jnp.isclose(a / b, result): + raise ValueError( + f"`a` must be nearly evenly divisible by `b` spacing, but got `a` " + f"{a} with `b` {b}." + ) + return result + + def seed_density(grid_shape: Tuple[int, int], **kwargs: Any) -> types.Density2DArray: """Return the seed density for a sorter component. @@ -149,7 +306,7 @@ def simulate_sorter( azimuthal_angle: jnp.ndarray, expansion: basis.Expansion, formulation: fmm.Formulation, -) -> Tuple[SorterResponse, Dict[str, Any]]: +) -> Tuple[SorterResponse, base.AuxDict]: """Simulates a sorter component, e.g. a wavelength or polarization sorter. This code is adapted from the fmmax.examples.sorter script. @@ -170,7 +327,7 @@ def simulate_sorter( | | | / substrate --> | q1 | q2 | / |____________|____________|/ - + The sorter is illuminated by plane waves incident from the ambient, and its response consists of the power captured by substrate monitors within each of the quadrants, as well as the power reflected back toward the ambient.