Skip to content

Commit

Permalink
add a --disable-backend option that adds backends to be disabled
Browse files Browse the repository at this point in the history
  • Loading branch information
kratsg authored and matthewfeickert committed Oct 3, 2023
1 parent 6ebf9a3 commit 7286340
Showing 1 changed file with 26 additions and 8 deletions.
34 changes: 26 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@
import pyhf


def pytest_addoption(parser):
parser.addoption(
"--disable-backend",
action="append",
type=str,
default=[],
choices=["tensorflow", "pytorch", "jax", "minuit"],
help="list of backends to disable in tests",
)


# Factory as fixture pattern
@pytest.fixture
def get_json_from_tarfile():
Expand Down Expand Up @@ -59,14 +70,14 @@ def reset_backend():
@pytest.fixture(
scope='function',
params=[
(pyhf.tensor.numpy_backend(), None),
(pyhf.tensor.pytorch_backend(), None),
(pyhf.tensor.pytorch_backend(precision='64b'), None),
(pyhf.tensor.tensorflow_backend(), None),
(pyhf.tensor.jax_backend(), None),
(("numpy_backend", dict()), ("scipy_optimizer", dict())),
(("pytorch_backend", dict()), ("scipy_optimizer", dict())),
(("pytorch_backend", dict(precision="64b")), ("scipy_optimizer", dict())),
(("tensorflow_backend", dict()), ("scipy_optimizer", dict())),
(("jax_backend", dict()), ("scipy_optimizer", dict())),
(
pyhf.tensor.numpy_backend(poisson_from_normal=True),
pyhf.optimize.minuit_optimizer(),
("numpy_backend", dict(poisson_from_normal=True)),
("minuit_optimizer", dict()),
),
],
ids=['numpy', 'pytorch', 'pytorch64', 'tensorflow', 'jax', 'numpy_minuit'],
Expand All @@ -87,13 +98,20 @@ def backend(request):
only_backends = [
pid for pid in param_ids if request.node.get_closest_marker(f'only_{pid}')
]
disable_backend = any(
backend in param_id for backend in request.config.disable_backend
)

if skip_backend and (param_id in only_backends):
raise ValueError(
f"Must specify skip_{param_id} or only_{param_id} but not both!"
)

if skip_backend:
if disable_backend:
pytest.skip(
f"skipping {func_name} as the backend is disabled: {request.config.disable_backend}"
)
elif skip_backend:
pytest.skip(f"skipping {func_name} as specified")
elif only_backends and param_id not in only_backends:
pytest.skip(
Expand Down

0 comments on commit 7286340

Please sign in to comment.