Skip to content

Commit

Permalink
Update for mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
mfschubert committed Nov 17, 2023
1 parent 2d053cc commit 36319e8
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 34 deletions.
41 changes: 16 additions & 25 deletions src/invrs_gym/challenges/ceviche/challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,8 @@
import jax
import jax.numpy as jnp
import numpy as onp
from ceviche_challenges import params # type: ignore[import-untyped]
from ceviche_challenges import units as u
from ceviche_challenges.beam_splitter import ( # type: ignore[import-untyped]
model as beam_splitter_model,
)
from ceviche_challenges.mode_converter import ( # type: ignore[import-untyped]
model as mode_converter_model,
)
from ceviche_challenges.waveguide_bend import ( # type: ignore[import-untyped]
model as waveguide_bend_model,
)
import ceviche_challenges as cc # type: ignore[import-untyped]
from ceviche_challenges import units as u # type: ignore[import-untyped]
from jax import tree_util
from totypes import types

Expand Down Expand Up @@ -318,8 +309,8 @@ def beam_splitter(
"""
return CevicheChallenge(
component=CevicheComponent(
ceviche_model=beam_splitter_model.BeamSplitterModel(
params=params.CevicheSimParams(
ceviche_model=cc.beam_splitter.model.BeamSplitterModel(
params=cc.params.CevicheSimParams(
resolution=resolution_nm * u.nm,
wavelengths=u.Array(wavelengths_nm, u.nm),
),
Expand Down Expand Up @@ -360,8 +351,8 @@ def lightweight_beam_splitter(
"""
return CevicheChallenge(
component=CevicheComponent(
ceviche_model=beam_splitter_model.BeamSplitterModel(
params=params.CevicheSimParams(
ceviche_model=cc.beam_splitter.model.BeamSplitterModel(
params=cc.params.CevicheSimParams(
resolution=resolution_nm * u.nm,
wavelengths=u.Array(wavelengths_nm, u.nm),
),
Expand Down Expand Up @@ -402,8 +393,8 @@ def mode_converter(
"""
return CevicheChallenge(
component=CevicheComponent(
ceviche_model=mode_converter_model.ModeConverterModel(
params=params.CevicheSimParams(
ceviche_model=cc.mode_converter.model.ModeConverterModel(
params=cc.params.CevicheSimParams(
resolution=resolution_nm * u.nm,
wavelengths=u.Array(wavelengths_nm, u.nm),
),
Expand Down Expand Up @@ -443,8 +434,8 @@ def lightweight_mode_converter(
"""
return CevicheChallenge(
component=CevicheComponent(
ceviche_model=mode_converter_model.ModeConverterModel(
params=params.CevicheSimParams(
ceviche_model=cc.mode_converter.model.ModeConverterModel(
params=cc.params.CevicheSimParams(
resolution=resolution_nm * u.nm,
wavelengths=u.Array(wavelengths_nm, u.nm),
),
Expand Down Expand Up @@ -484,8 +475,8 @@ def waveguide_bend(
"""
return CevicheChallenge(
component=CevicheComponent(
ceviche_model=waveguide_bend_model.WaveguideBendModel(
params=params.CevicheSimParams(
ceviche_model=cc.waveguide_bend.model.WaveguideBendModel(
params=cc.params.CevicheSimParams(
resolution=resolution_nm * u.nm,
wavelengths=u.Array(wavelengths_nm, u.nm),
),
Expand Down Expand Up @@ -526,8 +517,8 @@ def lightweight_waveguide_bend(
"""
return CevicheChallenge(
component=CevicheComponent(
ceviche_model=waveguide_bend_model.WaveguideBendModel(
params=params.CevicheSimParams(
ceviche_model=cc.waveguide_bend.model.WaveguideBendModel(
params=cc.params.CevicheSimParams(
resolution=resolution_nm * u.nm,
wavelengths=u.Array(wavelengths_nm, u.nm),
),
Expand Down Expand Up @@ -569,7 +560,7 @@ def wdm(
return CevicheChallenge(
component=CevicheComponent(
ceviche_model=wdm_model.WdmModel(
params=params.CevicheSimParams(
params=cc.params.CevicheSimParams(
resolution=resolution_nm * u.nm,
wavelengths=u.Array(wavelengths_nm, u.nm),
),
Expand Down Expand Up @@ -613,7 +604,7 @@ def lightweight_wdm(
return CevicheChallenge(
component=CevicheComponent(
ceviche_model=wdm_model.WdmModel(
params=params.CevicheSimParams(
params=cc.params.CevicheSimParams(
resolution=resolution_nm * u.nm,
wavelengths=u.Array(wavelengths_nm, u.nm),
),
Expand Down
22 changes: 13 additions & 9 deletions src/invrs_gym/challenges/extractor/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,15 +362,15 @@ def eigensolve_pml(permittivity: jnp.ndarray) -> fmm.LayerSolveResult:
)
solve_result_oxide = eigensolve_pml(
permittivity=utils.interpolate_permittivity(
permittivity_solid=spec.permittivity_oxide,
permittivity_void=spec.permittivity_ambient,
permittivity_solid=jnp.asarray(spec.permittivity_oxide),
permittivity_void=jnp.asarray(spec.permittivity_ambient),
density=density_array,
),
)
solve_result_extractor = eigensolve_pml(
permittivity=utils.interpolate_permittivity(
permittivity_solid=spec.permittivity_extractor,
permittivity_void=spec.permittivity_ambient,
permittivity_solid=jnp.asarray(spec.permittivity_extractor),
permittivity_void=jnp.asarray(spec.permittivity_ambient),
density=density_array,
),
)
Expand Down Expand Up @@ -485,12 +485,14 @@ def eigensolve_pml(permittivity: jnp.ndarray) -> fmm.LayerSolveResult:
# Compute the Poynting flux in the layer before the source, at the monitor.
fwd_amplitude_before_monitor = fields.propagate_amplitude(
amplitude=fwd_amplitude_before_start,
distance=spec.thickness_substrate_before_source - spec.offset_monitor_source,
distance=jnp.asarray(
spec.thickness_substrate_before_source - spec.offset_monitor_source
),
layer_solve_result=solve_result_substrate,
)
bwd_amplitude_before_monitor = fields.propagate_amplitude(
amplitude=bwd_amplitude_before_end,
distance=spec.offset_monitor_source,
distance=jnp.asarray(spec.offset_monitor_source),
layer_solve_result=solve_result_substrate,
)
fwd_flux_before_monitor, bwd_flux_before_monitor = fields.directional_poynting_flux(
Expand All @@ -502,12 +504,14 @@ def eigensolve_pml(permittivity: jnp.ndarray) -> fmm.LayerSolveResult:
# Compute the Poynting flux in the layer after the source, at the monitor.
fwd_amplitude_after_monitor = fields.propagate_amplitude(
amplitude=fwd_amplitude_after_start,
distance=spec.offset_monitor_source,
distance=jnp.asarray(spec.offset_monitor_source),
layer_solve_result=solve_result_substrate,
)
bwd_amplitude_after_monitor = fields.propagate_amplitude(
amplitude=bwd_amplitude_after_end,
distance=spec.thickness_substrate_after_source - spec.offset_monitor_source,
distance=jnp.asarray(
spec.thickness_substrate_after_source - spec.offset_monitor_source
),
layer_solve_result=solve_result_substrate,
)
fwd_flux_after_monitor, bwd_flux_after_monitor = fields.directional_poynting_flux(
Expand Down Expand Up @@ -535,7 +539,7 @@ def eigensolve_pml(permittivity: jnp.ndarray) -> fmm.LayerSolveResult:
# Compute the eigenmode amplitudes at the ambient flux monitor.
bwd_amplitude_ambient_monitor = fields.propagate_amplitude(
amplitude=bwd_amplitude_ambient_end,
distance=spec.offset_monitor_ambient,
distance=jnp.asarray(spec.offset_monitor_ambient),
layer_solve_result=solve_result_ambient,
)
_, bwd_flux_ambient_monitor = fields.directional_poynting_flux(
Expand Down

0 comments on commit 36319e8

Please sign in to comment.