Skip to content

Commit

Permalink
Merge pull request #141 from invrs-io/mirror
Browse files Browse the repository at this point in the history
Updated meta-atom library challenge
  • Loading branch information
mfschubert authored Aug 14, 2024
2 parents 79ce010 + 081559c commit 71fc881
Show file tree
Hide file tree
Showing 7 changed files with 305 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "v1.2.0"
current_version = "v1.3.0"
commit = true
commit_args = "--no-verify"
tag = true
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# invrs-gym
`v1.2.0`
`v1.3.0`

## Overview
The `invrs_gym` package is an open-source gym containing a diverse set of photonic design challenges, which are relevant for a wide range of applications such as AR/VR, optical networking, LIDAR, and others.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]

name = "invrs_gym"
version = "v1.2.0"
version = "v1.3.0"
description = "A collection of inverse design challenges"
keywords = ["topology", "optimization", "jax", "inverse design"]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion src/invrs_gym/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Copyright (c) 2023 The INVRS-IO authors.
"""

__version__ = "v1.2.0"
__version__ = "v1.3.0"
__author__ = "Martin F. Schubert <[email protected]>"

from invrs_gym import challenges as challenges
Expand Down
202 changes: 197 additions & 5 deletions src/invrs_gym/challenges/library/challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import jax
import jax.numpy as jnp
import numpy as onp
from jax import tree_util
from fmmax import basis, fmm
from totypes import types

Expand All @@ -32,7 +34,8 @@ class LibraryChallenge(base.Challenge):

def loss(self, response: library_component.LibraryResponse) -> jnp.ndarray:
"""Compute a scalar loss from the component `response`."""
(efficiency_rhcp, efficiency_lhcp), _ = _metagrating_efficiency(
response = response_with_optimal_rotation(response, self.component.spec)
(efficiency_rhcp, efficiency_lhcp), _ = metagrating_efficiency(
response, self.component.spec
)
loss_rhcp = jnp.sum(jnp.abs(1 - efficiency_rhcp)) ** 2
Expand All @@ -57,10 +60,11 @@ def eval_metric(
Returns:
The scalar eval metric.
"""
response = response_with_optimal_rotation(response, self.component.spec)
(
_,
(relative_efficiency_rhcp, relative_efficiency_lhcp),
) = _metagrating_efficiency(response, self.component.spec)
) = metagrating_efficiency(response, self.component.spec)
return jnp.minimum(
jnp.amin(relative_efficiency_rhcp),
jnp.amin(relative_efficiency_lhcp),
Expand Down Expand Up @@ -96,10 +100,11 @@ def metrics(
- Average metagrating relative efficiency.
"""
metrics = super().metrics(response, params, aux)
response = response_with_optimal_rotation(response, self.component.spec)
(
(efficiency_rhcp, efficiency_lhcp),
(relative_efficiency_rhcp, relative_efficiency_lhcp),
) = _metagrating_efficiency(response, self.component.spec)
) = metagrating_efficiency(response, self.component.spec)
metrics.update(
{
METAGRATING_EFFICIENCY_RHCP: efficiency_rhcp,
Expand All @@ -117,7 +122,7 @@ def metrics(
return metrics


def _metagrating_efficiency(
def metagrating_efficiency(
response: library_component.LibraryResponse,
spec: library_component.LibrarySpec,
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:
Expand Down Expand Up @@ -165,6 +170,187 @@ def _metagrating_efficiency(
)


# -----------------------------------------------------------------------------
# Functions related to nanostructure rotations.
# -----------------------------------------------------------------------------


def response_with_optimal_rotation(
response: library_component.LibraryResponse,
spec: library_component.LibrarySpec,
) -> library_component.LibraryResponse:
"""Return a modified response with the optimal per-nanostructure rotation.
The optimal rotation is the one which yields the highest relative efficiency
averaged across all wavelengths for a diffraction grating built from the meta
atom library.
Args:
response: The original response.
spec: The physical specification of the meta-atom library.
Returns:
The response for the optimal rotation.
"""
rotation_idx = optimal_rotation(response, spec)
responses = _all_rotations_response(response)
return tree_util.tree_map(lambda x: x[rotation_idx, ...], responses)


def optimal_rotation(
response: library_component.LibraryResponse,
spec: library_component.LibrarySpec,
) -> jnp.ndarray:
"""Return an integer representing the optimal per-nanostructure rotation.
The optimal rotation is the one which yields the highest relative efficiency
averaged across all wavelengths for a diffraction grating built from the meta
atom library.
Args:
response: The original response.
spec: The physical specification of the meta-atom library.
Returns:
The integer representing the optimal rotation. If n-th digit of the reversed
binary representation of the optimal representation is 1, this indicates that
the n-th nanostructure should be rotated.
"""
responses = _all_rotations_response(response)
(
_,
(relative_efficiency_rhcp, relative_efficiency_lhcp),
) = jax.vmap(
metagrating_efficiency, in_axes=(0, None)
)(responses, spec)

# Compute average relative efficiency across wavelength and incident polarization.
relative_efficiency = 0.5 * (relative_efficiency_rhcp + relative_efficiency_lhcp)
relative_efficiency = jnp.mean(
relative_efficiency,
axis=tuple(range(1, relative_efficiency.ndim)),
)
return jnp.argmax(relative_efficiency)


def rotate_params(
params: library_component.Params,
rotation_idx: jnp.ndarray,
) -> library_component.Params:
"""Applies the specified rotation to nanostructures in `params`.
Args:
params: The params to be rotated.
rotation_idx: The integer representing the rotation. If n-th digit of the
reversed binary representation of the representation is 1, this indicates
that the n-th nanostructure should be rotated.
Returns:
The parameters with rotated nanostructures.
"""
rotation_idx = jnp.asarray(rotation_idx)
assert rotation_idx.shape == ()
density: types.Density2DArray
density = params[library_component.DENSITY] # type: ignore[assignment]
num_nanostructures = density.shape[0]
is_rotated = _rotation_for_idx(rotation_idx, num_nanostructures)

array = jnp.stack(
[
jnp.rot90(x) if is_rotated_i else x
for x, is_rotated_i in zip(density.array, is_rotated, strict=True)
],
axis=0,
)
return {
library_component.THICKNESS: params[library_component.THICKNESS],
library_component.DENSITY: dataclasses.replace(density, array=array),
}


def rotation_idx_from_is_rotated(is_rotated: jnp.ndarray) -> jnp.ndarray:
"""Return the rotation index for given `is_rotated`."""
assert is_rotated.ndim == 1
assert is_rotated.dtype == bool
return jnp.sum(is_rotated * 2 ** jnp.arange(is_rotated.size)).astype(int)


def _all_rotations_response(
response: library_component.LibraryResponse,
) -> library_component.LibraryResponse:
"""Return a batch of responses corresponding to all possible unique rotations."""
num_nanostructures = response.transmission_rhcp.shape[0]
is_rotated_nanostructure = _all_rotations(num_nanostructures)
responses = jax.vmap(_rotate_response, in_axes=(None, 0))(
response, is_rotated_nanostructure
)
return responses


def _rotate_response(
response: library_component.LibraryResponse,
is_rotated: jnp.ndarray,
) -> library_component.LibraryResponse:
"""Return response that results from optional 90 degree rotation.
Args:
response: The response to be modified.
is_rotated: Array whose elements indicate whether specific meta-atoms are to
be rotated. Must have length equal to size of the meta-atom library (i.e.
the number of nanostructures in the library).
Returns:
The response for the specified rotation.
"""
assert is_rotated.ndim == 1
assert is_rotated.size == response.transmission_rhcp.shape[0]

shape = is_rotated.shape + (1,) * (response.transmission_rhcp.ndim - 2)
ones = jnp.ones(shape)
shifted = jnp.where(is_rotated.reshape(shape), -ones, ones)

rhcp_phase = jnp.stack([ones, shifted], axis=-1)
transmission_rhcp = response.transmission_rhcp * rhcp_phase
reflection_rhcp = response.reflection_rhcp * rhcp_phase

lhcp_phase = jnp.stack([shifted, ones], axis=-1)
transmission_lhcp = response.transmission_lhcp * lhcp_phase
reflection_lhcp = response.reflection_lhcp * lhcp_phase

return dataclasses.replace(
response,
transmission_rhcp=transmission_rhcp,
transmission_lhcp=transmission_lhcp,
reflection_rhcp=reflection_rhcp,
reflection_lhcp=reflection_lhcp,
)


def _all_rotations(num_nanostructures: int) -> jnp.ndarray:
"""Return all unique rotations, accounting for degeneracy."""
num = 2 ** (num_nanostructures - 1)
is_rotated_nanostructure = jnp.stack(
[_rotation_for_idx(jnp.asarray(i), num_nanostructures) for i in range(num)],
axis=0,
)
return is_rotated_nanostructure


def _rotation_for_idx(idx: jnp.ndarray, num_nanostructures: int) -> jnp.ndarray:
"""Return the array for rotation `i`."""

def _fn(idx):
is_rotated = [int(j) for j in onp.binary_repr(idx, width=num_nanostructures)]
return jnp.asarray(is_rotated).astype(bool)[::-1]

return jax.pure_callback(_fn, jnp.zeros((num_nanostructures,), dtype=bool), idx)


# -----------------------------------------------------------------------------
# Define the challenge.
# -----------------------------------------------------------------------------


def library_density_initializer(
key: jax.Array,
seed_density: types.Density2DArray,
Expand Down Expand Up @@ -212,7 +398,7 @@ def library_density_initializer(
truncation=basis.Truncation.CIRCULAR,
)

SYMMETRIES = ("reflection_n_s", "reflection_e_w")
SYMMETRIES = ("reflection_n_s",)


def meta_atom_library(
Expand All @@ -233,6 +419,12 @@ def meta_atom_library(
reaching broadband 90% relative diffraction efficiency" by Chen et al.
https://www.nature.com/articles/s41467-023-38185-2
In the library design challenge, all possible rotations of the individual meta-
atoms are considered. The optimal rotation (i.e. that which yield the highest
wavelength-averaged relative efficiency) is used to compute evaluation metrics
and loss. The optimal rotations can be found and applied using the
`optimal_rotation` and `rotate_params` functions in this module.
Args:
minimum_width: The minimum width target for the challenge, in pixels.
minimum_spacing: The minimum spacing target for the challenge, in pixels.
Expand Down
4 changes: 1 addition & 3 deletions tests/challenges/library/test_challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,7 @@ def test_density_has_expected_attrs(self, min_width, min_spacing):
self.assertEqual(density.lower_bound, 0.0)
self.assertEqual(density.upper_bound, 1.0)
self.assertSequenceEqual(density.periodic, (False, False))
self.assertSequenceEqual(
density.symmetries, ("reflection_n_s", "reflection_e_w")
)
self.assertSequenceEqual(density.symmetries, ("reflection_n_s",))
self.assertEqual(density.minimum_width, min_width)
self.assertEqual(density.minimum_spacing, min_spacing)
self.assertIsNone(density.fixed_solid)
Expand Down
Loading

0 comments on commit 71fc881

Please sign in to comment.