From add1f2fb08e5eba547fd24952f088db7cdbc1e22 Mon Sep 17 00:00:00 2001 From: Martin Schubert Date: Thu, 26 Oct 2023 13:27:16 -0700 Subject: [PATCH 1/2] Remove duplcate parameters --- src/invrs_gym/challenges/diffract/common.py | 5 +---- src/invrs_gym/challenges/diffract/metagrating_challenge.py | 1 - src/invrs_gym/challenges/diffract/splitter_challenge.py | 7 +++++-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/invrs_gym/challenges/diffract/common.py b/src/invrs_gym/challenges/diffract/common.py index 04fe08b..e01f290 100644 --- a/src/invrs_gym/challenges/diffract/common.py +++ b/src/invrs_gym/challenges/diffract/common.py @@ -152,7 +152,6 @@ def index_for_order( def grating_efficiency( density_array: jnp.ndarray, - thickness: jnp.ndarray, spec: GratingSpec, wavelength: jnp.ndarray, polarization: str, @@ -166,8 +165,6 @@ def grating_efficiency( Args: density_array: Defines the pattern of the grating layer. - thickness: The thickness of the grating layer. This overrides the grating - layer thickness given in `spec`. spec: Defines the physical specifcation of the grating. wavelength: The wavelength of the excitation. polarization: The polarization of the excitation, TE or TM. @@ -210,7 +207,7 @@ def grating_efficiency( # Layer thicknesses for the ambient and substrate are set to zero; these do not # affect the result of the calculation. - layer_thicknesses = (jnp.zeros(()), jnp.asarray(thickness), jnp.zeros(())) + layer_thicknesses = (jnp.zeros(()), jnp.asarray(spec.thickness_grating), jnp.zeros(())) s_matrix = scattering.stack_s_matrix(layer_solve_results, layer_thicknesses) diff --git a/src/invrs_gym/challenges/diffract/metagrating_challenge.py b/src/invrs_gym/challenges/diffract/metagrating_challenge.py index ba70ece..6e30a45 100644 --- a/src/invrs_gym/challenges/diffract/metagrating_challenge.py +++ b/src/invrs_gym/challenges/diffract/metagrating_challenge.py @@ -88,7 +88,6 @@ def response( wavelength = self.sim_params.wavelength transmission_efficiency, reflection_efficiency = common.grating_efficiency( density_array=params.array, # type: ignore[arg-type] - thickness=jnp.asarray(self.spec.thickness_grating), spec=self.spec, wavelength=jnp.asarray(wavelength), polarization=self.sim_params.polarization, diff --git a/src/invrs_gym/challenges/diffract/splitter_challenge.py b/src/invrs_gym/challenges/diffract/splitter_challenge.py index 7b0b117..05be8e0 100644 --- a/src/invrs_gym/challenges/diffract/splitter_challenge.py +++ b/src/invrs_gym/challenges/diffract/splitter_challenge.py @@ -114,10 +114,13 @@ def response( expansion = self.expansion if wavelength is None: wavelength = self.sim_params.wavelength + spec = dataclasses.replace( + self.spec, + thickness_grating=params[THICKNESS].array, + ) transmission_efficiency, reflection_efficiency = common.grating_efficiency( density_array=params[DENSITY].array, # type: ignore[arg-type] - thickness=params[THICKNESS].array, # type: ignore[arg-type] - spec=self.spec, + spec=spec, wavelength=jnp.asarray(wavelength), polarization=self.sim_params.polarization, expansion=expansion, From ded2783b34fb5af877c1e12c21c4771137dafeb8 Mon Sep 17 00:00:00 2001 From: Martin Schubert Date: Thu, 26 Oct 2023 13:50:25 -0700 Subject: [PATCH 2/2] Fix types --- src/invrs_gym/challenges/diffract/common.py | 8 ++++++-- src/invrs_gym/challenges/diffract/splitter_challenge.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/invrs_gym/challenges/diffract/common.py b/src/invrs_gym/challenges/diffract/common.py index e01f290..6e1392a 100644 --- a/src/invrs_gym/challenges/diffract/common.py +++ b/src/invrs_gym/challenges/diffract/common.py @@ -48,7 +48,7 @@ class GratingSpec: permittivity_encapsulation: complex permittivity_substrate: complex - thickness_grating: float + thickness_grating: float | jnp.ndarray period_x: float period_y: float @@ -207,7 +207,11 @@ def grating_efficiency( # Layer thicknesses for the ambient and substrate are set to zero; these do not # affect the result of the calculation. - layer_thicknesses = (jnp.zeros(()), jnp.asarray(spec.thickness_grating), jnp.zeros(())) + layer_thicknesses = ( + jnp.zeros(()), + jnp.asarray(spec.thickness_grating), + jnp.zeros(()), + ) s_matrix = scattering.stack_s_matrix(layer_solve_results, layer_thicknesses) diff --git a/src/invrs_gym/challenges/diffract/splitter_challenge.py b/src/invrs_gym/challenges/diffract/splitter_challenge.py index 05be8e0..b7de8b1 100644 --- a/src/invrs_gym/challenges/diffract/splitter_challenge.py +++ b/src/invrs_gym/challenges/diffract/splitter_challenge.py @@ -116,7 +116,7 @@ def response( wavelength = self.sim_params.wavelength spec = dataclasses.replace( self.spec, - thickness_grating=params[THICKNESS].array, + thickness_grating=jnp.asarray(params[THICKNESS].array), ) transmission_efficiency, reflection_efficiency = common.grating_efficiency( density_array=params[DENSITY].array, # type: ignore[arg-type]