diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 00000000..b40b54ec
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,65 @@
+repos:
+- repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.5.0 # Use the ref you want to point at
+ hooks:
+ - id: check-added-large-files
+ - id: check-case-conflict
+ - id: check-executables-have-shebangs
+ - id: check-merge-conflict
+ - id: check-toml
+ - id: check-vcs-permalinks
+ - id: check-yaml
+ - id: detect-private-key
+ - id: end-of-file-fixer
+ - id: fix-byte-order-marker
+ - id: mixed-line-ending
+ # - id: trailing-whitespace
+ # Python-specific
+ - id: check-ast
+ - id: check-docstring-first
+ - id: debug-statements
+
+- repo: https://github.com/crate-ci/typos
+ rev: v1.16.23
+ hooks:
+ - id: typos
+ args: []
+
+- repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.1.7
+ hooks:
+ - id: ruff
+ args: [--fix, --exit-non-zero-on-fix]
+ - id: ruff-format
+
+- repo: https://github.com/pre-commit/mirrors-mypy
+ rev: 'v1.3.0'
+ hooks:
+ - id: mypy
+ pass_filenames: false
+ args: [--package=guppy]
+ additional_dependencies: [
+ ormsgpack,
+ pydantic,
+ ]
+
+- repo: local
+ hooks:
+ - id: cargo-check
+ name: Cargo check
+ entry: bash -c 'cd validator && exec cargo check'
+ pass_filenames: false
+ types: [file, rust]
+ language: system
+ - id: rust-linting
+ name: Rust linting
+ entry: bash -c 'cd validator && exec cargo fmt --all --'
+ pass_filenames: true
+ types: [file, rust]
+ language: system
+ - id: rust-clippy
+ name: Rust clippy
+ entry: bash -c 'cd validator && exec cargo clippy --all-targets --all-features -- -Dclippy::all'
+ pass_filenames: false
+ types: [file, rust]
+ language: system
diff --git a/.typos.toml b/.typos.toml
new file mode 100644
index 00000000..92413257
--- /dev/null
+++ b/.typos.toml
@@ -0,0 +1,3 @@
+[default.extend-words]
+inot = "inot"
+fle = "fle"
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 00000000..fb07d6b7
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1,12 @@
+include *.toml
+include *.txt
+include *.yaml
+recursive-include examples *.ipynb
+recursive-include guppy *.py
+recursive-include tests *.err
+recursive-include tests *.py
+recursive-include tests *.sh
+recursive-exclude validator *.lock
+recursive-exclude validator *.rs
+recursive-exclude validator *.toml
+exclude .git-blame-ignore-revs
diff --git a/README.md b/README.md
index f5e76b91..44b8a55f 100644
--- a/README.md
+++ b/README.md
@@ -30,7 +30,8 @@ pip install -e '.[dev]'
### Git blame
-You can configure Git to ignore formatting commits when using `git blame` by running
+You can configure Git to ignore formatting commits when using `git blame` by running
+
```sh
git config blame.ignoreRevsFile .git-blame-ignore-revs
```
@@ -41,14 +42,14 @@ TODO
## Testing
-First, build the PyO3 Hugr validation library using
+First, build the PyO3 Hugr validation library from the `validator` directory using
+
```sh
maturin develop
```
-from the `validator` directory.
-
Run tests using
+
```sh
pytest -v
```
@@ -59,6 +60,7 @@ Integration test cases can be exported to a directory using
pytest --export-test-cases=guppy-exports
```
+
which will create a directory `./guppy-exports` populated with hugr modules serialised in msgpack.
## Packaging
diff --git a/guppy/ast_util.py b/guppy/ast_util.py
index 8c4be4f7..5a9e66f4 100644
--- a/guppy/ast_util.py
+++ b/guppy/ast_util.py
@@ -1,19 +1,19 @@
import ast
-from typing import Any, TypeVar, Generic, Union, Optional, TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
if TYPE_CHECKING:
from guppy.gtypes import GuppyType
-AstNode = Union[
- ast.AST,
- ast.operator,
- ast.expr,
- ast.arg,
- ast.stmt,
- ast.Name,
- ast.keyword,
- ast.FunctionDef,
-]
+AstNode = (
+ ast.AST
+ | ast.operator
+ | ast.expr
+ | ast.arg
+ | ast.stmt
+ | ast.Name
+ | ast.keyword
+ | ast.FunctionDef
+)
T = TypeVar("T", covariant=True)
@@ -114,7 +114,9 @@ def set_location_from(node: ast.AST, loc: ast.AST) -> None:
node.end_col_offset = loc.end_col_offset
source, file, line_offset = get_source(loc), get_file(loc), get_line_offset(loc)
- assert source is not None and file is not None and line_offset is not None
+ assert source is not None
+ assert file is not None
+ assert line_offset is not None
annotate_location(node, source, file, line_offset)
@@ -126,7 +128,7 @@ def annotate_location(
setattr(node, "source", source)
if recurse:
- for field, value in ast.iter_fields(node):
+ for _field, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
@@ -135,7 +137,7 @@ def annotate_location(
annotate_location(value, source, file, line_offset, recurse)
-def get_file(node: AstNode) -> Optional[str]:
+def get_file(node: AstNode) -> str | None:
"""Tries to retrieve a file annotation from an AST node."""
try:
file = getattr(node, "file")
@@ -144,7 +146,7 @@ def get_file(node: AstNode) -> Optional[str]:
return None
-def get_source(node: AstNode) -> Optional[str]:
+def get_source(node: AstNode) -> str | None:
"""Tries to retrieve a source annotation from an AST node."""
try:
source = getattr(node, "source")
@@ -153,7 +155,7 @@ def get_source(node: AstNode) -> Optional[str]:
return None
-def get_line_offset(node: AstNode) -> Optional[int]:
+def get_line_offset(node: AstNode) -> int | None:
"""Tries to retrieve a line offset annotation from an AST node."""
try:
line_offset = getattr(node, "line_offset")
diff --git a/guppy/cfg/analysis.py b/guppy/cfg/analysis.py
index 18fdc3f4..36ee8780 100644
--- a/guppy/cfg/analysis.py
+++ b/guppy/cfg/analysis.py
@@ -1,9 +1,9 @@
from abc import ABC, abstractmethod
-from typing import TypeVar, Generic, Iterable
+from collections.abc import Iterable
+from typing import Generic, TypeVar
from guppy.cfg.bb import BB
-
# Type variable for the lattice domain
T = TypeVar("T")
@@ -21,12 +21,10 @@ def eq(self, t1: T, t2: T, /) -> bool:
@abstractmethod
def initial(self) -> T:
"""Initial lattice value"""
- pass
@abstractmethod
def join(self, *ts: T) -> T:
"""Lattice join operation"""
- pass
@abstractmethod
def run(self, bbs: Iterable[BB]) -> Result[T]:
@@ -34,7 +32,6 @@ def run(self, bbs: Iterable[BB]) -> Result[T]:
Returns a mapping from basic blocks to lattice values at the start of each BB.
"""
- pass
class ForwardAnalysis(Analysis[T], ABC, Generic[T]):
@@ -43,7 +40,6 @@ class ForwardAnalysis(Analysis[T], ABC, Generic[T]):
@abstractmethod
def apply_bb(self, val_before: T, bb: BB, /) -> T:
"""Transformation a basic block applies to a lattice value"""
- pass
def run(self, bbs: Iterable[BB]) -> Result[T]:
"""Runs the analysis pass.
@@ -69,7 +65,6 @@ class BackwardAnalysis(Analysis[T], ABC, Generic[T]):
@abstractmethod
def apply_bb(self, val_after: T, bb: BB, /) -> T:
"""Transformation a basic block applies to a lattice value"""
- pass
def run(self, bbs: Iterable[BB]) -> Result[T]:
"""Runs the analysis pass.
@@ -106,7 +101,7 @@ def eq(self, live1: LivenessDomain, live2: LivenessDomain) -> bool:
return live1.keys() == live2.keys()
def initial(self) -> LivenessDomain:
- return dict()
+ return {}
def join(self, *ts: LivenessDomain) -> LivenessDomain:
res: LivenessDomain = {}
diff --git a/guppy/cfg/bb.py b/guppy/cfg/bb.py
index 95c9f63a..99fd61a2 100644
--- a/guppy/cfg/bb.py
+++ b/guppy/cfg/bb.py
@@ -1,7 +1,8 @@
import ast
from abc import ABC
from dataclasses import dataclass, field
-from typing import Optional, TYPE_CHECKING, Union
+from typing import TYPE_CHECKING
+
from typing_extensions import Self
from guppy.ast_util import AstNode, name_nodes_in_ast
@@ -34,9 +35,14 @@ def update_used(self, node: ast.AST) -> None:
self.used[name.id] = name
-BBStatement = Union[
- ast.Assign, ast.AugAssign, ast.AnnAssign, ast.Expr, ast.Return, NestedFunctionDef
-]
+BBStatement = (
+ ast.Assign
+ | ast.AugAssign
+ | ast.AnnAssign
+ | ast.Expr
+ | ast.Return
+ | NestedFunctionDef
+)
@dataclass(eq=False) # Disable equality to recover hash from `object`
@@ -57,10 +63,10 @@ class BB(ABC):
# If the BB has multiple successors, we need a predicate to decide to which one to
# jump to
- branch_pred: Optional[ast.expr] = None
+ branch_pred: ast.expr | None = None
# Information about assigned/used variables in the BB
- _vars: Optional[VariableStats] = None
+ _vars: VariableStats | None = None
@property
def vars(self) -> VariableStats:
@@ -123,9 +129,7 @@ def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> None:
# Only store used *external* variables: things defined in the current BB, as
# well as the function name and argument names should not be included
assigned_before_in_bb = (
- self.stats.assigned.keys()
- | {node.name}
- | set(a.arg for a in node.args.args)
+ self.stats.assigned.keys() | {node.name} | {a.arg for a in node.args.args}
)
self.stats.used |= {
x: using_bb.vars.used[x]
diff --git a/guppy/cfg/builder.py b/guppy/cfg/builder.py
index e937e144..878437bf 100644
--- a/guppy/cfg/builder.py
+++ b/guppy/cfg/builder.py
@@ -1,8 +1,9 @@
import ast
import itertools
-from typing import Optional, Iterator, NamedTuple
+from collections.abc import Iterator
+from typing import NamedTuple
-from guppy.ast_util import set_location_from, AstVisitor
+from guppy.ast_util import AstVisitor, set_location_from
from guppy.cfg.bb import BB, BBStatement
from guppy.cfg.cfg import CFG
from guppy.checker.core import Globals
@@ -24,11 +25,11 @@ class Jumps(NamedTuple):
"""Holds jump targets for return, continue, and break during CFG construction."""
return_bb: BB
- continue_bb: Optional[BB]
- break_bb: Optional[BB]
+ continue_bb: BB | None
+ break_bb: BB | None
-class CFGBuilder(AstVisitor[Optional[BB]]):
+class CFGBuilder(AstVisitor[BB | None]):
"""Constructs a CFG from ast nodes."""
cfg: CFG
@@ -57,8 +58,8 @@ def build(self, nodes: list[ast.stmt], returns_none: bool, globals: Globals) ->
return self.cfg
- def visit_stmts(self, nodes: list[ast.stmt], bb: BB, jumps: Jumps) -> Optional[BB]:
- bb_opt: Optional[BB] = bb
+ def visit_stmts(self, nodes: list[ast.stmt], bb: BB, jumps: Jumps) -> BB | None:
+ bb_opt: BB | None = bb
next_functional = False
for node in nodes:
if bb_opt is None:
@@ -69,7 +70,7 @@ def visit_stmts(self, nodes: list[ast.stmt], bb: BB, jumps: Jumps) -> Optional[B
if next_functional:
# TODO: This should be an assertion that the Hugr can be un-flattened
- raise NotImplementedError()
+ raise NotImplementedError
next_functional = False
else:
bb_opt = self.visit(node, bb_opt, jumps)
@@ -86,20 +87,16 @@ def _build_node_value(self, node: BBStatement, bb: BB) -> BB:
bb.statements.append(node)
return bb
- def visit_Assign(self, node: ast.Assign, bb: BB, jumps: Jumps) -> Optional[BB]:
+ def visit_Assign(self, node: ast.Assign, bb: BB, jumps: Jumps) -> BB | None:
return self._build_node_value(node, bb)
- def visit_AugAssign(
- self, node: ast.AugAssign, bb: BB, jumps: Jumps
- ) -> Optional[BB]:
+ def visit_AugAssign(self, node: ast.AugAssign, bb: BB, jumps: Jumps) -> BB | None:
return self._build_node_value(node, bb)
- def visit_AnnAssign(
- self, node: ast.AnnAssign, bb: BB, jumps: Jumps
- ) -> Optional[BB]:
+ def visit_AnnAssign(self, node: ast.AnnAssign, bb: BB, jumps: Jumps) -> BB | None:
return self._build_node_value(node, bb)
- def visit_Expr(self, node: ast.Expr, bb: BB, jumps: Jumps) -> Optional[BB]:
+ def visit_Expr(self, node: ast.Expr, bb: BB, jumps: Jumps) -> BB | None:
# This is an expression statement where the value is discarded
node.value, bb = ExprBuilder.build(node.value, self.cfg, bb)
# We don't add it to the BB if it's just a temporary variable. This will be the
@@ -110,7 +107,7 @@ def visit_Expr(self, node: ast.Expr, bb: BB, jumps: Jumps) -> Optional[BB]:
bb.statements.append(node)
return bb
- def visit_If(self, node: ast.If, bb: BB, jumps: Jumps) -> Optional[BB]:
+ def visit_If(self, node: ast.If, bb: BB, jumps: Jumps) -> BB | None:
then_bb, else_bb = self.cfg.new_bb(), self.cfg.new_bb()
BranchBuilder.add_branch(node.test, self.cfg, bb, then_bb, else_bb)
then_bb = self.visit_stmts(node.body, then_bb, jumps)
@@ -127,7 +124,7 @@ def visit_If(self, node: ast.If, bb: BB, jumps: Jumps) -> Optional[BB]:
# No branch jumps: We have to merge the control flow
return self.cfg.new_bb(then_bb, else_bb)
- def visit_While(self, node: ast.While, bb: BB, jumps: Jumps) -> Optional[BB]:
+ def visit_While(self, node: ast.While, bb: BB, jumps: Jumps) -> BB | None:
head_bb = self.cfg.new_bb(bb)
body_bb, tail_bb = self.cfg.new_bb(), self.cfg.new_bb()
BranchBuilder.add_branch(node.test, self.cfg, head_bb, body_bb, tail_bb)
@@ -145,29 +142,29 @@ def visit_While(self, node: ast.While, bb: BB, jumps: Jumps) -> Optional[BB]:
# its own jumps since the body is not guaranteed to execute
return tail_bb
- def visit_Continue(self, node: ast.Continue, bb: BB, jumps: Jumps) -> Optional[BB]:
+ def visit_Continue(self, node: ast.Continue, bb: BB, jumps: Jumps) -> BB | None:
if not jumps.continue_bb:
raise InternalGuppyError("Continue BB not defined")
self.cfg.link(bb, jumps.continue_bb)
return None
- def visit_Break(self, node: ast.Break, bb: BB, jumps: Jumps) -> Optional[BB]:
+ def visit_Break(self, node: ast.Break, bb: BB, jumps: Jumps) -> BB | None:
if not jumps.break_bb:
raise InternalGuppyError("Break BB not defined")
self.cfg.link(bb, jumps.break_bb)
return None
- def visit_Return(self, node: ast.Return, bb: BB, jumps: Jumps) -> Optional[BB]:
+ def visit_Return(self, node: ast.Return, bb: BB, jumps: Jumps) -> BB | None:
bb = self._build_node_value(node, bb)
self.cfg.link(bb, jumps.return_bb)
return None
- def visit_Pass(self, node: ast.Pass, bb: BB, jumps: Jumps) -> Optional[BB]:
+ def visit_Pass(self, node: ast.Pass, bb: BB, jumps: Jumps) -> BB | None:
return bb
def visit_FunctionDef(
self, node: ast.FunctionDef, bb: BB, jumps: Jumps
- ) -> Optional[BB]:
+ ) -> BB | None:
from guppy.checker.func_checker import check_signature
func_ty = check_signature(node, self.globals)
@@ -188,7 +185,7 @@ def visit_FunctionDef(
bb.statements.append(new_node)
return bb
- def generic_visit(self, node: ast.AST, bb: BB, jumps: Jumps) -> Optional[BB]: # type: ignore
+ def generic_visit(self, node: ast.AST, bb: BB, jumps: Jumps) -> BB | None: # type: ignore[override]
# When adding support for new statements, we have to remember to use the
# ExprBuilder to transform all included expressions!
raise GuppyError("Statement is not supported", node)
@@ -215,7 +212,7 @@ def build(node: ast.expr, cfg: CFG, bb: BB) -> tuple[ast.expr, BB]:
return builder.visit(node), builder.bb
@classmethod
- def _make_var(cls, name: str, loc: Optional[ast.expr] = None) -> ast.Name:
+ def _make_var(cls, name: str, loc: ast.expr | None = None) -> ast.Name:
"""Creates an `ast.Name` node."""
node = ast.Name(id=name, ctx=ast.Load)
if loc is not None:
@@ -343,7 +340,7 @@ def visit_Compare(
# Support chained comparisons, e.g. `x <= 5 < y` by compiling to `x <= 5 and
# 5 < y`. This way we get short-circuit evaluation for free.
if len(node.comparators) > 1:
- comparators = [node.left] + node.comparators
+ comparators = [node.left, *node.comparators]
values = [
ast.Compare(
left=left,
@@ -368,7 +365,7 @@ def visit_IfExp(self, node: ast.IfExp, bb: BB, true_bb: BB, false_bb: BB) -> Non
self.visit(node.body, then_bb, true_bb, false_bb)
self.visit(node.orelse, else_bb, true_bb, false_bb)
- def generic_visit(self, node: ast.expr, bb: BB, true_bb: BB, false_bb: BB) -> None: # type: ignore
+ def generic_visit(self, node: ast.expr, bb: BB, true_bb: BB, false_bb: BB) -> None: # type: ignore[override]
# We can always fall back to building the node as a regular expression and using
# the result as a branch predicate
pred, bb = ExprBuilder.build(node, self.cfg, bb)
diff --git a/guppy/cfg/cfg.py b/guppy/cfg/cfg.py
index 2648f696..986988d5 100644
--- a/guppy/cfg/cfg.py
+++ b/guppy/cfg/cfg.py
@@ -1,16 +1,15 @@
-from typing import Optional, TypeVar, Generic
+from typing import Generic, TypeVar
from guppy.cfg.analysis import (
- LivenessDomain,
- LivenessAnalysis,
AssignmentAnalysis,
DefAssignmentDomain,
+ LivenessAnalysis,
+ LivenessDomain,
MaybeAssignmentDomain,
Result,
)
from guppy.cfg.bb import BB, BBStatement
-
T = TypeVar("T", bound=BB)
@@ -26,7 +25,7 @@ class BaseCFG(Generic[T]):
maybe_ass_before: Result[MaybeAssignmentDomain]
def __init__(
- self, bbs: list[T], entry_bb: Optional[T] = None, exit_bb: Optional[T] = None
+ self, bbs: list[T], entry_bb: T | None = None, exit_bb: T | None = None
):
self.bbs = bbs
if entry_bb:
@@ -54,7 +53,7 @@ def __init__(self) -> None:
self.entry_bb = self.new_bb()
self.exit_bb = self.new_bb()
- def new_bb(self, *preds: BB, statements: Optional[list[BBStatement]] = None) -> BB:
+ def new_bb(self, *preds: BB, statements: list[BBStatement] | None = None) -> BB:
"""Adds a new basic block to the CFG."""
bb = BB(
len(self.bbs), self, predecessors=list(preds), statements=statements or []
diff --git a/guppy/checker/cfg_checker.py b/guppy/checker/cfg_checker.py
index 787b2ccb..c5cf8a64 100644
--- a/guppy/checker/cfg_checker.py
+++ b/guppy/checker/cfg_checker.py
@@ -5,21 +5,18 @@
"""
import collections
+from collections.abc import Sequence
from dataclasses import dataclass
-from typing import Sequence
from guppy.ast_util import line_col
from guppy.cfg.bb import BB
from guppy.cfg.cfg import CFG, BaseCFG
-from guppy.checker.core import Globals, Context
-
-from guppy.checker.core import Variable
+from guppy.checker.core import Context, Globals, Variable
from guppy.checker.expr_checker import ExprSynthesizer, to_bool
from guppy.checker.stmt_checker import StmtChecker
from guppy.error import GuppyError
from guppy.gtypes import GuppyType
-
VarRow = Sequence[Variable]
@@ -42,7 +39,7 @@ def empty() -> "Signature":
class CheckedBB(BB):
"""Basic block annotated with an input and output type signature."""
- sig: Signature = Signature.empty()
+ sig: Signature = Signature.empty() # noqa: RUF009
class CheckedCFG(BaseCFG[CheckedBB]):
@@ -64,7 +61,7 @@ def check_cfg(
unreachable blocks.
"""
# First, we need to run program analysis
- ass_before = set(v.name for v in inputs)
+ ass_before = {v.name for v in inputs}
cfg.analyze(ass_before, ass_before)
# We start by compiling the entry BB
diff --git a/guppy/checker/core.py b/guppy/checker/core.py
index 9bab6375..25dd256d 100644
--- a/guppy/checker/core.py
+++ b/guppy/checker/core.py
@@ -1,16 +1,16 @@
import ast
from abc import ABC, abstractmethod
from dataclasses import dataclass
-from typing import NamedTuple, Optional, Union
+from typing import NamedTuple
from guppy.ast_util import AstNode
from guppy.gtypes import (
- GuppyType,
+ BoolType,
FunctionType,
- TupleType,
- SumType,
+ GuppyType,
NoneType,
- BoolType,
+ SumType,
+ TupleType,
)
@@ -20,8 +20,8 @@ class Variable:
name: str
ty: GuppyType
- defined_at: Optional[AstNode]
- used: Optional[AstNode]
+ defined_at: AstNode | None
+ used: AstNode | None
@dataclass
@@ -65,7 +65,7 @@ def default() -> "Globals":
}
return Globals({}, tys)
- def get_instance_func(self, ty: GuppyType, name: str) -> Optional[CallableVariable]:
+ def get_instance_func(self, ty: GuppyType, name: str) -> CallableVariable | None:
"""Looks up an instance function with a given name for a type.
Returns `None` if the name doesn't exist or isn't a function.
@@ -83,7 +83,7 @@ def __or__(self, other: "Globals") -> "Globals":
self.types | other.types,
)
- def __ior__(self, other: "Globals") -> "Globals":
+ def __ior__(self, other: "Globals") -> "Globals": # noqa: PYI034
self.values.update(other.values)
self.types.update(other.types)
return self
@@ -100,7 +100,7 @@ class Context(NamedTuple):
locals: Locals
-def qualified_name(ty: Union[type[GuppyType], str], name: str) -> str:
+def qualified_name(ty: type[GuppyType] | str, name: str) -> str:
"""Returns a qualified name for an instance function on a type."""
ty_name = ty if isinstance(ty, str) else ty.name
return f"{ty_name}.{name}"
diff --git a/guppy/checker/expr_checker.py b/guppy/checker/expr_checker.py
index 09c12774..731d8f7c 100644
--- a/guppy/checker/expr_checker.py
+++ b/guppy/checker/expr_checker.py
@@ -22,13 +22,13 @@
import ast
from contextlib import suppress
-from typing import Optional, Union, NoReturn, Any
+from typing import Any, NoReturn
-from guppy.ast_util import AstVisitor, with_loc, AstNode, with_type, get_type_opt
-from guppy.checker.core import Context, CallableVariable, Globals
+from guppy.ast_util import AstNode, AstVisitor, get_type_opt, with_loc, with_type
+from guppy.checker.core import CallableVariable, Context, Globals
from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError
-from guppy.gtypes import GuppyType, TupleType, FunctionType, BoolType
-from guppy.nodes import LocalName, GlobalName, LocalCall
+from guppy.gtypes import BoolType, FunctionType, GuppyType, TupleType
+from guppy.nodes import GlobalName, LocalCall, LocalName
# Mapping from unary AST op to dunder method and display name
unary_table: dict[type[ast.unaryop], tuple[str, str]] = {
@@ -38,7 +38,7 @@
} # fmt: skip
# Mapping from binary AST op to left dunder method, right dunder method and display name
-AstOp = Union[ast.operator, ast.cmpop]
+AstOp = ast.operator | ast.cmpop
binary_table: dict[type[AstOp], tuple[str, str, str]] = {
ast.Add: ("__add__", "__radd__", "+"),
ast.Sub: ("__sub__", "__rsub__", "-"),
@@ -78,8 +78,8 @@ def __init__(self, ctx: Context) -> None:
def _fail(
self,
expected: GuppyType,
- actual: Union[ast.expr, GuppyType],
- loc: Optional[AstNode] = None,
+ actual: ast.expr | GuppyType,
+ loc: AstNode | None = None,
) -> NoReturn:
"""Raises a type error indicating that the type doesn't match."""
if not isinstance(actual, GuppyType):
@@ -348,7 +348,7 @@ def to_bool(
def python_value_to_guppy_type(
v: Any, node: ast.expr, globals: Globals
-) -> Optional[GuppyType]:
+) -> GuppyType | None:
"""Turns a primitive Python value into a Guppy type.
Returns `None` if the Python value cannot be represented in Guppy.
diff --git a/guppy/checker/func_checker.py b/guppy/checker/func_checker.py
index e0695824..20113685 100644
--- a/guppy/checker/func_checker.py
+++ b/guppy/checker/func_checker.py
@@ -8,15 +8,15 @@
import ast
from dataclasses import dataclass
-from guppy.ast_util import return_nodes_in_ast, AstNode, with_loc
+from guppy.ast_util import AstNode, return_nodes_in_ast, with_loc
from guppy.cfg.bb import BB
from guppy.cfg.builder import CFGBuilder
-from guppy.checker.core import Variable, Globals, Context, CallableVariable
-from guppy.checker.cfg_checker import check_cfg, CheckedCFG
-from guppy.checker.expr_checker import synthesize_call, check_call
+from guppy.checker.cfg_checker import CheckedCFG, check_cfg
+from guppy.checker.core import CallableVariable, Context, Globals, Variable
+from guppy.checker.expr_checker import check_call, synthesize_call
from guppy.error import GuppyError
-from guppy.gtypes import FunctionType, type_from_ast, NoneType, GuppyType
-from guppy.nodes import GlobalCall, CheckedNestedFunctionDef, NestedFunctionDef
+from guppy.gtypes import FunctionType, GuppyType, NoneType, type_from_ast
+from guppy.nodes import CheckedNestedFunctionDef, GlobalCall, NestedFunctionDef
@dataclass
@@ -178,7 +178,7 @@ def check_signature(func_def: ast.FunctionDef, globals: Globals) -> FunctionType
arg_tys = []
arg_names = []
- for i, arg in enumerate(func_def.args.args):
+ for _i, arg in enumerate(func_def.args.args):
if arg.annotation is None:
raise GuppyError("Argument type must be annotated", arg)
ty = type_from_ast(arg.annotation, globals)
diff --git a/guppy/checker/stmt_checker.py b/guppy/checker/stmt_checker.py
index 5622ee55..61df3120 100644
--- a/guppy/checker/stmt_checker.py
+++ b/guppy/checker/stmt_checker.py
@@ -9,14 +9,14 @@
"""
import ast
-from typing import Sequence
+from collections.abc import Sequence
-from guppy.ast_util import with_loc, AstVisitor
+from guppy.ast_util import AstVisitor, with_loc
from guppy.cfg.bb import BB, BBStatement
-from guppy.checker.core import Variable, Context
-from guppy.checker.expr_checker import ExprSynthesizer, ExprChecker
+from guppy.checker.core import Context, Variable
+from guppy.checker.expr_checker import ExprChecker, ExprSynthesizer
from guppy.error import GuppyError, GuppyTypeError, InternalGuppyError
-from guppy.gtypes import GuppyType, TupleType, type_from_ast, NoneType
+from guppy.gtypes import GuppyType, NoneType, TupleType, type_from_ast
from guppy.nodes import NestedFunctionDef
diff --git a/guppy/compiler/cfg_compiler.py b/guppy/compiler/cfg_compiler.py
index 086236e9..d2f1abb9 100644
--- a/guppy/compiler/cfg_compiler.py
+++ b/guppy/compiler/cfg_compiler.py
@@ -1,19 +1,22 @@
import functools
-from typing import Sequence
+from typing import TYPE_CHECKING
-from guppy.checker.cfg_checker import CheckedBB, VarRow, CheckedCFG, Signature
+from guppy.checker.cfg_checker import CheckedBB, CheckedCFG, Signature, VarRow
from guppy.checker.core import Variable
from guppy.compiler.core import (
CompiledGlobals,
- is_return_var,
DFContainer,
- return_var,
PortVariable,
+ is_return_var,
+ return_var,
)
from guppy.compiler.expr_compiler import ExprCompiler
from guppy.compiler.stmt_compiler import StmtCompiler
-from guppy.gtypes import TupleType, SumType, type_to_row
-from guppy.hugr.hugr import Hugr, Node, CFNode, OutPortV
+from guppy.gtypes import SumType, TupleType, type_to_row
+from guppy.hugr.hugr import CFNode, Hugr, Node, OutPortV
+
+if TYPE_CHECKING:
+ from collections.abc import Sequence
def compile_cfg(
@@ -119,7 +122,8 @@ def insert_return_vars(cfg: CheckedCFG) -> None:
# Also patch the predecessors
for pred in cfg.exit_bb.predecessors:
# The exit BB will be the only successor
- assert len(pred.sig.output_rows) == 1 and len(pred.sig.output_rows[0]) == 0
+ assert len(pred.sig.output_rows) == 1
+ assert len(pred.sig.output_rows[0]) == 0
pred.sig = Signature(pred.sig.input_row, [return_vars])
@@ -144,7 +148,7 @@ def choose_vars_for_tuple_sum(
conditional = graph.add_conditional(
cond_input=unit_sum, inputs=tuples, parent=dfg.node
)
- for i, ty in enumerate(tys):
+ for i, _ty in enumerate(tys):
case = graph.add_case(conditional)
inp = graph.add_input(output_tys=tys, parent=case).out_port(i)
tag = graph.add_tag(variants=tys, tag=i, inp=inp, parent=case).out_port(0)
diff --git a/guppy/compiler/core.py b/guppy/compiler/core.py
index 2b6a4fb2..d3b16e59 100644
--- a/guppy/compiler/core.py
+++ b/guppy/compiler/core.py
@@ -1,11 +1,11 @@
from abc import ABC, abstractmethod
+from collections.abc import Iterator
from dataclasses import dataclass
-from typing import Optional, Iterator
from guppy.ast_util import AstNode
-from guppy.checker.core import Variable, CallableVariable
+from guppy.checker.core import CallableVariable, Variable
from guppy.gtypes import FunctionType
-from guppy.hugr.hugr import OutPortV, DFContainingNode, Hugr
+from guppy.hugr.hugr import DFContainingNode, Hugr, OutPortV
@dataclass
@@ -21,8 +21,8 @@ def __init__(
self,
name: str,
port: OutPortV,
- defined_at: Optional[AstNode],
- used: Optional[AstNode] = None,
+ defined_at: AstNode | None,
+ used: AstNode | None = None,
) -> None:
super().__init__(name, port.ty, defined_at, used)
object.__setattr__(self, "port", port)
@@ -89,7 +89,7 @@ def __copy__(self) -> "DFContainer":
# mutate our variable mapping
return DFContainer(self.node, self.locals.copy())
- def get_var(self, name: str) -> Optional[PortVariable]:
+ def get_var(self, name: str) -> PortVariable | None:
return self.locals.get(name, None)
diff --git a/guppy/compiler/expr_compiler.py b/guppy/compiler/expr_compiler.py
index 5fdc378b..f74c18d2 100644
--- a/guppy/compiler/expr_compiler.py
+++ b/guppy/compiler/expr_compiler.py
@@ -1,13 +1,13 @@
import ast
-from typing import Any, Optional
+from typing import Any
from guppy.ast_util import AstVisitor, get_type
-from guppy.compiler.core import CompilerBase, DFContainer, CompiledFunction
+from guppy.compiler.core import CompiledFunction, CompilerBase, DFContainer
from guppy.error import InternalGuppyError
-from guppy.gtypes import FunctionType, type_to_row, BoolType
+from guppy.gtypes import BoolType, FunctionType, type_to_row
from guppy.hugr import ops, val
from guppy.hugr.hugr import OutPortV
-from guppy.nodes import LocalName, GlobalName, GlobalCall, LocalCall
+from guppy.nodes import GlobalCall, GlobalName, LocalCall, LocalName
class ExprCompiler(CompilerBase, AstVisitor[OutPortV]):
@@ -100,12 +100,12 @@ def expr_to_row(expr: ast.expr) -> list[ast.expr]:
return expr.elts if isinstance(expr, ast.Tuple) else [expr]
-def python_value_to_hugr(v: Any) -> Optional[val.Value]:
+def python_value_to_hugr(v: Any) -> val.Value | None:
"""Turns a Python value into a Hugr value.
Returns None if the Python value cannot be represented in Guppy.
"""
- from guppy.prelude._internal import int_value, bool_value, float_value
+ from guppy.prelude._internal import bool_value, float_value, int_value
if isinstance(v, bool):
return bool_value(v)
diff --git a/guppy/compiler/func_compiler.py b/guppy/compiler/func_compiler.py
index 4631f0db..f4376b17 100644
--- a/guppy/compiler/func_compiler.py
+++ b/guppy/compiler/func_compiler.py
@@ -9,8 +9,8 @@
DFContainer,
PortVariable,
)
-from guppy.gtypes import type_to_row, FunctionType
-from guppy.hugr.hugr import Hugr, OutPortV, DFContainingVNode
+from guppy.gtypes import FunctionType, type_to_row
+from guppy.hugr.hugr import DFContainingVNode, Hugr, OutPortV
from guppy.nodes import CheckedNestedFunctionDef
diff --git a/guppy/compiler/stmt_compiler.py b/guppy/compiler/stmt_compiler.py
index db656b6f..18147b22 100644
--- a/guppy/compiler/stmt_compiler.py
+++ b/guppy/compiler/stmt_compiler.py
@@ -1,12 +1,12 @@
import ast
-from typing import Sequence
+from collections.abc import Sequence
from guppy.ast_util import AstVisitor
from guppy.checker.cfg_checker import CheckedBB
from guppy.compiler.core import (
+ CompiledGlobals,
CompilerBase,
DFContainer,
- CompiledGlobals,
PortVariable,
return_var,
)
diff --git a/guppy/custom.py b/guppy/custom.py
index a80fd134..4c20ea09 100644
--- a/guppy/custom.py
+++ b/guppy/custom.py
@@ -1,27 +1,26 @@
import ast
from abc import ABC, abstractmethod
-from typing import Optional
-from guppy.ast_util import AstNode, with_type, with_loc, get_type
+from guppy.ast_util import AstNode, get_type, with_loc, with_type
from guppy.checker.core import Context, Globals
from guppy.checker.expr_checker import check_call, synthesize_call
from guppy.checker.func_checker import check_signature
-from guppy.compiler.core import CompiledFunction, DFContainer, CompiledGlobals
+from guppy.compiler.core import CompiledFunction, CompiledGlobals, DFContainer
from guppy.error import (
GuppyError,
InternalGuppyError,
UnknownFunctionType,
)
-from guppy.gtypes import GuppyType, FunctionType, type_to_row
+from guppy.gtypes import FunctionType, GuppyType, type_to_row
from guppy.hugr import ops
-from guppy.hugr.hugr import OutPortV, Hugr, Node, DFContainingVNode
+from guppy.hugr.hugr import DFContainingVNode, Hugr, Node, OutPortV
from guppy.nodes import GlobalCall
class CustomFunction(CompiledFunction):
"""A function whose type checking and compilation behaviour can be customised."""
- defined_at: Optional[ast.FunctionDef]
+ defined_at: ast.FunctionDef | None
# Whether the function may be used as a higher-order value. This is only possible
# if a static type for the function is provided.
@@ -30,17 +29,17 @@ class CustomFunction(CompiledFunction):
call_checker: "CustomCallChecker"
call_compiler: "CustomCallCompiler"
- _ty: Optional[FunctionType] = None
- _defined: dict[Node, DFContainingVNode] = {}
+ _ty: FunctionType | None = None
+ _defined: dict[Node, DFContainingVNode] = {} # noqa: RUF012
def __init__(
self,
name: str,
- defined_at: Optional[ast.FunctionDef],
+ defined_at: ast.FunctionDef | None,
compiler: "CustomCallCompiler",
checker: "CustomCallChecker",
higher_order_value: bool = True,
- ty: Optional[FunctionType] = None,
+ ty: FunctionType | None = None,
):
self.name = name
self.defined_at = defined_at
@@ -51,7 +50,7 @@ def __init__(
self._ty = ty
self._defined = {}
- @property # type: ignore
+ @property # type: ignore[override]
def ty(self) -> FunctionType:
if self._ty is None:
return UnknownFunctionType()
@@ -77,11 +76,11 @@ def check_type(self, globals: Globals) -> None:
try:
self._ty = check_signature(self.defined_at, globals)
- except GuppyError as err:
+ except GuppyError:
# We can ignore the error if a custom call checker is provided and the
# function may not be used as a higher-order value
if self.call_checker is None or self.higher_order_value:
- raise err
+ raise
def check_call(
self, args: list[ast.expr], ty: GuppyType, node: AstNode, ctx: Context
diff --git a/guppy/declared.py b/guppy/declared.py
index 91b5283e..265ef1a6 100644
--- a/guppy/declared.py
+++ b/guppy/declared.py
@@ -1,15 +1,14 @@
import ast
from dataclasses import dataclass
-from typing import Optional
from guppy.ast_util import AstNode, has_empty_body
-from guppy.checker.core import Globals, Context
+from guppy.checker.core import Context, Globals
from guppy.checker.expr_checker import check_call, synthesize_call
from guppy.checker.func_checker import check_signature
-from guppy.compiler.core import CompiledFunction, DFContainer, CompiledGlobals
+from guppy.compiler.core import CompiledFunction, CompiledGlobals, DFContainer
from guppy.error import GuppyError
-from guppy.gtypes import type_to_row, GuppyType
-from guppy.hugr.hugr import VNode, Hugr, Node, OutPortV
+from guppy.gtypes import GuppyType, type_to_row
+from guppy.hugr.hugr import Hugr, Node, OutPortV, VNode
from guppy.nodes import GlobalCall
@@ -17,7 +16,7 @@
class DeclaredFunction(CompiledFunction):
"""A user-declared function that compiles to a Hugr function declaration."""
- node: Optional[VNode] = None
+ node: VNode | None = None
@staticmethod
def from_ast(
diff --git a/guppy/decorator.py b/guppy/decorator.py
index 493ddcb2..32f93ab1 100644
--- a/guppy/decorator.py
+++ b/guppy/decorator.py
@@ -1,19 +1,20 @@
import functools
+from collections.abc import Callable
from dataclasses import dataclass
-from typing import Optional, Union, Callable, Any
+from typing import Any
from guppy.ast_util import AstNode, has_empty_body
from guppy.custom import (
+ CustomCallChecker,
+ CustomCallCompiler,
CustomFunction,
- OpCompiler,
DefaultCallChecker,
- CustomCallCompiler,
- CustomCallChecker,
DefaultCallCompiler,
+ OpCompiler,
)
from guppy.error import GuppyError, pretty_errors
from guppy.gtypes import GuppyType
-from guppy.hugr import tys, ops
+from guppy.hugr import ops, tys
from guppy.hugr.hugr import Hugr
from guppy.module import GuppyModule, PyFunc, parse_py_func
@@ -26,7 +27,7 @@ class _Guppy:
"""Class for the `@guppy` decorator."""
# The current module
- _module: Optional[GuppyModule]
+ _module: GuppyModule | None
def __init__(self) -> None:
self._module = None
@@ -35,13 +36,11 @@ def set_module(self, module: GuppyModule) -> None:
self._module = module
@pretty_errors
- def __call__(
- self, arg: Union[PyFunc, GuppyModule]
- ) -> Union[Optional[Hugr], FuncDecorator]:
+ def __call__(self, arg: PyFunc | GuppyModule) -> Hugr | None | FuncDecorator:
"""Decorator to annotate Python functions as Guppy code.
- Optionally, the `GuppyModule` in which the function should be placed can be passed
- to the decorator.
+ Optionally, the `GuppyModule` in which the function should be placed can be
+ passed to the decorator.
"""
if isinstance(arg, GuppyModule):
@@ -98,9 +97,7 @@ class NewType(GuppyType):
name = _name
@staticmethod
- def build(
- *args: GuppyType, node: Optional[AstNode] = None
- ) -> "GuppyType":
+ def build(*args: GuppyType, node: AstNode | None = None) -> "GuppyType":
# At the moment, custom types don't support type arguments.
if len(args) > 0:
raise GuppyError(
@@ -131,8 +128,8 @@ def __str__(self) -> str:
def custom(
self,
module: GuppyModule,
- compiler: Optional[CustomCallCompiler] = None,
- checker: Optional[CustomCallChecker] = None,
+ compiler: CustomCallCompiler | None = None,
+ checker: CustomCallChecker | None = None,
higher_order_value: bool = True,
name: str = "",
) -> CustomFuncDecorator:
@@ -168,7 +165,7 @@ def hugr_op(
self,
module: GuppyModule,
op: ops.OpType,
- checker: Optional[CustomCallChecker] = None,
+ checker: CustomCallChecker | None = None,
higher_order_value: bool = True,
name: str = "",
) -> CustomFuncDecorator:
diff --git a/guppy/error.py b/guppy/error.py
index b0ace141..bc6d38e7 100644
--- a/guppy/error.py
+++ b/guppy/error.py
@@ -2,13 +2,13 @@
import functools
import sys
import textwrap
+from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
-from typing import Optional, Any, Sequence, Callable, TypeVar, cast
-
-from guppy.ast_util import AstNode, get_line_offset, get_file, get_source
-from guppy.gtypes import GuppyType, FunctionType
-from guppy.hugr.hugr import OutPortV, Node
+from typing import Any, TypeVar, cast
+from guppy.ast_util import AstNode, get_file, get_line_offset, get_source
+from guppy.gtypes import FunctionType, GuppyType
+from guppy.hugr.hugr import Node, OutPortV
# Whether the interpreter should exit when a Guppy error occurs
EXIT_ON_ERROR: bool = True
@@ -25,12 +25,13 @@ class SourceLoc:
file: str
line: int
col: int
- ast_node: Optional[AstNode]
+ ast_node: AstNode | None
@staticmethod
def from_ast(node: AstNode) -> "SourceLoc":
file, line_offset = get_file(node), get_line_offset(node)
- assert file is not None and line_offset is not None
+ assert file is not None
+ assert line_offset is not None
return SourceLoc(file, line_offset + node.lineno - 1, node.col_offset, node)
def __str__(self) -> str:
@@ -50,9 +51,9 @@ class GuppyError(Exception):
`{1}`, etc. and passing the corresponding AST nodes to `locs_in_msg`."""
raw_msg: str
- location: Optional[AstNode] = None
+ location: AstNode | None = None
# The message can also refer to AST locations using format placeholders `{0}`, `{1}`
- locs_in_msg: Sequence[Optional[AstNode]] = field(default_factory=list)
+ locs_in_msg: Sequence[AstNode | None] = field(default_factory=list)
def get_msg(self) -> str:
"""Returns the message associated with this error.
@@ -70,14 +71,10 @@ def get_msg(self) -> str:
class GuppyTypeError(GuppyError):
"""Special Guppy exception for type errors."""
- pass
-
class InternalGuppyError(Exception):
"""Exception for internal problems during compilation."""
- pass
-
class UndefinedPort(OutPortV):
"""Dummy port for undefined variables.
@@ -119,7 +116,7 @@ def returns(self) -> GuppyType:
raise InternalGuppyError("Tried to access unknown function type")
@property
- def args_names(self) -> Optional[Sequence[str]]:
+ def args_names(self) -> Sequence[str] | None:
raise InternalGuppyError("Tried to access unknown function type")
@@ -130,7 +127,8 @@ def format_source_location(
) -> str:
"""Creates a pretty banner to show source locations for errors."""
source, line_offset = get_source(loc), get_line_offset(loc)
- assert source is not None and line_offset is not None
+ assert source is not None
+ assert line_offset is not None
source_lines = source.splitlines(keepends=True)
end_col_offset = loc.end_col_offset or len(source_lines[loc.lineno])
s = "".join(source_lines[max(loc.lineno - num_lines, 0) : loc.lineno]).rstrip()
@@ -160,12 +158,13 @@ def wrapped(*args: Any, **kwargs: Any) -> Any:
except GuppyError as err:
# Reraise if we're missing a location
if not err.location:
- raise err
+ raise
loc = err.location
file, line_offset = get_file(loc), get_line_offset(loc)
- assert file is not None and line_offset is not None
+ assert file is not None
+ assert line_offset is not None
line = line_offset + loc.lineno - 1
- print(
+ print( # noqa: T201
f"Guppy compilation failed. Error in file {file}:{line}\n\n"
f"{format_source_location(loc)}\n"
f"{err.__class__.__name__}: {err.get_msg()}",
diff --git a/guppy/gtypes.py b/guppy/gtypes.py
index 4d98e29b..a6a15c20 100644
--- a/guppy/gtypes.py
+++ b/guppy/gtypes.py
@@ -1,7 +1,8 @@
import ast
from abc import ABC, abstractmethod
+from collections.abc import Sequence
from dataclasses import dataclass, field
-from typing import Optional, Sequence, TYPE_CHECKING
+from typing import TYPE_CHECKING
import guppy.hugr.tys as tys
from guppy.ast_util import AstNode, set_location_from
@@ -20,7 +21,7 @@ class GuppyType(ABC):
@staticmethod
@abstractmethod
- def build(*args: "GuppyType", node: Optional[AstNode] = None) -> "GuppyType":
+ def build(*args: "GuppyType", node: AstNode | None = None) -> "GuppyType":
pass
@property
@@ -37,7 +38,7 @@ def to_hugr(self) -> tys.SimpleType:
class FunctionType(GuppyType):
args: Sequence[GuppyType]
returns: GuppyType
- arg_names: Optional[Sequence[str]] = field(
+ arg_names: Sequence[str] | None = field(
default=None,
compare=False, # Argument names are not taken into account for type equality
)
@@ -53,10 +54,10 @@ def __str__(self) -> str:
return f"({', '.join(str(a) for a in self.args)}) -> {self.returns}"
@staticmethod
- def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType:
+ def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType:
# Function types cannot be constructed using `build`. The type parsing code
# has a special case for function types.
- raise NotImplementedError()
+ raise NotImplementedError
def to_hugr(self) -> tys.SimpleType:
ins = [t.to_hugr() for t in self.args]
@@ -72,7 +73,7 @@ class TupleType(GuppyType):
name: str = "tuple"
@staticmethod
- def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType:
+ def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType:
from guppy.error import GuppyError
# TODO: Parse empty tuples via `tuple[()]`
@@ -97,10 +98,10 @@ class SumType(GuppyType):
element_types: Sequence[GuppyType]
@staticmethod
- def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType:
+ def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType:
# Sum types cannot be parsed and constructed using `build` since they cannot be
# written by the user
- raise NotImplementedError()
+ raise NotImplementedError
def __str__(self) -> str:
return f"Sum({', '.join(str(e) for e in self.element_types)})"
@@ -124,7 +125,7 @@ class NoneType(GuppyType):
linear: bool = False
@staticmethod
- def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType:
+ def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType:
if len(args) > 0:
from guppy.error import GuppyError
@@ -150,7 +151,7 @@ def __init__(self) -> None:
super().__init__([TupleType([]), TupleType([])])
@staticmethod
- def build(*args: GuppyType, node: Optional[AstNode] = None) -> GuppyType:
+ def build(*args: GuppyType, node: AstNode | None = None) -> GuppyType:
if len(args) > 0:
from guppy.error import GuppyError
@@ -161,7 +162,7 @@ def __str__(self) -> str:
return "bool"
-def _lookup_type(node: AstNode, globals: "Globals") -> Optional[type[GuppyType]]:
+def _lookup_type(node: AstNode, globals: "Globals") -> type[GuppyType] | None:
if isinstance(node, ast.Name) and node.id in globals.types:
return globals.types[node.id]
if isinstance(node, ast.Constant) and node.value is None:
diff --git a/guppy/hugr/hugr.py b/guppy/hugr/hugr.py
index 96a647f1..41e1f2bf 100644
--- a/guppy/hugr/hugr.py
+++ b/guppy/hugr/hugr.py
@@ -1,20 +1,21 @@
import itertools
-import networkx # type: ignore
-
from abc import ABC, abstractmethod
+from collections.abc import Iterator, Sequence
from contextlib import contextmanager
-from typing import Optional, Iterator, Tuple, Any, Sequence
-from dataclasses import field, dataclass
+from dataclasses import dataclass, field
+from typing import Any, Optional
+
+import networkx as nx # type: ignore[import]
import guppy.hugr.ops as ops
import guppy.hugr.raw as raw
from guppy.gtypes import (
- GuppyType,
- TupleType,
FunctionType,
+ GuppyType,
SumType,
- type_to_row,
+ TupleType,
row_to_type,
+ type_to_row,
)
from guppy.hugr import val
@@ -27,20 +28,16 @@ class Port(ABC):
"""Base class for ports on nodes."""
node: "Node"
- offset: Optional[PortOffset]
+ offset: PortOffset | None
class InPort(Port, ABC):
"""Base class for a port that incoming wires connect to."""
- pass
-
class OutPort(Port, ABC):
"""Base class for a port that outgoing wires come from."""
- pass
-
@dataclass(frozen=True)
class InPortV(InPort):
@@ -69,8 +66,6 @@ class InPortCF(InPort):
class OutPortCF(OutPort):
"""A control-flow output port."""
- pass
-
Edge = tuple[OutPort, InPort]
@@ -93,23 +88,19 @@ class Node(ABC):
@abstractmethod
def num_in_ports(self) -> int:
"""The number of input ports on this node."""
- pass
@property
@abstractmethod
def num_out_ports(self) -> int:
"""The number of output ports on this node."""
- pass
@abstractmethod
- def in_port(self, offset: Optional[PortOffset]) -> InPort:
+ def in_port(self, offset: PortOffset | None) -> InPort:
"""Returns the input port at the given offset."""
- pass
@abstractmethod
- def out_port(self, offset: Optional[PortOffset]) -> OutPort:
+ def out_port(self, offset: PortOffset | None) -> OutPort:
"""Returns the output port at the given offset."""
- pass
@abstractmethod
def update_op(self) -> None:
@@ -117,7 +108,6 @@ def update_op(self) -> None:
This should be called before serialisation.
"""
- pass
@property
def in_ports(self) -> Iterator[InPort]:
@@ -159,14 +149,16 @@ def add_out_port(self, ty: GuppyType) -> OutPortV:
self.out_port_types.append(ty)
return p
- def in_port(self, offset: Optional[PortOffset]) -> InPortV:
+ def in_port(self, offset: PortOffset | None) -> InPortV:
"""Returns the input port at the given offset."""
- assert offset is not None and offset < self.num_in_ports
+ assert offset is not None
+ assert offset < self.num_in_ports
return InPortV(self, offset, self.in_port_types[offset])
- def out_port(self, offset: Optional[PortOffset]) -> OutPortV:
+ def out_port(self, offset: PortOffset | None) -> OutPortV:
"""Returns the output port at the given offset."""
- assert offset is not None and offset < self.num_out_ports
+ assert offset is not None
+ assert offset < self.num_out_ports
return OutPortV(self, offset, self.out_port_types[offset])
@property
@@ -216,11 +208,11 @@ def add_out_port(self) -> OutPortCF:
self._num_out_ports += 1
return p
- def in_port(self, offset: Optional[PortOffset]) -> InPortCF:
+ def in_port(self, offset: PortOffset | None) -> InPortCF:
assert offset is None
return InPortCF(self)
- def out_port(self, offset: Optional[PortOffset]) -> OutPortCF:
+ def out_port(self, offset: PortOffset | None) -> OutPortCF:
"""Returns the output port at the given offset."""
assert offset is not None
assert offset < self.num_out_ports
@@ -246,7 +238,8 @@ def update_op(self) -> None:
Feeds type information from the signature of the contained dataflow graph to
the operation class to. This function must be called before serialisation.
"""
- assert self.input_child is not None and self.output_child is not None
+ assert self.input_child is not None
+ assert self.output_child is not None
# Input and output node may have extra order edges connected, so we filter
# `None`s here
ins = [ty.to_hugr() for ty in self.input_child.out_port_types]
@@ -272,15 +265,15 @@ class Hugr:
name: str
root: VNode
- _graph: networkx.MultiDiGraph # TODO: We probably don't need networkx.
+ _graph: nx.MultiDiGraph # TODO: We probably don't need networkx.
_children: dict[NodeIdx, list[Node]]
- _default_parent: Optional[Node]
+ _default_parent: Node | None
- def __init__(self, name: Optional[str] = None) -> None:
+ def __init__(self, name: str | None = None) -> None:
"""Creates a new Hugr."""
self.name = name or "Unnamed"
self._default_parent = None
- self._graph = networkx.MultiDiGraph()
+ self._graph = nx.MultiDiGraph()
self._children = {-1: []}
self.root = self.add_node(
op=ops.Module(), meta_data={"name": name}, parent=None
@@ -294,7 +287,7 @@ def parent(self, parent: Node) -> Iterator[None]:
yield
self._default_parent = old_default
- def _insert_node(self, node: Node, inputs: Optional[list[OutPortV]] = None) -> None:
+ def _insert_node(self, node: Node, inputs: list[OutPortV] | None = None) -> None:
"""Helper method to insert a node into the graph datastructure."""
self._graph.add_node(node.idx, data=node)
self._children[node.idx] = []
@@ -306,11 +299,11 @@ def _insert_node(self, node: Node, inputs: Optional[list[OutPortV]] = None) -> N
def add_node(
self,
op: ops.OpType,
- input_types: Optional[TypeList] = None,
- output_types: Optional[TypeList] = None,
- parent: Optional[Node] = None,
- inputs: Optional[list[OutPortV]] = None,
- meta_data: Optional[dict[str, Any]] = None,
+ input_types: TypeList | None = None,
+ output_types: TypeList | None = None,
+ parent: Node | None = None,
+ inputs: list[OutPortV] | None = None,
+ meta_data: dict[str, Any] | None = None,
) -> VNode:
"""Helper method to add a generic value node to the graph."""
input_types = input_types or []
@@ -330,11 +323,11 @@ def add_node(
def _add_dfg_node(
self,
op: ops.OpType,
- input_types: Optional[TypeList] = None,
- output_types: Optional[TypeList] = None,
- parent: Optional[Node] = None,
- inputs: Optional[list[OutPortV]] = None,
- meta_data: Optional[dict[str, Any]] = None,
+ input_types: TypeList | None = None,
+ output_types: TypeList | None = None,
+ parent: Node | None = None,
+ inputs: list[OutPortV] | None = None,
+ meta_data: dict[str, Any] | None = None,
) -> DFContainingVNode:
"""Helper method to add a generic dataflow containing value node to the
graph."""
@@ -358,7 +351,7 @@ def set_root_name(self, name: str) -> VNode:
return self.root
def add_constant(
- self, value: val.Value, ty: GuppyType, parent: Optional[Node] = None
+ self, value: val.Value, ty: GuppyType, parent: Node | None = None
) -> VNode:
"""Adds a constant node holding a given value to the graph."""
return self.add_node(
@@ -366,7 +359,7 @@ def add_constant(
)
def add_input(
- self, output_tys: Optional[TypeList] = None, parent: Optional[Node] = None
+ self, output_tys: TypeList | None = None, parent: Node | None = None
) -> VNode:
"""Adds an `Input` node to the graph."""
parent = parent or self._default_parent
@@ -376,7 +369,7 @@ def add_input(
return node
def add_input_with_ports(
- self, output_tys: Sequence[GuppyType], parent: Optional[Node] = None
+ self, output_tys: Sequence[GuppyType], parent: Node | None = None
) -> tuple[VNode, list[OutPortV]]:
"""Adds an `Input` node to the graph."""
node = self.add_input(None, parent)
@@ -385,9 +378,9 @@ def add_input_with_ports(
def add_output(
self,
- inputs: Optional[list[OutPortV]] = None,
- input_tys: Optional[TypeList] = None,
- parent: Optional[Node] = None,
+ inputs: list[OutPortV] | None = None,
+ input_tys: TypeList | None = None,
+ parent: Node | None = None,
) -> VNode:
"""Adds an `Output` node to the graph."""
node = self.add_node(ops.Output(), input_tys, [], parent, inputs)
@@ -395,7 +388,7 @@ def add_output(
parent.output_child = node
return node
- def add_block(self, parent: Optional[Node], num_successors: int = 0) -> BlockNode:
+ def add_block(self, parent: Node | None, num_successors: int = 0) -> BlockNode:
"""Adds a `Block` node to the graph."""
node = BlockNode(
idx=self._graph.number_of_nodes(), op=ops.DFB(), parent=parent, meta_data={}
@@ -433,27 +426,27 @@ def add_conditional(
self,
cond_input: OutPortV,
inputs: list[OutPortV],
- parent: Optional[Node] = None,
+ parent: Node | None = None,
) -> VNode:
"""Adds a `Conditional` node to the graph."""
- inputs = [cond_input] + inputs
+ inputs = [cond_input, *inputs]
return self.add_node(ops.Conditional(), None, None, parent, inputs)
def add_tail_loop(
- self, inputs: list[OutPortV], parent: Optional[Node] = None
+ self, inputs: list[OutPortV], parent: Node | None = None
) -> DFContainingVNode:
"""Adds a `TailLoop` node to the graph."""
return self._add_dfg_node(ops.TailLoop(), None, None, parent, inputs)
def add_make_tuple(
- self, inputs: list[OutPortV], parent: Optional[Node] = None
+ self, inputs: list[OutPortV], parent: Node | None = None
) -> VNode:
"""Adds a `MakeTuple` node to the graph."""
ty = TupleType([port.ty for port in inputs])
return self.add_node(ops.MakeTuple(), None, [ty], parent, inputs)
def add_unpack_tuple(
- self, input_tuple: OutPortV, parent: Optional[Node] = None
+ self, input_tuple: OutPortV, parent: Node | None = None
) -> VNode:
"""Adds an `UnpackTuple` node to the graph."""
assert isinstance(input_tuple.ty, TupleType)
@@ -466,7 +459,7 @@ def add_unpack_tuple(
)
def add_tag(
- self, variants: TypeList, tag: int, inp: OutPortV, parent: Optional[Node] = None
+ self, variants: TypeList, tag: int, inp: OutPortV, parent: Node | None = None
) -> VNode:
"""Adds a `Tag` node to the graph."""
types = [ty.to_hugr() for ty in variants]
@@ -476,7 +469,7 @@ def add_tag(
)
def add_load_constant(
- self, const_port: OutPortV, parent: Optional[Node] = None
+ self, const_port: OutPortV, parent: Node | None = None
) -> VNode:
"""Adds a `LoadConstant` node to the graph."""
return self.add_node(
@@ -488,7 +481,7 @@ def add_load_constant(
)
def add_call(
- self, def_port: OutPortV, args: list[OutPortV], parent: Optional[Node] = None
+ self, def_port: OutPortV, args: list[OutPortV], parent: Node | None = None
) -> VNode:
"""Adds a `Call` node to the graph."""
assert isinstance(def_port.ty, FunctionType)
@@ -497,11 +490,11 @@ def add_call(
None,
list(type_to_row(def_port.ty.returns)),
parent,
- args + [def_port],
+ [*args, def_port],
)
def add_indirect_call(
- self, fun_port: OutPortV, args: list[OutPortV], parent: Optional[Node] = None
+ self, fun_port: OutPortV, args: list[OutPortV], parent: Node | None = None
) -> VNode:
"""Adds an `IndirectCall` node to the graph."""
assert isinstance(fun_port.ty, FunctionType)
@@ -510,11 +503,11 @@ def add_indirect_call(
None,
list(type_to_row(fun_port.ty.returns)),
parent,
- [fun_port] + args,
+ [fun_port, *args],
)
def add_partial(
- self, def_port: OutPortV, args: list[OutPortV], parent: Optional[Node] = None
+ self, def_port: OutPortV, args: list[OutPortV], parent: Node | None = None
) -> VNode:
"""Adds a `Partial` evaluation node to the graph."""
assert isinstance(def_port.ty, FunctionType)
@@ -528,11 +521,11 @@ def add_partial(
else None,
)
return self.add_node(
- ops.DummyOp(name="partial"), None, [new_ty], parent, args + [def_port]
+ ops.DummyOp(name="partial"), None, [new_ty], parent, [*args, def_port]
)
def add_def(
- self, fun_ty: FunctionType, parent: Optional[Node], name: str
+ self, fun_ty: FunctionType, parent: Node | None, name: str
) -> DFContainingVNode:
"""Adds a `Def` node to the graph."""
return self._add_dfg_node(ops.FuncDefn(name=name), [], [fun_ty], parent, None)
@@ -544,13 +537,12 @@ def add_declare(self, fun_ty: FunctionType, parent: Node, name: str) -> VNode:
def add_edge(self, src_port: OutPort, tgt_port: InPort) -> None:
"""Adds an edge between two ports."""
if isinstance(src_port, OutPortV) or isinstance(tgt_port, InPortV):
- assert (
- isinstance(src_port, OutPortV)
- and isinstance(tgt_port, InPortV)
- and src_port.ty == tgt_port.ty
- )
+ assert isinstance(src_port, OutPortV)
+ assert isinstance(tgt_port, InPortV)
+ assert src_port.ty == tgt_port.ty
else:
- assert isinstance(src_port, OutPortCF) and isinstance(tgt_port, InPortCF)
+ assert isinstance(src_port, OutPortCF)
+ assert isinstance(tgt_port, InPortCF)
self._graph.add_edge(
src_port.node.idx, tgt_port.node.idx, key=(src_port.offset, tgt_port.offset)
)
@@ -565,7 +557,7 @@ def nodes(self) -> Iterator[Node]:
def get_node(self, idx: int) -> Node:
"""Returns the node corresponding to given index."""
- return self._graph.nodes[idx]["data"] # type: ignore
+ return self._graph.nodes[idx]["data"] # type: ignore[no-any-return]
def children(self, node: Node) -> list[Node]:
"""Returns list of a node's immediate children in the hierarchy."""
@@ -612,18 +604,18 @@ def out_edges(self, port: OutPort) -> Iterator[Edge]:
def order_successors(self, node: Node) -> Iterator[Node]:
"""Returns an iterator over all nodes that this node connects to via an
order edge."""
- for src, tgt, key in self._graph.out_edges(node.idx, keys=True):
+ for _src, tgt, key in self._graph.out_edges(node.idx, keys=True):
if key == ORDER_EDGE_KEY:
yield tgt
def order_predecessors(self, node: Node) -> Iterator[Node]:
"""Returns an iterator over all nodes that are connected to this node via an
order edge."""
- for src, tgt, key in self._graph.in_edges(node.idx, keys=True):
+ for src, _tgt, key in self._graph.in_edges(node.idx, keys=True):
if key == ORDER_EDGE_KEY:
yield src
- def _to_edge(self, src: int, tgt: int, key: Tuple[int, int]) -> Edge:
+ def _to_edge(self, src: int, tgt: int, key: tuple[int, int]) -> Edge:
src_node = self.get_node(src)
tgt_node = self.get_node(tgt)
return src_node.out_port(key[0]), tgt_node.in_port(key[1])
@@ -678,7 +670,7 @@ def insert_order_edges(self) -> "Hugr":
# Special case: Call ops for functions without any arguments are
# only connected to the top-level def/declare and also need an
# order edge
- if isinstance(n.op, ops.Call) and n.num_in_ports == 1:
+ if isinstance(n.op, ops.Call) and n.num_in_ports == 1: # noqa: SIM114
assert n.parent.input_child is not None
self.add_order_edge(n.parent.input_child, n)
# Special case: Load constant ops always need an order edge
diff --git a/guppy/hugr/ops.py b/guppy/hugr/ops.py
index 81a3fb7c..5af8a6d3 100644
--- a/guppy/hugr/ops.py
+++ b/guppy/hugr/ops.py
@@ -1,18 +1,20 @@
import inspect
import sys
from abc import ABC
-from typing import Annotated, Literal, Union, Optional, Any
-from pydantic import Field, BaseModel
+from typing import Annotated, Any, Literal
+
+from pydantic import BaseModel, Field
+
+import guppy.hugr.tys as tys
from .tys import (
- TypeRow,
- SimpleType,
- PolyFuncType,
- FunctionType,
ExtensionId,
ExtensionSet,
+ FunctionType,
+ PolyFuncType,
+ SimpleType,
+ TypeRow,
)
-import guppy.hugr.tys as tys
from .val import Value
NodeID = int
@@ -28,11 +30,9 @@ class BaseOp(ABC, BaseModel):
def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
"""Hook to insert type information from the input and output ports into the
op"""
- pass
def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None:
"""Hook to insert type information from a child dataflow graph"""
- pass
def display_name(self) -> str:
"""Name of the op for visualisation"""
@@ -152,7 +152,7 @@ class Exit(BasicBlock):
cfg_outputs: TypeRow
-BasicBlockOp = Annotated[Union[DFB, Exit], Field(discriminator="block")]
+BasicBlockOp = Annotated[DFB | Exit, Field(discriminator="block")]
# ---------------------------------------------
@@ -323,7 +323,7 @@ def insert_port_types(self, inputs: TypeRow, outputs: TypeRow) -> None:
)
-ControlFlowOp = Union[Conditional, TailLoop, CFG]
+ControlFlowOp = Conditional | TailLoop | CFG
# -----------------------------------------
@@ -338,7 +338,7 @@ class CustomOp(LeafOp):
lop: Literal["CustomOp"] = "CustomOp"
extension: ExtensionId
op_name: str
- signature: Optional[tys.FunctionType] = None
+ signature: tys.FunctionType | None = None
description: str = ""
args: list[tys.TypeArgUnion] = Field(default_factory=list)
@@ -353,77 +353,66 @@ class H(LeafOp):
"""A Hadamard gate."""
lop: Literal["H"] = "H"
- pass
class T(LeafOp):
"""A T gate."""
lop: Literal["T"] = "T"
- pass
class S(LeafOp):
"""An S gate."""
lop: Literal["S"] = "S"
- pass
class X(LeafOp):
"""A Pauli X gate."""
lop: Literal["X"] = "X"
- pass
class Y(LeafOp):
"""A Pauli Y gate."""
lop: Literal["Y"] = "Y"
- pass
class Z(LeafOp):
"""A Pauli Z gate."""
lop: Literal["Z"] = "Z"
- pass
class Tadj(LeafOp):
"""An adjoint T gate."""
lop: Literal["Tadj"] = "Tadj"
- pass
class Sadj(LeafOp):
"""An adjoint S gate."""
lop: Literal["Sadj"] = "Sadj"
- pass
class CX(LeafOp):
"""A controlled X gate."""
lop: Literal["CX"] = "CX"
- pass
class ZZMax(LeafOp):
"""A maximally entangling ZZ phase gate."""
lop: Literal["ZZMax"] = "ZZMax"
- pass
class Reset(LeafOp):
"""A qubit reset operation."""
lop: Literal["Reset"] = "Reset"
- pass
class Noop(LeafOp):
@@ -443,21 +432,18 @@ class Measure(LeafOp):
"""A qubit measurement operation."""
lop: Literal["Measure"] = "Measure"
- pass
class RzF64(LeafOp):
"""A rotation of a qubit about the Pauli Z axis by an input float angle."""
lop: Literal["RzF64"] = "RzF64"
- pass
class Xor(LeafOp):
"""A bitwise XOR operation."""
lop: Literal["Xor"] = "Xor"
- pass
class MakeTuple(LeafOp):
@@ -500,54 +486,54 @@ class Tag(LeafOp):
LeafOpUnion = Annotated[
- Union[
- CustomOp,
- H,
- S,
- T,
- X,
- Y,
- Z,
- Tadj,
- Sadj,
- CX,
- ZZMax,
- Reset,
- Noop,
- Measure,
- RzF64,
- Xor,
- MakeTuple,
- UnpackTuple,
- MakeNewType,
- Tag,
- ],
+ (
+ CustomOp
+ | H
+ | S
+ | T
+ | X
+ | Y
+ | Z
+ | Tadj
+ | Sadj
+ | CX
+ | ZZMax
+ | Reset
+ | Noop
+ | Measure
+ | RzF64
+ | Xor
+ | MakeTuple
+ | UnpackTuple
+ | MakeNewType
+ | Tag
+ ),
Field(discriminator="lop"),
]
OpType = Annotated[
- Union[
- Module,
- BasicBlock,
- Case,
- Module,
- FuncDefn,
- FuncDecl,
- Const,
- DummyOp,
- BasicBlockOp,
- Conditional,
- TailLoop,
- CFG,
- Input,
- Output,
- Call,
- CallIndirect,
- LoadConstant,
- LeafOpUnion,
- DFG,
- ],
+ (
+ Module
+ | BasicBlock
+ | Case
+ | Module
+ | FuncDefn
+ | FuncDecl
+ | Const
+ | DummyOp
+ | BasicBlockOp
+ | Conditional
+ | TailLoop
+ | CFG
+ | Input
+ | Output
+ | Call
+ | CallIndirect
+ | LoadConstant
+ | LeafOpUnion
+ | DFG
+ ),
Field(discriminator="op"),
]
@@ -562,10 +548,10 @@ class OpDef(BaseOp, allow_population_by_field_name=True):
name: str # Unique identifier of the operation.
description: str # Human readable description of the operation.
- inputs: list[tuple[Optional[str], SimpleType]]
- outputs: list[tuple[Optional[str], SimpleType]]
+ inputs: list[tuple[str | None, SimpleType]]
+ outputs: list[tuple[str | None, SimpleType]]
misc: dict[str, Any] # Miscellaneous data associated with the operation.
- def_: Optional[str] = Field(
+ def_: str | None = Field(
..., alias="def"
) # (YAML?)-encoded definition of the operation.
extension_reqs: ExtensionSet # Resources required to execute this operation.
diff --git a/guppy/hugr/raw.py b/guppy/hugr/raw.py
index a92e4469..a3249a3c 100644
--- a/guppy/hugr/raw.py
+++ b/guppy/hugr/raw.py
@@ -1,12 +1,11 @@
-from typing import Literal, Optional
+from typing import Literal
import ormsgpack
-
from pydantic import BaseModel
-from guppy.hugr.ops import NodeID, OpType
+from guppy.hugr.ops import NodeID, OpType
-Port = tuple[NodeID, Optional[int]] # (node, offset)
+Port = tuple[NodeID, int | None] # (node, offset)
Edge = tuple[Port, Port]
diff --git a/guppy/hugr/tys.py b/guppy/hugr/tys.py
index ebc032d6..fa01db41 100644
--- a/guppy/hugr/tys.py
+++ b/guppy/hugr/tys.py
@@ -2,9 +2,9 @@
import sys
from abc import ABC
from enum import Enum
-from typing import Literal, Union, Annotated, Optional
-from pydantic import Field, BaseModel
+from typing import Annotated, Literal
+from pydantic import BaseModel, Field
ExtensionId = str
ExtensionSet = list[ # TODO: Set not supported by MessagePack. Is list correct here?
@@ -24,7 +24,7 @@ class TypeParam(BaseModel):
class BoundedNatParam(BaseModel):
tp: Literal["BoundedNat"] = "BoundedNat"
- bound: Optional[int]
+ bound: int | None
class OpaqueParam(BaseModel):
@@ -43,7 +43,7 @@ class TupleParam(BaseModel):
TypeParamUnion = Annotated[
- Union[TypeParam, BoundedNatParam, OpaqueParam, ListParam, TupleParam],
+ TypeParam | BoundedNatParam | OpaqueParam | ListParam | TupleParam,
Field(discriminator="tp"),
]
@@ -84,7 +84,7 @@ class ExtensionsArg(BaseModel):
TypeArgUnion = Annotated[
- Union[TypeArg, BoundedNatArg, OpaqueArg, SequenceArg, ExtensionsArg],
+ TypeArg | BoundedNatArg | OpaqueArg | SequenceArg | ExtensionsArg,
Field(discriminator="tya"),
]
@@ -231,9 +231,17 @@ class Qubit(BaseModel):
SimpleType = Annotated[
- Union[
- Qubit, Variable, Int, F64, String, PolyFuncType, List, Array, Tuple, Sum, Opaque
- ],
+ Qubit
+ | Variable
+ | Int
+ | F64
+ | String
+ | PolyFuncType
+ | List
+ | Array
+ | Tuple
+ | Sum
+ | Opaque,
Field(discriminator="t"),
]
diff --git a/guppy/hugr/val.py b/guppy/hugr/val.py
index 493ec014..b8b5eabb 100644
--- a/guppy/hugr/val.py
+++ b/guppy/hugr/val.py
@@ -1,10 +1,9 @@
import inspect
import sys
-from typing import Literal, Any, Annotated, Union
+from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field
-
CustomConst = Any # TODO
@@ -40,9 +39,7 @@ class Sum(BaseModel):
value: "Value"
-Value = Annotated[
- Union[ExtensionVal, FunctionVal, Tuple, Sum], Field(discriminator="v")
-]
+Value = Annotated[ExtensionVal | FunctionVal | Tuple | Sum, Field(discriminator="v")]
# Now that all classes are defined, we need to update the ForwardRefs in all type
diff --git a/guppy/hugr/visualise.py b/guppy/hugr/visualise.py
index eba0bca6..450b5293 100644
--- a/guppy/hugr/visualise.py
+++ b/guppy/hugr/visualise.py
@@ -1,16 +1,17 @@
"""Visualise HUGR using graphviz."""
import ast
+from collections.abc import Iterable
+from typing import TYPE_CHECKING
-import graphviz as gv # type: ignore
-from typing import Iterable, TYPE_CHECKING
+import graphviz as gv # type: ignore[import]
from guppy.cfg.analysis import (
- LivenessDomain,
DefAssignmentDomain,
+ LivenessDomain,
MaybeAssignmentDomain,
)
from guppy.cfg.bb import BB
-from guppy.hugr.hugr import InPort, OutPort, Node, Hugr, OutPortV
+from guppy.hugr.hugr import Hugr, InPort, Node, OutPort, OutPortV
if TYPE_CHECKING:
from guppy.cfg.cfg import CFG
@@ -55,18 +56,23 @@
_FONTFACE = "monospace"
_HTML_LABEL_TEMPLATE = """
-
-{inputs_row}
-
-
-
-
- {node_label}{node_data} |
-
-
- |
-
-{outputs_row}
+
+ {inputs_row}
+
+
+
+
+
+
+ {node_label}{node_data}
+
+ |
+
+
+ |
+
+ {outputs_row}
"""
diff --git a/guppy/module.py b/guppy/module.py
index e07a15ac..3061bafe 100644
--- a/guppy/module.py
+++ b/guppy/module.py
@@ -1,15 +1,15 @@
import ast
import inspect
import textwrap
+from collections.abc import Callable
from types import ModuleType
+from typing import Any, Union
-from typing import Callable, Any, Optional, Union
-
-from guppy.ast_util import annotate_location, AstNode
+from guppy.ast_util import AstNode, annotate_location
from guppy.checker.core import Globals, qualified_name
from guppy.checker.func_checker import DefinedFunction, check_global_func_def
from guppy.compiler.core import CompiledGlobals
-from guppy.compiler.func_compiler import compile_global_func_def, CompiledFunctionDef
+from guppy.compiler.func_compiler import CompiledFunctionDef, compile_global_func_def
from guppy.custom import CustomFunction
from guppy.declared import DeclaredFunction
from guppy.error import GuppyError, pretty_errors
@@ -44,7 +44,7 @@ class GuppyModule:
# When `_instance_buffer` is not `None`, then all registered functions will be
# buffered in this list. They only get properly registered, once
# `_register_buffered_instance_funcs` is called. This way, we can associate
- _instance_func_buffer: Optional[dict[str, Union[PyFunc, CustomFunction]]]
+ _instance_func_buffer: dict[str, PyFunc | CustomFunction] | None
def __init__(self, name: str, import_builtins: bool = True):
self.name = name
@@ -88,7 +88,7 @@ def load(self, m: Union[ModuleType, "GuppyModule"]) -> None:
self.load(val)
def register_func_def(
- self, f: PyFunc, instance: Optional[type[GuppyType]] = None
+ self, f: PyFunc, instance: type[GuppyType] | None = None
) -> None:
"""Registers a Python function definition as belonging to this Guppy module."""
self._check_not_yet_compiled()
@@ -110,7 +110,7 @@ def register_func_decl(self, f: PyFunc) -> None:
self._func_decls[func_ast.name] = func_ast
def register_custom_func(
- self, func: CustomFunction, instance: Optional[type[GuppyType]] = None
+ self, func: CustomFunction, instance: type[GuppyType] | None = None
) -> None:
"""Registers a custom function as belonging to this Guppy module."""
self._check_not_yet_compiled()
@@ -130,7 +130,7 @@ def _register_buffered_instance_funcs(self, instance: type[GuppyType]) -> None:
assert self._instance_func_buffer is not None
buffer = self._instance_func_buffer
self._instance_func_buffer = None
- for name, f in buffer.items():
+ for f in buffer.values():
if isinstance(f, CustomFunction):
self.register_custom_func(f, instance)
else:
@@ -141,7 +141,7 @@ def compiled(self) -> bool:
return self._compiled
@pretty_errors
- def compile(self) -> Optional[Hugr]:
+ def compile(self) -> Hugr | None:
"""Compiles the module and returns the final Hugr."""
if self.compiled:
raise GuppyError("Module has already been compiled")
@@ -200,7 +200,7 @@ def _check_not_yet_compiled(self) -> None:
if self._compiled:
raise GuppyError(f"The module `{self.name}` has already been compiled")
- def _check_name_available(self, name: str, node: Optional[AstNode]) -> None:
+ def _check_name_available(self, name: str, node: AstNode | None) -> None:
if name in self._func_defs or name in self._custom_funcs:
raise GuppyError(
f"Module `{self.name}` already contains a function named `{name}`",
diff --git a/guppy/nodes.py b/guppy/nodes.py
index dfd02349..22730db5 100644
--- a/guppy/nodes.py
+++ b/guppy/nodes.py
@@ -1,14 +1,15 @@
"""Custom AST nodes used by Guppy"""
import ast
-from typing import TYPE_CHECKING, Any, Mapping
+from collections.abc import Mapping
+from typing import TYPE_CHECKING, Any
from guppy.gtypes import FunctionType
if TYPE_CHECKING:
from guppy.cfg.cfg import CFG
- from guppy.checker.core import Variable, CallableVariable
from guppy.checker.cfg_checker import CheckedCFG
+ from guppy.checker.core import CallableVariable, Variable
class LocalName(ast.expr):
diff --git a/guppy/prelude/_internal.py b/guppy/prelude/_internal.py
index d8a64c32..6262b1e8 100644
--- a/guppy/prelude/_internal.py
+++ b/guppy/prelude/_internal.py
@@ -1,24 +1,23 @@
import ast
-from typing import Optional, Literal
+from typing import Literal
from pydantic import BaseModel
-from guppy.ast_util import with_type, AstNode, with_loc, get_type
-from guppy.checker.core import Context, CallableVariable
+from guppy.ast_util import AstNode, get_type, with_loc, with_type
+from guppy.checker.core import CallableVariable, Context
from guppy.checker.expr_checker import ExprSynthesizer, check_num_args
from guppy.custom import (
CustomCallChecker,
- DefaultCallChecker,
- CustomFunction,
CustomCallCompiler,
+ CustomFunction,
+ DefaultCallChecker,
)
-from guppy.error import GuppyTypeError, GuppyError
-from guppy.gtypes import GuppyType, FunctionType, BoolType
+from guppy.error import GuppyError, GuppyTypeError
+from guppy.gtypes import BoolType, FunctionType, GuppyType
from guppy.hugr import ops, tys, val
from guppy.hugr.hugr import OutPortV
from guppy.nodes import GlobalCall
-
INT_WIDTH = 6 # 2^6 = 64 bit
@@ -68,7 +67,7 @@ def float_value(f: float) -> val.Value:
return val.ExtensionVal(c=(ConstF64(value=f),))
-def logic_op(op_name: str, args: Optional[list[tys.TypeArgUnion]] = None) -> ops.OpType:
+def logic_op(op_name: str, args: list[tys.TypeArgUnion] | None = None) -> ops.OpType:
"""Utility method to create Hugr logic ops."""
return ops.CustomOp(extension="logic", op_name=op_name, args=args or [])
@@ -110,7 +109,7 @@ class ReversingChecker(CustomCallChecker):
base_checker: CustomCallChecker
- def __init__(self, base_checker: Optional[CustomCallChecker] = None):
+ def __init__(self, base_checker: CustomCallChecker | None = None):
self.base_checker = base_checker or DefaultCallChecker()
def _setup(self, ctx: Context, node: AstNode, func: CustomFunction) -> None:
@@ -209,7 +208,7 @@ class IntTruedivCompiler(CustomCallCompiler):
"""Compiler for the `int.__truediv__` method."""
def compile(self, args: list[OutPortV]) -> list[OutPortV]:
- from .builtins import Int, Float
+ from .builtins import Float, Int
# Compile `truediv` using float arithmetic
[left, right] = args
diff --git a/guppy/prelude/builtins.py b/guppy/prelude/builtins.py
index 25dfc741..57c22920 100644
--- a/guppy/prelude/builtins.py
+++ b/guppy/prelude/builtins.py
@@ -2,30 +2,29 @@
# mypy: disable-error-code="empty-body, misc, override, no-untyped-def"
-from guppy.custom import NoopCompiler, DefaultCallChecker
+from guppy.custom import DefaultCallChecker, NoopCompiler
from guppy.decorator import guppy
from guppy.gtypes import BoolType
-from guppy.hugr import tys, ops
+from guppy.hugr import ops, tys
from guppy.module import GuppyModule
from guppy.prelude._internal import (
- logic_op,
- int_op,
- hugr_int_type,
- hugr_float_type,
- float_op,
+ CallableChecker,
CoercingChecker,
- ReversingChecker,
- IntTruedivCompiler,
+ DunderChecker,
FloatBoolCompiler,
FloatDivmodCompiler,
FloatFloordivCompiler,
FloatModCompiler,
- DunderChecker,
- CallableChecker,
+ IntTruedivCompiler,
+ ReversingChecker,
UnsupportedChecker,
+ float_op,
+ hugr_float_type,
+ hugr_int_type,
+ int_op,
+ logic_op,
)
-
builtins = GuppyModule("builtins", import_builtins=False)
diff --git a/guppy/prelude/quantum.py b/guppy/prelude/quantum.py
index 3b40af88..dcfc794b 100644
--- a/guppy/prelude/quantum.py
+++ b/guppy/prelude/quantum.py
@@ -3,11 +3,10 @@
# mypy: disable-error-code=empty-body
from guppy.decorator import guppy
-from guppy.hugr import tys, ops
+from guppy.hugr import ops, tys
from guppy.hugr.tys import TypeBound
from guppy.module import GuppyModule
-
quantum = GuppyModule("quantum")
diff --git a/pyproject.toml b/pyproject.toml
index df4fc7d7..6d7333c8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -16,12 +16,13 @@ dependencies = [
[project.optional-dependencies]
dev = [
- "pip-tools", # provides pip-compile
"build",
- "pytest",
+ "maturin>=1.1,<2.0",
"mypy==1.3.0",
+ "pip-tools", # provides pip-compile
+ "pre-commit",
+ "pytest",
"ruff",
- "maturin>=1.1,<2.0",
]
[tool.setuptools]
diff --git a/requirements.txt b/requirements.txt
index d64d15c3..846d2fa0 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,19 +1,33 @@
#
-# This file is autogenerated by pip-compile with Python 3.10
+# This file is autogenerated by pip-compile with Python 3.12
# by the following command:
#
-# pip-compile --extra=dev --no-annotate --output-file=requirements.txt pyproject.toml
+# pip-compile --extra=dev --no-annotate --output-file=requirements.txt --strip-extras pyproject.toml
#
-build==0.10.0
-exceptiongroup==1.1.1
+build==1.0.3
+cfgv==3.4.0
+click==8.1.7
+distlib==0.3.7
+filelock==3.13.1
graphviz==0.20.1
+identify==2.5.32
iniconfig==2.0.0
+maturin==1.3.2
+mypy==1.3.0
+mypy-extensions==1.0.0
networkx==3.0
-ormsgpack==1.2.5
-packaging==23.1
-pluggy==1.0.0
+nodeenv==1.8.0
+ormsgpack==1.4.1
+packaging==23.2
+pip-tools==7.3.0
+platformdirs==4.1.0
+pluggy==1.3.0
+pre-commit==3.5.0
pydantic==1.10.8
pyproject-hooks==1.0.0
-pytest==7.3.1
-tomli==2.0.1
-typing-extensions==4.6.2
+pytest==7.4.3
+pyyaml==6.0.1
+ruff==0.1.7
+typing-extensions==4.8.0
+virtualenv==20.25.0
+wheel==0.42.0
diff --git a/ruff.toml b/ruff.toml
new file mode 100644
index 00000000..8c9e618a
--- /dev/null
+++ b/ruff.toml
@@ -0,0 +1,83 @@
+# See https://docs.astral.sh/ruff/rules/
+target-version = "py310"
+
+line-length = 88
+
+exclude = ["tests/error"]
+
+select = [
+ "F", # pyflakes
+ "E", # pycodestyle Errors
+ "W", # pycodestyle Warnings
+
+ # "A", # flake8-builtins
+ # "ANN", # flake8-annotations
+ # "ARG", # flake8-unused-arguments
+ "B", # flake8-Bugbear
+ "BLE", # flake8-blind-except
+ "C4", # flake8-comprehensions
+ # "C90", # mccabe
+ # "COM", # flake8-commas
+ # "CPY", # flake8-copyright
+ # "D", # pydocstyle
+ "EM", # flake8-errmsg
+ # "ERA", # eradicate
+ "EXE", # flake8-executable
+ "FA", # flake8-future-annotations
+ # "FBT", # flake8-boolean-trap
+ # "FIX", # flake8-fixme
+ "FLY", # flynt
+ # "FURB", # refurb
+ "G", # flake8-logging-format
+ "I", # isort
+ "ICN", # flake8-import-conventions
+ "INP", # flake8-no-pep420
+ "INT", # flake8-gettext
+ # "ISC", # flake8-implicit-str-concat
+ # "LOG", # flake8-logging
+ # "N", # pep8-Naming
+ "NPY", # NumPy-specific
+ "PERF", # Perflint
+ "PGH", # pygrep-hooks
+ "PIE", # flake8-pie
+ # "PL", # pylint
+ "PT", # flake8-pytest-style
+ "PTH", # flake8-use-pathlib
+ "PYI", # flake8-pyi
+ "Q", # flake8-quotes
+ # "RET", # flake8-return
+ "RSE", # flake8-raise
+ "RUF", # Ruff-specific
+ "S", # flake8-bandit (Security)
+ "SIM", # flake8-simplify
+ # "SLF", # flake8-self
+ "SLOT", # flake8-slots
+ "T10", # flake8-debugger
+ "T20", # flake8-print
+ "TCH", # flake8-type-checking
+ # "TD", # flake8-todos
+ "TID", # flake8-tidy-imports
+ "TRY", # tryceratops
+ "UP", # pyupgrade
+ "YTT", # flake8-2020
+]
+
+ignore = [
+ "COM812", "ISC001", # conflicting with the formatter
+ "EM101", "EM102", # Exception must not use a string (an f-string) literal, assign to variable first
+ "S101", # Use of `assert` detected
+ "TRY003", # Avoid specifying long messages outside the exception class
+ "B905", # `zip()` without an explicit `strict=` parameter
+]
+
+[per-file-ignores]
+"guppy/ast_util.py" = ["B009", "B010"]
+"guppy/decorator.py" = ["B010"]
+"tests/integration/*" = ["F841"]
+"tests/{hugr,integration}/*" = ["B", "FBT", "SIM", "I"]
+
+# [pydocstyle]
+# convention = "google"
+
+# [flake8-copyright]
+# author = "Quantinuum"
diff --git a/tests/conftest.py b/tests/conftest.py
index 8d1eedc3..f5d7a394 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,12 +1,18 @@
-import pytest
-from pathlib import Path
import argparse
+from pathlib import Path
+
def pytest_addoption(parser):
def dir_path(s):
path = Path(s)
if not path.exists() or path.is_dir():
return path
- raise argparse.ArgumentTypeError(f"export-test-cases dir:{path} exists and is not a directory")
+ msg = f"export-test-cases dir:{path} exists and is not a directory"
+ raise argparse.ArgumentTypeError(msg)
- parser.addoption("--export-test-cases", action="store", type=dir_path, help="A directory to which to export test cases")
+ parser.addoption(
+ "--export-test-cases",
+ action="store",
+ type=dir_path,
+ help="A directory to which to export test cases",
+ )
diff --git a/tests/error/errors_on_usage/if_expr_cond_type_change.py b/tests/error/errors_on_usage/if_expr_cond_type_change.py
index c9ec27ab..d108fb78 100644
--- a/tests/error/errors_on_usage/if_expr_cond_type_change.py
+++ b/tests/error/errors_on_usage/if_expr_cond_type_change.py
@@ -6,4 +6,4 @@ def foo(x: bool) -> int:
y = 4
0 if (y := x) else (y := 6)
z = y
- return 42
\ No newline at end of file
+ return 42
diff --git a/tests/error/generate_test.sh b/tests/error/generate_test.sh
index 0cc01a2e..5150167b 100755
--- a/tests/error/generate_test.sh
+++ b/tests/error/generate_test.sh
@@ -1,3 +1,3 @@
#!/bin/bash
-python -m tests.error.$1.$2 2> tests/error/$1/$2.err
\ No newline at end of file
+python -m tests.error.$1.$2 2> tests/error/$1/$2.err
diff --git a/tests/error/nested_errors/different_types_if.py b/tests/error/nested_errors/different_types_if.py
index 501de616..2eebce38 100644
--- a/tests/error/nested_errors/different_types_if.py
+++ b/tests/error/nested_errors/different_types_if.py
@@ -11,4 +11,3 @@ def bar() -> bool:
return False
return bar()
-
diff --git a/tests/error/nested_errors/not_defined_if.py b/tests/error/nested_errors/not_defined_if.py
index 67715f3e..8798212d 100644
--- a/tests/error/nested_errors/not_defined_if.py
+++ b/tests/error/nested_errors/not_defined_if.py
@@ -7,4 +7,3 @@ def foo(b: bool) -> int:
def bar() -> int:
return 0
return bar()
-
diff --git a/tests/error/nested_errors/reassign_capture_1.py b/tests/error/nested_errors/reassign_capture_1.py
index 1a770cd0..bf8da199 100644
--- a/tests/error/nested_errors/reassign_capture_1.py
+++ b/tests/error/nested_errors/reassign_capture_1.py
@@ -10,4 +10,3 @@ def bar() -> None:
bar()
return y
-
diff --git a/tests/error/nested_errors/reassign_capture_2.py b/tests/error/nested_errors/reassign_capture_2.py
index 269699e9..7317bdf5 100644
--- a/tests/error/nested_errors/reassign_capture_2.py
+++ b/tests/error/nested_errors/reassign_capture_2.py
@@ -12,4 +12,3 @@ def bar() -> None:
bar()
return y
-
diff --git a/tests/error/nested_errors/reassign_capture_3.py b/tests/error/nested_errors/reassign_capture_3.py
index 719fd41a..4648f379 100644
--- a/tests/error/nested_errors/reassign_capture_3.py
+++ b/tests/error/nested_errors/reassign_capture_3.py
@@ -14,4 +14,3 @@ def baz() -> None:
bar()
return y
-
diff --git a/tests/error/nested_errors/var_not_defined.py b/tests/error/nested_errors/var_not_defined.py
index a82af6a8..d18b305e 100644
--- a/tests/error/nested_errors/var_not_defined.py
+++ b/tests/error/nested_errors/var_not_defined.py
@@ -7,4 +7,3 @@ def bar() -> int:
return x
return bar()
-
diff --git a/tests/error/test_errors_on_usage.py b/tests/error/test_errors_on_usage.py
index 4f53575d..4999cb81 100644
--- a/tests/error/test_errors_on_usage.py
+++ b/tests/error/test_errors_on_usage.py
@@ -4,7 +4,12 @@
from tests.error.util import run_error_test
path = pathlib.Path(__file__).parent.resolve() / "errors_on_usage"
-files = [x for x in path.iterdir() if x.is_file() if x.suffix == ".py" and x.name != "__init__.py"]
+files = [
+ x
+ for x in path.iterdir()
+ if x.is_file()
+ if x.suffix == ".py" and x.name != "__init__.py"
+]
# TODO: Skip functional tests for now
files = [f for f in files if "functional" not in f.name]
diff --git a/tests/error/test_linear_errors.py b/tests/error/test_linear_errors.py
index d334724d..71615330 100644
--- a/tests/error/test_linear_errors.py
+++ b/tests/error/test_linear_errors.py
@@ -4,7 +4,12 @@
from tests.error.util import run_error_test
path = pathlib.Path(__file__).parent.resolve() / "linear_errors"
-files = [x for x in path.iterdir() if x.is_file() if x.suffix == ".py" and x.name != "__init__.py"]
+files = [
+ x
+ for x in path.iterdir()
+ if x.is_file()
+ if x.suffix == ".py" and x.name != "__init__.py"
+]
# TODO: Skip functional tests for now
files = [f for f in files if "functional" not in f.name]
diff --git a/tests/error/test_misc_errors.py b/tests/error/test_misc_errors.py
index d3115ff4..03803e65 100644
--- a/tests/error/test_misc_errors.py
+++ b/tests/error/test_misc_errors.py
@@ -4,7 +4,12 @@
from tests.error.util import run_error_test
path = pathlib.Path(__file__).parent.resolve() / "misc_errors"
-files = [x for x in path.iterdir() if x.is_file() if x.suffix == ".py" and x.name != "__init__.py"]
+files = [
+ x
+ for x in path.iterdir()
+ if x.is_file()
+ if x.suffix == ".py" and x.name != "__init__.py"
+]
# TODO: Skip functional tests for now
files = [f for f in files if "functional" not in f.name]
diff --git a/tests/error/test_nested_errors.py b/tests/error/test_nested_errors.py
index ba059c26..523a6574 100644
--- a/tests/error/test_nested_errors.py
+++ b/tests/error/test_nested_errors.py
@@ -4,7 +4,12 @@
from tests.error.util import run_error_test
path = pathlib.Path(__file__).parent.resolve() / "nested_errors"
-files = [x for x in path.iterdir() if x.is_file() if x.suffix == ".py" and x.name != "__init__.py"]
+files = [
+ x
+ for x in path.iterdir()
+ if x.is_file()
+ if x.suffix == ".py" and x.name != "__init__.py"
+]
# Turn paths into strings, otherwise pytest doesn't display the names
files = [str(f) for f in files]
diff --git a/tests/error/test_type_errors.py b/tests/error/test_type_errors.py
index 619438c9..864f15c7 100644
--- a/tests/error/test_type_errors.py
+++ b/tests/error/test_type_errors.py
@@ -4,7 +4,12 @@
from tests.error.util import run_error_test
path = pathlib.Path(__file__).parent.resolve() / "type_errors"
-files = [x for x in path.iterdir() if x.is_file() if x.suffix == ".py" and x.name != "__init__.py"]
+files = [
+ x
+ for x in path.iterdir()
+ if x.is_file()
+ if x.suffix == ".py" and x.name != "__init__.py"
+]
# TODO: Skip functional tests for now
files = [f for f in files if "functional" not in f.name]
diff --git a/tests/error/type_errors/fun_ty_mismatch_1.py b/tests/error/type_errors/fun_ty_mismatch_1.py
index 66445957..159825b6 100644
--- a/tests/error/type_errors/fun_ty_mismatch_1.py
+++ b/tests/error/type_errors/fun_ty_mismatch_1.py
@@ -1,4 +1,4 @@
-from typing import Callable
+from collections.abc import Callable
from tests.error.util import guppy
diff --git a/tests/error/type_errors/fun_ty_mismatch_2.py b/tests/error/type_errors/fun_ty_mismatch_2.py
index ab421bed..8d98dc01 100644
--- a/tests/error/type_errors/fun_ty_mismatch_2.py
+++ b/tests/error/type_errors/fun_ty_mismatch_2.py
@@ -1,4 +1,4 @@
-from typing import Callable
+from collections.abc import Callable
from tests.error.util import guppy
diff --git a/tests/error/type_errors/fun_ty_mismatch_3.py b/tests/error/type_errors/fun_ty_mismatch_3.py
index 2e35f0ce..bbb4f428 100644
--- a/tests/error/type_errors/fun_ty_mismatch_3.py
+++ b/tests/error/type_errors/fun_ty_mismatch_3.py
@@ -1,4 +1,4 @@
-from typing import Callable
+from collections.abc import Callable
from tests.error.util import guppy
diff --git a/tests/error/util.py b/tests/error/util.py
index a3e39191..c7d207d8 100644
--- a/tests/error/util.py
+++ b/tests/error/util.py
@@ -2,7 +2,8 @@
import pathlib
import pytest
-from typing import Callable, Optional, Any
+from typing import Any
+from collections.abc import Callable
from guppy.hugr import tys
from guppy.hugr.tys import TypeBound
@@ -12,8 +13,8 @@
import guppy.decorator as decorator
-def guppy(f: Callable[..., Any]) -> Optional[Hugr]:
- """ Decorator to compile functions outside of modules for testing. """
+def guppy(f: Callable[..., Any]) -> Hugr | None:
+ """Decorator to compile functions outside of modules for testing."""
module = GuppyModule("module")
module.register_func_def(f)
return module.compile()
@@ -29,7 +30,7 @@ def run_error_test(file, capsys):
err = capsys.readouterr().err
- with open(file.with_suffix(".err")) as f:
+ with pathlib.Path(file.with_suffix(".err")).open() as f:
exp_err = f.read()
exp_err = exp_err.replace("$FILE", str(file))
@@ -39,7 +40,8 @@ def run_error_test(file, capsys):
util = GuppyModule("test")
-@decorator.guppy.type(util, tys.Opaque(extension="", id="", args=[], bound=TypeBound.Copyable))
+@decorator.guppy.type(
+ util, tys.Opaque(extension="", id="", args=[], bound=TypeBound.Copyable)
+)
class NonBool:
pass
-
diff --git a/tests/hugr/test_dummy_nodes.py b/tests/hugr/test_dummy_nodes.py
index e1efaab9..e33537cd 100644
--- a/tests/hugr/test_dummy_nodes.py
+++ b/tests/hugr/test_dummy_nodes.py
@@ -20,7 +20,9 @@ def test_single_dummy():
def test_unique_names():
g = Hugr()
- defn = g.add_def(FunctionType([BoolType()], TupleType([BoolType(), BoolType()])), g.root, "test")
+ defn = g.add_def(
+ FunctionType([BoolType()], TupleType([BoolType(), BoolType()])), g.root, "test"
+ )
dfg = g.add_dfg(defn)
inp = g.add_input([BoolType()], dfg).out_port(0)
dummy1 = g.add_node(
@@ -34,4 +36,3 @@ def test_unique_names():
g.remove_dummy_nodes()
[decl1, decl2] = [n for n in g.nodes() if isinstance(n.op, ops.FuncDecl)]
assert {decl1.op.name, decl2.op.name} == {"dummy", "dummy$1"}
-
diff --git a/tests/hugr/test_ports.py b/tests/hugr/test_ports.py
index 77e07ef4..299b310e 100644
--- a/tests/hugr/test_ports.py
+++ b/tests/hugr/test_ports.py
@@ -12,4 +12,3 @@ def test_undefined_port():
p.node
with pytest.raises(InternalGuppyError, match="Tried to access undefined Port"):
p.offset
-
diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py
index af4f558c..9841036a 100644
--- a/tests/integration/conftest.py
+++ b/tests/integration/conftest.py
@@ -2,21 +2,23 @@
from . import util
-@pytest.fixture
+
+@pytest.fixture()
def export_test_cases_dir(request):
- r = request.config.getoption('--export-test-cases')
+ r = request.config.getoption("--export-test-cases")
if r and not r.exists():
r.mkdir(parents=True)
return r
-@pytest.fixture
+@pytest.fixture()
def validate(request, export_test_cases_dir):
- def validate_impl(hugr,name=None):
+ def validate_impl(hugr, name=None):
bs = hugr.serialize()
util.validate_bytes(bs)
if export_test_cases_dir:
file_name = f"{request.node.name}{f'_{name}' if name else ''}.msgpack"
export_file = export_test_cases_dir / file_name
export_file.write_bytes(bs)
+
return validate_impl
diff --git a/tests/integration/test_basic.py b/tests/integration/test_basic.py
index cde2801a..487f0197 100644
--- a/tests/integration/test_basic.py
+++ b/tests/integration/test_basic.py
@@ -77,5 +77,7 @@ def test_func_decl_name():
def func_name() -> None:
...
- [def_op] = [n.op for n in module.compile().nodes() if isinstance(n.op, ops.FuncDecl)]
+ [def_op] = [
+ n.op for n in module.compile().nodes() if isinstance(n.op, ops.FuncDecl)
+ ]
assert def_op.name == "func_name"
diff --git a/tests/integration/test_call.py b/tests/integration/test_call.py
index bad4d824..2cb7ec80 100644
--- a/tests/integration/test_call.py
+++ b/tests/integration/test_call.py
@@ -50,6 +50,3 @@ def bar(x: int) -> int:
return foo(x)
validate(module.compile())
-
-
-
diff --git a/tests/integration/test_functional.py b/tests/integration/test_functional.py
index 8c462d6b..ae61a30d 100644
--- a/tests/integration/test_functional.py
+++ b/tests/integration/test_functional.py
@@ -8,7 +8,7 @@
def test_if_no_else(validate):
@guppy
def foo(x: bool, y: int) -> int:
- _@functional
+ _ @ functional
if x:
y += 1
return y
@@ -20,7 +20,7 @@ def foo(x: bool, y: int) -> int:
def test_if_else(validate):
@guppy
def foo(x: bool, y: int) -> int:
- _@functional
+ _ @ functional
if x:
y += 1
else:
@@ -34,7 +34,7 @@ def foo(x: bool, y: int) -> int:
def test_if_elif(validate):
@guppy
def foo(x: bool, y: int) -> int:
- _@functional
+ _ @ functional
if x:
y += 1
elif y > 4:
@@ -48,7 +48,7 @@ def foo(x: bool, y: int) -> int:
def test_if_elif_else(validate):
@guppy
def foo(x: bool, y: int) -> int:
- _@functional
+ _ @ functional
if x:
y += 1
elif y > 4:
@@ -87,7 +87,7 @@ def test_nested_loop(validate):
@guppy
def foo(x: int, y: int) -> int:
p = 0
- _@functional
+ _ @ functional
while x > 0:
s = 0
while y > 0:
diff --git a/tests/integration/test_higher_order.py b/tests/integration/test_higher_order.py
index b406c932..fd8856d8 100644
--- a/tests/integration/test_higher_order.py
+++ b/tests/integration/test_higher_order.py
@@ -1,4 +1,4 @@
-from typing import Callable
+from collections.abc import Callable
from guppy.decorator import guppy
from guppy.module import GuppyModule
@@ -73,13 +73,18 @@ def curry(f: Callable[[int, int], bool]) -> Callable[[int], Callable[[int], bool
def g(x: int) -> Callable[[int], bool]:
def h(y: int) -> bool:
return f(x, y)
+
return h
+
return g
@guppy(module)
- def uncurry(f: Callable[[int], Callable[[int], bool]]) -> Callable[[int, int], bool]:
+ def uncurry(
+ f: Callable[[int], Callable[[int], bool]],
+ ) -> Callable[[int, int], bool]:
def g(x: int, y: int) -> bool:
return f(x)(y)
+
return g
@guppy(module)
diff --git a/tests/integration/test_if.py b/tests/integration/test_if.py
index 4f09b295..ac4f5f10 100644
--- a/tests/integration/test_if.py
+++ b/tests/integration/test_if.py
@@ -1,7 +1,6 @@
import pytest
from guppy.decorator import guppy
-from guppy.module import GuppyModule
def test_if_no_else(validate):
@@ -55,7 +54,7 @@ def foo(x: bool, y: int) -> int:
def test_if_expr(validate):
@guppy
def foo(x: bool, y: int) -> int:
- return y+1 if x else 42
+ return y + 1 if x else 42
validate(foo)
@@ -265,4 +264,3 @@ def foo(x: int) -> int:
return z
validate(foo)
-
diff --git a/tests/integration/test_linear.py b/tests/integration/test_linear.py
index 0d744e06..35091b61 100644
--- a/tests/integration/test_linear.py
+++ b/tests/integration/test_linear.py
@@ -55,7 +55,9 @@ def g(q1: Qubit, q2: Qubit) -> tuple[Qubit, Qubit]:
...
@guppy(module)
- def test(a: Qubit, b: Qubit, c: Qubit, d: Qubit) -> tuple[Qubit, Qubit, Qubit, Qubit]:
+ def test(
+ a: Qubit, b: Qubit, c: Qubit, d: Qubit
+ ) -> tuple[Qubit, Qubit, Qubit, Qubit]:
a, b = f(a, b)
c, d = f(c, d)
b, c = g(b, c)
diff --git a/tests/integration/test_nested.py b/tests/integration/test_nested.py
index fe58523f..23698f59 100644
--- a/tests/integration/test_nested.py
+++ b/tests/integration/test_nested.py
@@ -6,6 +6,7 @@ def test_basic(validate):
def foo(x: int) -> int:
def bar(y: int) -> int:
return y
+
return bar(x + 1)
validate(foo)
@@ -16,6 +17,7 @@ def test_call_twice(validate):
def foo(x: int) -> int:
def bar(y: int) -> int:
return y + 3
+
if x > 5:
return bar(x)
else:
@@ -45,11 +47,14 @@ def test_define_twice(validate):
@guppy
def foo(x: int) -> int:
if x == 0:
+
def bar(y: int) -> int:
return y + 3
else:
+
def bar(y: int) -> int:
return y - 42
+
return bar(x)
validate(foo)
@@ -61,7 +66,9 @@ def foo(x: int) -> int:
def bar(y: int) -> int:
def baz(z: int) -> int:
return z - 1
- return baz(5*y)
+
+ return baz(5 * y)
+
return bar(x + 1)
validate(foo)
@@ -74,6 +81,7 @@ def bar(y: int) -> int:
if y == 0:
return 0
return 2 * bar(y - 1)
+
return bar(x)
validate(foo)
@@ -84,6 +92,7 @@ def test_capture_arg(validate):
def foo(x: int) -> int:
def bar() -> int:
return 1 + x
+
return bar()
validate(foo)
@@ -96,6 +105,7 @@ def foo(x: int) -> int:
def bar() -> int:
return y
+
return bar()
validate(foo)
@@ -113,6 +123,7 @@ def foo(x: int) -> int:
def bar() -> int:
q = y
return q + z
+
return bar()
validate(foo)
@@ -127,6 +138,7 @@ def foo(x: int) -> int:
def bar() -> int:
return x + y + a
+
return bar()
return 4
@@ -159,6 +171,7 @@ def bar(y: int, z: int) -> int:
if y == 0:
return z
return bar(z, z * x)
+
return bar(x, 0)
validate(foo)
@@ -190,7 +203,7 @@ def foo(x: int) -> int:
while x > 0:
def bar() -> int:
- return x*x
+ return x * x
a += bar()
x -= 1
diff --git a/tests/integration/test_programs.py b/tests/integration/test_programs.py
index 01c2a6fd..ee7c87cc 100644
--- a/tests/integration/test_programs.py
+++ b/tests/integration/test_programs.py
@@ -1,6 +1,5 @@
from guppy.decorator import guppy
from guppy.module import GuppyModule
-from tests.integration.util import functional, _
def test_factorial(validate):
@@ -24,19 +23,9 @@ def factorial3(x: int, acc: int) -> int:
return acc
return factorial3(x - 1, acc * x)
- # @guppy
- # def factorial4(x: int) -> int:
- # acc = 1
- # _@functional
- # while x > 0:
- # acc *= x
- # x -= 1
- # return acc
-
validate(factorial1, name="factorial1")
validate(factorial2, name="factorial2")
validate(factorial3, name="factorial3")
- # validate(factorial4)
def test_even_odd(validate):
diff --git a/tests/integration/test_unused.py b/tests/integration/test_unused.py
index 772e6b60..b8592bc4 100644
--- a/tests/integration/test_unused.py
+++ b/tests/integration/test_unused.py
@@ -1,9 +1,7 @@
-import pytest
+""" All sorts of weird stuff is allowed when variables are not used. """
from guppy.decorator import guppy
-""" All sorts of weird stuff is allowed when variables are not used. """
-
def test_not_always_defined1(validate):
@guppy
diff --git a/tests/integration/util.py b/tests/integration/util.py
index 52075019..a99478e2 100644
--- a/tests/integration/util.py
+++ b/tests/integration/util.py
@@ -1,12 +1,10 @@
-from typing import TypeVar
-
import validator
-from guppy.hugr.hugr import Hugr
def validate_bytes(hugr: bytes):
validator.validate(hugr)
+
class Decorator:
def __matmul__(self, other):
return None
diff --git a/validator/src/lib.rs b/validator/src/lib.rs
index 702b0dc2..4d31a136 100644
--- a/validator/src/lib.rs
+++ b/validator/src/lib.rs
@@ -1,10 +1,8 @@
-use pyo3::prelude::*;
-use hugr;
use hugr::extension::{ExtensionRegistry, PRELUDE};
-use hugr::std_extensions::arithmetic::{int_ops, int_types, float_ops, float_types};
+use hugr::std_extensions::arithmetic::{float_ops, float_types, int_ops, int_types};
use hugr::std_extensions::logic;
use lazy_static::lazy_static;
-
+use pyo3::prelude::*;
lazy_static! {
pub static ref REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([
@@ -14,7 +12,8 @@ lazy_static! {
int_ops::EXTENSION.to_owned(),
float_types::extension(),
float_ops::extension()
- ]).unwrap();
+ ])
+ .unwrap();
}
#[pyfunction]
@@ -28,4 +27,4 @@ fn validate(hugr: Vec) -> PyResult<()> {
fn validator(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(validate, m)?)?;
Ok(())
-}
\ No newline at end of file
+}