diff --git a/jax_triton/experimental/fusion/jaxpr_rewriter.py b/jax_triton/experimental/fusion/jaxpr_rewriter.py index 55d2a16..bdde241 100644 --- a/jax_triton/experimental/fusion/jaxpr_rewriter.py +++ b/jax_triton/experimental/fusion/jaxpr_rewriter.py @@ -35,7 +35,8 @@ class Node(matcher.Pattern, metaclass=abc.ABCMeta): - @abc.abstractproperty + @property + @abc.abstractmethod def parents(self) -> list[Node]: ...