From 7100132fc15757c67e9c8f5f5c7f233c2fa31e8e Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Mon, 15 Jan 2024 10:11:34 +0000 Subject: [PATCH 1/2] feat: Add compile-time Python expressions (#74) Co-authored-by: Kartik Singhal --- guppy/cfg/bb.py | 23 +++++----- guppy/cfg/builder.py | 20 ++++++++- guppy/checker/core.py | 53 +++++++++++++++++++++-- guppy/checker/expr_checker.py | 50 +++++++++++++++++++-- guppy/checker/func_checker.py | 2 +- guppy/compiler/expr_compiler.py | 21 ++++++--- guppy/module.py | 48 ++++++++++++++++++--- guppy/nodes.py | 8 ++++ tests/error/py_errors/__init__.py | 0 tests/error/py_errors/guppy_name1.err | 7 +++ tests/error/py_errors/guppy_name1.py | 6 +++ tests/error/py_errors/guppy_name2.err | 7 +++ tests/error/py_errors/guppy_name2.py | 9 ++++ tests/error/py_errors/no_args.err | 7 +++ tests/error/py_errors/no_args.py | 6 +++ tests/error/py_errors/python_err.err | 12 ++++++ tests/error/py_errors/python_err.py | 6 +++ tests/error/py_errors/unsupported.err | 7 +++ tests/error/py_errors/unsupported.py | 6 +++ tests/error/test_py_errors.py | 20 +++++++++ tests/integration/test_py.py | 62 +++++++++++++++++++++++++++ tests/integration/util.py | 7 +++ 22 files changed, 355 insertions(+), 32 deletions(-) create mode 100644 tests/error/py_errors/__init__.py create mode 100644 tests/error/py_errors/guppy_name1.err create mode 100644 tests/error/py_errors/guppy_name1.py create mode 100644 tests/error/py_errors/guppy_name2.err create mode 100644 tests/error/py_errors/guppy_name2.py create mode 100644 tests/error/py_errors/no_args.err create mode 100644 tests/error/py_errors/no_args.py create mode 100644 tests/error/py_errors/python_err.err create mode 100644 tests/error/py_errors/python_err.py create mode 100644 tests/error/py_errors/unsupported.err create mode 100644 tests/error/py_errors/unsupported.py create mode 100644 tests/error/test_py_errors.py create mode 100644 tests/integration/test_py.py diff --git a/guppy/cfg/bb.py b/guppy/cfg/bb.py index 99fd61a2..a91965ca 100644 --- a/guppy/cfg/bb.py +++ b/guppy/cfg/bb.py @@ -6,7 +6,7 @@ from typing_extensions import Self from guppy.ast_util import AstNode, name_nodes_in_ast -from guppy.nodes import NestedFunctionDef +from guppy.nodes import NestedFunctionDef, PyExpr if TYPE_CHECKING: from guppy.cfg.cfg import BaseCFG @@ -83,10 +83,9 @@ def compute_variable_stats(self) -> None: visitor = VariableVisitor(self) for s in self.statements: visitor.visit(s) - self._vars = visitor.stats - if self.branch_pred is not None: - self._vars.update_used(self.branch_pred) + visitor.visit(self.branch_pred) + self._vars = visitor.stats class VariableVisitor(ast.NodeVisitor): @@ -99,24 +98,31 @@ def __init__(self, bb: BB): self.bb = bb self.stats = VariableStats() + def visit_Name(self, node: ast.Name) -> None: + self.stats.update_used(node) + def visit_Assign(self, node: ast.Assign) -> None: - self.stats.update_used(node.value) + self.visit(node.value) for t in node.targets: for name in name_nodes_in_ast(t): self.stats.assigned[name.id] = node def visit_AugAssign(self, node: ast.AugAssign) -> None: - self.stats.update_used(node.value) + self.visit(node.value) self.stats.update_used(node.target) # The target is also used for name in name_nodes_in_ast(node.target): self.stats.assigned[name.id] = node def visit_AnnAssign(self, node: ast.AnnAssign) -> None: if node.value: - self.stats.update_used(node.value) + self.visit(node.value) for name in name_nodes_in_ast(node.target): self.stats.assigned[name.id] = node + def visit_PyExpr(self, node: PyExpr) -> None: + # Don't look into `py(...)` expressions + pass + def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> None: # In order to compute the used external variables in a nested function # definition, we have to run live variable analysis first @@ -139,6 +145,3 @@ def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> None: # The name of the function is now assigned self.stats.assigned[node.name] = node - - def generic_visit(self, node: ast.AST) -> None: - self.stats.update_used(node) diff --git a/guppy/cfg/builder.py b/guppy/cfg/builder.py index b83b0483..3296ec93 100644 --- a/guppy/cfg/builder.py +++ b/guppy/cfg/builder.py @@ -3,13 +3,13 @@ from collections.abc import Iterator from typing import NamedTuple -from guppy.ast_util import AstVisitor, set_location_from +from guppy.ast_util import AstVisitor, set_location_from, with_loc from guppy.cfg.bb import BB, BBStatement from guppy.cfg.cfg import CFG from guppy.checker.core import Globals from guppy.error import GuppyError, InternalGuppyError from guppy.gtypes import NoneType -from guppy.nodes import NestedFunctionDef +from guppy.nodes import NestedFunctionDef, PyExpr # In order to build expressions, need an endless stream of unique temporary variables # to store intermediate results @@ -258,6 +258,22 @@ def visit_IfExp(self, node: ast.IfExp) -> ast.Name: # The final value is stored in the temporary variable return self._make_var(tmp, node) + def visit_Call(self, node: ast.Call) -> ast.AST: + # Parse compile-time evaluated `py(...)` expression + if isinstance(node.func, ast.Name) and node.func.id == "py": + match node.args: + case []: + raise GuppyError( + "Compile-time `py(...)` expression requires an argument", + node, + ) + case [arg]: + pass + case args: + arg = with_loc(node, ast.Tuple(elts=args, ctx=ast.Load)) + return with_loc(node, PyExpr(value=arg)) + return self.generic_visit(node) + def generic_visit(self, node: ast.AST) -> ast.AST: # Short-circuit expressions must be built using the `BranchBuilder`. However, we # can turn them into regular expressions by assigning True/False to a temporary diff --git a/guppy/checker/core.py b/guppy/checker/core.py index 600c3669..aa495311 100644 --- a/guppy/checker/core.py +++ b/guppy/checker/core.py @@ -1,9 +1,9 @@ import ast from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import NamedTuple +from typing import Any, NamedTuple -from guppy.ast_util import AstNode +from guppy.ast_util import AstNode, name_nodes_in_ast from guppy.gtypes import ( BoolType, FunctionType, @@ -52,6 +52,9 @@ class TypeVarDecl: linear: bool +PyScope = dict[str, Any] + + class Globals(NamedTuple): """Collection of names that are available on module-level. @@ -62,6 +65,7 @@ class Globals(NamedTuple): values: dict[str, Variable] types: dict[str, type[GuppyType]] type_vars: dict[str, TypeVarDecl] + python_scope: PyScope @staticmethod def default() -> "Globals": @@ -73,7 +77,7 @@ def default() -> "Globals": NoneType.name: NoneType, BoolType.name: BoolType, } - return Globals({}, tys, {}) + return Globals({}, tys, {}, {}) def get_instance_func(self, ty: GuppyType, name: str) -> CallableVariable | None: """Looks up an instance function with a given name for a type. @@ -92,6 +96,7 @@ def __or__(self, other: "Globals") -> "Globals": self.values | other.values, self.types | other.types, self.type_vars | other.type_vars, + self.python_scope | other.python_scope, ) def __ior__(self, other: "Globals") -> "Globals": # noqa: PYI034 @@ -112,6 +117,48 @@ class Context(NamedTuple): locals: Locals +class DummyEvalDict(PyScope): + """A custom dict that can be passed to `eval` to give better error messages. + This class is used to implement the `py(...)` expression. If the user tries to + access a Guppy variable in the Python context, we give an informative error message. + """ + + ctx: Context + node: ast.expr + + @dataclass + class GuppyVarUsedError(BaseException): + """Error that is raised when the user tries to access a Guppy variable.""" + + var: str + node: ast.Name | None + + def __init__(self, ctx: Context, node: ast.expr): + super().__init__(**ctx.globals.python_scope) + self.ctx = ctx + self.node = node + + def _check_item(self, key: str) -> None: + # Catch the user trying to access Guppy variables + if key in self.ctx.locals: + # Find the name node in the AST where the usage occurs + n = next((n for n in name_nodes_in_ast(self.node) if n.id == key), None) + raise self.GuppyVarUsedError(key, n) + + def __getitem__(self, key: str) -> Any: + self._check_item(key) + return super().__getitem__(key) + + def __delitem__(self, key: str) -> None: + self._check_item(key) + super().__delitem__(key) + + def __contains__(self, key: object) -> bool: + if isinstance(key, str) and key in self.ctx.locals: + return True + return super().__contains__(key) + + def qualified_name(ty: type[GuppyType] | str, name: str) -> str: """Returns a qualified name for an instance function on a type.""" ty_name = ty if isinstance(ty, str) else ty.name diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index 5ad23c8a..5ad876bf 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -21,11 +21,13 @@ """ import ast +import sys +import traceback from contextlib import suppress -from typing import Any, NoReturn +from typing import Any, NoReturn, cast from guppy.ast_util import AstNode, AstVisitor, get_type_opt, with_loc, with_type -from guppy.checker.core import CallableVariable, Context, Globals +from guppy.checker.core import CallableVariable, Context, DummyEvalDict, Globals from guppy.error import ( GuppyError, GuppyTypeError, @@ -42,7 +44,7 @@ TupleType, unify, ) -from guppy.nodes import GlobalName, LocalCall, LocalName, TypeApply +from guppy.nodes import GlobalName, LocalCall, LocalName, PyExpr, TypeApply # Mapping from unary AST op to dunder method and display name unary_table: dict[type[ast.unaryop], tuple[str, str]] = { @@ -324,6 +326,43 @@ def visit_Call(self, node: ast.Call) -> tuple[ast.expr, GuppyType]: else: raise GuppyTypeError(f"Expected function type, got `{ty}`", node.func) + def visit_PyExpr(self, node: PyExpr) -> tuple[ast.expr, GuppyType]: + # The method we used for obtaining the Python variables in scope only works in + # CPython (see `get_py_scope()`). + if sys.implementation.name != "cpython": + raise GuppyError( + "Compile-time `py(...)` expressions are only supported in CPython", node + ) + + try: + python_val = eval( # noqa: S307, PGH001 + ast.unparse(node.value), + None, + DummyEvalDict(self.ctx, node.value), + ) + except DummyEvalDict.GuppyVarUsedError as e: + raise GuppyError( + f"Guppy variable `{e.var}` cannot be accessed in a compile-time " + "`py(...)` expression", + e.node or node, + ) from None + except Exception as e: # noqa: BLE001 + # Remove the top frame pointing to the `eval` call from the stack trace + tb = e.__traceback__.tb_next if e.__traceback__ else None + raise GuppyError( + "Error occurred while evaluating Python expression:\n\n" + + "".join(traceback.format_exception(type(e), e, tb)), + node, + ) from e + + if ty := python_value_to_guppy_type(python_val, node, self.ctx.globals): + return with_loc(node, ast.Constant(value=python_val)), ty + + raise GuppyError( + f"Python expression of type `{type(python_val)}` is not supported by Guppy", + node, + ) + def visit_NamedExpr(self, node: ast.NamedExpr) -> tuple[ast.expr, GuppyType]: raise InternalGuppyError( "BB contains `NamedExpr`. Should have been removed during CFG" @@ -614,5 +653,10 @@ def python_value_to_guppy_type( return globals.types["int"].build(node=node) case float(): return globals.types["float"].build(node=node) + case tuple(elts): + tys = [python_value_to_guppy_type(elt, node, globals) for elt in elts] + if any(ty is None for ty in tys): + return None + return TupleType(cast(list[GuppyType], tys)) case _: return None diff --git a/guppy/checker/func_checker.py b/guppy/checker/func_checker.py index dcef53bd..ea7f23b6 100644 --- a/guppy/checker/func_checker.py +++ b/guppy/checker/func_checker.py @@ -142,7 +142,7 @@ def check_nested_func_def( if not captured: # If there are no captured vars, we treat the function like a global name func = DefinedFunction(func_def.name, func_ty, func_def, None) - globals = ctx.globals | Globals({func_def.name: func}, {}, {}) + globals = ctx.globals | Globals({func_def.name: func}, {}, {}, {}) else: # Otherwise, we treat it like a local name diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py index 97b08945..cf01fc07 100644 --- a/guppy/compiler/expr_compiler.py +++ b/guppy/compiler/expr_compiler.py @@ -145,10 +145,17 @@ def python_value_to_hugr(v: Any) -> val.Value | None: """ from guppy.prelude._internal import bool_value, float_value, int_value - if isinstance(v, bool): - return bool_value(v) - elif isinstance(v, int): - return int_value(v) - elif isinstance(v, float): - return float_value(v) - return None + match v: + case bool(): + return bool_value(v) + case int(): + return int_value(v) + case float(): + return float_value(v) + case tuple(elts): + vs = [python_value_to_hugr(elt) for elt in elts] + if any(value is None for value in vs): + return None + return val.Tuple(vs=vs) + case _: + return None diff --git a/guppy/module.py b/guppy/module.py index 5dfd77c3..bd41d5d2 100644 --- a/guppy/module.py +++ b/guppy/module.py @@ -1,12 +1,13 @@ import ast import inspect +import sys import textwrap from collections.abc import Callable from types import ModuleType from typing import Any, Union from guppy.ast_util import AstNode, annotate_location -from guppy.checker.core import Globals, TypeVarDecl, qualified_name +from guppy.checker.core import Globals, PyScope, TypeVarDecl, qualified_name from guppy.checker.func_checker import DefinedFunction, check_global_func_def from guppy.compiler.core import CompiledGlobals from guppy.compiler.func_compiler import CompiledFunctionDef, compile_global_func_def @@ -37,7 +38,7 @@ class GuppyModule: _compiled_globals: CompiledGlobals # Mappings of functions defined in this module - _func_defs: dict[str, ast.FunctionDef] + _func_defs: dict[str, tuple[ast.FunctionDef, PyScope]] _func_decls: dict[str, ast.FunctionDef] _custom_funcs: dict[str, CustomFunction] @@ -48,7 +49,7 @@ class GuppyModule: def __init__(self, name: str, import_builtins: bool = True): self.name = name - self._globals = Globals({}, {}, {}) + self._globals = Globals({}, {}, {}, {}) self._compiled_globals = {} self._imported_globals = Globals.default() self._imported_compiled_globals = {} @@ -100,7 +101,7 @@ def register_func_def( qualified_name(instance, func_ast.name) if instance else func_ast.name ) self._check_name_available(name, func_ast) - self._func_defs[name] = func_ast + self._func_defs[name] = func_ast, get_py_scope(f) def register_func_decl(self, f: PyFunc) -> None: """Registers a Python function declaration as belonging to this Guppy module.""" @@ -159,7 +160,7 @@ def compile(self) -> Hugr | None: func.check_type(self._imported_globals | self._globals) defined_funcs = { x: DefinedFunction.from_ast(f, x, self._imported_globals | self._globals) - for x, f in self._func_defs.items() + for x, (f, _) in self._func_defs.items() } declared_funcs = { x: DeclaredFunction.from_ast(f, x, self._imported_globals | self._globals) @@ -171,7 +172,12 @@ def compile(self) -> Hugr | None: # Type check function definitions checked = { - x: check_global_func_def(f, self._imported_globals | self._globals) + x: check_global_func_def( + f, + self._imported_globals + | self._globals + | Globals({}, {}, {}, self._func_defs[x][1]), + ) for x, f in defined_funcs.items() } @@ -249,3 +255,33 @@ def parse_py_func(f: PyFunc) -> ast.FunctionDef: if not isinstance(func_ast, ast.FunctionDef): raise GuppyError("Expected a function definition", func_ast) return func_ast + + +def get_py_scope(f: PyFunc) -> PyScope: + """Returns a mapping of all variables captured by a Python function. + + Note that this function only works in CPython. On other platforms, an empty + dictionary is returned. + + Relies on inspecting the `__globals__` and `__closure__` attributes of the function. + See https://docs.python.org/3/reference/datamodel.html#special-read-only-attributes + """ + if sys.implementation.name != "cpython": + return {} + + if inspect.ismethod(f): + f = f.__func__ + code = f.__code__ + + nonlocals: PyScope = {} + if f.__closure__ is not None: + for var, cell in zip(code.co_freevars, f.__closure__): + try: + value = cell.cell_contents + except ValueError: + # The call to `cell_contents` will fail if `var` is a recursive + # reference to the decorated function + continue + nonlocals[var] = value + + return nonlocals | f.__globals__.copy() diff --git a/guppy/nodes.py b/guppy/nodes.py index a08d30da..ef50d84c 100644 --- a/guppy/nodes.py +++ b/guppy/nodes.py @@ -60,6 +60,14 @@ class TypeApply(ast.expr): ) +class PyExpr(ast.expr): + """A compile-time evaluated `py(...)` expression.""" + + value: ast.expr + + _fields = ("value",) + + class NestedFunctionDef(ast.FunctionDef): cfg: "CFG" ty: FunctionType diff --git a/tests/error/py_errors/__init__.py b/tests/error/py_errors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/error/py_errors/guppy_name1.err b/tests/error/py_errors/guppy_name1.err new file mode 100644 index 00000000..f870fe12 --- /dev/null +++ b/tests/error/py_errors/guppy_name1.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:6 + +4: @guppy +5: def foo(x: int) -> int: +6: return py(x + 1) + ^ +GuppyError: Guppy variable `x` cannot be accessed in a compile-time `py(...)` expression diff --git a/tests/error/py_errors/guppy_name1.py b/tests/error/py_errors/guppy_name1.py new file mode 100644 index 00000000..d1697d0a --- /dev/null +++ b/tests/error/py_errors/guppy_name1.py @@ -0,0 +1,6 @@ +from guppy.decorator import guppy + + +@guppy +def foo(x: int) -> int: + return py(x + 1) diff --git a/tests/error/py_errors/guppy_name2.err b/tests/error/py_errors/guppy_name2.err new file mode 100644 index 00000000..fb0c8437 --- /dev/null +++ b/tests/error/py_errors/guppy_name2.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:9 + +7: @guppy +8: def foo(x: int) -> int: +9: return py(x + 1) + ^ +GuppyError: Guppy variable `x` cannot be accessed in a compile-time `py(...)` expression diff --git a/tests/error/py_errors/guppy_name2.py b/tests/error/py_errors/guppy_name2.py new file mode 100644 index 00000000..44ca9852 --- /dev/null +++ b/tests/error/py_errors/guppy_name2.py @@ -0,0 +1,9 @@ +from guppy.decorator import guppy + + +x = 42 + + +@guppy +def foo(x: int) -> int: + return py(x + 1) diff --git a/tests/error/py_errors/no_args.err b/tests/error/py_errors/no_args.err new file mode 100644 index 00000000..5f8f43fb --- /dev/null +++ b/tests/error/py_errors/no_args.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:6 + +4: @guppy +5: def foo() -> int: +6: return py() + ^^^^ +GuppyError: Compile-time `py(...)` expression requires an argument diff --git a/tests/error/py_errors/no_args.py b/tests/error/py_errors/no_args.py new file mode 100644 index 00000000..26bdda63 --- /dev/null +++ b/tests/error/py_errors/no_args.py @@ -0,0 +1,6 @@ +from guppy.decorator import guppy + + +@guppy +def foo() -> int: + return py() diff --git a/tests/error/py_errors/python_err.err b/tests/error/py_errors/python_err.err new file mode 100644 index 00000000..9d4c1e70 --- /dev/null +++ b/tests/error/py_errors/python_err.err @@ -0,0 +1,12 @@ +Guppy compilation failed. Error in file $FILE:6 + +4: @guppy +5: def foo() -> int: +6: return py(1 / 0) + ^^^^^^^^^ +GuppyError: Error occurred while evaluating Python expression: + +Traceback (most recent call last): + File "", line 1, in +ZeroDivisionError: division by zero + diff --git a/tests/error/py_errors/python_err.py b/tests/error/py_errors/python_err.py new file mode 100644 index 00000000..803ec227 --- /dev/null +++ b/tests/error/py_errors/python_err.py @@ -0,0 +1,6 @@ +from guppy.decorator import guppy + + +@guppy +def foo() -> int: + return py(1 / 0) diff --git a/tests/error/py_errors/unsupported.err b/tests/error/py_errors/unsupported.err new file mode 100644 index 00000000..22343611 --- /dev/null +++ b/tests/error/py_errors/unsupported.err @@ -0,0 +1,7 @@ +Guppy compilation failed. Error in file $FILE:6 + +4: @guppy +5: def foo() -> int: +6: return py({1, 2, 3}) + ^^^^^^^^^^^^^ +GuppyError: Python expression of type `` is not supported by Guppy diff --git a/tests/error/py_errors/unsupported.py b/tests/error/py_errors/unsupported.py new file mode 100644 index 00000000..e238f300 --- /dev/null +++ b/tests/error/py_errors/unsupported.py @@ -0,0 +1,6 @@ +from guppy.decorator import guppy + + +@guppy +def foo() -> int: + return py({1, 2, 3}) diff --git a/tests/error/test_py_errors.py b/tests/error/test_py_errors.py new file mode 100644 index 00000000..0dacbd53 --- /dev/null +++ b/tests/error/test_py_errors.py @@ -0,0 +1,20 @@ +import pathlib +import pytest + +from tests.error.util import run_error_test + +path = pathlib.Path(__file__).parent.resolve() / "py_errors" +files = [ + x + for x in path.iterdir() + if x.is_file() + if x.suffix == ".py" and x.name != "__init__.py" +] + +# Turn paths into strings, otherwise pytest doesn't display the names +files = [str(f) for f in files] + + +@pytest.mark.parametrize("file", files) +def test_py_errors(file, capsys): + run_error_test(file, capsys) diff --git a/tests/integration/test_py.py b/tests/integration/test_py.py new file mode 100644 index 00000000..2e896aee --- /dev/null +++ b/tests/integration/test_py.py @@ -0,0 +1,62 @@ +from guppy.decorator import guppy +from tests.integration.util import py + + +def test_basic(validate): + x = 42 + + @guppy + def foo() -> int: + return py(x + 1) + + validate(foo) + + +def test_builtin(validate): + @guppy + def foo() -> int: + return py(len({"a": 1337, "b": None})) + + validate(foo) + + +def test_if(validate): + b = True + + @guppy + def foo() -> int: + if py(b or 1 > 6): + return 0 + return 1 + + validate(foo) + + +def test_redeclare_after(validate): + x = 1 + + @guppy + def foo() -> int: + return py(x) + + x = False + + validate(foo) + + +def test_tuple(validate): + @guppy + def foo() -> int: + x, y = py((1, False)) + return x + + validate(foo) + + +def test_tuple_implicit(validate): + @guppy + def foo() -> int: + x, y = py(1, False) + return x + + validate(foo) diff --git a/tests/integration/util.py b/tests/integration/util.py index a99478e2..89343c73 100644 --- a/tests/integration/util.py +++ b/tests/integration/util.py @@ -1,3 +1,5 @@ +from typing import Any + import validator @@ -13,3 +15,8 @@ def __matmul__(self, other): # Dummy names to import to avoid errors for `_@functional` pseudo-decorator: functional = Decorator() _ = Decorator() + + +def py(*args: Any) -> Any: + """Dummy function to import to avoid errors for `py(...)` expressions""" + raise NotImplementedError From 7fac597cec6b51c5cdbf09037952076d9803b4c1 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 15 Jan 2024 11:34:18 +0000 Subject: [PATCH 2/2] fix: default input extensions should be None (#99) --- guppy/hugr/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guppy/hugr/ops.py b/guppy/hugr/ops.py index 970ee163..7d60a6d8 100644 --- a/guppy/hugr/ops.py +++ b/guppy/hugr/ops.py @@ -26,7 +26,7 @@ class BaseOp(ABC, BaseModel): # Parent node index of node the op belongs to, used only at serialization time parent: NodeID = 0 - input_extensions: ExtensionSet = Field(default_factory=list) + input_extensions: ExtensionSet | None = None def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: """Hook to insert type information from the input and output ports into the