Skip to content

Commit

Permalink
fix: Update a bunch of tests lol
Browse files Browse the repository at this point in the history
  • Loading branch information
andrijapau committed Jan 21, 2025
1 parent c69309c commit 8e9d687
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 22 deletions.
55 changes: 34 additions & 21 deletions tests/gradients/core/test_pulse_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,7 +1385,10 @@ def test_simple_qnode_expval(self, num_split_times, shots, tol, seed):
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)

@qml.qnode(
dev, interface="jax", diff_method=stoch_pulse_grad, num_split_times=num_split_times
dev,
interface="jax",
diff_method=stoch_pulse_grad,
gradient_kwargs={"num_split_times": num_split_times},
)
def circuit(params):
qml.evolve(ham_single_q_const)(params, T)
Expand Down Expand Up @@ -1415,7 +1418,10 @@ def test_simple_qnode_expval_two_evolves(self, num_split_times, shots, tol, seed
ham_y = qml.pulse.constant * qml.PauliX(0)

@qml.qnode(
dev, interface="jax", diff_method=stoch_pulse_grad, num_split_times=num_split_times
dev,
interface="jax",
diff_method=stoch_pulse_grad,
gradient_kwargs={"num_split_times": num_split_times},
)
def circuit(params):
qml.evolve(ham_x)(params[0], T_x)
Expand Down Expand Up @@ -1444,7 +1450,10 @@ def test_simple_qnode_probs(self, num_split_times, shots, tol, seed):
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)

@qml.qnode(
dev, interface="jax", diff_method=stoch_pulse_grad, num_split_times=num_split_times
dev,
interface="jax",
diff_method=stoch_pulse_grad,
gradient_kwargs={"num_split_times": num_split_times},
)
def circuit(params):
qml.evolve(ham_single_q_const)(params, T)
Expand All @@ -1471,7 +1480,10 @@ def test_simple_qnode_probs_expval(self, num_split_times, shots, tol, seed):
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)

@qml.qnode(
dev, interface="jax", diff_method=stoch_pulse_grad, num_split_times=num_split_times
dev,
interface="jax",
diff_method=stoch_pulse_grad,
gradient_kwargs={"num_split_times": num_split_times},
)
def circuit(params):
qml.evolve(ham_single_q_const)(params, T)
Expand Down Expand Up @@ -1503,7 +1515,10 @@ def test_simple_qnode_jit(self, num_split_times, time_interface):
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)

@qml.qnode(
dev, interface="jax", diff_method=stoch_pulse_grad, num_split_times=num_split_times
dev,
interface="jax",
diff_method=stoch_pulse_grad,
gradient_kwargs={"num_split_times": num_split_times},
)
def circuit(params, T=None):
qml.evolve(ham_single_q_const)(params, T)
Expand Down Expand Up @@ -1542,8 +1557,7 @@ def ansatz(params):
dev,
interface="jax",
diff_method=stoch_pulse_grad,
num_split_times=num_split_times,
sampler_seed=seed,
gradient_kwargs={"num_split_times": num_split_times, "sampler_seed": seed},
)
qnode_backprop = qml.QNode(ansatz, dev, interface="jax")

Expand Down Expand Up @@ -1574,8 +1588,7 @@ def test_qnode_probs_expval_broadcasting(self, num_split_times, shots, tol, seed
dev,
interface="jax",
diff_method=stoch_pulse_grad,
num_split_times=num_split_times,
use_broadcasting=True,
gradient_kwargs={"num_split_times": num_split_times, "use_broadcasting": True},
)
def circuit(params):
qml.evolve(ham_single_q_const)(params, T)
Expand Down Expand Up @@ -1619,18 +1632,22 @@ def ansatz(params):
dev,
interface="jax",
diff_method=stoch_pulse_grad,
num_split_times=num_split_times,
use_broadcasting=True,
sampler_seed=seed,
gradient_kwargs={
"num_split_times": num_split_times,
"use_broadcasting": True,
"sampler_seed": seed,
},
)
circuit_no_bc = qml.QNode(
ansatz,
dev,
interface="jax",
diff_method=stoch_pulse_grad,
num_split_times=num_split_times,
use_broadcasting=False,
sampler_seed=seed,
gradient_kwargs={
"num_split_times": num_split_times,
"use_broadcasting": False,
"sampler_seed": seed,
},
)
params = [jnp.array(0.4)]
jac_bc = jax.jacobian(circuit_bc)(params)
Expand Down Expand Up @@ -1684,9 +1701,7 @@ def ansatz(params):
dev,
interface="jax",
diff_method=qml.gradients.stoch_pulse_grad,
num_split_times=7,
use_broadcasting=True,
sampler_seed=seed,
gradient_kwargs={"num_split_times": 7, "sampler_seed": seed, "use_broadcasting": True},
)
cost_jax = qml.QNode(ansatz, dev, interface="jax")
params = (0.42,)
Expand Down Expand Up @@ -1729,9 +1744,7 @@ def ansatz(params):
dev,
interface="jax",
diff_method=qml.gradients.stoch_pulse_grad,
num_split_times=7,
use_broadcasting=True,
sampler_seed=seed,
gradient_kwargs={"num_split_times": 7, "sampler_seed": seed, "use_broadcasting": True},
)
cost_jax = qml.QNode(ansatz, dev, interface="jax")

Expand Down
2 changes: 1 addition & 1 deletion tests/workflow/test_construct_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_get_transform_program_diff_method_transform(self):
@partial(qml.transforms.compile, num_passes=2)
@partial(qml.transforms.merge_rotations, atol=1e-5)
@qml.transforms.cancel_inverses
@qml.qnode(dev, diff_method="parameter-shift", shifts=2)
@qml.qnode(dev, diff_method="parameter-shift", gradient_kwargs={"shifts": 2})
def circuit():
return qml.expval(qml.PauliZ(0))

Expand Down

0 comments on commit 8e9d687

Please sign in to comment.