Skip to content

Commit

Permalink
Simplify s-matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
mfschubert committed Jul 30, 2024
1 parent fb9972c commit ede6d04
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 17 deletions.
21 changes: 16 additions & 5 deletions src/invrs_gym/challenges/extractor/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,12 +453,17 @@ def eigensolve_pml(permittivity: jnp.ndarray) -> fmm.LayerSolveResult:

s_matrix_before_source_no_substrate = scattering.stack_s_matrix(
layer_solve_results=(
solve_result_ambient, # ambient
solve_result_ambient, # oxide
solve_result_ambient, # extractor
solve_result_ambient, # ambient + oxide + extractor
solve_result_substrate,
),
layer_thicknesses=thicknesses_before_source,
layer_thicknesses=(
jnp.asarray(
spec.thickness_ambient
+ spec.thickness_oxide
+ spec.thickness_extractor
),
jnp.asarray(spec.thickness_substrate_before_source),
),
)

# Generate the Fourier representation of x, y, and z-oriented point dipoles.
Expand Down Expand Up @@ -561,10 +566,16 @@ def compute_power(
# Total extracted power measured at a monitor above the extractor.
# -------------------------------------------------------------------------

with jax.ensure_compile_time_eval():
print(s_matrix_before_source.start_layer_thickness)

# Compute the eigenmode amplitudes at the ambient flux monitor.
bwd_amplitude_ambient_monitor = fields.propagate_amplitude(
amplitude=bwd_amplitude_ambient_end,
distance=jnp.asarray(spec.offset_monitor_ambient),
distance=jnp.asarray(
s_matrix_before_source.start_layer_thickness
- (spec.thickness_ambient - spec.offset_monitor_ambient)
),
layer_solve_result=solve_result_ambient,
)
_, bwd_flux_ambient_monitor = fields.directional_poynting_flux(
Expand Down
80 changes: 68 additions & 12 deletions tests/challenges/extractor/test_reference_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,54 @@


class ReferenceExtractorTest(unittest.TestCase):
@pytest.mark.slow
def test_bare_substrate_matches_expected(self):
spec = dataclasses.replace(challenge.EXTRACTOR_SPEC, fwhm_source=0.0)
sim_params = dataclasses.replace(
challenge.EXTRACTOR_SIM_PARAMS, approximate_num_terms=1200
)

# Compute the response of the reference design.
pec = challenge.photon_extractor(spec=spec, sim_params=sim_params)
params = pec.component.init(jax.random.PRNGKey(0))
params.array = jnp.zeros_like(params.array)

response, _ = pec.component.response(params)

expected_bare_substrate_emitted_power = (86.998184, 87.00018, 101.31145)
expected_bare_substrate_extracted_power = (4.6281686, 4.628636, 0.29498392)
expected_bare_substrate_collected_power = (2.6329136, 2.6330986, 0.13777894)

with self.subTest("emitted_power"):
onp.testing.assert_allclose(
response.emitted_power, response.bare_substrate_emitted_power, rtol=1e-3
)
onp.testing.assert_allclose(
response.emitted_power, expected_bare_substrate_emitted_power, rtol=1e-3
)
with self.subTest("extracted_power"):
onp.testing.assert_allclose(
response.extracted_power,
response.bare_substrate_extracted_power,
rtol=1e-3,
)
onp.testing.assert_allclose(
response.extracted_power,
expected_bare_substrate_extracted_power,
rtol=1e-3,
)
with self.subTest("collected_power"):
onp.testing.assert_allclose(
response.collected_power,
response.bare_substrate_collected_power,
rtol=1e-3,
)
onp.testing.assert_allclose(
response.collected_power,
expected_bare_substrate_collected_power,
rtol=1e-3,
)

@pytest.mark.slow
def test_boost_matches_expected(self):
# Larger number of terms lets us model dipoles with narrower spatial
Expand Down Expand Up @@ -83,21 +131,29 @@ def test_boost_matches_expected(self):
expected_dos_boost_jx = expected_dos_boost_jy = 1.42
expected_dos_boost_jz = 1.35

onp.testing.assert_allclose(flux_boost_jx, expected_flux_boost_jx, rtol=0.25)
onp.testing.assert_allclose(flux_boost_jy, expected_flux_boost_jy, rtol=0.25)
onp.testing.assert_allclose(flux_boost_jz, expected_flux_boost_jz, rtol=0.54)
with self.subTest("flux_boost"):
onp.testing.assert_allclose(
flux_boost_jx, expected_flux_boost_jx, rtol=0.25
)
onp.testing.assert_allclose(
flux_boost_jy, expected_flux_boost_jy, rtol=0.25
)
onp.testing.assert_allclose(
flux_boost_jz, expected_flux_boost_jz, rtol=0.54
)

self.assertLess(flux_boost_jx, expected_flux_boost_jx)
self.assertLess(flux_boost_jy, expected_flux_boost_jy)
self.assertLess(flux_boost_jz, expected_flux_boost_jz)
self.assertLess(flux_boost_jx, expected_flux_boost_jx)
self.assertLess(flux_boost_jy, expected_flux_boost_jy)
self.assertLess(flux_boost_jz, expected_flux_boost_jz)

onp.testing.assert_allclose(dos_boost_jx, expected_dos_boost_jx, rtol=0.08)
onp.testing.assert_allclose(dos_boost_jy, expected_dos_boost_jy, rtol=0.08)
onp.testing.assert_allclose(dos_boost_jz, expected_dos_boost_jz, rtol=0.12)
with self.subTest("dos_boost"):
onp.testing.assert_allclose(dos_boost_jx, expected_dos_boost_jx, rtol=0.08)
onp.testing.assert_allclose(dos_boost_jy, expected_dos_boost_jy, rtol=0.08)
onp.testing.assert_allclose(dos_boost_jz, expected_dos_boost_jz, rtol=0.12)

self.assertLess(dos_boost_jx, expected_dos_boost_jx)
self.assertLess(dos_boost_jy, expected_dos_boost_jy)
self.assertLess(dos_boost_jz, expected_dos_boost_jz)
self.assertLess(dos_boost_jx, expected_dos_boost_jx)
self.assertLess(dos_boost_jy, expected_dos_boost_jy)
self.assertLess(dos_boost_jz, expected_dos_boost_jz)

@pytest.mark.slow
def test_convergence(self):
Expand Down

0 comments on commit ede6d04

Please sign in to comment.