Skip to content

Commit

Permalink
More unit_tests cleanup: type annotations, stricter error check
Browse files Browse the repository at this point in the history
The sweep of incident energies is cut down by some more (seeded)
random sampling.

At the moment we have some cases with invalid chopper frequencies;
make sure the error is checked robustly when handling these.

Ideally we should have smarter test-case generation and check the
errors in a separate test
  • Loading branch information
ajjackson committed Nov 8, 2024
1 parent bd5d48e commit 64e8b75
Showing 1 changed file with 37 additions and 22 deletions.
59 changes: 37 additions & 22 deletions tests/unit_tests/test_pychop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Validate PyChop_fit resolutions against reference PyChop library"""

from __future__ import annotations

import itertools
import random

Expand All @@ -14,14 +16,20 @@

from resolution_functions.instrument import Instrument
from resolution_functions.models.pychop import *
from resolution_functions.models.model_base import InvalidInputError
from resolution_functions.models.model_base import InvalidInputError, ModelData

if TYPE_CHECKING:
from jaxtyping import Float


random.seed(1)

DEBUG = False
N_SAMPLES = 10

EINIT = np.arange(50, 2000, 50)
EINIT_SAMPLE = reservoir_sample(EINIT, k=N_SAMPLES)

CHOPPER_FREQ_FERMI = np.arange(50, 601, 50)
MATRIX_FERMI = list(reservoir_sample(
itertools.product(EINIT, CHOPPER_FREQ_FERMI),
Expand Down Expand Up @@ -54,6 +62,7 @@ def matrix_nonfermi_id(matrix_row: tuple[int, int, int]) -> str:
[('MERLIN', 'MERLIN')],
[('SEQUOIA', 'SEQUOIA')],
]

INSTRUMENT_SETTINGS_FERMI = [
['SEQ-100-2.0-AST', 'SEQ-700-3.5-AST', 'ARCS-100-1.5-AST', 'ARCS-700-1.5-AST',
'ARCS-700-0.5-AST', 'ARCS-100-1.5-SMI', 'ARCS-700-1.5-SMI'],
Expand All @@ -65,20 +74,16 @@ def matrix_nonfermi_id(matrix_row: tuple[int, int, int]) -> str:
'ARCS-700-0.5-AST', 'ARCS-100-1.5-SMI', 'ARCS-700-1.5-SMI'],
]


INSTRUMENT_MATRIX_FERMI = list(
itertools.chain.from_iterable(
itertools.product(instr, settings)
for instr, settings in zip(INSTRUMENTS_FERMI, INSTRUMENT_SETTINGS_FERMI)
)
)

INSTRUMENTS_NONFERMI = [
[('CNCS', 'CNCS')]
]
INSTRUMENT_SETTINGS_NONFERMI = [
['High Flux', 'Intermediate', 'High Resolution']
]
INSTRUMENTS_NONFERMI = [[('CNCS', 'CNCS')]]

INSTRUMENT_SETTINGS_NONFERMI = [['High Flux', 'Intermediate', 'High Resolution']]

INSTRUMENT_MATRIX_NONFERMI = list(
itertools.chain.from_iterable(
Expand All @@ -93,13 +98,16 @@ def instrument_id(matrix_row: tuple[tuple[str, str], str]) -> str:
(instrument, _), setting = matrix_row
return f"{instrument}_{setting}"

# id formatter for E_i input
format_ei = "ei={}".format

def get_fake_frequencies(e_init: float):

def get_fake_frequencies(e_init: float) -> Float[np.ndarray]:
return np.linspace(0, e_init, 40, endpoint=False)


@pytest.fixture(scope="module", params=INSTRUMENT_MATRIX_FERMI, ids=instrument_id)
def pychop_fermi_data(request):
def pychop_fermi_data(request) -> tuple[ModelData, PyChopInstrument]:
(name, version), setting = request.param
maps = Instrument.from_default(name, version)
rf = maps.get_model_data('PyChop_fit', chopper_package=setting)
Expand All @@ -109,7 +117,7 @@ def pychop_fermi_data(request):


@pytest.fixture(scope="module", params=INSTRUMENT_MATRIX_NONFERMI, ids=instrument_id)
def pychop_nonfermi_data(request):
def pychop_nonfermi_data(request) -> tuple[ModelData, PyChopInstrument]:
(name, version), setting = request.param
maps = Instrument.from_default(name, version)
rf = maps.get_model_data('PyChop_fit', chopper_package=setting)
Expand Down Expand Up @@ -153,7 +161,7 @@ def cncs_data():
def test_fermi_invalid_chopper_frequency(
chopper_frequency, mari_data: tuple[PyChopModelDataFermi, PyChopInstrument]
):
with pytest.raises(InvalidInputError, match="The provided chopper frequency") as e:
with pytest.raises(InvalidInputError, match="The provided chopper frequency"):
PyChopModelFermi(mari_data[0], chopper_frequency=chopper_frequency)


Expand All @@ -163,7 +171,7 @@ def test_fermi_invalid_chopper_frequency(
def test_fermi_invalid_e_init(
e_init, mari_data: tuple[PyChopModelDataFermi, PyChopInstrument]
):
with pytest.raises(InvalidInputError, match="The provided incident energy") as e:
with pytest.raises(InvalidInputError, match="The provided incident energy"):
PyChopModelFermi(mari_data[0], e_init=e_init)

@pytest.mark.parametrize(
Expand All @@ -184,15 +192,15 @@ def test_fermi_invalid_e_init(
def test_nonfermi_invalid_chopper_frequency(
chopper_frequency, cncs_data: PyChopModelDataNonFermi
):
with pytest.raises(InvalidInputError, match="The provided chopper frequency") as e:
with pytest.raises(InvalidInputError, match="The provided chopper frequency"):
PyChopModelNonFermi(cncs_data, chopper_frequency=chopper_frequency)


@pytest.mark.parametrize(
"e_init", [-5, -0.00048, -np.inf, 2000.1, np.inf, 13554.1654, np.nan]
)
def test_nonfermi_invalid_e_init(e_init, cncs_data: PyChopModelDataNonFermi):
with pytest.raises(InvalidInputError, match="The provided incident energy") as e:
with pytest.raises(InvalidInputError, match="The provided incident energy"):
PyChopModelNonFermi(cncs_data, e_init=e_init)


Expand All @@ -206,12 +214,12 @@ def test_distances(mari_data: tuple[PyChopModelData, PyChopInstrument]):
assert xm == expected[-1]


@pytest.mark.parametrize('e_init', EINIT)
@pytest.mark.parametrize('e_init', EINIT_SAMPLE)
def test_fermi_moderator_width_analytical(e_init, pychop_fermi_data):
_test_moderator_width_analytical(e_init, *pychop_fermi_data, PyChopModelFermi)


@pytest.mark.parametrize('e_init', EINIT)
@pytest.mark.parametrize('e_init', EINIT_SAMPLE)
def test_nonfermi_moderator_width_analytical(e_init, pychop_nonfermi_data):
_test_moderator_width_analytical(e_init, *pychop_nonfermi_data, PyChopModelNonFermi)

Expand All @@ -225,12 +233,12 @@ def _test_moderator_width_analytical(e_init, data, pychop, cls):
assert_allclose(actual, expected, rtol=0, atol=1e-8)


@pytest.mark.parametrize('e_init', EINIT, ids="ei={}".format)
@pytest.mark.parametrize('e_init', EINIT_SAMPLE, ids=format_ei)
def test_fermi_moderator_width(e_init, pychop_fermi_data):
_test_moderator_width(e_init, PyChopModelFermi, *pychop_fermi_data)


@pytest.mark.parametrize('e_init', EINIT, ids="ei={}".format)
@pytest.mark.parametrize('e_init', EINIT_SAMPLE, ids=format_ei)
def test_nonfermi_moderator_width(e_init, pychop_fermi_data):
_test_moderator_width(e_init, PyChopModelNonFermi, *pychop_fermi_data)

Expand Down Expand Up @@ -330,12 +338,12 @@ def test_he_detector_width_squared():
assert_allclose(actual, expected)


@pytest.mark.parametrize('e_init', EINIT, ids=[f'ei={ei}' for ei in EINIT])
@pytest.mark.parametrize('e_init', EINIT, ids=format_ei)
def test_fermi_detector_width_squared(e_init, pychop_fermi_data):
_test_get_detector_width_squared(e_init, PyChopModelFermi, *pychop_fermi_data)


@pytest.mark.parametrize('e_init', EINIT, ids=[f'ei={ei}' for ei in EINIT])
@pytest.mark.parametrize('e_init', EINIT, ids=format_ei)
def test_nonfermi_detector_width_squared(e_init, pychop_nonfermi_data):
_test_get_detector_width_squared(e_init, PyChopModelNonFermi, *pychop_nonfermi_data)

Expand Down Expand Up @@ -460,8 +468,15 @@ def _test_precompute_resolution(e_init, chopper_frequency, cls, data, pychop):
try:
pychop.chopper_system.setFrequency(chopper_frequency)
except ValueError as e:
if 'maximum allowed' in str(e):
if 'Value of frequencies outside maximum allowed' in str(e):
# Energy out of range for Pychop: make sure the model agrees
with pytest.raises(
InvalidInputError,
match=rf"The provided chopper frequency \(\[{chopper_frequency[0]}\]\) is not allowed"):
cls(model_data=data, chopper_frequency=chopper_frequency, e_init=e_init)

return
raise e

fake_frequencies = np.linspace(0, e_init, 40, endpoint=False)
expected_resolution = pychop.getResolution(Ei_in=e_init, Etrans=fake_frequencies)
Expand Down

0 comments on commit 64e8b75

Please sign in to comment.