Skip to content

Commit

Permalink
Merge pull request #148 from invrs-io/rot2
Browse files Browse the repository at this point in the history
fix for jax > 0.4.31
  • Loading branch information
mfschubert authored Oct 14, 2024
2 parents f012244 + 65db00f commit d93baa6
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "v1.4.2"
current_version = "v1.4.3"
commit = true
commit_args = "--no-verify"
tag = true
Expand Down
9 changes: 8 additions & 1 deletion .github/workflows/build-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ jobs:
- name: darglint docstring validation
run: darglint src --strictness=short --ignore-raise=ValueError

tests:
tests-misc:
runs-on: ubuntu-latest
timeout-minutes: 3
steps:
- name: Checkout repository
uses: actions/checkout@v4
Expand All @@ -77,6 +78,7 @@ jobs:
tests-bayer:
runs-on: ubuntu-latest
timeout-minutes: 3
steps:
- name: Checkout repository
uses: actions/checkout@v4
Expand All @@ -98,6 +100,7 @@ jobs:

tests-ceviche:
runs-on: ubuntu-latest
timeout-minutes: 3
steps:
- name: Checkout repository
uses: actions/checkout@v4
Expand All @@ -119,6 +122,7 @@ jobs:

tests-diffract:
runs-on: ubuntu-latest
timeout-minutes: 3
steps:
- name: Checkout repository
uses: actions/checkout@v4
Expand All @@ -140,6 +144,7 @@ jobs:

tests-extractor:
runs-on: ubuntu-latest
timeout-minutes: 3
steps:
- name: Checkout repository
uses: actions/checkout@v4
Expand All @@ -161,6 +166,7 @@ jobs:

tests-library:
runs-on: ubuntu-latest
timeout-minutes: 8
steps:
- name: Checkout repository
uses: actions/checkout@v4
Expand All @@ -182,6 +188,7 @@ jobs:

tests-metalens:
runs-on: ubuntu-latest
timeout-minutes: 3
steps:
- name: Checkout repository
uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# invrs-gym
`v1.4.2`
`v1.4.3`

## Overview
The `invrs_gym` package is an open-source gym containing a diverse set of photonic design challenges, which are relevant for a wide range of applications such as AR/VR, optical networking, LIDAR, and others.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]

name = "invrs_gym"
version = "v1.4.2"
version = "v1.4.3"
description = "A collection of inverse design challenges"
keywords = ["topology", "optimization", "jax", "inverse design"]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion src/invrs_gym/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Copyright (c) 2023 The INVRS-IO authors.
"""

__version__ = "v1.4.2"
__version__ = "v1.4.3"
__author__ = "Martin F. Schubert <[email protected]>"

from invrs_gym import challenges as challenges
Expand Down
10 changes: 3 additions & 7 deletions src/invrs_gym/challenges/library/challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import jax
import jax.numpy as jnp
import numpy as onp
from jax import tree_util
from fmmax import basis, fmm
from totypes import types
Expand Down Expand Up @@ -338,12 +337,9 @@ def _all_rotations(num_nanostructures: int) -> jnp.ndarray:

def _rotation_for_idx(idx: jnp.ndarray, num_nanostructures: int) -> jnp.ndarray:
"""Return the array for rotation `i`."""

def _fn(idx):
is_rotated = [int(j) for j in onp.binary_repr(idx, width=num_nanostructures)]
return jnp.asarray(is_rotated).astype(bool)[::-1]

return jax.pure_callback(_fn, jnp.zeros((num_nanostructures,), dtype=bool), idx)
i = jnp.arange(num_nanostructures - 1)
arr = (idx // 2**i) % 2 == 1
return jnp.concatenate([arr, jnp.asarray([False])])


# -----------------------------------------------------------------------------
Expand Down
15 changes: 10 additions & 5 deletions src/invrs_gym/utils/materials.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,24 @@ def permittivity_from_database(
background_extinction_coeff: float,
) -> jnp.ndarray:
"""Return the permittivity for the specified material from the database."""
is_x64 = jax.config.read("jax_enable_x64")

def _jax_fn(wavelength_um: jnp.ndarray) -> jnp.ndarray:
wavelength_nm = 1000 * wavelength_um
def _refractive_index_fn(wavelength_um: jnp.ndarray) -> onp.ndarray:
wavelength_nm = 1000 * onp.asarray(wavelength_um)
try:
epsilon = material.get_epsilon(wavelength_nm)
refractive_index = onp.sqrt(epsilon)
except ri.refractiveindex.NoExtinctionCoefficient:
refractive_index = material.get_refractive_index(wavelength_nm)
epsilon = (refractive_index + 1j * background_extinction_coeff) ** 2
return jnp.asarray(epsilon, dtype=jnp.zeros((), dtype=complex).dtype)
return onp.asarray(
refractive_index, dtype=(onp.complex128 if is_x64 else onp.complex64)
)

result_shape_dtypes = jnp.zeros_like(wavelength_um, dtype=complex)
return jax.pure_callback(_jax_fn, result_shape_dtypes, wavelength_um)
refractive_index = jax.pure_callback(
_refractive_index_fn, result_shape_dtypes, wavelength_um
)
return (refractive_index + 1j * background_extinction_coeff) ** 2


def permittivity_vacuum(
Expand Down
17 changes: 17 additions & 0 deletions tests/challenges/library/test_rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,20 @@ def test_optimal_rotation(self):
optimal_rotation_idx_for_rotated_response,
expected_rotation_idx,
)

def test_rotation_for_idx(self):
num_nanostructures = 8

def expected_fn(idx):
# Reference implementation of `_rotation_for_idx`.
is_rotated = [
int(j) for j in onp.binary_repr(idx, width=num_nanostructures)
]
return onp.asarray(is_rotated).astype(bool)[::-1]

for i in range(128):
expected = expected_fn(idx=i)
result = library_challenge._rotation_for_idx(
idx=i, num_nanostructures=num_nanostructures
)
onp.testing.assert_array_equal(result, expected)

0 comments on commit d93baa6

Please sign in to comment.