diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..b40b54ec --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,65 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 # Use the ref you want to point at + hooks: + - id: check-added-large-files + - id: check-case-conflict + - id: check-executables-have-shebangs + - id: check-merge-conflict + - id: check-toml + - id: check-vcs-permalinks + - id: check-yaml + - id: detect-private-key + - id: end-of-file-fixer + - id: fix-byte-order-marker + - id: mixed-line-ending + # - id: trailing-whitespace + # Python-specific + - id: check-ast + - id: check-docstring-first + - id: debug-statements + +- repo: https://github.com/crate-ci/typos + rev: v1.16.23 + hooks: + - id: typos + args: [] + +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.7 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + - id: ruff-format + +- repo: https://github.com/pre-commit/mirrors-mypy + rev: 'v1.3.0' + hooks: + - id: mypy + pass_filenames: false + args: [--package=guppy] + additional_dependencies: [ + ormsgpack, + pydantic, + ] + +- repo: local + hooks: + - id: cargo-check + name: Cargo check + entry: bash -c 'cd validator && exec cargo check' + pass_filenames: false + types: [file, rust] + language: system + - id: rust-linting + name: Rust linting + entry: bash -c 'cd validator && exec cargo fmt --all --' + pass_filenames: true + types: [file, rust] + language: system + - id: rust-clippy + name: Rust clippy + entry: bash -c 'cd validator && exec cargo clippy --all-targets --all-features -- -Dclippy::all' + pass_filenames: false + types: [file, rust] + language: system diff --git a/.typos.toml b/.typos.toml new file mode 100644 index 00000000..92413257 --- /dev/null +++ b/.typos.toml @@ -0,0 +1,3 @@ +[default.extend-words] +inot = "inot" +fle = "fle" diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..fb07d6b7 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,12 @@ +include *.toml +include *.txt +include *.yaml +recursive-include examples *.ipynb +recursive-include guppy *.py +recursive-include tests *.err +recursive-include tests *.py +recursive-include tests *.sh +recursive-exclude validator *.lock +recursive-exclude validator *.rs +recursive-exclude validator *.toml +exclude .git-blame-ignore-revs diff --git a/README.md b/README.md index f5e76b91..44b8a55f 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,8 @@ pip install -e '.[dev]' ### Git blame -You can configure Git to ignore formatting commits when using `git blame` by running +You can configure Git to ignore formatting commits when using `git blame` by running + ```sh git config blame.ignoreRevsFile .git-blame-ignore-revs ``` @@ -41,14 +42,14 @@ TODO ## Testing -First, build the PyO3 Hugr validation library using +First, build the PyO3 Hugr validation library from the `validator` directory using + ```sh maturin develop ``` -from the `validator` directory. - Run tests using + ```sh pytest -v ``` @@ -59,6 +60,7 @@ Integration test cases can be exported to a directory using pytest --export-test-cases=guppy-exports ``` + which will create a directory `./guppy-exports` populated with hugr modules serialised in msgpack. ## Packaging diff --git a/guppy/ast_util.py b/guppy/ast_util.py index 8c4be4f7..5a9e66f4 100644 --- a/guppy/ast_util.py +++ b/guppy/ast_util.py @@ -1,19 +1,19 @@ import ast -from typing import Any, TypeVar, Generic, Union, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar if TYPE_CHECKING: from guppy.gtypes import GuppyType -AstNode = Union[ - ast.AST, - ast.operator, - ast.expr, - ast.arg, - ast.stmt, - ast.Name, - ast.keyword, - ast.FunctionDef, -] +AstNode = ( + ast.AST + | ast.operator + | ast.expr + | ast.arg + | ast.stmt + | ast.Name + | ast.keyword + | ast.FunctionDef +) T = TypeVar("T", covariant=True) @@ -114,7 +114,9 @@ def set_location_from(node: ast.AST, loc: ast.AST) -> None: node.end_col_offset = loc.end_col_offset source, file, line_offset = get_source(loc), get_file(loc), get_line_offset(loc) - assert source is not None and file is not None and line_offset is not None + assert source is not None + assert file is not None + assert line_offset is not None annotate_location(node, source, file, line_offset) @@ -126,7 +128,7 @@ def annotate_location( setattr(node, "source", source) if recurse: - for field, value in ast.iter_fields(node): + for _field, value in ast.iter_fields(node): if isinstance(value, list): for item in value: if isinstance(item, ast.AST): @@ -135,7 +137,7 @@ def annotate_location( annotate_location(value, source, file, line_offset, recurse) -def get_file(node: AstNode) -> Optional[str]: +def get_file(node: AstNode) -> str | None: """Tries to retrieve a file annotation from an AST node.""" try: file = getattr(node, "file") @@ -144,7 +146,7 @@ def get_file(node: AstNode) -> Optional[str]: return None -def get_source(node: AstNode) -> Optional[str]: +def get_source(node: AstNode) -> str | None: """Tries to retrieve a source annotation from an AST node.""" try: source = getattr(node, "source") @@ -153,7 +155,7 @@ def get_source(node: AstNode) -> Optional[str]: return None -def get_line_offset(node: AstNode) -> Optional[int]: +def get_line_offset(node: AstNode) -> int | None: """Tries to retrieve a line offset annotation from an AST node.""" try: line_offset = getattr(node, "line_offset") diff --git a/guppy/cfg/analysis.py b/guppy/cfg/analysis.py index 18fdc3f4..36ee8780 100644 --- a/guppy/cfg/analysis.py +++ b/guppy/cfg/analysis.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod -from typing import TypeVar, Generic, Iterable +from collections.abc import Iterable +from typing import Generic, TypeVar from guppy.cfg.bb import BB - # Type variable for the lattice domain T = TypeVar("T") @@ -21,12 +21,10 @@ def eq(self, t1: T, t2: T, /) -> bool: @abstractmethod def initial(self) -> T: """Initial lattice value""" - pass @abstractmethod def join(self, *ts: T) -> T: """Lattice join operation""" - pass @abstractmethod def run(self, bbs: Iterable[BB]) -> Result[T]: @@ -34,7 +32,6 @@ def run(self, bbs: Iterable[BB]) -> Result[T]: Returns a mapping from basic blocks to lattice values at the start of each BB. """ - pass class ForwardAnalysis(Analysis[T], ABC, Generic[T]): @@ -43,7 +40,6 @@ class ForwardAnalysis(Analysis[T], ABC, Generic[T]): @abstractmethod def apply_bb(self, val_before: T, bb: BB, /) -> T: """Transformation a basic block applies to a lattice value""" - pass def run(self, bbs: Iterable[BB]) -> Result[T]: """Runs the analysis pass. @@ -69,7 +65,6 @@ class BackwardAnalysis(Analysis[T], ABC, Generic[T]): @abstractmethod def apply_bb(self, val_after: T, bb: BB, /) -> T: """Transformation a basic block applies to a lattice value""" - pass def run(self, bbs: Iterable[BB]) -> Result[T]: """Runs the analysis pass. @@ -106,7 +101,7 @@ def eq(self, live1: LivenessDomain, live2: LivenessDomain) -> bool: return live1.keys() == live2.keys() def initial(self) -> LivenessDomain: - return dict() + return {} def join(self, *ts: LivenessDomain) -> LivenessDomain: res: LivenessDomain = {} diff --git a/guppy/cfg/bb.py b/guppy/cfg/bb.py index 95c9f63a..99fd61a2 100644 --- a/guppy/cfg/bb.py +++ b/guppy/cfg/bb.py @@ -1,7 +1,8 @@ import ast from abc import ABC from dataclasses import dataclass, field -from typing import Optional, TYPE_CHECKING, Union +from typing import TYPE_CHECKING + from typing_extensions import Self from guppy.ast_util import AstNode, name_nodes_in_ast @@ -34,9 +35,14 @@ def update_used(self, node: ast.AST) -> None: self.used[name.id] = name -BBStatement = Union[ - ast.Assign, ast.AugAssign, ast.AnnAssign, ast.Expr, ast.Return, NestedFunctionDef -] +BBStatement = ( + ast.Assign + | ast.AugAssign + | ast.AnnAssign + | ast.Expr + | ast.Return + | NestedFunctionDef +) @dataclass(eq=False) # Disable equality to recover hash from `object` @@ -57,10 +63,10 @@ class BB(ABC): # If the BB has multiple successors, we need a predicate to decide to which one to # jump to - branch_pred: Optional[ast.expr] = None + branch_pred: ast.expr | None = None # Information about assigned/used variables in the BB - _vars: Optional[VariableStats] = None + _vars: VariableStats | None = None @property def vars(self) -> VariableStats: @@ -123,9 +129,7 @@ def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> None: # Only store used *external* variables: things defined in the current BB, as # well as the function name and argument names should not be included assigned_before_in_bb = ( - self.stats.assigned.keys() - | {node.name} - | set(a.arg for a in node.args.args) + self.stats.assigned.keys() | {node.name} | {a.arg for a in node.args.args} ) self.stats.used |= { x: using_bb.vars.used[x] diff --git a/guppy/cfg/builder.py b/guppy/cfg/builder.py index e937e144..878437bf 100644 --- a/guppy/cfg/builder.py +++ b/guppy/cfg/builder.py @@ -1,8 +1,9 @@ import ast import itertools -from typing import Optional, Iterator, NamedTuple +from collections.abc import Iterator +from typing import NamedTuple -from guppy.ast_util import set_location_from, AstVisitor +from guppy.ast_util import AstVisitor, set_location_from from guppy.cfg.bb import BB, BBStatement from guppy.cfg.cfg import CFG from guppy.checker.core import Globals @@ -24,11 +25,11 @@ class Jumps(NamedTuple): """Holds jump targets for return, continue, and break during CFG construction.""" return_bb: BB - continue_bb: Optional[BB] - break_bb: Optional[BB] + continue_bb: BB | None + break_bb: BB | None -class CFGBuilder(AstVisitor[Optional[BB]]): +class CFGBuilder(AstVisitor[BB | None]): """Constructs a CFG from ast nodes.""" cfg: CFG @@ -57,8 +58,8 @@ def build(self, nodes: list[ast.stmt], returns_none: bool, globals: Globals) -> return self.cfg - def visit_stmts(self, nodes: list[ast.stmt], bb: BB, jumps: Jumps) -> Optional[BB]: - bb_opt: Optional[BB] = bb + def visit_stmts(self, nodes: list[ast.stmt], bb: BB, jumps: Jumps) -> BB | None: + bb_opt: BB | None = bb next_functional = False for node in nodes: if bb_opt is None: @@ -69,7 +70,7 @@ def visit_stmts(self, nodes: list[ast.stmt], bb: BB, jumps: Jumps) -> Optional[B if next_functional: # TODO: This should be an assertion that the Hugr can be un-flattened - raise NotImplementedError() + raise NotImplementedError next_functional = False else: bb_opt = self.visit(node, bb_opt, jumps) @@ -86,20 +87,16 @@ def _build_node_value(self, node: BBStatement, bb: BB) -> BB: bb.statements.append(node) return bb - def visit_Assign(self, node: ast.Assign, bb: BB, jumps: Jumps) -> Optional[BB]: + def visit_Assign(self, node: ast.Assign, bb: BB, jumps: Jumps) -> BB | None: return self._build_node_value(node, bb) - def visit_AugAssign( - self, node: ast.AugAssign, bb: BB, jumps: Jumps - ) -> Optional[BB]: + def visit_AugAssign(self, node: ast.AugAssign, bb: BB, jumps: Jumps) -> BB | None: return self._build_node_value(node, bb) - def visit_AnnAssign( - self, node: ast.AnnAssign, bb: BB, jumps: Jumps - ) -> Optional[BB]: + def visit_AnnAssign(self, node: ast.AnnAssign, bb: BB, jumps: Jumps) -> BB | None: return self._build_node_value(node, bb) - def visit_Expr(self, node: ast.Expr, bb: BB, jumps: Jumps) -> Optional[BB]: + def visit_Expr(self, node: ast.Expr, bb: BB, jumps: Jumps) -> BB | None: # This is an expression statement where the value is discarded node.value, bb = ExprBuilder.build(node.value, self.cfg, bb) # We don't add it to the BB if it's just a temporary variable. This will be the @@ -110,7 +107,7 @@ def visit_Expr(self, node: ast.Expr, bb: BB, jumps: Jumps) -> Optional[BB]: bb.statements.append(node) return bb - def visit_If(self, node: ast.If, bb: BB, jumps: Jumps) -> Optional[BB]: + def visit_If(self, node: ast.If, bb: BB, jumps: Jumps) -> BB | None: then_bb, else_bb = self.cfg.new_bb(), self.cfg.new_bb() BranchBuilder.add_branch(node.test, self.cfg, bb, then_bb, else_bb) then_bb = self.visit_stmts(node.body, then_bb, jumps) @@ -127,7 +124,7 @@ def visit_If(self, node: ast.If, bb: BB, jumps: Jumps) -> Optional[BB]: # No branch jumps: We have to merge the control flow return self.cfg.new_bb(then_bb, else_bb) - def visit_While(self, node: ast.While, bb: BB, jumps: Jumps) -> Optional[BB]: + def visit_While(self, node: ast.While, bb: BB, jumps: Jumps) -> BB | None: head_bb = self.cfg.new_bb(bb) body_bb, tail_bb = self.cfg.new_bb(), self.cfg.new_bb() BranchBuilder.add_branch(node.test, self.cfg, head_bb, body_bb, tail_bb) @@ -145,29 +142,29 @@ def visit_While(self, node: ast.While, bb: BB, jumps: Jumps) -> Optional[BB]: # its own jumps since the body is not guaranteed to execute return tail_bb - def visit_Continue(self, node: ast.Continue, bb: BB, jumps: Jumps) -> Optional[BB]: + def visit_Continue(self, node: ast.Continue, bb: BB, jumps: Jumps) -> BB | None: if not jumps.continue_bb: raise InternalGuppyError("Continue BB not defined") self.cfg.link(bb, jumps.continue_bb) return None - def visit_Break(self, node: ast.Break, bb: BB, jumps: Jumps) -> Optional[BB]: + def visit_Break(self, node: ast.Break, bb: BB, jumps: Jumps) -> BB | None: if not jumps.break_bb: raise InternalGuppyError("Break BB not defined") self.cfg.link(bb, jumps.break_bb) return None - def visit_Return(self, node: ast.Return, bb: BB, jumps: Jumps) -> Optional[BB]: + def visit_Return(self, node: ast.Return, bb: BB, jumps: Jumps) -> BB | None: bb = self._build_node_value(node, bb) self.cfg.link(bb, jumps.return_bb) return None - def visit_Pass(self, node: ast.Pass, bb: BB, jumps: Jumps) -> Optional[BB]: + def visit_Pass(self, node: ast.Pass, bb: BB, jumps: Jumps) -> BB | None: return bb def visit_FunctionDef( self, node: ast.FunctionDef, bb: BB, jumps: Jumps - ) -> Optional[BB]: + ) -> BB | None: from guppy.checker.func_checker import check_signature func_ty = check_signature(node, self.globals) @@ -188,7 +185,7 @@ def visit_FunctionDef( bb.statements.append(new_node) return bb - def generic_visit(self, node: ast.AST, bb: BB, jumps: Jumps) -> Optional[BB]: # type: ignore + def generic_visit(self, node: ast.AST, bb: BB, jumps: Jumps) -> BB | None: # type: ignore[override] # When adding support for new statements, we have to remember to use the # ExprBuilder to transform all included expressions! raise GuppyError("Statement is not supported", node) @@ -215,7 +212,7 @@ def build(node: ast.expr, cfg: CFG, bb: BB) -> tuple[ast.expr, BB]: return builder.visit(node), builder.bb @classmethod - def _make_var(cls, name: str, loc: Optional[ast.expr] = None) -> ast.Name: + def _make_var(cls, name: str, loc: ast.expr | None = None) -> ast.Name: """Creates an `ast.Name` node.""" node = ast.Name(id=name, ctx=ast.Load) if loc is not None: @@ -343,7 +340,7 @@ def visit_Compare( # Support chained comparisons, e.g. `x <= 5 < y` by compiling to `x <= 5 and # 5 < y`. This way we get short-circuit evaluation for free. if len(node.comparators) > 1: - comparators = [node.left] + node.comparators + comparators = [node.left, *node.comparators] values = [ ast.Compare( left=left, @@ -368,7 +365,7 @@ def visit_IfExp(self, node: ast.IfExp, bb: BB, true_bb: BB, false_bb: BB) -> Non self.visit(node.body, then_bb, true_bb, false_bb) self.visit(node.orelse, else_bb, true_bb, false_bb) - def generic_visit(self, node: ast.expr, bb: BB, true_bb: BB, false_bb: BB) -> None: # type: ignore + def generic_visit(self, node: ast.expr, bb: BB, true_bb: BB, false_bb: BB) -> None: # type: ignore[override] # We can always fall back to building the node as a regular expression and using # the result as a branch predicate pred, bb = ExprBuilder.build(node, self.cfg, bb) diff --git a/guppy/cfg/cfg.py b/guppy/cfg/cfg.py index 2648f696..986988d5 100644 --- a/guppy/cfg/cfg.py +++ b/guppy/cfg/cfg.py @@ -1,16 +1,15 @@ -from typing import Optional, TypeVar, Generic +from typing import Generic, TypeVar from guppy.cfg.analysis import ( - LivenessDomain, - LivenessAnalysis, AssignmentAnalysis, DefAssignmentDomain, + LivenessAnalysis, + LivenessDomain, MaybeAssignmentDomain, Result, ) from guppy.cfg.bb import BB, BBStatement - T = TypeVar("T", bound=BB) @@ -26,7 +25,7 @@ class BaseCFG(Generic[T]): maybe_ass_before: Result[MaybeAssignmentDomain] def __init__( - self, bbs: list[T], entry_bb: Optional[T] = None, exit_bb: Optional[T] = None + self, bbs: list[T], entry_bb: T | None = None, exit_bb: T | None = None ): self.bbs = bbs if entry_bb: @@ -54,7 +53,7 @@ def __init__(self) -> None: self.entry_bb = self.new_bb() self.exit_bb = self.new_bb() - def new_bb(self, *preds: BB, statements: Optional[list[BBStatement]] = None) -> BB: + def new_bb(self, *preds: BB, statements: list[BBStatement] | None = None) -> BB: """Adds a new basic block to the CFG.""" bb = BB( len(self.bbs), self, predecessors=list(preds), statements=statements or [] diff --git a/guppy/checker/cfg_checker.py b/guppy/checker/cfg_checker.py index 787b2ccb..c5cf8a64 100644 --- a/guppy/checker/cfg_checker.py +++ b/guppy/checker/cfg_checker.py @@ -5,21 +5,18 @@ """ import collections +from collections.abc import Sequence from dataclasses import dataclass -from typing import Sequence from guppy.ast_util import line_col from guppy.cfg.bb import BB from guppy.cfg.cfg import CFG, BaseCFG -from guppy.checker.core import Globals, Context - -from guppy.checker.core import Variable +from guppy.checker.core import Context, Globals, Variable from guppy.checker.expr_checker import ExprSynthesizer, to_bool from guppy.checker.stmt_checker import StmtChecker from guppy.error import GuppyError from guppy.gtypes import GuppyType - VarRow = Sequence[Variable] @@ -42,7 +39,7 @@ def empty() -> "Signature": class CheckedBB(BB): """Basic block annotated with an input and output type signature.""" - sig: Signature = Signature.empty() + sig: Signature = Signature.empty() # noqa: RUF009 class CheckedCFG(BaseCFG[CheckedBB]): @@ -64,7 +61,7 @@ def check_cfg( unreachable blocks. """ # First, we need to run program analysis - ass_before = set(v.name for v in inputs) + ass_before = {v.name for v in inputs} cfg.analyze(ass_before, ass_before) # We start by compiling the entry BB diff --git a/guppy/checker/core.py b/guppy/checker/core.py index 9bab6375..25dd256d 100644 --- a/guppy/checker/core.py +++ b/guppy/checker/core.py @@ -1,16 +1,16 @@ import ast from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import NamedTuple, Optional, Union +from typing import NamedTuple from guppy.ast_util import AstNode from guppy.gtypes import ( - GuppyType, + BoolType, FunctionType, - TupleType, - SumType, + GuppyType, NoneType, - BoolType, + SumType, + TupleType, ) @@ -20,8 +20,8 @@ class Variable: name: str ty: GuppyType - defined_at: Optional[AstNode] - used: Optional[AstNode] + defined_at: AstNode | None + used: AstNode | None @dataclass @@ -65,7 +65,7 @@ def default() -> "Globals": } return Globals({}, tys) - def get_instance_func(self, ty: GuppyType, name: str) -> Optional[CallableVariable]: + def get_instance_func(self, ty: GuppyType, name: str) -> CallableVariable | None: """Looks up an instance function with a given name for a type. Returns `None` if the name doesn't exist or isn't a function. @@ -83,7 +83,7 @@ def __or__(self, other: "Globals") -> "Globals": self.types | other.types, ) - def __ior__(self, other: "Globals") -> "Globals": + def __ior__(self, other: "Globals") -> "Globals": # noqa: PYI034 self.values.update(other.values) self.types.update(other.types) return self @@ -100,7 +100,7 @@ class Context(NamedTuple): locals: Locals -def qualified_name(ty: Union[type[GuppyType], str], name: str) -> str: +def qualified_name(ty: type[GuppyType] | str, name: str) -> str: """Returns a qualified name for an instance function on a type.""" ty_name = ty if isinstance(ty, str) else ty.name return f"{ty_name}.{name}" diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py index 09c12774..731d8f7c 100644 --- a/guppy/checker/expr_checker.py +++ b/guppy/checker/expr_checker.py @@ -22,13 +22,13 @@ import ast from contextlib import suppress -from typing import Optional, Union, NoReturn, Any +from typing import Any, NoReturn -from guppy.ast_util import AstVisitor, with_loc, AstNode, with_type, get_type_opt -from guppy.checker.core import Context, CallableVariable, Globals +from guppy.ast_util import AstNode, AstVisitor, get_type_opt, with_loc, with_type +from guppy.checker.core import CallableVariable, Context, Globals from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError -from guppy.gtypes import GuppyType, TupleType, FunctionType, BoolType -from guppy.nodes import LocalName, GlobalName, LocalCall +from guppy.gtypes import BoolType, FunctionType, GuppyType, TupleType +from guppy.nodes import GlobalName, LocalCall, LocalName # Mapping from unary AST op to dunder method and display name unary_table: dict[type[ast.unaryop], tuple[str, str]] = { @@ -38,7 +38,7 @@ } # fmt: skip # Mapping from binary AST op to left dunder method, right dunder method and display name -AstOp = Union[ast.operator, ast.cmpop] +AstOp = ast.operator | ast.cmpop binary_table: dict[type[AstOp], tuple[str, str, str]] = { ast.Add: ("__add__", "__radd__", "+"), ast.Sub: ("__sub__", "__rsub__", "-"), @@ -78,8 +78,8 @@ def __init__(self, ctx: Context) -> None: def _fail( self, expected: GuppyType, - actual: Union[ast.expr, GuppyType], - loc: Optional[AstNode] = None, + actual: ast.expr | GuppyType, + loc: AstNode | None = None, ) -> NoReturn: """Raises a type error indicating that the type doesn't match.""" if not isinstance(actual, GuppyType): @@ -348,7 +348,7 @@ def to_bool( def python_value_to_guppy_type( v: Any, node: ast.expr, globals: Globals -) -> Optional[GuppyType]: +) -> GuppyType | None: """Turns a primitive Python value into a Guppy type. Returns `None` if the Python value cannot be represented in Guppy. diff --git a/guppy/checker/func_checker.py b/guppy/checker/func_checker.py index e0695824..20113685 100644 --- a/guppy/checker/func_checker.py +++ b/guppy/checker/func_checker.py @@ -8,15 +8,15 @@ import ast from dataclasses import dataclass -from guppy.ast_util import return_nodes_in_ast, AstNode, with_loc +from guppy.ast_util import AstNode, return_nodes_in_ast, with_loc from guppy.cfg.bb import BB from guppy.cfg.builder import CFGBuilder -from guppy.checker.core import Variable, Globals, Context, CallableVariable -from guppy.checker.cfg_checker import check_cfg, CheckedCFG -from guppy.checker.expr_checker import synthesize_call, check_call +from guppy.checker.cfg_checker import CheckedCFG, check_cfg +from guppy.checker.core import CallableVariable, Context, Globals, Variable +from guppy.checker.expr_checker import check_call, synthesize_call from guppy.error import GuppyError -from guppy.gtypes import FunctionType, type_from_ast, NoneType, GuppyType -from guppy.nodes import GlobalCall, CheckedNestedFunctionDef, NestedFunctionDef +from guppy.gtypes import FunctionType, GuppyType, NoneType, type_from_ast +from guppy.nodes import CheckedNestedFunctionDef, GlobalCall, NestedFunctionDef @dataclass @@ -178,7 +178,7 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType arg_tys = [] arg_names = [] - for i, arg in enumerate(func_def.args.args): + for _i, arg in enumerate(func_def.args.args): if arg.annotation is None: raise GuppyError("Argument type must be annotated", arg) ty = type_from_ast(arg.annotation, globals) diff --git a/guppy/checker/stmt_checker.py b/guppy/checker/stmt_checker.py index 5622ee55..61df3120 100644 --- a/guppy/checker/stmt_checker.py +++ b/guppy/checker/stmt_checker.py @@ -9,14 +9,14 @@ """ import ast -from typing import Sequence +from collections.abc import Sequence -from guppy.ast_util import with_loc, AstVisitor +from guppy.ast_util import AstVisitor, with_loc from guppy.cfg.bb import BB, BBStatement -from guppy.checker.core import Variable, Context -from guppy.checker.expr_checker import ExprSynthesizer, ExprChecker +from guppy.checker.core import Context, Variable +from guppy.checker.expr_checker import ExprChecker, ExprSynthesizer from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError -from guppy.gtypes import GuppyType, TupleType, type_from_ast, NoneType +from guppy.gtypes import GuppyType, NoneType, TupleType, type_from_ast from guppy.nodes import NestedFunctionDef diff --git a/guppy/compiler/cfg_compiler.py b/guppy/compiler/cfg_compiler.py index 086236e9..d2f1abb9 100644 --- a/guppy/compiler/cfg_compiler.py +++ b/guppy/compiler/cfg_compiler.py @@ -1,19 +1,22 @@ import functools -from typing import Sequence +from typing import TYPE_CHECKING -from guppy.checker.cfg_checker import CheckedBB, VarRow, CheckedCFG, Signature +from guppy.checker.cfg_checker import CheckedBB, CheckedCFG, Signature, VarRow from guppy.checker.core import Variable from guppy.compiler.core import ( CompiledGlobals, - is_return_var, DFContainer, - return_var, PortVariable, + is_return_var, + return_var, ) from guppy.compiler.expr_compiler import ExprCompiler from guppy.compiler.stmt_compiler import StmtCompiler -from guppy.gtypes import TupleType, SumType, type_to_row -from guppy.hugr.hugr import Hugr, Node, CFNode, OutPortV +from guppy.gtypes import SumType, TupleType, type_to_row +from guppy.hugr.hugr import CFNode, Hugr, Node, OutPortV + +if TYPE_CHECKING: + from collections.abc import Sequence def compile_cfg( @@ -119,7 +122,8 @@ def insert_return_vars(cfg: CheckedCFG) -> None: # Also patch the predecessors for pred in cfg.exit_bb.predecessors: # The exit BB will be the only successor - assert len(pred.sig.output_rows) == 1 and len(pred.sig.output_rows[0]) == 0 + assert len(pred.sig.output_rows) == 1 + assert len(pred.sig.output_rows[0]) == 0 pred.sig = Signature(pred.sig.input_row, [return_vars]) @@ -144,7 +148,7 @@ def choose_vars_for_tuple_sum( conditional = graph.add_conditional( cond_input=unit_sum, inputs=tuples, parent=dfg.node ) - for i, ty in enumerate(tys): + for i, _ty in enumerate(tys): case = graph.add_case(conditional) inp = graph.add_input(output_tys=tys, parent=case).out_port(i) tag = graph.add_tag(variants=tys, tag=i, inp=inp, parent=case).out_port(0) diff --git a/guppy/compiler/core.py b/guppy/compiler/core.py index 2b6a4fb2..d3b16e59 100644 --- a/guppy/compiler/core.py +++ b/guppy/compiler/core.py @@ -1,11 +1,11 @@ from abc import ABC, abstractmethod +from collections.abc import Iterator from dataclasses import dataclass -from typing import Optional, Iterator from guppy.ast_util import AstNode -from guppy.checker.core import Variable, CallableVariable +from guppy.checker.core import CallableVariable, Variable from guppy.gtypes import FunctionType -from guppy.hugr.hugr import OutPortV, DFContainingNode, Hugr +from guppy.hugr.hugr import DFContainingNode, Hugr, OutPortV @dataclass @@ -21,8 +21,8 @@ def __init__( self, name: str, port: OutPortV, - defined_at: Optional[AstNode], - used: Optional[AstNode] = None, + defined_at: AstNode | None, + used: AstNode | None = None, ) -> None: super().__init__(name, port.ty, defined_at, used) object.__setattr__(self, "port", port) @@ -89,7 +89,7 @@ def __copy__(self) -> "DFContainer": # mutate our variable mapping return DFContainer(self.node, self.locals.copy()) - def get_var(self, name: str) -> Optional[PortVariable]: + def get_var(self, name: str) -> PortVariable | None: return self.locals.get(name, None) diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py index 5fdc378b..f74c18d2 100644 --- a/guppy/compiler/expr_compiler.py +++ b/guppy/compiler/expr_compiler.py @@ -1,13 +1,13 @@ import ast -from typing import Any, Optional +from typing import Any from guppy.ast_util import AstVisitor, get_type -from guppy.compiler.core import CompilerBase, DFContainer, CompiledFunction +from guppy.compiler.core import CompiledFunction, CompilerBase, DFContainer from guppy.error import InternalGuppyError -from guppy.gtypes import FunctionType, type_to_row, BoolType +from guppy.gtypes import BoolType, FunctionType, type_to_row from guppy.hugr import ops, val from guppy.hugr.hugr import OutPortV -from guppy.nodes import LocalName, GlobalName, GlobalCall, LocalCall +from guppy.nodes import GlobalCall, GlobalName, LocalCall, LocalName class ExprCompiler(CompilerBase, AstVisitor[OutPortV]): @@ -100,12 +100,12 @@ def expr_to_row(expr: ast.expr) -> list[ast.expr]: return expr.elts if isinstance(expr, ast.Tuple) else [expr] -def python_value_to_hugr(v: Any) -> Optional[val.Value]: +def python_value_to_hugr(v: Any) -> val.Value | None: """Turns a Python value into a Hugr value. Returns None if the Python value cannot be represented in Guppy. """ - from guppy.prelude._internal import int_value, bool_value, float_value + from guppy.prelude._internal import bool_value, float_value, int_value if isinstance(v, bool): return bool_value(v) diff --git a/guppy/compiler/func_compiler.py b/guppy/compiler/func_compiler.py index 4631f0db..f4376b17 100644 --- a/guppy/compiler/func_compiler.py +++ b/guppy/compiler/func_compiler.py @@ -9,8 +9,8 @@ DFContainer, PortVariable, ) -from guppy.gtypes import type_to_row, FunctionType -from guppy.hugr.hugr import Hugr, OutPortV, DFContainingVNode +from guppy.gtypes import FunctionType, type_to_row +from guppy.hugr.hugr import DFContainingVNode, Hugr, OutPortV from guppy.nodes import CheckedNestedFunctionDef diff --git a/guppy/compiler/stmt_compiler.py b/guppy/compiler/stmt_compiler.py index db656b6f..18147b22 100644 --- a/guppy/compiler/stmt_compiler.py +++ b/guppy/compiler/stmt_compiler.py @@ -1,12 +1,12 @@ import ast -from typing import Sequence +from collections.abc import Sequence from guppy.ast_util import AstVisitor from guppy.checker.cfg_checker import CheckedBB from guppy.compiler.core import ( + CompiledGlobals, CompilerBase, DFContainer, - CompiledGlobals, PortVariable, return_var, ) diff --git a/guppy/custom.py b/guppy/custom.py index a80fd134..4c20ea09 100644 --- a/guppy/custom.py +++ b/guppy/custom.py @@ -1,27 +1,26 @@ import ast from abc import ABC, abstractmethod -from typing import Optional -from guppy.ast_util import AstNode, with_type, with_loc, get_type +from guppy.ast_util import AstNode, get_type, with_loc, with_type from guppy.checker.core import Context, Globals from guppy.checker.expr_checker import check_call, synthesize_call from guppy.checker.func_checker import check_signature -from guppy.compiler.core import CompiledFunction, DFContainer, CompiledGlobals +from guppy.compiler.core import CompiledFunction, CompiledGlobals, DFContainer from guppy.error import ( GuppyError, InternalGuppyError, UnknownFunctionType, ) -from guppy.gtypes import GuppyType, FunctionType, type_to_row +from guppy.gtypes import FunctionType, GuppyType, type_to_row from guppy.hugr import ops -from guppy.hugr.hugr import OutPortV, Hugr, Node, DFContainingVNode +from guppy.hugr.hugr import DFContainingVNode, Hugr, Node, OutPortV from guppy.nodes import GlobalCall class CustomFunction(CompiledFunction): """A function whose type checking and compilation behaviour can be customised.""" - defined_at: Optional[ast.FunctionDef] + defined_at: ast.FunctionDef | None # Whether the function may be used as a higher-order value. This is only possible # if a static type for the function is provided. @@ -30,17 +29,17 @@ class CustomFunction(CompiledFunction): call_checker: "CustomCallChecker" call_compiler: "CustomCallCompiler" - _ty: Optional[FunctionType] = None - _defined: dict[Node, DFContainingVNode] = {} + _ty: FunctionType | None = None + _defined: dict[Node, DFContainingVNode] = {} # noqa: RUF012 def __init__( self, name: str, - defined_at: Optional[ast.FunctionDef], + defined_at: ast.FunctionDef | None, compiler: "CustomCallCompiler", checker: "CustomCallChecker", higher_order_value: bool = True, - ty: Optional[FunctionType] = None, + ty: FunctionType | None = None, ): self.name = name self.defined_at = defined_at @@ -51,7 +50,7 @@ def __init__( self._ty = ty self._defined = {} - @property # type: ignore + @property # type: ignore[override] def ty(self) -> FunctionType: if self._ty is None: return UnknownFunctionType() @@ -77,11 +76,11 @@ def check_type(self, globals: Globals) -> None: try: self._ty = check_signature(self.defined_at, globals) - except GuppyError as err: + except GuppyError: # We can ignore the error if a custom call checker is provided and the # function may not be used as a higher-order value if self.call_checker is None or self.higher_order_value: - raise err + raise def check_call( self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context diff --git a/guppy/declared.py b/guppy/declared.py index 91b5283e..265ef1a6 100644 --- a/guppy/declared.py +++ b/guppy/declared.py @@ -1,15 +1,14 @@ import ast from dataclasses import dataclass -from typing import Optional from guppy.ast_util import AstNode, has_empty_body -from guppy.checker.core import Globals, Context +from guppy.checker.core import Context, Globals from guppy.checker.expr_checker import check_call, synthesize_call from guppy.checker.func_checker import check_signature -from guppy.compiler.core import CompiledFunction, DFContainer, CompiledGlobals +from guppy.compiler.core import CompiledFunction, CompiledGlobals, DFContainer from guppy.error import GuppyError -from guppy.gtypes import type_to_row, GuppyType -from guppy.hugr.hugr import VNode, Hugr, Node, OutPortV +from guppy.gtypes import GuppyType, type_to_row +from guppy.hugr.hugr import Hugr, Node, OutPortV, VNode from guppy.nodes import GlobalCall @@ -17,7 +16,7 @@ class DeclaredFunction(CompiledFunction): """A user-declared function that compiles to a Hugr function declaration.""" - node: Optional[VNode] = None + node: VNode | None = None @staticmethod def from_ast( diff --git a/guppy/decorator.py b/guppy/decorator.py index 493ddcb2..32f93ab1 100644 --- a/guppy/decorator.py +++ b/guppy/decorator.py @@ -1,19 +1,20 @@ import functools +from collections.abc import Callable from dataclasses import dataclass -from typing import Optional, Union, Callable, Any +from typing import Any from guppy.ast_util import AstNode, has_empty_body from guppy.custom import ( + CustomCallChecker, + CustomCallCompiler, CustomFunction, - OpCompiler, DefaultCallChecker, - CustomCallCompiler, - CustomCallChecker, DefaultCallCompiler, + OpCompiler, ) from guppy.error import GuppyError, pretty_errors from guppy.gtypes import GuppyType -from guppy.hugr import tys, ops +from guppy.hugr import ops, tys from guppy.hugr.hugr import Hugr from guppy.module import GuppyModule, PyFunc, parse_py_func @@ -26,7 +27,7 @@ class _Guppy: """Class for the `@guppy` decorator.""" # The current module - _module: Optional[GuppyModule] + _module: GuppyModule | None def __init__(self) -> None: self._module = None @@ -35,13 +36,11 @@ def set_module(self, module: GuppyModule) -> None: self._module = module @pretty_errors - def __call__( - self, arg: Union[PyFunc, GuppyModule] - ) -> Union[Optional[Hugr], FuncDecorator]: + def __call__(self, arg: PyFunc | GuppyModule) -> Hugr | None | FuncDecorator: """Decorator to annotate Python functions as Guppy code. - Optionally, the `GuppyModule` in which the function should be placed can be passed - to the decorator. + Optionally, the `GuppyModule` in which the function should be placed can be + passed to the decorator. """ if isinstance(arg, GuppyModule): @@ -98,9 +97,7 @@ class NewType(GuppyType): name = _name @staticmethod - def build( - *args: GuppyType, node: Optional[AstNode] = None - ) -> "GuppyType": + def build(*args: GuppyType, node: AstNode | None = None) -> "GuppyType": # At the moment, custom types don't support type arguments. if len(args) > 0: raise GuppyError( @@ -131,8 +128,8 @@ def __str__(self) -> str: def custom( self, module: GuppyModule, - compiler: Optional[CustomCallCompiler] = None, - checker: Optional[CustomCallChecker] = None, + compiler: CustomCallCompiler | None = None, + checker: CustomCallChecker | None = None, higher_order_value: bool = True, name: str = "", ) -> CustomFuncDecorator: @@ -168,7 +165,7 @@ def hugr_op( self, module: GuppyModule, op: ops.OpType, - checker: Optional[CustomCallChecker] = None, + checker: CustomCallChecker | None = None, higher_order_value: bool = True, name: str = "", ) -> CustomFuncDecorator: diff --git a/guppy/error.py b/guppy/error.py index b0ace141..bc6d38e7 100644 --- a/guppy/error.py +++ b/guppy/error.py @@ -2,13 +2,13 @@ import functools import sys import textwrap +from collections.abc import Callable, Sequence from dataclasses import dataclass, field -from typing import Optional, Any, Sequence, Callable, TypeVar, cast - -from guppy.ast_util import AstNode, get_line_offset, get_file, get_source -from guppy.gtypes import GuppyType, FunctionType -from guppy.hugr.hugr import OutPortV, Node +from typing import Any, TypeVar, cast +from guppy.ast_util import AstNode, get_file, get_line_offset, get_source +from guppy.gtypes import FunctionType, GuppyType +from guppy.hugr.hugr import Node, OutPortV # Whether the interpreter should exit when a Guppy error occurs EXIT_ON_ERROR: bool = True @@ -25,12 +25,13 @@ class SourceLoc: file: str line: int col: int - ast_node: Optional[AstNode] + ast_node: AstNode | None @staticmethod def from_ast(node: AstNode) -> "SourceLoc": file, line_offset = get_file(node), get_line_offset(node) - assert file is not None and line_offset is not None + assert file is not None + assert line_offset is not None return SourceLoc(file, line_offset + node.lineno - 1, node.col_offset, node) def __str__(self) -> str: @@ -50,9 +51,9 @@ class GuppyError(Exception): `{1}`, etc. and passing the corresponding AST nodes to `locs_in_msg`.""" raw_msg: str - location: Optional[AstNode] = None + location: AstNode | None = None # The message can also refer to AST locations using format placeholders `{0}`, `{1}` - locs_in_msg: Sequence[Optional[AstNode]] = field(default_factory=list) + locs_in_msg: Sequence[AstNode | None] = field(default_factory=list) def get_msg(self) -> str: """Returns the message associated with this error. @@ -70,14 +71,10 @@ def get_msg(self) -> str: class GuppyTypeError(GuppyError): """Special Guppy exception for type errors.""" - pass - class InternalGuppyError(Exception): """Exception for internal problems during compilation.""" - pass - class UndefinedPort(OutPortV): """Dummy port for undefined variables. @@ -119,7 +116,7 @@ def returns(self) -> GuppyType: raise InternalGuppyError("Tried to access unknown function type") @property - def args_names(self) -> Optional[Sequence[str]]: + def args_names(self) -> Sequence[str] | None: raise InternalGuppyError("Tried to access unknown function type") @@ -130,7 +127,8 @@ def format_source_location( ) -> str: """Creates a pretty banner to show source locations for errors.""" source, line_offset = get_source(loc), get_line_offset(loc) - assert source is not None and line_offset is not None + assert source is not None + assert line_offset is not None source_lines = source.splitlines(keepends=True) end_col_offset = loc.end_col_offset or len(source_lines[loc.lineno]) s = "".join(source_lines[max(loc.lineno - num_lines, 0) : loc.lineno]).rstrip() @@ -160,12 +158,13 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: except GuppyError as err: # Reraise if we're missing a location if not err.location: - raise err + raise loc = err.location file, line_offset = get_file(loc), get_line_offset(loc) - assert file is not None and line_offset is not None + assert file is not None + assert line_offset is not None line = line_offset + loc.lineno - 1 - print( + print( # noqa: T201 f"Guppy compilation failed. Error in file {file}:{line}\n\n" f"{format_source_location(loc)}\n" f"{err.__class__.__name__}: {err.get_msg()}", diff --git a/guppy/gtypes.py b/guppy/gtypes.py index 4d98e29b..a6a15c20 100644 --- a/guppy/gtypes.py +++ b/guppy/gtypes.py @@ -1,7 +1,8 @@ import ast from abc import ABC, abstractmethod +from collections.abc import Sequence from dataclasses import dataclass, field -from typing import Optional, Sequence, TYPE_CHECKING +from typing import TYPE_CHECKING import guppy.hugr.tys as tys from guppy.ast_util import AstNode, set_location_from @@ -20,7 +21,7 @@ class GuppyType(ABC): @staticmethod @abstractmethod - def build(*args: "GuppyType", node: Optional[AstNode] = None) -> "GuppyType": + def build(*args: "GuppyType", node: AstNode | None = None) -> "GuppyType": pass @property @@ -37,7 +38,7 @@ def to_hugr(self) -> tys.SimpleType: class FunctionType(GuppyType): args: Sequence[GuppyType] returns: GuppyType - arg_names: Optional[Sequence[str]] = field( + arg_names: Sequence[str] | None = field( default=None, compare=False, # Argument names are not taken into account for type equality ) @@ -53,10 +54,10 @@ def __str__(self) -> str: return f"({', '.join(str(a) for a in self.args)}) -> {self.returns}" @staticmethod - def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType: + def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType: # Function types cannot be constructed using `build`. The type parsing code # has a special case for function types. - raise NotImplementedError() + raise NotImplementedError def to_hugr(self) -> tys.SimpleType: ins = [t.to_hugr() for t in self.args] @@ -72,7 +73,7 @@ class TupleType(GuppyType): name: str = "tuple" @staticmethod - def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType: + def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType: from guppy.error import GuppyError # TODO: Parse empty tuples via `tuple[()]` @@ -97,10 +98,10 @@ class SumType(GuppyType): element_types: Sequence[GuppyType] @staticmethod - def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType: + def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType: # Sum types cannot be parsed and constructed using `build` since they cannot be # written by the user - raise NotImplementedError() + raise NotImplementedError def __str__(self) -> str: return f"Sum({', '.join(str(e) for e in self.element_types)})" @@ -124,7 +125,7 @@ class NoneType(GuppyType): linear: bool = False @staticmethod - def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType: + def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType: if len(args) > 0: from guppy.error import GuppyError @@ -150,7 +151,7 @@ def __init__(self) -> None: super().__init__([TupleType([]), TupleType([])]) @staticmethod - def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType: + def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType: if len(args) > 0: from guppy.error import GuppyError @@ -161,7 +162,7 @@ def __str__(self) -> str: return "bool" -def _lookup_type(node: AstNode, globals: "Globals") -> Optional[type[GuppyType]]: +def _lookup_type(node: AstNode, globals: "Globals") -> type[GuppyType] | None: if isinstance(node, ast.Name) and node.id in globals.types: return globals.types[node.id] if isinstance(node, ast.Constant) and node.value is None: diff --git a/guppy/hugr/hugr.py b/guppy/hugr/hugr.py index 96a647f1..41e1f2bf 100644 --- a/guppy/hugr/hugr.py +++ b/guppy/hugr/hugr.py @@ -1,20 +1,21 @@ import itertools -import networkx # type: ignore - from abc import ABC, abstractmethod +from collections.abc import Iterator, Sequence from contextlib import contextmanager -from typing import Optional, Iterator, Tuple, Any, Sequence -from dataclasses import field, dataclass +from dataclasses import dataclass, field +from typing import Any, Optional + +import networkx as nx # type: ignore[import] import guppy.hugr.ops as ops import guppy.hugr.raw as raw from guppy.gtypes import ( - GuppyType, - TupleType, FunctionType, + GuppyType, SumType, - type_to_row, + TupleType, row_to_type, + type_to_row, ) from guppy.hugr import val @@ -27,20 +28,16 @@ class Port(ABC): """Base class for ports on nodes.""" node: "Node" - offset: Optional[PortOffset] + offset: PortOffset | None class InPort(Port, ABC): """Base class for a port that incoming wires connect to.""" - pass - class OutPort(Port, ABC): """Base class for a port that outgoing wires come from.""" - pass - @dataclass(frozen=True) class InPortV(InPort): @@ -69,8 +66,6 @@ class InPortCF(InPort): class OutPortCF(OutPort): """A control-flow output port.""" - pass - Edge = tuple[OutPort, InPort] @@ -93,23 +88,19 @@ class Node(ABC): @abstractmethod def num_in_ports(self) -> int: """The number of input ports on this node.""" - pass @property @abstractmethod def num_out_ports(self) -> int: """The number of output ports on this node.""" - pass @abstractmethod - def in_port(self, offset: Optional[PortOffset]) -> InPort: + def in_port(self, offset: PortOffset | None) -> InPort: """Returns the input port at the given offset.""" - pass @abstractmethod - def out_port(self, offset: Optional[PortOffset]) -> OutPort: + def out_port(self, offset: PortOffset | None) -> OutPort: """Returns the output port at the given offset.""" - pass @abstractmethod def update_op(self) -> None: @@ -117,7 +108,6 @@ def update_op(self) -> None: This should be called before serialisation. """ - pass @property def in_ports(self) -> Iterator[InPort]: @@ -159,14 +149,16 @@ def add_out_port(self, ty: GuppyType) -> OutPortV: self.out_port_types.append(ty) return p - def in_port(self, offset: Optional[PortOffset]) -> InPortV: + def in_port(self, offset: PortOffset | None) -> InPortV: """Returns the input port at the given offset.""" - assert offset is not None and offset < self.num_in_ports + assert offset is not None + assert offset < self.num_in_ports return InPortV(self, offset, self.in_port_types[offset]) - def out_port(self, offset: Optional[PortOffset]) -> OutPortV: + def out_port(self, offset: PortOffset | None) -> OutPortV: """Returns the output port at the given offset.""" - assert offset is not None and offset < self.num_out_ports + assert offset is not None + assert offset < self.num_out_ports return OutPortV(self, offset, self.out_port_types[offset]) @property @@ -216,11 +208,11 @@ def add_out_port(self) -> OutPortCF: self._num_out_ports += 1 return p - def in_port(self, offset: Optional[PortOffset]) -> InPortCF: + def in_port(self, offset: PortOffset | None) -> InPortCF: assert offset is None return InPortCF(self) - def out_port(self, offset: Optional[PortOffset]) -> OutPortCF: + def out_port(self, offset: PortOffset | None) -> OutPortCF: """Returns the output port at the given offset.""" assert offset is not None assert offset < self.num_out_ports @@ -246,7 +238,8 @@ def update_op(self) -> None: Feeds type information from the signature of the contained dataflow graph to the operation class to. This function must be called before serialisation. """ - assert self.input_child is not None and self.output_child is not None + assert self.input_child is not None + assert self.output_child is not None # Input and output node may have extra order edges connected, so we filter # `None`s here ins = [ty.to_hugr() for ty in self.input_child.out_port_types] @@ -272,15 +265,15 @@ class Hugr: name: str root: VNode - _graph: networkx.MultiDiGraph # TODO: We probably don't need networkx. + _graph: nx.MultiDiGraph # TODO: We probably don't need networkx. _children: dict[NodeIdx, list[Node]] - _default_parent: Optional[Node] + _default_parent: Node | None - def __init__(self, name: Optional[str] = None) -> None: + def __init__(self, name: str | None = None) -> None: """Creates a new Hugr.""" self.name = name or "Unnamed" self._default_parent = None - self._graph = networkx.MultiDiGraph() + self._graph = nx.MultiDiGraph() self._children = {-1: []} self.root = self.add_node( op=ops.Module(), meta_data={"name": name}, parent=None @@ -294,7 +287,7 @@ def parent(self, parent: Node) -> Iterator[None]: yield self._default_parent = old_default - def _insert_node(self, node: Node, inputs: Optional[list[OutPortV]] = None) -> None: + def _insert_node(self, node: Node, inputs: list[OutPortV] | None = None) -> None: """Helper method to insert a node into the graph datastructure.""" self._graph.add_node(node.idx, data=node) self._children[node.idx] = [] @@ -306,11 +299,11 @@ def _insert_node(self, node: Node, inputs: Optional[list[OutPortV]] = None) -> N def add_node( self, op: ops.OpType, - input_types: Optional[TypeList] = None, - output_types: Optional[TypeList] = None, - parent: Optional[Node] = None, - inputs: Optional[list[OutPortV]] = None, - meta_data: Optional[dict[str, Any]] = None, + input_types: TypeList | None = None, + output_types: TypeList | None = None, + parent: Node | None = None, + inputs: list[OutPortV] | None = None, + meta_data: dict[str, Any] | None = None, ) -> VNode: """Helper method to add a generic value node to the graph.""" input_types = input_types or [] @@ -330,11 +323,11 @@ def add_node( def _add_dfg_node( self, op: ops.OpType, - input_types: Optional[TypeList] = None, - output_types: Optional[TypeList] = None, - parent: Optional[Node] = None, - inputs: Optional[list[OutPortV]] = None, - meta_data: Optional[dict[str, Any]] = None, + input_types: TypeList | None = None, + output_types: TypeList | None = None, + parent: Node | None = None, + inputs: list[OutPortV] | None = None, + meta_data: dict[str, Any] | None = None, ) -> DFContainingVNode: """Helper method to add a generic dataflow containing value node to the graph.""" @@ -358,7 +351,7 @@ def set_root_name(self, name: str) -> VNode: return self.root def add_constant( - self, value: val.Value, ty: GuppyType, parent: Optional[Node] = None + self, value: val.Value, ty: GuppyType, parent: Node | None = None ) -> VNode: """Adds a constant node holding a given value to the graph.""" return self.add_node( @@ -366,7 +359,7 @@ def add_constant( ) def add_input( - self, output_tys: Optional[TypeList] = None, parent: Optional[Node] = None + self, output_tys: TypeList | None = None, parent: Node | None = None ) -> VNode: """Adds an `Input` node to the graph.""" parent = parent or self._default_parent @@ -376,7 +369,7 @@ def add_input( return node def add_input_with_ports( - self, output_tys: Sequence[GuppyType], parent: Optional[Node] = None + self, output_tys: Sequence[GuppyType], parent: Node | None = None ) -> tuple[VNode, list[OutPortV]]: """Adds an `Input` node to the graph.""" node = self.add_input(None, parent) @@ -385,9 +378,9 @@ def add_input_with_ports( def add_output( self, - inputs: Optional[list[OutPortV]] = None, - input_tys: Optional[TypeList] = None, - parent: Optional[Node] = None, + inputs: list[OutPortV] | None = None, + input_tys: TypeList | None = None, + parent: Node | None = None, ) -> VNode: """Adds an `Output` node to the graph.""" node = self.add_node(ops.Output(), input_tys, [], parent, inputs) @@ -395,7 +388,7 @@ def add_output( parent.output_child = node return node - def add_block(self, parent: Optional[Node], num_successors: int = 0) -> BlockNode: + def add_block(self, parent: Node | None, num_successors: int = 0) -> BlockNode: """Adds a `Block` node to the graph.""" node = BlockNode( idx=self._graph.number_of_nodes(), op=ops.DFB(), parent=parent, meta_data={} @@ -433,27 +426,27 @@ def add_conditional( self, cond_input: OutPortV, inputs: list[OutPortV], - parent: Optional[Node] = None, + parent: Node | None = None, ) -> VNode: """Adds a `Conditional` node to the graph.""" - inputs = [cond_input] + inputs + inputs = [cond_input, *inputs] return self.add_node(ops.Conditional(), None, None, parent, inputs) def add_tail_loop( - self, inputs: list[OutPortV], parent: Optional[Node] = None + self, inputs: list[OutPortV], parent: Node | None = None ) -> DFContainingVNode: """Adds a `TailLoop` node to the graph.""" return self._add_dfg_node(ops.TailLoop(), None, None, parent, inputs) def add_make_tuple( - self, inputs: list[OutPortV], parent: Optional[Node] = None + self, inputs: list[OutPortV], parent: Node | None = None ) -> VNode: """Adds a `MakeTuple` node to the graph.""" ty = TupleType([port.ty for port in inputs]) return self.add_node(ops.MakeTuple(), None, [ty], parent, inputs) def add_unpack_tuple( - self, input_tuple: OutPortV, parent: Optional[Node] = None + self, input_tuple: OutPortV, parent: Node | None = None ) -> VNode: """Adds an `UnpackTuple` node to the graph.""" assert isinstance(input_tuple.ty, TupleType) @@ -466,7 +459,7 @@ def add_unpack_tuple( ) def add_tag( - self, variants: TypeList, tag: int, inp: OutPortV, parent: Optional[Node] = None + self, variants: TypeList, tag: int, inp: OutPortV, parent: Node | None = None ) -> VNode: """Adds a `Tag` node to the graph.""" types = [ty.to_hugr() for ty in variants] @@ -476,7 +469,7 @@ def add_tag( ) def add_load_constant( - self, const_port: OutPortV, parent: Optional[Node] = None + self, const_port: OutPortV, parent: Node | None = None ) -> VNode: """Adds a `LoadConstant` node to the graph.""" return self.add_node( @@ -488,7 +481,7 @@ def add_load_constant( ) def add_call( - self, def_port: OutPortV, args: list[OutPortV], parent: Optional[Node] = None + self, def_port: OutPortV, args: list[OutPortV], parent: Node | None = None ) -> VNode: """Adds a `Call` node to the graph.""" assert isinstance(def_port.ty, FunctionType) @@ -497,11 +490,11 @@ def add_call( None, list(type_to_row(def_port.ty.returns)), parent, - args + [def_port], + [*args, def_port], ) def add_indirect_call( - self, fun_port: OutPortV, args: list[OutPortV], parent: Optional[Node] = None + self, fun_port: OutPortV, args: list[OutPortV], parent: Node | None = None ) -> VNode: """Adds an `IndirectCall` node to the graph.""" assert isinstance(fun_port.ty, FunctionType) @@ -510,11 +503,11 @@ def add_indirect_call( None, list(type_to_row(fun_port.ty.returns)), parent, - [fun_port] + args, + [fun_port, *args], ) def add_partial( - self, def_port: OutPortV, args: list[OutPortV], parent: Optional[Node] = None + self, def_port: OutPortV, args: list[OutPortV], parent: Node | None = None ) -> VNode: """Adds a `Partial` evaluation node to the graph.""" assert isinstance(def_port.ty, FunctionType) @@ -528,11 +521,11 @@ def add_partial( else None, ) return self.add_node( - ops.DummyOp(name="partial"), None, [new_ty], parent, args + [def_port] + ops.DummyOp(name="partial"), None, [new_ty], parent, [*args, def_port] ) def add_def( - self, fun_ty: FunctionType, parent: Optional[Node], name: str + self, fun_ty: FunctionType, parent: Node | None, name: str ) -> DFContainingVNode: """Adds a `Def` node to the graph.""" return self._add_dfg_node(ops.FuncDefn(name=name), [], [fun_ty], parent, None) @@ -544,13 +537,12 @@ def add_declare(self, fun_ty: FunctionType, parent: Node, name: str) -> VNode: def add_edge(self, src_port: OutPort, tgt_port: InPort) -> None: """Adds an edge between two ports.""" if isinstance(src_port, OutPortV) or isinstance(tgt_port, InPortV): - assert ( - isinstance(src_port, OutPortV) - and isinstance(tgt_port, InPortV) - and src_port.ty == tgt_port.ty - ) + assert isinstance(src_port, OutPortV) + assert isinstance(tgt_port, InPortV) + assert src_port.ty == tgt_port.ty else: - assert isinstance(src_port, OutPortCF) and isinstance(tgt_port, InPortCF) + assert isinstance(src_port, OutPortCF) + assert isinstance(tgt_port, InPortCF) self._graph.add_edge( src_port.node.idx, tgt_port.node.idx, key=(src_port.offset, tgt_port.offset) ) @@ -565,7 +557,7 @@ def nodes(self) -> Iterator[Node]: def get_node(self, idx: int) -> Node: """Returns the node corresponding to given index.""" - return self._graph.nodes[idx]["data"] # type: ignore + return self._graph.nodes[idx]["data"] # type: ignore[no-any-return] def children(self, node: Node) -> list[Node]: """Returns list of a node's immediate children in the hierarchy.""" @@ -612,18 +604,18 @@ def out_edges(self, port: OutPort) -> Iterator[Edge]: def order_successors(self, node: Node) -> Iterator[Node]: """Returns an iterator over all nodes that this node connects to via an order edge.""" - for src, tgt, key in self._graph.out_edges(node.idx, keys=True): + for _src, tgt, key in self._graph.out_edges(node.idx, keys=True): if key == ORDER_EDGE_KEY: yield tgt def order_predecessors(self, node: Node) -> Iterator[Node]: """Returns an iterator over all nodes that are connected to this node via an order edge.""" - for src, tgt, key in self._graph.in_edges(node.idx, keys=True): + for src, _tgt, key in self._graph.in_edges(node.idx, keys=True): if key == ORDER_EDGE_KEY: yield src - def _to_edge(self, src: int, tgt: int, key: Tuple[int, int]) -> Edge: + def _to_edge(self, src: int, tgt: int, key: tuple[int, int]) -> Edge: src_node = self.get_node(src) tgt_node = self.get_node(tgt) return src_node.out_port(key[0]), tgt_node.in_port(key[1]) @@ -678,7 +670,7 @@ def insert_order_edges(self) -> "Hugr": # Special case: Call ops for functions without any arguments are # only connected to the top-level def/declare and also need an # order edge - if isinstance(n.op, ops.Call) and n.num_in_ports == 1: + if isinstance(n.op, ops.Call) and n.num_in_ports == 1: # noqa: SIM114 assert n.parent.input_child is not None self.add_order_edge(n.parent.input_child, n) # Special case: Load constant ops always need an order edge diff --git a/guppy/hugr/ops.py b/guppy/hugr/ops.py index 81a3fb7c..5af8a6d3 100644 --- a/guppy/hugr/ops.py +++ b/guppy/hugr/ops.py @@ -1,18 +1,20 @@ import inspect import sys from abc import ABC -from typing import Annotated, Literal, Union, Optional, Any -from pydantic import Field, BaseModel +from typing import Annotated, Any, Literal + +from pydantic import BaseModel, Field + +import guppy.hugr.tys as tys from .tys import ( - TypeRow, - SimpleType, - PolyFuncType, - FunctionType, ExtensionId, ExtensionSet, + FunctionType, + PolyFuncType, + SimpleType, + TypeRow, ) -import guppy.hugr.tys as tys from .val import Value NodeID = int @@ -28,11 +30,9 @@ class BaseOp(ABC, BaseModel): def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: """Hook to insert type information from the input and output ports into the op""" - pass def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: """Hook to insert type information from a child dataflow graph""" - pass def display_name(self) -> str: """Name of the op for visualisation""" @@ -152,7 +152,7 @@ class Exit(BasicBlock): cfg_outputs: TypeRow -BasicBlockOp = Annotated[Union[DFB, Exit], Field(discriminator="block")] +BasicBlockOp = Annotated[DFB | Exit, Field(discriminator="block")] # --------------------------------------------- @@ -323,7 +323,7 @@ def insert_port_types(self, inputs: TypeRow, outputs: TypeRow) -> None: ) -ControlFlowOp = Union[Conditional, TailLoop, CFG] +ControlFlowOp = Conditional | TailLoop | CFG # ----------------------------------------- @@ -338,7 +338,7 @@ class CustomOp(LeafOp): lop: Literal["CustomOp"] = "CustomOp" extension: ExtensionId op_name: str - signature: Optional[tys.FunctionType] = None + signature: tys.FunctionType | None = None description: str = "" args: list[tys.TypeArgUnion] = Field(default_factory=list) @@ -353,77 +353,66 @@ class H(LeafOp): """A Hadamard gate.""" lop: Literal["H"] = "H" - pass class T(LeafOp): """A T gate.""" lop: Literal["T"] = "T" - pass class S(LeafOp): """An S gate.""" lop: Literal["S"] = "S" - pass class X(LeafOp): """A Pauli X gate.""" lop: Literal["X"] = "X" - pass class Y(LeafOp): """A Pauli Y gate.""" lop: Literal["Y"] = "Y" - pass class Z(LeafOp): """A Pauli Z gate.""" lop: Literal["Z"] = "Z" - pass class Tadj(LeafOp): """An adjoint T gate.""" lop: Literal["Tadj"] = "Tadj" - pass class Sadj(LeafOp): """An adjoint S gate.""" lop: Literal["Sadj"] = "Sadj" - pass class CX(LeafOp): """A controlled X gate.""" lop: Literal["CX"] = "CX" - pass class ZZMax(LeafOp): """A maximally entangling ZZ phase gate.""" lop: Literal["ZZMax"] = "ZZMax" - pass class Reset(LeafOp): """A qubit reset operation.""" lop: Literal["Reset"] = "Reset" - pass class Noop(LeafOp): @@ -443,21 +432,18 @@ class Measure(LeafOp): """A qubit measurement operation.""" lop: Literal["Measure"] = "Measure" - pass class RzF64(LeafOp): """A rotation of a qubit about the Pauli Z axis by an input float angle.""" lop: Literal["RzF64"] = "RzF64" - pass class Xor(LeafOp): """A bitwise XOR operation.""" lop: Literal["Xor"] = "Xor" - pass class MakeTuple(LeafOp): @@ -500,54 +486,54 @@ class Tag(LeafOp): LeafOpUnion = Annotated[ - Union[ - CustomOp, - H, - S, - T, - X, - Y, - Z, - Tadj, - Sadj, - CX, - ZZMax, - Reset, - Noop, - Measure, - RzF64, - Xor, - MakeTuple, - UnpackTuple, - MakeNewType, - Tag, - ], + ( + CustomOp + | H + | S + | T + | X + | Y + | Z + | Tadj + | Sadj + | CX + | ZZMax + | Reset + | Noop + | Measure + | RzF64 + | Xor + | MakeTuple + | UnpackTuple + | MakeNewType + | Tag + ), Field(discriminator="lop"), ] OpType = Annotated[ - Union[ - Module, - BasicBlock, - Case, - Module, - FuncDefn, - FuncDecl, - Const, - DummyOp, - BasicBlockOp, - Conditional, - TailLoop, - CFG, - Input, - Output, - Call, - CallIndirect, - LoadConstant, - LeafOpUnion, - DFG, - ], + ( + Module + | BasicBlock + | Case + | Module + | FuncDefn + | FuncDecl + | Const + | DummyOp + | BasicBlockOp + | Conditional + | TailLoop + | CFG + | Input + | Output + | Call + | CallIndirect + | LoadConstant + | LeafOpUnion + | DFG + ), Field(discriminator="op"), ] @@ -562,10 +548,10 @@ class OpDef(BaseOp, allow_population_by_field_name=True): name: str # Unique identifier of the operation. description: str # Human readable description of the operation. - inputs: list[tuple[Optional[str], SimpleType]] - outputs: list[tuple[Optional[str], SimpleType]] + inputs: list[tuple[str | None, SimpleType]] + outputs: list[tuple[str | None, SimpleType]] misc: dict[str, Any] # Miscellaneous data associated with the operation. - def_: Optional[str] = Field( + def_: str | None = Field( ..., alias="def" ) # (YAML?)-encoded definition of the operation. extension_reqs: ExtensionSet # Resources required to execute this operation. diff --git a/guppy/hugr/raw.py b/guppy/hugr/raw.py index a92e4469..a3249a3c 100644 --- a/guppy/hugr/raw.py +++ b/guppy/hugr/raw.py @@ -1,12 +1,11 @@ -from typing import Literal, Optional +from typing import Literal import ormsgpack - from pydantic import BaseModel -from guppy.hugr.ops import NodeID, OpType +from guppy.hugr.ops import NodeID, OpType -Port = tuple[NodeID, Optional[int]] # (node, offset) +Port = tuple[NodeID, int | None] # (node, offset) Edge = tuple[Port, Port] diff --git a/guppy/hugr/tys.py b/guppy/hugr/tys.py index ebc032d6..fa01db41 100644 --- a/guppy/hugr/tys.py +++ b/guppy/hugr/tys.py @@ -2,9 +2,9 @@ import sys from abc import ABC from enum import Enum -from typing import Literal, Union, Annotated, Optional -from pydantic import Field, BaseModel +from typing import Annotated, Literal +from pydantic import BaseModel, Field ExtensionId = str ExtensionSet = list[ # TODO: Set not supported by MessagePack. Is list correct here? @@ -24,7 +24,7 @@ class TypeParam(BaseModel): class BoundedNatParam(BaseModel): tp: Literal["BoundedNat"] = "BoundedNat" - bound: Optional[int] + bound: int | None class OpaqueParam(BaseModel): @@ -43,7 +43,7 @@ class TupleParam(BaseModel): TypeParamUnion = Annotated[ - Union[TypeParam, BoundedNatParam, OpaqueParam, ListParam, TupleParam], + TypeParam | BoundedNatParam | OpaqueParam | ListParam | TupleParam, Field(discriminator="tp"), ] @@ -84,7 +84,7 @@ class ExtensionsArg(BaseModel): TypeArgUnion = Annotated[ - Union[TypeArg, BoundedNatArg, OpaqueArg, SequenceArg, ExtensionsArg], + TypeArg | BoundedNatArg | OpaqueArg | SequenceArg | ExtensionsArg, Field(discriminator="tya"), ] @@ -231,9 +231,17 @@ class Qubit(BaseModel): SimpleType = Annotated[ - Union[ - Qubit, Variable, Int, F64, String, PolyFuncType, List, Array, Tuple, Sum, Opaque - ], + Qubit + | Variable + | Int + | F64 + | String + | PolyFuncType + | List + | Array + | Tuple + | Sum + | Opaque, Field(discriminator="t"), ] diff --git a/guppy/hugr/val.py b/guppy/hugr/val.py index 493ec014..b8b5eabb 100644 --- a/guppy/hugr/val.py +++ b/guppy/hugr/val.py @@ -1,10 +1,9 @@ import inspect import sys -from typing import Literal, Any, Annotated, Union +from typing import Annotated, Any, Literal from pydantic import BaseModel, Field - CustomConst = Any # TODO @@ -40,9 +39,7 @@ class Sum(BaseModel): value: "Value" -Value = Annotated[ - Union[ExtensionVal, FunctionVal, Tuple, Sum], Field(discriminator="v") -] +Value = Annotated[ExtensionVal | FunctionVal | Tuple | Sum, Field(discriminator="v")] # Now that all classes are defined, we need to update the ForwardRefs in all type diff --git a/guppy/hugr/visualise.py b/guppy/hugr/visualise.py index eba0bca6..450b5293 100644 --- a/guppy/hugr/visualise.py +++ b/guppy/hugr/visualise.py @@ -1,16 +1,17 @@ """Visualise HUGR using graphviz.""" import ast +from collections.abc import Iterable +from typing import TYPE_CHECKING -import graphviz as gv # type: ignore -from typing import Iterable, TYPE_CHECKING +import graphviz as gv # type: ignore[import] from guppy.cfg.analysis import ( - LivenessDomain, DefAssignmentDomain, + LivenessDomain, MaybeAssignmentDomain, ) from guppy.cfg.bb import BB -from guppy.hugr.hugr import InPort, OutPort, Node, Hugr, OutPortV +from guppy.hugr.hugr import Hugr, InPort, Node, OutPort, OutPortV if TYPE_CHECKING: from guppy.cfg.cfg import CFG @@ -55,18 +56,23 @@ _FONTFACE = "monospace" _HTML_LABEL_TEMPLATE = """ - -{inputs_row} - - - -{outputs_row} +
- - - - -
{node_label}{node_data}
-
+ {inputs_row} + + + + {outputs_row}
+ + + + +
+ + {node_label}{node_data} + +
+
""" diff --git a/guppy/module.py b/guppy/module.py index e07a15ac..3061bafe 100644 --- a/guppy/module.py +++ b/guppy/module.py @@ -1,15 +1,15 @@ import ast import inspect import textwrap +from collections.abc import Callable from types import ModuleType +from typing import Any, Union -from typing import Callable, Any, Optional, Union - -from guppy.ast_util import annotate_location, AstNode +from guppy.ast_util import AstNode, annotate_location from guppy.checker.core import Globals, qualified_name from guppy.checker.func_checker import DefinedFunction, check_global_func_def from guppy.compiler.core import CompiledGlobals -from guppy.compiler.func_compiler import compile_global_func_def, CompiledFunctionDef +from guppy.compiler.func_compiler import CompiledFunctionDef, compile_global_func_def from guppy.custom import CustomFunction from guppy.declared import DeclaredFunction from guppy.error import GuppyError, pretty_errors @@ -44,7 +44,7 @@ class GuppyModule: # When `_instance_buffer` is not `None`, then all registered functions will be # buffered in this list. They only get properly registered, once # `_register_buffered_instance_funcs` is called. This way, we can associate - _instance_func_buffer: Optional[dict[str, Union[PyFunc, CustomFunction]]] + _instance_func_buffer: dict[str, PyFunc | CustomFunction] | None def __init__(self, name: str, import_builtins: bool = True): self.name = name @@ -88,7 +88,7 @@ def load(self, m: Union[ModuleType, "GuppyModule"]) -> None: self.load(val) def register_func_def( - self, f: PyFunc, instance: Optional[type[GuppyType]] = None + self, f: PyFunc, instance: type[GuppyType] | None = None ) -> None: """Registers a Python function definition as belonging to this Guppy module.""" self._check_not_yet_compiled() @@ -110,7 +110,7 @@ def register_func_decl(self, f: PyFunc) -> None: self._func_decls[func_ast.name] = func_ast def register_custom_func( - self, func: CustomFunction, instance: Optional[type[GuppyType]] = None + self, func: CustomFunction, instance: type[GuppyType] | None = None ) -> None: """Registers a custom function as belonging to this Guppy module.""" self._check_not_yet_compiled() @@ -130,7 +130,7 @@ def _register_buffered_instance_funcs(self, instance: type[GuppyType]) -> None: assert self._instance_func_buffer is not None buffer = self._instance_func_buffer self._instance_func_buffer = None - for name, f in buffer.items(): + for f in buffer.values(): if isinstance(f, CustomFunction): self.register_custom_func(f, instance) else: @@ -141,7 +141,7 @@ def compiled(self) -> bool: return self._compiled @pretty_errors - def compile(self) -> Optional[Hugr]: + def compile(self) -> Hugr | None: """Compiles the module and returns the final Hugr.""" if self.compiled: raise GuppyError("Module has already been compiled") @@ -200,7 +200,7 @@ def _check_not_yet_compiled(self) -> None: if self._compiled: raise GuppyError(f"The module `{self.name}` has already been compiled") - def _check_name_available(self, name: str, node: Optional[AstNode]) -> None: + def _check_name_available(self, name: str, node: AstNode | None) -> None: if name in self._func_defs or name in self._custom_funcs: raise GuppyError( f"Module `{self.name}` already contains a function named `{name}`", diff --git a/guppy/nodes.py b/guppy/nodes.py index dfd02349..22730db5 100644 --- a/guppy/nodes.py +++ b/guppy/nodes.py @@ -1,14 +1,15 @@ """Custom AST nodes used by Guppy""" import ast -from typing import TYPE_CHECKING, Any, Mapping +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any from guppy.gtypes import FunctionType if TYPE_CHECKING: from guppy.cfg.cfg import CFG - from guppy.checker.core import Variable, CallableVariable from guppy.checker.cfg_checker import CheckedCFG + from guppy.checker.core import CallableVariable, Variable class LocalName(ast.expr): diff --git a/guppy/prelude/_internal.py b/guppy/prelude/_internal.py index d8a64c32..6262b1e8 100644 --- a/guppy/prelude/_internal.py +++ b/guppy/prelude/_internal.py @@ -1,24 +1,23 @@ import ast -from typing import Optional, Literal +from typing import Literal from pydantic import BaseModel -from guppy.ast_util import with_type, AstNode, with_loc, get_type -from guppy.checker.core import Context, CallableVariable +from guppy.ast_util import AstNode, get_type, with_loc, with_type +from guppy.checker.core import CallableVariable, Context from guppy.checker.expr_checker import ExprSynthesizer, check_num_args from guppy.custom import ( CustomCallChecker, - DefaultCallChecker, - CustomFunction, CustomCallCompiler, + CustomFunction, + DefaultCallChecker, ) -from guppy.error import GuppyTypeError, GuppyError -from guppy.gtypes import GuppyType, FunctionType, BoolType +from guppy.error import GuppyError, GuppyTypeError +from guppy.gtypes import BoolType, FunctionType, GuppyType from guppy.hugr import ops, tys, val from guppy.hugr.hugr import OutPortV from guppy.nodes import GlobalCall - INT_WIDTH = 6 # 2^6 = 64 bit @@ -68,7 +67,7 @@ def float_value(f: float) -> val.Value: return val.ExtensionVal(c=(ConstF64(value=f),)) -def logic_op(op_name: str, args: Optional[list[tys.TypeArgUnion]] = None) -> ops.OpType: +def logic_op(op_name: str, args: list[tys.TypeArgUnion] | None = None) -> ops.OpType: """Utility method to create Hugr logic ops.""" return ops.CustomOp(extension="logic", op_name=op_name, args=args or []) @@ -110,7 +109,7 @@ class ReversingChecker(CustomCallChecker): base_checker: CustomCallChecker - def __init__(self, base_checker: Optional[CustomCallChecker] = None): + def __init__(self, base_checker: CustomCallChecker | None = None): self.base_checker = base_checker or DefaultCallChecker() def _setup(self, ctx: Context, node: AstNode, func: CustomFunction) -> None: @@ -209,7 +208,7 @@ class IntTruedivCompiler(CustomCallCompiler): """Compiler for the `int.__truediv__` method.""" def compile(self, args: list[OutPortV]) -> list[OutPortV]: - from .builtins import Int, Float + from .builtins import Float, Int # Compile `truediv` using float arithmetic [left, right] = args diff --git a/guppy/prelude/builtins.py b/guppy/prelude/builtins.py index 25dfc741..57c22920 100644 --- a/guppy/prelude/builtins.py +++ b/guppy/prelude/builtins.py @@ -2,30 +2,29 @@ # mypy: disable-error-code="empty-body, misc, override, no-untyped-def" -from guppy.custom import NoopCompiler, DefaultCallChecker +from guppy.custom import DefaultCallChecker, NoopCompiler from guppy.decorator import guppy from guppy.gtypes import BoolType -from guppy.hugr import tys, ops +from guppy.hugr import ops, tys from guppy.module import GuppyModule from guppy.prelude._internal import ( - logic_op, - int_op, - hugr_int_type, - hugr_float_type, - float_op, + CallableChecker, CoercingChecker, - ReversingChecker, - IntTruedivCompiler, + DunderChecker, FloatBoolCompiler, FloatDivmodCompiler, FloatFloordivCompiler, FloatModCompiler, - DunderChecker, - CallableChecker, + IntTruedivCompiler, + ReversingChecker, UnsupportedChecker, + float_op, + hugr_float_type, + hugr_int_type, + int_op, + logic_op, ) - builtins = GuppyModule("builtins", import_builtins=False) diff --git a/guppy/prelude/quantum.py b/guppy/prelude/quantum.py index 3b40af88..dcfc794b 100644 --- a/guppy/prelude/quantum.py +++ b/guppy/prelude/quantum.py @@ -3,11 +3,10 @@ # mypy: disable-error-code=empty-body from guppy.decorator import guppy -from guppy.hugr import tys, ops +from guppy.hugr import ops, tys from guppy.hugr.tys import TypeBound from guppy.module import GuppyModule - quantum = GuppyModule("quantum") diff --git a/pyproject.toml b/pyproject.toml index df4fc7d7..6d7333c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,12 +16,13 @@ dependencies = [ [project.optional-dependencies] dev = [ - "pip-tools", # provides pip-compile "build", - "pytest", + "maturin>=1.1,<2.0", "mypy==1.3.0", + "pip-tools", # provides pip-compile + "pre-commit", + "pytest", "ruff", - "maturin>=1.1,<2.0", ] [tool.setuptools] diff --git a/requirements.txt b/requirements.txt index d64d15c3..846d2fa0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,19 +1,33 @@ # -# This file is autogenerated by pip-compile with Python 3.10 +# This file is autogenerated by pip-compile with Python 3.12 # by the following command: # -# pip-compile --extra=dev --no-annotate --output-file=requirements.txt pyproject.toml +# pip-compile --extra=dev --no-annotate --output-file=requirements.txt --strip-extras pyproject.toml # -build==0.10.0 -exceptiongroup==1.1.1 +build==1.0.3 +cfgv==3.4.0 +click==8.1.7 +distlib==0.3.7 +filelock==3.13.1 graphviz==0.20.1 +identify==2.5.32 iniconfig==2.0.0 +maturin==1.3.2 +mypy==1.3.0 +mypy-extensions==1.0.0 networkx==3.0 -ormsgpack==1.2.5 -packaging==23.1 -pluggy==1.0.0 +nodeenv==1.8.0 +ormsgpack==1.4.1 +packaging==23.2 +pip-tools==7.3.0 +platformdirs==4.1.0 +pluggy==1.3.0 +pre-commit==3.5.0 pydantic==1.10.8 pyproject-hooks==1.0.0 -pytest==7.3.1 -tomli==2.0.1 -typing-extensions==4.6.2 +pytest==7.4.3 +pyyaml==6.0.1 +ruff==0.1.7 +typing-extensions==4.8.0 +virtualenv==20.25.0 +wheel==0.42.0 diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 00000000..8c9e618a --- /dev/null +++ b/ruff.toml @@ -0,0 +1,83 @@ +# See https://docs.astral.sh/ruff/rules/ +target-version = "py310" + +line-length = 88 + +exclude = ["tests/error"] + +select = [ + "F", # pyflakes + "E", # pycodestyle Errors + "W", # pycodestyle Warnings + + # "A", # flake8-builtins + # "ANN", # flake8-annotations + # "ARG", # flake8-unused-arguments + "B", # flake8-Bugbear + "BLE", # flake8-blind-except + "C4", # flake8-comprehensions + # "C90", # mccabe + # "COM", # flake8-commas + # "CPY", # flake8-copyright + # "D", # pydocstyle + "EM", # flake8-errmsg + # "ERA", # eradicate + "EXE", # flake8-executable + "FA", # flake8-future-annotations + # "FBT", # flake8-boolean-trap + # "FIX", # flake8-fixme + "FLY", # flynt + # "FURB", # refurb + "G", # flake8-logging-format + "I", # isort + "ICN", # flake8-import-conventions + "INP", # flake8-no-pep420 + "INT", # flake8-gettext + # "ISC", # flake8-implicit-str-concat + # "LOG", # flake8-logging + # "N", # pep8-Naming + "NPY", # NumPy-specific + "PERF", # Perflint + "PGH", # pygrep-hooks + "PIE", # flake8-pie + # "PL", # pylint + "PT", # flake8-pytest-style + "PTH", # flake8-use-pathlib + "PYI", # flake8-pyi + "Q", # flake8-quotes + # "RET", # flake8-return + "RSE", # flake8-raise + "RUF", # Ruff-specific + "S", # flake8-bandit (Security) + "SIM", # flake8-simplify + # "SLF", # flake8-self + "SLOT", # flake8-slots + "T10", # flake8-debugger + "T20", # flake8-print + "TCH", # flake8-type-checking + # "TD", # flake8-todos + "TID", # flake8-tidy-imports + "TRY", # tryceratops + "UP", # pyupgrade + "YTT", # flake8-2020 +] + +ignore = [ + "COM812", "ISC001", # conflicting with the formatter + "EM101", "EM102", # Exception must not use a string (an f-string) literal, assign to variable first + "S101", # Use of `assert` detected + "TRY003", # Avoid specifying long messages outside the exception class + "B905", # `zip()` without an explicit `strict=` parameter +] + +[per-file-ignores] +"guppy/ast_util.py" = ["B009", "B010"] +"guppy/decorator.py" = ["B010"] +"tests/integration/*" = ["F841"] +"tests/{hugr,integration}/*" = ["B", "FBT", "SIM", "I"] + +# [pydocstyle] +# convention = "google" + +# [flake8-copyright] +# author = "Quantinuum" diff --git a/tests/conftest.py b/tests/conftest.py index 8d1eedc3..f5d7a394 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,18 @@ -import pytest -from pathlib import Path import argparse +from pathlib import Path + def pytest_addoption(parser): def dir_path(s): path = Path(s) if not path.exists() or path.is_dir(): return path - raise argparse.ArgumentTypeError(f"export-test-cases dir:{path} exists and is not a directory") + msg = f"export-test-cases dir:{path} exists and is not a directory" + raise argparse.ArgumentTypeError(msg) - parser.addoption("--export-test-cases", action="store", type=dir_path, help="A directory to which to export test cases") + parser.addoption( + "--export-test-cases", + action="store", + type=dir_path, + help="A directory to which to export test cases", + ) diff --git a/tests/error/errors_on_usage/if_expr_cond_type_change.py b/tests/error/errors_on_usage/if_expr_cond_type_change.py index c9ec27ab..d108fb78 100644 --- a/tests/error/errors_on_usage/if_expr_cond_type_change.py +++ b/tests/error/errors_on_usage/if_expr_cond_type_change.py @@ -6,4 +6,4 @@ def foo(x: bool) -> int: y = 4 0 if (y := x) else (y := 6) z = y - return 42 \ No newline at end of file + return 42 diff --git a/tests/error/generate_test.sh b/tests/error/generate_test.sh index 0cc01a2e..5150167b 100755 --- a/tests/error/generate_test.sh +++ b/tests/error/generate_test.sh @@ -1,3 +1,3 @@ #!/bin/bash -python -m tests.error.$1.$2 2> tests/error/$1/$2.err \ No newline at end of file +python -m tests.error.$1.$2 2> tests/error/$1/$2.err diff --git a/tests/error/nested_errors/different_types_if.py b/tests/error/nested_errors/different_types_if.py index 501de616..2eebce38 100644 --- a/tests/error/nested_errors/different_types_if.py +++ b/tests/error/nested_errors/different_types_if.py @@ -11,4 +11,3 @@ def bar() -> bool: return False return bar() - diff --git a/tests/error/nested_errors/not_defined_if.py b/tests/error/nested_errors/not_defined_if.py index 67715f3e..8798212d 100644 --- a/tests/error/nested_errors/not_defined_if.py +++ b/tests/error/nested_errors/not_defined_if.py @@ -7,4 +7,3 @@ def foo(b: bool) -> int: def bar() -> int: return 0 return bar() - diff --git a/tests/error/nested_errors/reassign_capture_1.py b/tests/error/nested_errors/reassign_capture_1.py index 1a770cd0..bf8da199 100644 --- a/tests/error/nested_errors/reassign_capture_1.py +++ b/tests/error/nested_errors/reassign_capture_1.py @@ -10,4 +10,3 @@ def bar() -> None: bar() return y - diff --git a/tests/error/nested_errors/reassign_capture_2.py b/tests/error/nested_errors/reassign_capture_2.py index 269699e9..7317bdf5 100644 --- a/tests/error/nested_errors/reassign_capture_2.py +++ b/tests/error/nested_errors/reassign_capture_2.py @@ -12,4 +12,3 @@ def bar() -> None: bar() return y - diff --git a/tests/error/nested_errors/reassign_capture_3.py b/tests/error/nested_errors/reassign_capture_3.py index 719fd41a..4648f379 100644 --- a/tests/error/nested_errors/reassign_capture_3.py +++ b/tests/error/nested_errors/reassign_capture_3.py @@ -14,4 +14,3 @@ def baz() -> None: bar() return y - diff --git a/tests/error/nested_errors/var_not_defined.py b/tests/error/nested_errors/var_not_defined.py index a82af6a8..d18b305e 100644 --- a/tests/error/nested_errors/var_not_defined.py +++ b/tests/error/nested_errors/var_not_defined.py @@ -7,4 +7,3 @@ def bar() -> int: return x return bar() - diff --git a/tests/error/test_errors_on_usage.py b/tests/error/test_errors_on_usage.py index 4f53575d..4999cb81 100644 --- a/tests/error/test_errors_on_usage.py +++ b/tests/error/test_errors_on_usage.py @@ -4,7 +4,12 @@ from tests.error.util import run_error_test path = pathlib.Path(__file__).parent.resolve() / "errors_on_usage" -files = [x for x in path.iterdir() if x.is_file() if x.suffix == ".py" and x.name != "__init__.py"] +files = [ + x + for x in path.iterdir() + if x.is_file() + if x.suffix == ".py" and x.name != "__init__.py" +] # TODO: Skip functional tests for now files = [f for f in files if "functional" not in f.name] diff --git a/tests/error/test_linear_errors.py b/tests/error/test_linear_errors.py index d334724d..71615330 100644 --- a/tests/error/test_linear_errors.py +++ b/tests/error/test_linear_errors.py @@ -4,7 +4,12 @@ from tests.error.util import run_error_test path = pathlib.Path(__file__).parent.resolve() / "linear_errors" -files = [x for x in path.iterdir() if x.is_file() if x.suffix == ".py" and x.name != "__init__.py"] +files = [ + x + for x in path.iterdir() + if x.is_file() + if x.suffix == ".py" and x.name != "__init__.py" +] # TODO: Skip functional tests for now files = [f for f in files if "functional" not in f.name] diff --git a/tests/error/test_misc_errors.py b/tests/error/test_misc_errors.py index d3115ff4..03803e65 100644 --- a/tests/error/test_misc_errors.py +++ b/tests/error/test_misc_errors.py @@ -4,7 +4,12 @@ from tests.error.util import run_error_test path = pathlib.Path(__file__).parent.resolve() / "misc_errors" -files = [x for x in path.iterdir() if x.is_file() if x.suffix == ".py" and x.name != "__init__.py"] +files = [ + x + for x in path.iterdir() + if x.is_file() + if x.suffix == ".py" and x.name != "__init__.py" +] # TODO: Skip functional tests for now files = [f for f in files if "functional" not in f.name] diff --git a/tests/error/test_nested_errors.py b/tests/error/test_nested_errors.py index ba059c26..523a6574 100644 --- a/tests/error/test_nested_errors.py +++ b/tests/error/test_nested_errors.py @@ -4,7 +4,12 @@ from tests.error.util import run_error_test path = pathlib.Path(__file__).parent.resolve() / "nested_errors" -files = [x for x in path.iterdir() if x.is_file() if x.suffix == ".py" and x.name != "__init__.py"] +files = [ + x + for x in path.iterdir() + if x.is_file() + if x.suffix == ".py" and x.name != "__init__.py" +] # Turn paths into strings, otherwise pytest doesn't display the names files = [str(f) for f in files] diff --git a/tests/error/test_type_errors.py b/tests/error/test_type_errors.py index 619438c9..864f15c7 100644 --- a/tests/error/test_type_errors.py +++ b/tests/error/test_type_errors.py @@ -4,7 +4,12 @@ from tests.error.util import run_error_test path = pathlib.Path(__file__).parent.resolve() / "type_errors" -files = [x for x in path.iterdir() if x.is_file() if x.suffix == ".py" and x.name != "__init__.py"] +files = [ + x + for x in path.iterdir() + if x.is_file() + if x.suffix == ".py" and x.name != "__init__.py" +] # TODO: Skip functional tests for now files = [f for f in files if "functional" not in f.name] diff --git a/tests/error/type_errors/fun_ty_mismatch_1.py b/tests/error/type_errors/fun_ty_mismatch_1.py index 66445957..159825b6 100644 --- a/tests/error/type_errors/fun_ty_mismatch_1.py +++ b/tests/error/type_errors/fun_ty_mismatch_1.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable from tests.error.util import guppy diff --git a/tests/error/type_errors/fun_ty_mismatch_2.py b/tests/error/type_errors/fun_ty_mismatch_2.py index ab421bed..8d98dc01 100644 --- a/tests/error/type_errors/fun_ty_mismatch_2.py +++ b/tests/error/type_errors/fun_ty_mismatch_2.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable from tests.error.util import guppy diff --git a/tests/error/type_errors/fun_ty_mismatch_3.py b/tests/error/type_errors/fun_ty_mismatch_3.py index 2e35f0ce..bbb4f428 100644 --- a/tests/error/type_errors/fun_ty_mismatch_3.py +++ b/tests/error/type_errors/fun_ty_mismatch_3.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable from tests.error.util import guppy diff --git a/tests/error/util.py b/tests/error/util.py index a3e39191..c7d207d8 100644 --- a/tests/error/util.py +++ b/tests/error/util.py @@ -2,7 +2,8 @@ import pathlib import pytest -from typing import Callable, Optional, Any +from typing import Any +from collections.abc import Callable from guppy.hugr import tys from guppy.hugr.tys import TypeBound @@ -12,8 +13,8 @@ import guppy.decorator as decorator -def guppy(f: Callable[..., Any]) -> Optional[Hugr]: - """ Decorator to compile functions outside of modules for testing. """ +def guppy(f: Callable[..., Any]) -> Hugr | None: + """Decorator to compile functions outside of modules for testing.""" module = GuppyModule("module") module.register_func_def(f) return module.compile() @@ -29,7 +30,7 @@ def run_error_test(file, capsys): err = capsys.readouterr().err - with open(file.with_suffix(".err")) as f: + with pathlib.Path(file.with_suffix(".err")).open() as f: exp_err = f.read() exp_err = exp_err.replace("$FILE", str(file)) @@ -39,7 +40,8 @@ def run_error_test(file, capsys): util = GuppyModule("test") -@decorator.guppy.type(util, tys.Opaque(extension="", id="", args=[], bound=TypeBound.Copyable)) +@decorator.guppy.type( + util, tys.Opaque(extension="", id="", args=[], bound=TypeBound.Copyable) +) class NonBool: pass - diff --git a/tests/hugr/test_dummy_nodes.py b/tests/hugr/test_dummy_nodes.py index e1efaab9..e33537cd 100644 --- a/tests/hugr/test_dummy_nodes.py +++ b/tests/hugr/test_dummy_nodes.py @@ -20,7 +20,9 @@ def test_single_dummy(): def test_unique_names(): g = Hugr() - defn = g.add_def(FunctionType([BoolType()], TupleType([BoolType(), BoolType()])), g.root, "test") + defn = g.add_def( + FunctionType([BoolType()], TupleType([BoolType(), BoolType()])), g.root, "test" + ) dfg = g.add_dfg(defn) inp = g.add_input([BoolType()], dfg).out_port(0) dummy1 = g.add_node( @@ -34,4 +36,3 @@ def test_unique_names(): g.remove_dummy_nodes() [decl1, decl2] = [n for n in g.nodes() if isinstance(n.op, ops.FuncDecl)] assert {decl1.op.name, decl2.op.name} == {"dummy", "dummy$1"} - diff --git a/tests/hugr/test_ports.py b/tests/hugr/test_ports.py index 77e07ef4..299b310e 100644 --- a/tests/hugr/test_ports.py +++ b/tests/hugr/test_ports.py @@ -12,4 +12,3 @@ def test_undefined_port(): p.node with pytest.raises(InternalGuppyError, match="Tried to access undefined Port"): p.offset - diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index af4f558c..9841036a 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -2,21 +2,23 @@ from . import util -@pytest.fixture + +@pytest.fixture() def export_test_cases_dir(request): - r = request.config.getoption('--export-test-cases') + r = request.config.getoption("--export-test-cases") if r and not r.exists(): r.mkdir(parents=True) return r -@pytest.fixture +@pytest.fixture() def validate(request, export_test_cases_dir): - def validate_impl(hugr,name=None): + def validate_impl(hugr, name=None): bs = hugr.serialize() util.validate_bytes(bs) if export_test_cases_dir: file_name = f"{request.node.name}{f'_{name}' if name else ''}.msgpack" export_file = export_test_cases_dir / file_name export_file.write_bytes(bs) + return validate_impl diff --git a/tests/integration/test_basic.py b/tests/integration/test_basic.py index cde2801a..487f0197 100644 --- a/tests/integration/test_basic.py +++ b/tests/integration/test_basic.py @@ -77,5 +77,7 @@ def test_func_decl_name(): def func_name() -> None: ... - [def_op] = [n.op for n in module.compile().nodes() if isinstance(n.op, ops.FuncDecl)] + [def_op] = [ + n.op for n in module.compile().nodes() if isinstance(n.op, ops.FuncDecl) + ] assert def_op.name == "func_name" diff --git a/tests/integration/test_call.py b/tests/integration/test_call.py index bad4d824..2cb7ec80 100644 --- a/tests/integration/test_call.py +++ b/tests/integration/test_call.py @@ -50,6 +50,3 @@ def bar(x: int) -> int: return foo(x) validate(module.compile()) - - - diff --git a/tests/integration/test_functional.py b/tests/integration/test_functional.py index 8c462d6b..ae61a30d 100644 --- a/tests/integration/test_functional.py +++ b/tests/integration/test_functional.py @@ -8,7 +8,7 @@ def test_if_no_else(validate): @guppy def foo(x: bool, y: int) -> int: - _@functional + _ @ functional if x: y += 1 return y @@ -20,7 +20,7 @@ def foo(x: bool, y: int) -> int: def test_if_else(validate): @guppy def foo(x: bool, y: int) -> int: - _@functional + _ @ functional if x: y += 1 else: @@ -34,7 +34,7 @@ def foo(x: bool, y: int) -> int: def test_if_elif(validate): @guppy def foo(x: bool, y: int) -> int: - _@functional + _ @ functional if x: y += 1 elif y > 4: @@ -48,7 +48,7 @@ def foo(x: bool, y: int) -> int: def test_if_elif_else(validate): @guppy def foo(x: bool, y: int) -> int: - _@functional + _ @ functional if x: y += 1 elif y > 4: @@ -87,7 +87,7 @@ def test_nested_loop(validate): @guppy def foo(x: int, y: int) -> int: p = 0 - _@functional + _ @ functional while x > 0: s = 0 while y > 0: diff --git a/tests/integration/test_higher_order.py b/tests/integration/test_higher_order.py index b406c932..fd8856d8 100644 --- a/tests/integration/test_higher_order.py +++ b/tests/integration/test_higher_order.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable from guppy.decorator import guppy from guppy.module import GuppyModule @@ -73,13 +73,18 @@ def curry(f: Callable[[int, int], bool]) -> Callable[[int], Callable[[int], bool def g(x: int) -> Callable[[int], bool]: def h(y: int) -> bool: return f(x, y) + return h + return g @guppy(module) - def uncurry(f: Callable[[int], Callable[[int], bool]]) -> Callable[[int, int], bool]: + def uncurry( + f: Callable[[int], Callable[[int], bool]], + ) -> Callable[[int, int], bool]: def g(x: int, y: int) -> bool: return f(x)(y) + return g @guppy(module) diff --git a/tests/integration/test_if.py b/tests/integration/test_if.py index 4f09b295..ac4f5f10 100644 --- a/tests/integration/test_if.py +++ b/tests/integration/test_if.py @@ -1,7 +1,6 @@ import pytest from guppy.decorator import guppy -from guppy.module import GuppyModule def test_if_no_else(validate): @@ -55,7 +54,7 @@ def foo(x: bool, y: int) -> int: def test_if_expr(validate): @guppy def foo(x: bool, y: int) -> int: - return y+1 if x else 42 + return y + 1 if x else 42 validate(foo) @@ -265,4 +264,3 @@ def foo(x: int) -> int: return z validate(foo) - diff --git a/tests/integration/test_linear.py b/tests/integration/test_linear.py index 0d744e06..35091b61 100644 --- a/tests/integration/test_linear.py +++ b/tests/integration/test_linear.py @@ -55,7 +55,9 @@ def g(q1: Qubit, q2: Qubit) -> tuple[Qubit, Qubit]: ... @guppy(module) - def test(a: Qubit, b: Qubit, c: Qubit, d: Qubit) -> tuple[Qubit, Qubit, Qubit, Qubit]: + def test( + a: Qubit, b: Qubit, c: Qubit, d: Qubit + ) -> tuple[Qubit, Qubit, Qubit, Qubit]: a, b = f(a, b) c, d = f(c, d) b, c = g(b, c) diff --git a/tests/integration/test_nested.py b/tests/integration/test_nested.py index fe58523f..23698f59 100644 --- a/tests/integration/test_nested.py +++ b/tests/integration/test_nested.py @@ -6,6 +6,7 @@ def test_basic(validate): def foo(x: int) -> int: def bar(y: int) -> int: return y + return bar(x + 1) validate(foo) @@ -16,6 +17,7 @@ def test_call_twice(validate): def foo(x: int) -> int: def bar(y: int) -> int: return y + 3 + if x > 5: return bar(x) else: @@ -45,11 +47,14 @@ def test_define_twice(validate): @guppy def foo(x: int) -> int: if x == 0: + def bar(y: int) -> int: return y + 3 else: + def bar(y: int) -> int: return y - 42 + return bar(x) validate(foo) @@ -61,7 +66,9 @@ def foo(x: int) -> int: def bar(y: int) -> int: def baz(z: int) -> int: return z - 1 - return baz(5*y) + + return baz(5 * y) + return bar(x + 1) validate(foo) @@ -74,6 +81,7 @@ def bar(y: int) -> int: if y == 0: return 0 return 2 * bar(y - 1) + return bar(x) validate(foo) @@ -84,6 +92,7 @@ def test_capture_arg(validate): def foo(x: int) -> int: def bar() -> int: return 1 + x + return bar() validate(foo) @@ -96,6 +105,7 @@ def foo(x: int) -> int: def bar() -> int: return y + return bar() validate(foo) @@ -113,6 +123,7 @@ def foo(x: int) -> int: def bar() -> int: q = y return q + z + return bar() validate(foo) @@ -127,6 +138,7 @@ def foo(x: int) -> int: def bar() -> int: return x + y + a + return bar() return 4 @@ -159,6 +171,7 @@ def bar(y: int, z: int) -> int: if y == 0: return z return bar(z, z * x) + return bar(x, 0) validate(foo) @@ -190,7 +203,7 @@ def foo(x: int) -> int: while x > 0: def bar() -> int: - return x*x + return x * x a += bar() x -= 1 diff --git a/tests/integration/test_programs.py b/tests/integration/test_programs.py index 01c2a6fd..ee7c87cc 100644 --- a/tests/integration/test_programs.py +++ b/tests/integration/test_programs.py @@ -1,6 +1,5 @@ from guppy.decorator import guppy from guppy.module import GuppyModule -from tests.integration.util import functional, _ def test_factorial(validate): @@ -24,19 +23,9 @@ def factorial3(x: int, acc: int) -> int: return acc return factorial3(x - 1, acc * x) - # @guppy - # def factorial4(x: int) -> int: - # acc = 1 - # _@functional - # while x > 0: - # acc *= x - # x -= 1 - # return acc - validate(factorial1, name="factorial1") validate(factorial2, name="factorial2") validate(factorial3, name="factorial3") - # validate(factorial4) def test_even_odd(validate): diff --git a/tests/integration/test_unused.py b/tests/integration/test_unused.py index 772e6b60..b8592bc4 100644 --- a/tests/integration/test_unused.py +++ b/tests/integration/test_unused.py @@ -1,9 +1,7 @@ -import pytest +""" All sorts of weird stuff is allowed when variables are not used. """ from guppy.decorator import guppy -""" All sorts of weird stuff is allowed when variables are not used. """ - def test_not_always_defined1(validate): @guppy diff --git a/tests/integration/util.py b/tests/integration/util.py index 52075019..a99478e2 100644 --- a/tests/integration/util.py +++ b/tests/integration/util.py @@ -1,12 +1,10 @@ -from typing import TypeVar - import validator -from guppy.hugr.hugr import Hugr def validate_bytes(hugr: bytes): validator.validate(hugr) + class Decorator: def __matmul__(self, other): return None diff --git a/validator/src/lib.rs b/validator/src/lib.rs index 702b0dc2..4d31a136 100644 --- a/validator/src/lib.rs +++ b/validator/src/lib.rs @@ -1,10 +1,8 @@ -use pyo3::prelude::*; -use hugr; use hugr::extension::{ExtensionRegistry, PRELUDE}; -use hugr::std_extensions::arithmetic::{int_ops, int_types, float_ops, float_types}; +use hugr::std_extensions::arithmetic::{float_ops, float_types, int_ops, int_types}; use hugr::std_extensions::logic; use lazy_static::lazy_static; - +use pyo3::prelude::*; lazy_static! { pub static ref REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ @@ -14,7 +12,8 @@ lazy_static! { int_ops::EXTENSION.to_owned(), float_types::extension(), float_ops::extension() - ]).unwrap(); + ]) + .unwrap(); } #[pyfunction] @@ -28,4 +27,4 @@ fn validate(hugr: Vec) -> PyResult<()> { fn validator(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(validate, m)?)?; Ok(()) -} \ No newline at end of file +}