From 36319e8b44b6d74b6be6669f8ed16e44ca610226 Mon Sep 17 00:00:00 2001 From: Martin Schubert Date: Fri, 17 Nov 2023 14:26:18 -0800 Subject: [PATCH] Update for mypy --- src/invrs_gym/challenges/ceviche/challenge.py | 41 ++++++++----------- .../challenges/extractor/component.py | 22 ++++++---- 2 files changed, 29 insertions(+), 34 deletions(-) diff --git a/src/invrs_gym/challenges/ceviche/challenge.py b/src/invrs_gym/challenges/ceviche/challenge.py index 8f7964e..807c4f7 100644 --- a/src/invrs_gym/challenges/ceviche/challenge.py +++ b/src/invrs_gym/challenges/ceviche/challenge.py @@ -12,17 +12,8 @@ import jax import jax.numpy as jnp import numpy as onp -from ceviche_challenges import params # type: ignore[import-untyped] -from ceviche_challenges import units as u -from ceviche_challenges.beam_splitter import ( # type: ignore[import-untyped] - model as beam_splitter_model, -) -from ceviche_challenges.mode_converter import ( # type: ignore[import-untyped] - model as mode_converter_model, -) -from ceviche_challenges.waveguide_bend import ( # type: ignore[import-untyped] - model as waveguide_bend_model, -) +import ceviche_challenges as cc # type: ignore[import-untyped] +from ceviche_challenges import units as u # type: ignore[import-untyped] from jax import tree_util from totypes import types @@ -318,8 +309,8 @@ def beam_splitter( """ return CevicheChallenge( component=CevicheComponent( - ceviche_model=beam_splitter_model.BeamSplitterModel( - params=params.CevicheSimParams( + ceviche_model=cc.beam_splitter.model.BeamSplitterModel( + params=cc.params.CevicheSimParams( resolution=resolution_nm * u.nm, wavelengths=u.Array(wavelengths_nm, u.nm), ), @@ -360,8 +351,8 @@ def lightweight_beam_splitter( """ return CevicheChallenge( component=CevicheComponent( - ceviche_model=beam_splitter_model.BeamSplitterModel( - params=params.CevicheSimParams( + ceviche_model=cc.beam_splitter.model.BeamSplitterModel( + params=cc.params.CevicheSimParams( resolution=resolution_nm * u.nm, wavelengths=u.Array(wavelengths_nm, u.nm), ), @@ -402,8 +393,8 @@ def mode_converter( """ return CevicheChallenge( component=CevicheComponent( - ceviche_model=mode_converter_model.ModeConverterModel( - params=params.CevicheSimParams( + ceviche_model=cc.mode_converter.model.ModeConverterModel( + params=cc.params.CevicheSimParams( resolution=resolution_nm * u.nm, wavelengths=u.Array(wavelengths_nm, u.nm), ), @@ -443,8 +434,8 @@ def lightweight_mode_converter( """ return CevicheChallenge( component=CevicheComponent( - ceviche_model=mode_converter_model.ModeConverterModel( - params=params.CevicheSimParams( + ceviche_model=cc.mode_converter.model.ModeConverterModel( + params=cc.params.CevicheSimParams( resolution=resolution_nm * u.nm, wavelengths=u.Array(wavelengths_nm, u.nm), ), @@ -484,8 +475,8 @@ def waveguide_bend( """ return CevicheChallenge( component=CevicheComponent( - ceviche_model=waveguide_bend_model.WaveguideBendModel( - params=params.CevicheSimParams( + ceviche_model=cc.waveguide_bend.model.WaveguideBendModel( + params=cc.params.CevicheSimParams( resolution=resolution_nm * u.nm, wavelengths=u.Array(wavelengths_nm, u.nm), ), @@ -526,8 +517,8 @@ def lightweight_waveguide_bend( """ return CevicheChallenge( component=CevicheComponent( - ceviche_model=waveguide_bend_model.WaveguideBendModel( - params=params.CevicheSimParams( + ceviche_model=cc.waveguide_bend.model.WaveguideBendModel( + params=cc.params.CevicheSimParams( resolution=resolution_nm * u.nm, wavelengths=u.Array(wavelengths_nm, u.nm), ), @@ -569,7 +560,7 @@ def wdm( return CevicheChallenge( component=CevicheComponent( ceviche_model=wdm_model.WdmModel( - params=params.CevicheSimParams( + params=cc.params.CevicheSimParams( resolution=resolution_nm * u.nm, wavelengths=u.Array(wavelengths_nm, u.nm), ), @@ -613,7 +604,7 @@ def lightweight_wdm( return CevicheChallenge( component=CevicheComponent( ceviche_model=wdm_model.WdmModel( - params=params.CevicheSimParams( + params=cc.params.CevicheSimParams( resolution=resolution_nm * u.nm, wavelengths=u.Array(wavelengths_nm, u.nm), ), diff --git a/src/invrs_gym/challenges/extractor/component.py b/src/invrs_gym/challenges/extractor/component.py index 0a20101..3318b0a 100644 --- a/src/invrs_gym/challenges/extractor/component.py +++ b/src/invrs_gym/challenges/extractor/component.py @@ -362,15 +362,15 @@ def eigensolve_pml(permittivity: jnp.ndarray) -> fmm.LayerSolveResult: ) solve_result_oxide = eigensolve_pml( permittivity=utils.interpolate_permittivity( - permittivity_solid=spec.permittivity_oxide, - permittivity_void=spec.permittivity_ambient, + permittivity_solid=jnp.asarray(spec.permittivity_oxide), + permittivity_void=jnp.asarray(spec.permittivity_ambient), density=density_array, ), ) solve_result_extractor = eigensolve_pml( permittivity=utils.interpolate_permittivity( - permittivity_solid=spec.permittivity_extractor, - permittivity_void=spec.permittivity_ambient, + permittivity_solid=jnp.asarray(spec.permittivity_extractor), + permittivity_void=jnp.asarray(spec.permittivity_ambient), density=density_array, ), ) @@ -485,12 +485,14 @@ def eigensolve_pml(permittivity: jnp.ndarray) -> fmm.LayerSolveResult: # Compute the Poynting flux in the layer before the source, at the monitor. fwd_amplitude_before_monitor = fields.propagate_amplitude( amplitude=fwd_amplitude_before_start, - distance=spec.thickness_substrate_before_source - spec.offset_monitor_source, + distance=jnp.asarray( + spec.thickness_substrate_before_source - spec.offset_monitor_source + ), layer_solve_result=solve_result_substrate, ) bwd_amplitude_before_monitor = fields.propagate_amplitude( amplitude=bwd_amplitude_before_end, - distance=spec.offset_monitor_source, + distance=jnp.asarray(spec.offset_monitor_source), layer_solve_result=solve_result_substrate, ) fwd_flux_before_monitor, bwd_flux_before_monitor = fields.directional_poynting_flux( @@ -502,12 +504,14 @@ def eigensolve_pml(permittivity: jnp.ndarray) -> fmm.LayerSolveResult: # Compute the Poynting flux in the layer after the source, at the monitor. fwd_amplitude_after_monitor = fields.propagate_amplitude( amplitude=fwd_amplitude_after_start, - distance=spec.offset_monitor_source, + distance=jnp.asarray(spec.offset_monitor_source), layer_solve_result=solve_result_substrate, ) bwd_amplitude_after_monitor = fields.propagate_amplitude( amplitude=bwd_amplitude_after_end, - distance=spec.thickness_substrate_after_source - spec.offset_monitor_source, + distance=jnp.asarray( + spec.thickness_substrate_after_source - spec.offset_monitor_source + ), layer_solve_result=solve_result_substrate, ) fwd_flux_after_monitor, bwd_flux_after_monitor = fields.directional_poynting_flux( @@ -535,7 +539,7 @@ def eigensolve_pml(permittivity: jnp.ndarray) -> fmm.LayerSolveResult: # Compute the eigenmode amplitudes at the ambient flux monitor. bwd_amplitude_ambient_monitor = fields.propagate_amplitude( amplitude=bwd_amplitude_ambient_end, - distance=spec.offset_monitor_ambient, + distance=jnp.asarray(spec.offset_monitor_ambient), layer_solve_result=solve_result_ambient, ) _, bwd_flux_ambient_monitor = fields.directional_poynting_flux(