Skip to content

Commit

Permalink
Merge pull request #35 from invrs-io/issues
Browse files Browse the repository at this point in the history
Remove duplicate parameters
  • Loading branch information
mfschubert authored Oct 26, 2023
2 parents dbb46da + ded2783 commit eb25efa
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
11 changes: 6 additions & 5 deletions src/invrs_gym/challenges/diffract/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class GratingSpec:
permittivity_encapsulation: complex
permittivity_substrate: complex

thickness_grating: float
thickness_grating: float | jnp.ndarray

period_x: float
period_y: float
Expand Down Expand Up @@ -152,7 +152,6 @@ def index_for_order(

def grating_efficiency(
density_array: jnp.ndarray,
thickness: jnp.ndarray,
spec: GratingSpec,
wavelength: jnp.ndarray,
polarization: str,
Expand All @@ -166,8 +165,6 @@ def grating_efficiency(
Args:
density_array: Defines the pattern of the grating layer.
thickness: The thickness of the grating layer. This overrides the grating
layer thickness given in `spec`.
spec: Defines the physical specifcation of the grating.
wavelength: The wavelength of the excitation.
polarization: The polarization of the excitation, TE or TM.
Expand Down Expand Up @@ -210,7 +207,11 @@ def grating_efficiency(

# Layer thicknesses for the ambient and substrate are set to zero; these do not
# affect the result of the calculation.
layer_thicknesses = (jnp.zeros(()), jnp.asarray(thickness), jnp.zeros(()))
layer_thicknesses = (
jnp.zeros(()),
jnp.asarray(spec.thickness_grating),
jnp.zeros(()),
)

s_matrix = scattering.stack_s_matrix(layer_solve_results, layer_thicknesses)

Expand Down
1 change: 0 additions & 1 deletion src/invrs_gym/challenges/diffract/metagrating_challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def response(
wavelength = self.sim_params.wavelength
transmission_efficiency, reflection_efficiency = common.grating_efficiency(
density_array=params.array, # type: ignore[arg-type]
thickness=jnp.asarray(self.spec.thickness_grating),
spec=self.spec,
wavelength=jnp.asarray(wavelength),
polarization=self.sim_params.polarization,
Expand Down
7 changes: 5 additions & 2 deletions src/invrs_gym/challenges/diffract/splitter_challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,13 @@ def response(
expansion = self.expansion
if wavelength is None:
wavelength = self.sim_params.wavelength
spec = dataclasses.replace(
self.spec,
thickness_grating=jnp.asarray(params[THICKNESS].array),
)
transmission_efficiency, reflection_efficiency = common.grating_efficiency(
density_array=params[DENSITY].array, # type: ignore[arg-type]
thickness=params[THICKNESS].array, # type: ignore[arg-type]
spec=self.spec,
spec=spec,
wavelength=jnp.asarray(wavelength),
polarization=self.sim_params.polarization,
expansion=expansion,
Expand Down

0 comments on commit eb25efa

Please sign in to comment.