Skip to content

Commit

Permalink
Add sorter component
Browse files Browse the repository at this point in the history
  • Loading branch information
mfschubert committed Nov 8, 2023
1 parent c5f5865 commit a6895a6
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 18 deletions.
5 changes: 2 additions & 3 deletions scripts/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import multiprocessing as mp
import os
import random
import time
from typing import Any, Dict, List, Tuple

from invrs_utils.experiment import sweep
Expand Down Expand Up @@ -103,11 +104,9 @@ def run_work_unit(

# The use of multiprocessing requires that some modules be imported here, as they
# cannot be imported in the main process which is forked.
import time

import invrs_opt
from invrs_utils.experiment import checkpoint
import jax
from invrs_utils.experiment import checkpoint
from jax import numpy as jnp
from totypes import json_utils

Expand Down
14 changes: 7 additions & 7 deletions src/invrs_gym/challenges/ceviche/challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,21 @@
from typing import Any, Optional, Sequence, Tuple

import agjax # type: ignore[import-untyped]
import ceviche_challenges.wdm.model as wdm_model # type: ignore[import-untyped]
import jax
import jax.numpy as jnp
import numpy as onp
from ceviche_challenges import params # type: ignore[import-untyped]
from ceviche_challenges import units as u
from ceviche_challenges.beam_splitter import ( # type: ignore[import-untyped]
model as beam_splitter_model,
from ceviche_challenges.beam_splitter import (
model as beam_splitter_model, # type: ignore[import-untyped]
)
from ceviche_challenges.mode_converter import ( # type: ignore[import-untyped]
model as mode_converter_model,
from ceviche_challenges.mode_converter import (
model as mode_converter_model, # type: ignore[import-untyped]
)
from ceviche_challenges.waveguide_bend import ( # type: ignore[import-untyped]
model as waveguide_bend_model,
from ceviche_challenges.waveguide_bend import (
model as waveguide_bend_model, # type: ignore[import-untyped]
)
import ceviche_challenges.wdm.model as wdm_model # type: ignore[import-untyped]
from jax import tree_util
from totypes import types

Expand Down
173 changes: 165 additions & 8 deletions src/invrs_gym/challenges/sorter/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,24 @@
"""

import dataclasses
from typing import Any, Dict, Tuple
from typing import Any, Callable, Dict, Optional, Tuple

import jax
import jax.numpy as jnp
import numpy as onp
from fmmax import basis, fields, fmm, scattering, utils # type: ignore[import-untyped]
from jax import tree_util
from totypes import types

Params = Any
from invrs_gym.challenges import base

Params = Dict[str, types.BoundedArray | types.Density2DArray]
ThicknessInitializer = Callable[[jax.Array, types.BoundedArray], types.BoundedArray]


DENSITY_METASURFACE = "density_metasurface"
THICKNESS_CAP = "thickness_cap"
THICKNESS_METASURFACE = "thickness_metasurface"
THICKNESS_SPACER = "thickness_spacer"

DENSITY_LOWER_BOUND = 0.0
DENSITY_UPPER_BOUND = 1.0
Expand All @@ -22,7 +30,7 @@
@dataclasses.dataclass
class SorterSpec:
"""Defines the physical specification of a sorter.
Attributes:
permittivity_ambient: Permittivity of the ambient material.
permittivity_cap: Permittivity of the cap layer.
Expand All @@ -37,6 +45,7 @@ class SorterSpec:
offset_monitor_substrate: Offset of the monitor plane from the interface
between spacer and substrate.
"""

permittivity_ambient: complex
permittivity_cap: complex
permittivity_metasurface_solid: complex
Expand All @@ -56,7 +65,7 @@ class SorterSpec:
@dataclasses.dataclass
class SorterSimParams:
"""Parameters that configure the simulation of a sorter.
Attributes:
grid_spacing: The spacing of points on the real-space grid.
wavelength: The wavelength of the excitation.
Expand All @@ -66,6 +75,7 @@ class SorterSimParams:
approximate_num_terms: Defines the number of terms in the Fourier expansion.
truncation: Determines how the Fourier basis is truncated.
"""

grid_spacing: float
wavelength: float | jnp.ndarray
polar_angle: float | jnp.ndarray
Expand All @@ -78,7 +88,7 @@ class SorterSimParams:
@dataclasses.dataclass
class SorterResponse:
"""Contains the response of the sorter.
Attributes:
wavelength: The wavelength for the sorter response.
polar_angle: The polar angle for the sorter response.
Expand All @@ -87,6 +97,7 @@ class SorterResponse:
polarizations (i.e. x, y, x + y, x - y).
reflection: The reflection back to the ambient for the four polarizations.
"""

wavelength: jnp.ndarray
polar_angle: jnp.ndarray
azimuthal_angle: jnp.ndarray
Expand All @@ -110,6 +121,152 @@ class SorterResponse:
)


class SorterComponent(base.Component):
"""Defines a photon extractor component."""

def __init__(
self,
spec: SorterSpec,
sim_params: SorterSimParams,
thickness_initializer: ThicknessInitializer,
density_initializer: base.DensityInitializer,
**seed_density_kwargs: Any,
) -> None:
"""Initializes the sorter component.
Args:
spec: Defines the physical specification of the sorter.
sim_params: Defines simulation parameters for the sorter.
thickness_initializer: Callable which returns the initial thickness for
a layer from a random key and a bounded array with value equal the
thickness from `spec`.
density_initializer: Callable which generates the initial density from
a random key and the seed density.
**seed_density_kwargs: Keyword arguments which set the attributes of
the seed density used to generate the inital parameters.
"""

self.spec = spec
self.sim_params = sim_params
self.thickness_initializer = thickness_initializer
self.density_initializer = density_initializer
self.grid_shape = (divide_and_round(spec.pitch, sim_params.grid_spacing),) * 2

self.seed_density = seed_density(
grid_shape=self.grid_shape, **seed_density_kwargs
)
self.expansion = basis.generate_expansion(
primitive_lattice_vectors=basis.LatticeVectors(
u=self.spec.pitch * basis.X,
v=self.spec.pitch * basis.Y,
),
approximate_num_terms=self.sim_params.approximate_num_terms,
truncation=self.sim_params.truncation,
)

def init(self, key: jax.Array) -> Params:
"""Return the initial parameters for the sorter component."""
(
key_thickness_cap,
key_thickness_metasurface,
key_density_metasurface,
key_thickness_spacer,
) = jax.random.split(key, 4)
params = {
THICKNESS_CAP: self.thickness_initializer(
key_thickness_cap,
types.BoundedArray(
self.spec.thickness_cap,
lower_bound=0.0,
upper_bound=None,
),
),
THICKNESS_METASURFACE: self.thickness_initializer(
key_thickness_metasurface,
types.BoundedArray(
self.spec.thickness_metasurface,
lower_bound=0.0,
upper_bound=None,
),
),
DENSITY_METASURFACE: self.density_initializer(
key_density_metasurface, self.seed_density
),
THICKNESS_SPACER: self.thickness_initializer(
key_thickness_spacer,
types.BoundedArray(
self.spec.thickness_spacer,
lower_bound=0.0,
upper_bound=None,
),
),
}
# Ensure that there are no weak types in the initial parameters.
return tree_util.tree_map(
lambda x: jnp.asarray(x, jnp.asarray(x).dtype), params
)

def response(
self,
params: Params,
*,
wavelength: Optional[float | jnp.ndarray] = None,
polar_angle: Optional[float | jnp.ndarray] = None,
azimuthal_angle: Optional[float | jnp.ndarray] = None,
expansion: Optional[basis.Expansion] = None,
) -> Tuple[SorterResponse, base.AuxDict]:
"""Computes the response of the sorter.
Args:
params: The parameters defining the sorter, with structure matching that
of the parameters returned by the `init` method.
wavelength: Optional wavelength to override the default in `sim_params`.
polar_angle: Optional polar angle to override the default.
azimuthal_angle: Optional azimuthal angle to override the default.
expansion: Optional expansion to override the default `expansion`.
Returns:
The `(response, aux)` tuple.
"""
if expansion is None:
expansion = self.expansion
if wavelength is None:
wavelength = self.sim_params.wavelength
if polar_angle is None:
polar_angle = self.sim_params.polar_angle
if azimuthal_angle is None:
azimuthal_angle = self.sim_params.azimuthal_angle

spec = dataclasses.replace(
self.spec,
thickness_cap=params[THICKNESS_CAP].array, # type: ignore[arg-type]
thickness_metasurface=(
params[THICKNESS_METASURFACE].array # type: ignore[arg-type]
),
thickness_spacer=params[THICKNESS_SPACER].array, # type: ignore[arg-type]
)
return simulate_sorter(
density_array=params[DENSITY_METASURFACE].array, # type: ignore[arg-type]
spec=spec,
wavelength=jnp.asarray(wavelength),
polar_angle=jnp.asarray(polar_angle),
azimuthal_angle=jnp.asarray(azimuthal_angle),
expansion=expansion,
formulation=self.sim_params.formulation,
)


def divide_and_round(a: float, b: float) -> int:
"""Checks that `a` is nearly evenly divisible by `b`, and returns `a / b`."""
result = int(jnp.around(a / b))
if not jnp.isclose(a / b, result):
raise ValueError(
f"`a` must be nearly evenly divisible by `b` spacing, but got `a` "
f"{a} with `b` {b}."
)
return result


def seed_density(grid_shape: Tuple[int, int], **kwargs: Any) -> types.Density2DArray:
"""Return the seed density for a sorter component.
Expand Down Expand Up @@ -149,7 +306,7 @@ def simulate_sorter(
azimuthal_angle: jnp.ndarray,
expansion: basis.Expansion,
formulation: fmm.Formulation,
) -> Tuple[SorterResponse, Dict[str, Any]]:
) -> Tuple[SorterResponse, base.AuxDict]:
"""Simulates a sorter component, e.g. a wavelength or polarization sorter.
This code is adapted from the fmmax.examples.sorter script.
Expand All @@ -170,7 +327,7 @@ def simulate_sorter(
| | | /
substrate --> | q1 | q2 | /
|____________|____________|/
The sorter is illuminated by plane waves incident from the ambient, and its
response consists of the power captured by substrate monitors within each of
the quadrants, as well as the power reflected back toward the ambient.
Expand Down

0 comments on commit a6895a6

Please sign in to comment.