Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace the type checking using property return_type with direct is… #1044

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

### Improvements

* Replace the type checking using the property `return_type` of `MeasurementProcess` with direct `isinstance` checks.
[(#1044)](https://github.com/PennyLaneAI/pennylane-lightning/pull/1044)

* Update `qml.MultiControlledX` tests following the latest updates in PennyLane.
[(#1040)](https://github.com/PennyLaneAI/pennylane-lightning/pull/1040)

Expand Down
14 changes: 7 additions & 7 deletions pennylane_lightning/core/_adjoint_jacobian_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import numpy as np
import pennylane as qml
from pennylane import BasisState, QuantumFunctionError, StatePrep
from pennylane.measurements import Expectation, MeasurementProcess, State
from pennylane.measurements import ExpectationMP, MeasurementProcess, StateMP
from pennylane.operation import Operation
from pennylane.tape import QuantumTape

Expand Down Expand Up @@ -84,10 +84,10 @@
if not measurements:
return None

if len(measurements) == 1 and measurements[0].return_type is State:
return State
if len(measurements) == 1 and isinstance(measurements[0], StateMP):
return "state"

Check warning on line 88 in pennylane_lightning/core/_adjoint_jacobian_base.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/core/_adjoint_jacobian_base.py#L87-L88

Added lines #L87 - L88 were not covered by tests

return Expectation
return "expval"

Check warning on line 90 in pennylane_lightning/core/_adjoint_jacobian_base.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/core/_adjoint_jacobian_base.py#L90

Added line #L90 was not covered by tests

def _process_jacobian_tape(self, tape: QuantumTape, split_obs: bool = False):
"""Process a tape, serializing and building a dictionary proper for
Expand Down Expand Up @@ -184,7 +184,7 @@
# the tape does not have measurements
return True

if tape_return_type is State:
if tape_return_type is "state":

Check warning on line 187 in pennylane_lightning/core/_adjoint_jacobian_base.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/core/_adjoint_jacobian_base.py#L187

Added line #L187 was not covered by tests
raise QuantumFunctionError(
"Adjoint differentiation method does not support measurement StateMP."
)
Expand All @@ -194,12 +194,12 @@
# the tape does not have measurements or the gradient is 0.0
return True

if tape_return_type is State:
if tape_return_type is "state":

Check warning on line 197 in pennylane_lightning/core/_adjoint_jacobian_base.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/core/_adjoint_jacobian_base.py#L197

Added line #L197 was not covered by tests
raise QuantumFunctionError(
"Adjoint differentiation does not support State measurements."
)

if any(m.return_type is not Expectation for m in tape.measurements):
if any(not isinstance(m, ExpectationMP) for m in tape.measurements):

Check warning on line 202 in pennylane_lightning/core/_adjoint_jacobian_base.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/core/_adjoint_jacobian_base.py#L202

Added line #L202 was not covered by tests
raise QuantumFunctionError(
"Adjoint differentiation method does not support expectation return type "
"mixed with other return types"
Expand Down
2 changes: 1 addition & 1 deletion pennylane_lightning/core/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
Version number (major.minor.patch[-label])
"""

__version__ = "0.41.0-dev5"
__version__ = "0.41.0-dev6"
8 changes: 4 additions & 4 deletions pennylane_lightning/core/lightning_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import pennylane as qml
from pennylane import BasisState, StatePrep
from pennylane.devices import QubitDevice
from pennylane.measurements import Expectation, MeasurementProcess, State
from pennylane.measurements import ExpectationMP, MeasurementProcess, StateMP

Check warning on line 26 in pennylane_lightning/core/lightning_base.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/core/lightning_base.py#L26

Added line #L26 was not covered by tests
from pennylane.operation import Operation
from pennylane.ops import Prod, Projector, SProd, Sum
from pennylane.wires import Wires
Expand Down Expand Up @@ -348,22 +348,22 @@
if not measurements:
return None

if len(measurements) == 1 and measurements[0].return_type is State:
if len(measurements) == 1 and isinstance(measurements[0], StateMP):

Check warning on line 351 in pennylane_lightning/core/lightning_base.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/core/lightning_base.py#L351

Added line #L351 was not covered by tests
# return State
raise qml.QuantumFunctionError(
"Adjoint differentiation does not support State measurements."
)

# The return_type of measurement processes must be expectation
if any(m.return_type is not Expectation for m in measurements):
if any(not isinstance(m, ExpectationMP) for m in measurements):

Check warning on line 358 in pennylane_lightning/core/lightning_base.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/core/lightning_base.py#L358

Added line #L358 was not covered by tests
raise qml.QuantumFunctionError(
"Adjoint differentiation method does not support expectation return type "
"mixed with other return types"
)

for measurement in measurements:
LightningBase._assert_adjdiff_no_projectors(measurement.obs)
return Expectation
return "expval"

Check warning on line 366 in pennylane_lightning/core/lightning_base.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/core/lightning_base.py#L366

Added line #L366 was not covered by tests

@staticmethod
def _adjoint_jacobian_processing(jac):
Expand Down
10 changes: 5 additions & 5 deletions tests/test_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from conftest import LightningDevice as ld
from conftest import device_name, lightning_ops, validate_measurements
from flaky import flaky
from pennylane.measurements import Expectation, Shots, Variance
from pennylane.measurements import ExpectationMP, Shots, VarianceMP

if not ld._CPP_BINARY_AVAILABLE:
pytest.skip("No binary module found. Skipping.", allow_module_level=True)
Expand Down Expand Up @@ -387,12 +387,12 @@ def circuit():
circuit()

def test_observable_return_type_is_expectation(self, dev):
"""Test that the return type of the observable is :attr:`ObservableReturnTypes.Expectation`"""
"""Test that the return type of the observable is :class:`ExpectationMP`"""

@qml.qnode(dev)
def circuit():
res = qml.expval(qml.PauliZ(0))
assert res.return_type is Expectation
assert isinstance(res, ExpectationMP)
return res

circuit()
Expand Down Expand Up @@ -488,12 +488,12 @@ def circuit():
circuit()

def test_observable_return_type_is_variance(self, dev):
"""Test that the return type of the observable is :attr:`ObservableReturnTypes.Variance`"""
"""Test that the return type is :class:`VarianceMP`"""

@qml.qnode(dev)
def circuit():
res = qml.var(qml.PauliZ(0))
assert res.return_type is Variance
assert isinstance(res, VarianceMP)
return res

circuit()
Expand Down
Loading