diff --git a/src/invrs_gym/challenges/extractor/component.py b/src/invrs_gym/challenges/extractor/component.py index 8beb16c..aed5cf1 100644 --- a/src/invrs_gym/challenges/extractor/component.py +++ b/src/invrs_gym/challenges/extractor/component.py @@ -453,12 +453,17 @@ def eigensolve_pml(permittivity: jnp.ndarray) -> fmm.LayerSolveResult: s_matrix_before_source_no_substrate = scattering.stack_s_matrix( layer_solve_results=( - solve_result_ambient, # ambient - solve_result_ambient, # oxide - solve_result_ambient, # extractor + solve_result_ambient, # ambient + oxide + extractor solve_result_substrate, ), - layer_thicknesses=thicknesses_before_source, + layer_thicknesses=( + jnp.asarray( + spec.thickness_ambient + + spec.thickness_oxide + + spec.thickness_extractor + ), + jnp.asarray(spec.thickness_substrate_before_source), + ), ) # Generate the Fourier representation of x, y, and z-oriented point dipoles. @@ -561,10 +566,16 @@ def compute_power( # Total extracted power measured at a monitor above the extractor. # ------------------------------------------------------------------------- + with jax.ensure_compile_time_eval(): + print(s_matrix_before_source.start_layer_thickness) + # Compute the eigenmode amplitudes at the ambient flux monitor. bwd_amplitude_ambient_monitor = fields.propagate_amplitude( amplitude=bwd_amplitude_ambient_end, - distance=jnp.asarray(spec.offset_monitor_ambient), + distance=jnp.asarray( + s_matrix_before_source.start_layer_thickness + - (spec.thickness_ambient - spec.offset_monitor_ambient) + ), layer_solve_result=solve_result_ambient, ) _, bwd_flux_ambient_monitor = fields.directional_poynting_flux( diff --git a/tests/challenges/extractor/test_reference_devices.py b/tests/challenges/extractor/test_reference_devices.py index d2743d0..7b2c709 100644 --- a/tests/challenges/extractor/test_reference_devices.py +++ b/tests/challenges/extractor/test_reference_devices.py @@ -20,6 +20,54 @@ class ReferenceExtractorTest(unittest.TestCase): + @pytest.mark.slow + def test_bare_substrate_matches_expected(self): + spec = dataclasses.replace(challenge.EXTRACTOR_SPEC, fwhm_source=0.0) + sim_params = dataclasses.replace( + challenge.EXTRACTOR_SIM_PARAMS, approximate_num_terms=1200 + ) + + # Compute the response of the reference design. + pec = challenge.photon_extractor(spec=spec, sim_params=sim_params) + params = pec.component.init(jax.random.PRNGKey(0)) + params.array = jnp.zeros_like(params.array) + + response, _ = pec.component.response(params) + + expected_bare_substrate_emitted_power = (86.998184, 87.00018, 101.31145) + expected_bare_substrate_extracted_power = (4.6281686, 4.628636, 0.29498392) + expected_bare_substrate_collected_power = (2.6329136, 2.6330986, 0.13777894) + + with self.subTest("emitted_power"): + onp.testing.assert_allclose( + response.emitted_power, response.bare_substrate_emitted_power, rtol=1e-3 + ) + onp.testing.assert_allclose( + response.emitted_power, expected_bare_substrate_emitted_power, rtol=1e-3 + ) + with self.subTest("extracted_power"): + onp.testing.assert_allclose( + response.extracted_power, + response.bare_substrate_extracted_power, + rtol=1e-3, + ) + onp.testing.assert_allclose( + response.extracted_power, + expected_bare_substrate_extracted_power, + rtol=1e-3, + ) + with self.subTest("collected_power"): + onp.testing.assert_allclose( + response.collected_power, + response.bare_substrate_collected_power, + rtol=1e-3, + ) + onp.testing.assert_allclose( + response.collected_power, + expected_bare_substrate_collected_power, + rtol=1e-3, + ) + @pytest.mark.slow def test_boost_matches_expected(self): # Larger number of terms lets us model dipoles with narrower spatial @@ -83,21 +131,29 @@ def test_boost_matches_expected(self): expected_dos_boost_jx = expected_dos_boost_jy = 1.42 expected_dos_boost_jz = 1.35 - onp.testing.assert_allclose(flux_boost_jx, expected_flux_boost_jx, rtol=0.25) - onp.testing.assert_allclose(flux_boost_jy, expected_flux_boost_jy, rtol=0.25) - onp.testing.assert_allclose(flux_boost_jz, expected_flux_boost_jz, rtol=0.54) + with self.subTest("flux_boost"): + onp.testing.assert_allclose( + flux_boost_jx, expected_flux_boost_jx, rtol=0.25 + ) + onp.testing.assert_allclose( + flux_boost_jy, expected_flux_boost_jy, rtol=0.25 + ) + onp.testing.assert_allclose( + flux_boost_jz, expected_flux_boost_jz, rtol=0.54 + ) - self.assertLess(flux_boost_jx, expected_flux_boost_jx) - self.assertLess(flux_boost_jy, expected_flux_boost_jy) - self.assertLess(flux_boost_jz, expected_flux_boost_jz) + self.assertLess(flux_boost_jx, expected_flux_boost_jx) + self.assertLess(flux_boost_jy, expected_flux_boost_jy) + self.assertLess(flux_boost_jz, expected_flux_boost_jz) - onp.testing.assert_allclose(dos_boost_jx, expected_dos_boost_jx, rtol=0.08) - onp.testing.assert_allclose(dos_boost_jy, expected_dos_boost_jy, rtol=0.08) - onp.testing.assert_allclose(dos_boost_jz, expected_dos_boost_jz, rtol=0.12) + with self.subTest("dos_boost"): + onp.testing.assert_allclose(dos_boost_jx, expected_dos_boost_jx, rtol=0.08) + onp.testing.assert_allclose(dos_boost_jy, expected_dos_boost_jy, rtol=0.08) + onp.testing.assert_allclose(dos_boost_jz, expected_dos_boost_jz, rtol=0.12) - self.assertLess(dos_boost_jx, expected_dos_boost_jx) - self.assertLess(dos_boost_jy, expected_dos_boost_jy) - self.assertLess(dos_boost_jz, expected_dos_boost_jz) + self.assertLess(dos_boost_jx, expected_dos_boost_jx) + self.assertLess(dos_boost_jy, expected_dos_boost_jy) + self.assertLess(dos_boost_jz, expected_dos_boost_jz) @pytest.mark.slow def test_convergence(self):