Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Capture] Add a QmlPrimitive class to differentiate between different types of primitives #6847

Open
wants to merge 13 commits into
base: master
Choose a base branch
from

Conversation

mudit2812
Copy link
Contributor

@mudit2812 mudit2812 commented Jan 16, 2025

This PR adds a QmlPrimitive subclass of jax.core.Primitive. This class contains a prim_type property set using a new PrimitiveType enum. PrimitiveTypes currently available are "default", "operator", "measurement", "transform", and "higher_order". This can be made more or less fine grained as needed, but should be enough to differentiate between different types of primitives for now. Additionally, this PR:

  • updates NonInterpPrimitive to be a subclass of QmlPrimitive
  • updates all existing PennyLane primitives to be either QmlPrimitive or NonInterpPrimitive. See this comment to see the logic used to determine which Primitive subclass is used for each primitive.
  • updates PlxprInterpreter.eval and CancelInversesInterpreter.eval to use this prim_type property.

[sc-82420]

@mudit2812 mudit2812 marked this pull request as draft January 16, 2025 20:09
Copy link

codecov bot commented Jan 16, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.60%. Comparing base (397273b) to head (828ad46).

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #6847   +/-   ##
=======================================
  Coverage   99.60%   99.60%           
=======================================
  Files         476      477    +1     
  Lines       45182    45217   +35     
=======================================
+ Hits        45002    45037   +35     
  Misses        180      180           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@mudit2812 mudit2812 marked this pull request as ready for review January 20, 2025 16:12
@mudit2812 mudit2812 requested a review from albi3ro January 20, 2025 16:14
@mudit2812
Copy link
Contributor Author

mudit2812 commented Jan 20, 2025

Subclasses of PlxprInterpreter still need to be updated. Done

Copy link
Contributor

@albi3ro albi3ro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also consider adding it to the __init__?

def __init__(self, name, prim_type):
    self._prim_type = prim_type
    super().__init__(name)

@mudit2812
Copy link
Contributor Author

We could also consider adding it to the __init__?

def __init__(self, name, prim_type):
    self._prim_type = prim_type
    super().__init__(name)

Good point. I personally like the idea of preserving the same signature as jax.core.Primitive. It's also consistent with how we currently set primitive.multiple_results. Happy to hear more thoughts and update the constructor as needed.

Copy link
Contributor

@PietropaoloFrisoni PietropaoloFrisoni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing improvement 🚀

Just left a couple of totally optional suggestions

return self._prim_type.value

@prim_type.setter
def prim_type(self, value):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def prim_type(self, value):
def prim_type(self, value: Union[str, PrimitiveType]):

Just to emphasize that currently both of these work:

from pennylane.capture.custom_primitives import PrimitiveType, QmlPrimitive

prim = QmlPrimitive("primitive")

prim.prim_type = "default"

prim.prim_type = PrimitiveType("default")

in the future, it would also be nice to raise a more informative error and list all the available options, but right now this seems perfectly fine to me

This submodule offers custom primitives for the PennyLane capture module.
"""
from enum import Enum

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from typing import Union

if custom_handler:
invals = [self.read(invar) for invar in eqn.invars]
outvals = custom_handler(self, *invals, **eqn.params)
elif isinstance(eqn.outvars[0].aval, AbstractOperator):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice replacement. There are also similar possible replacements in tests.capture.test_operators.py (although they are just tests, so probably not relevant)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants