-
Notifications
You must be signed in to change notification settings - Fork 615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Capture] Add a QmlPrimitive
class to differentiate between different types of primitives
#6847
base: master
Are you sure you want to change the base?
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #6847 +/- ##
=======================================
Coverage 99.60% 99.60%
=======================================
Files 476 477 +1
Lines 45182 45217 +35
=======================================
+ Hits 45002 45037 +35
Misses 180 180 ☔ View full report in Codecov by Sentry. |
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also consider adding it to the __init__
?
def __init__(self, name, prim_type):
self._prim_type = prim_type
super().__init__(name)
Good point. I personally like the idea of preserving the same signature as |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing improvement 🚀
Just left a couple of totally optional suggestions
return self._prim_type.value | ||
|
||
@prim_type.setter | ||
def prim_type(self, value): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def prim_type(self, value): | |
def prim_type(self, value: Union[str, PrimitiveType]): |
Just to emphasize that currently both of these work:
from pennylane.capture.custom_primitives import PrimitiveType, QmlPrimitive
prim = QmlPrimitive("primitive")
prim.prim_type = "default"
prim.prim_type = PrimitiveType("default")
in the future, it would also be nice to raise a more informative error and list all the available options, but right now this seems perfectly fine to me
This submodule offers custom primitives for the PennyLane capture module. | ||
""" | ||
from enum import Enum | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from typing import Union |
if custom_handler: | ||
invals = [self.read(invar) for invar in eqn.invars] | ||
outvals = custom_handler(self, *invals, **eqn.params) | ||
elif isinstance(eqn.outvars[0].aval, AbstractOperator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice replacement. There are also similar possible replacements in tests.capture.test_operators.py
(although they are just tests, so probably not relevant)
This PR adds a
QmlPrimitive
subclass ofjax.core.Primitive
. This class contains aprim_type
property set using a newPrimitiveType
enum.PrimitiveType
s currently available are "default", "operator", "measurement", "transform", and "higher_order". This can be made more or less fine grained as needed, but should be enough to differentiate between different types of primitives for now. Additionally, this PR:NonInterpPrimitive
to be a subclass ofQmlPrimitive
QmlPrimitive
orNonInterpPrimitive
. See this comment to see the logic used to determine whichPrimitive
subclass is used for each primitive.PlxprInterpreter.eval
andCancelInversesInterpreter.eval
to use thisprim_type
property.[sc-82420]