Skip to content

Commit

Permalink
fix failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
albi3ro committed Jan 21, 2025
1 parent eb1e468 commit 4bea166
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
4 changes: 3 additions & 1 deletion tests/devices/default_qubit/test_default_qubit_native_mcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,9 @@ def func(x, y, z):
results1 = func1(*params)

jaxpr = str(jax.make_jaxpr(func)(*params))
assert "pure_callback" not in jaxpr
# assert "pure_callback" not in jaxpr
assert "pure_callback" in jaxpr
# will change once we solve compilation overhead issue

func2 = jax.jit(func)
results2 = func2(*params)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,16 @@ def circuit(x):
assert dev.tracker.totals["execute_and_derivative_batches"] == 1

@pytest.mark.parametrize("interface", ("jax", "jax-jit"))
def test_not_convert_to_numpy_with_jax(self, interface):
def test_convert_to_numpy_with_jax(self, interface):
"""Test that we will not convert to numpy when working with jax."""

# separate test so we can easily update it once we solve the
# compilation overhead issue
dev = qml.device("default.qubit")
config = qml.devices.ExecutionConfig(
gradient_method=qml.gradients.param_shift, interface=interface
)
processed = dev.setup_execution_config(config)
assert not processed.convert_to_numpy
assert processed.convert_to_numpy

def test_convert_to_numpy_with_adjoint(self):
"""Test that we will convert to numpy with adjoint."""
Expand Down

0 comments on commit 4bea166

Please sign in to comment.