From 64e8b7570a0820b0d59b6d588e7686c96266246d Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Fri, 8 Nov 2024 15:20:06 +0000 Subject: [PATCH] More unit_tests cleanup: type annotations, stricter error check The sweep of incident energies is cut down by some more (seeded) random sampling. At the moment we have some cases with invalid chopper frequencies; make sure the error is checked robustly when handling these. Ideally we should have smarter test-case generation and check the errors in a separate test --- tests/unit_tests/test_pychop.py | 59 +++++++++++++++++++++------------ 1 file changed, 37 insertions(+), 22 deletions(-) diff --git a/tests/unit_tests/test_pychop.py b/tests/unit_tests/test_pychop.py index 9c9e0dd..a2e6a8b 100644 --- a/tests/unit_tests/test_pychop.py +++ b/tests/unit_tests/test_pychop.py @@ -1,5 +1,7 @@ """Validate PyChop_fit resolutions against reference PyChop library""" +from __future__ import annotations + import itertools import random @@ -14,7 +16,11 @@ from resolution_functions.instrument import Instrument from resolution_functions.models.pychop import * -from resolution_functions.models.model_base import InvalidInputError +from resolution_functions.models.model_base import InvalidInputError, ModelData + +if TYPE_CHECKING: + from jaxtyping import Float + random.seed(1) @@ -22,6 +28,8 @@ N_SAMPLES = 10 EINIT = np.arange(50, 2000, 50) +EINIT_SAMPLE = reservoir_sample(EINIT, k=N_SAMPLES) + CHOPPER_FREQ_FERMI = np.arange(50, 601, 50) MATRIX_FERMI = list(reservoir_sample( itertools.product(EINIT, CHOPPER_FREQ_FERMI), @@ -54,6 +62,7 @@ def matrix_nonfermi_id(matrix_row: tuple[int, int, int]) -> str: [('MERLIN', 'MERLIN')], [('SEQUOIA', 'SEQUOIA')], ] + INSTRUMENT_SETTINGS_FERMI = [ ['SEQ-100-2.0-AST', 'SEQ-700-3.5-AST', 'ARCS-100-1.5-AST', 'ARCS-700-1.5-AST', 'ARCS-700-0.5-AST', 'ARCS-100-1.5-SMI', 'ARCS-700-1.5-SMI'], @@ -65,7 +74,6 @@ def matrix_nonfermi_id(matrix_row: tuple[int, int, int]) -> str: 'ARCS-700-0.5-AST', 'ARCS-100-1.5-SMI', 'ARCS-700-1.5-SMI'], ] - INSTRUMENT_MATRIX_FERMI = list( itertools.chain.from_iterable( itertools.product(instr, settings) @@ -73,12 +81,9 @@ def matrix_nonfermi_id(matrix_row: tuple[int, int, int]) -> str: ) ) -INSTRUMENTS_NONFERMI = [ - [('CNCS', 'CNCS')] -] -INSTRUMENT_SETTINGS_NONFERMI = [ - ['High Flux', 'Intermediate', 'High Resolution'] -] +INSTRUMENTS_NONFERMI = [[('CNCS', 'CNCS')]] + +INSTRUMENT_SETTINGS_NONFERMI = [['High Flux', 'Intermediate', 'High Resolution']] INSTRUMENT_MATRIX_NONFERMI = list( itertools.chain.from_iterable( @@ -93,13 +98,16 @@ def instrument_id(matrix_row: tuple[tuple[str, str], str]) -> str: (instrument, _), setting = matrix_row return f"{instrument}_{setting}" +# id formatter for E_i input +format_ei = "ei={}".format -def get_fake_frequencies(e_init: float): + +def get_fake_frequencies(e_init: float) -> Float[np.ndarray]: return np.linspace(0, e_init, 40, endpoint=False) @pytest.fixture(scope="module", params=INSTRUMENT_MATRIX_FERMI, ids=instrument_id) -def pychop_fermi_data(request): +def pychop_fermi_data(request) -> tuple[ModelData, PyChopInstrument]: (name, version), setting = request.param maps = Instrument.from_default(name, version) rf = maps.get_model_data('PyChop_fit', chopper_package=setting) @@ -109,7 +117,7 @@ def pychop_fermi_data(request): @pytest.fixture(scope="module", params=INSTRUMENT_MATRIX_NONFERMI, ids=instrument_id) -def pychop_nonfermi_data(request): +def pychop_nonfermi_data(request) -> tuple[ModelData, PyChopInstrument]: (name, version), setting = request.param maps = Instrument.from_default(name, version) rf = maps.get_model_data('PyChop_fit', chopper_package=setting) @@ -153,7 +161,7 @@ def cncs_data(): def test_fermi_invalid_chopper_frequency( chopper_frequency, mari_data: tuple[PyChopModelDataFermi, PyChopInstrument] ): - with pytest.raises(InvalidInputError, match="The provided chopper frequency") as e: + with pytest.raises(InvalidInputError, match="The provided chopper frequency"): PyChopModelFermi(mari_data[0], chopper_frequency=chopper_frequency) @@ -163,7 +171,7 @@ def test_fermi_invalid_chopper_frequency( def test_fermi_invalid_e_init( e_init, mari_data: tuple[PyChopModelDataFermi, PyChopInstrument] ): - with pytest.raises(InvalidInputError, match="The provided incident energy") as e: + with pytest.raises(InvalidInputError, match="The provided incident energy"): PyChopModelFermi(mari_data[0], e_init=e_init) @pytest.mark.parametrize( @@ -184,7 +192,7 @@ def test_fermi_invalid_e_init( def test_nonfermi_invalid_chopper_frequency( chopper_frequency, cncs_data: PyChopModelDataNonFermi ): - with pytest.raises(InvalidInputError, match="The provided chopper frequency") as e: + with pytest.raises(InvalidInputError, match="The provided chopper frequency"): PyChopModelNonFermi(cncs_data, chopper_frequency=chopper_frequency) @@ -192,7 +200,7 @@ def test_nonfermi_invalid_chopper_frequency( "e_init", [-5, -0.00048, -np.inf, 2000.1, np.inf, 13554.1654, np.nan] ) def test_nonfermi_invalid_e_init(e_init, cncs_data: PyChopModelDataNonFermi): - with pytest.raises(InvalidInputError, match="The provided incident energy") as e: + with pytest.raises(InvalidInputError, match="The provided incident energy"): PyChopModelNonFermi(cncs_data, e_init=e_init) @@ -206,12 +214,12 @@ def test_distances(mari_data: tuple[PyChopModelData, PyChopInstrument]): assert xm == expected[-1] -@pytest.mark.parametrize('e_init', EINIT) +@pytest.mark.parametrize('e_init', EINIT_SAMPLE) def test_fermi_moderator_width_analytical(e_init, pychop_fermi_data): _test_moderator_width_analytical(e_init, *pychop_fermi_data, PyChopModelFermi) -@pytest.mark.parametrize('e_init', EINIT) +@pytest.mark.parametrize('e_init', EINIT_SAMPLE) def test_nonfermi_moderator_width_analytical(e_init, pychop_nonfermi_data): _test_moderator_width_analytical(e_init, *pychop_nonfermi_data, PyChopModelNonFermi) @@ -225,12 +233,12 @@ def _test_moderator_width_analytical(e_init, data, pychop, cls): assert_allclose(actual, expected, rtol=0, atol=1e-8) -@pytest.mark.parametrize('e_init', EINIT, ids="ei={}".format) +@pytest.mark.parametrize('e_init', EINIT_SAMPLE, ids=format_ei) def test_fermi_moderator_width(e_init, pychop_fermi_data): _test_moderator_width(e_init, PyChopModelFermi, *pychop_fermi_data) -@pytest.mark.parametrize('e_init', EINIT, ids="ei={}".format) +@pytest.mark.parametrize('e_init', EINIT_SAMPLE, ids=format_ei) def test_nonfermi_moderator_width(e_init, pychop_fermi_data): _test_moderator_width(e_init, PyChopModelNonFermi, *pychop_fermi_data) @@ -330,12 +338,12 @@ def test_he_detector_width_squared(): assert_allclose(actual, expected) -@pytest.mark.parametrize('e_init', EINIT, ids=[f'ei={ei}' for ei in EINIT]) +@pytest.mark.parametrize('e_init', EINIT, ids=format_ei) def test_fermi_detector_width_squared(e_init, pychop_fermi_data): _test_get_detector_width_squared(e_init, PyChopModelFermi, *pychop_fermi_data) -@pytest.mark.parametrize('e_init', EINIT, ids=[f'ei={ei}' for ei in EINIT]) +@pytest.mark.parametrize('e_init', EINIT, ids=format_ei) def test_nonfermi_detector_width_squared(e_init, pychop_nonfermi_data): _test_get_detector_width_squared(e_init, PyChopModelNonFermi, *pychop_nonfermi_data) @@ -460,8 +468,15 @@ def _test_precompute_resolution(e_init, chopper_frequency, cls, data, pychop): try: pychop.chopper_system.setFrequency(chopper_frequency) except ValueError as e: - if 'maximum allowed' in str(e): + if 'Value of frequencies outside maximum allowed' in str(e): + # Energy out of range for Pychop: make sure the model agrees + with pytest.raises( + InvalidInputError, + match=rf"The provided chopper frequency \(\[{chopper_frequency[0]}\]\) is not allowed"): + cls(model_data=data, chopper_frequency=chopper_frequency, e_init=e_init) + return + raise e fake_frequencies = np.linspace(0, e_init, 40, endpoint=False) expected_resolution = pychop.getResolution(Ei_in=e_init, Etrans=fake_frequencies)