Skip to content

Commit

Permalink
refactor: Separate type checking from compilation code (#57)
Browse files Browse the repository at this point in the history
* Type checking code moved to guppy/checker (#58)
* Graph generation code moved to guppy/compiler (#59)
* Unified extension system with regular Guppy modules and added new buitlins module (#60)
  • Loading branch information
mark-koch authored Dec 5, 2023
1 parent edc978a commit b701c06
Show file tree
Hide file tree
Showing 78 changed files with 4,011 additions and 2,968 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pull-request.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:

strategy:
matrix:
python-version: [3.9]
python-version: ['3.10']

steps:
- uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion guppy/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__all__ = ["guppy_types"]
__all__ = ["types.py"]
90 changes: 88 additions & 2 deletions guppy/ast_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import ast
from typing import Any, TypeVar, Generic, Union
from typing import Any, TypeVar, Generic, Union, Optional, TYPE_CHECKING

if TYPE_CHECKING:
from guppy.gtypes import GuppyType

AstNode = Union[
ast.AST,
Expand Down Expand Up @@ -111,8 +113,92 @@ def set_location_from(node: ast.AST, loc: ast.AST) -> None:
node.end_lineno = loc.end_lineno
node.end_col_offset = loc.end_col_offset

source, file, line_offset = get_source(loc), get_file(loc), get_line_offset(loc)
assert source is not None and file is not None and line_offset is not None
annotate_location(node, source, file, line_offset)

def is_empty_body(func_ast: ast.FunctionDef) -> bool:

def annotate_location(
node: ast.AST, source: str, file: str, line_offset: int, recurse: bool = True
) -> None:
setattr(node, "line_offset", line_offset)
setattr(node, "file", file)
setattr(node, "source", source)

if recurse:
for field, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
annotate_location(item, source, file, line_offset, recurse)
elif isinstance(value, ast.AST):
annotate_location(value, source, file, line_offset, recurse)


def get_file(node: AstNode) -> Optional[str]:
"""Tries to retrieve a file annotation from an AST node."""
try:
file = getattr(node, "file")
return file if isinstance(file, str) else None
except AttributeError:
return None


def get_source(node: AstNode) -> Optional[str]:
"""Tries to retrieve a source annotation from an AST node."""
try:
source = getattr(node, "source")
return source if isinstance(source, str) else None
except AttributeError:
return None


def get_line_offset(node: AstNode) -> Optional[int]:
"""Tries to retrieve a line offset annotation from an AST node."""
try:
line_offset = getattr(node, "line_offset")
return line_offset if isinstance(line_offset, int) else None
except AttributeError:
return None


A = TypeVar("A", bound=ast.AST)


def with_loc(loc: ast.AST, node: A) -> A:
"""Copy source location from one AST node to the other."""
set_location_from(node, loc)
return node


def with_type(ty: "GuppyType", node: A) -> A:
"""Annotates an AST node with a type."""
setattr(node, "type", ty)
return node


def get_type_opt(node: AstNode) -> Optional["GuppyType"]:
"""Tries to retrieve a type annotation from an AST node."""
from guppy.gtypes import GuppyType

try:
ty = getattr(node, "type")
return ty if isinstance(ty, GuppyType) else None
except AttributeError:
return None


def get_type(node: AstNode) -> "GuppyType":
"""Retrieve a type annotation from an AST node.
Fails if the node is not annotated.
"""
ty = get_type_opt(node)
assert ty is not None
return ty


def has_empty_body(func_ast: ast.FunctionDef) -> bool:
"""Returns `True` if the body of a function definition is empty.
This is the case if the body only contains a single `pass` statement or an ellipsis
Expand Down
95 changes: 21 additions & 74 deletions guppy/cfg/bb.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import ast
from abc import ABC
from dataclasses import dataclass, field
from typing import Optional, Sequence, TYPE_CHECKING, Union, Any
from typing import Optional, TYPE_CHECKING, Union
from typing_extensions import Self

from guppy.ast_util import AstNode, name_nodes_in_ast
from guppy.compiler_base import RawVariable, return_var
from guppy.guppy_types import FunctionType
from guppy.hugr.hugr import CFNode
from guppy.nodes import NestedFunctionDef

if TYPE_CHECKING:
from guppy.cfg.cfg import CFG
from guppy.cfg.cfg import BaseCFG


@dataclass
Expand All @@ -34,61 +34,26 @@ def update_used(self, node: ast.AST) -> None:
self.used[name.id] = name


VarRow = Sequence[RawVariable]


@dataclass(frozen=True)
class Signature:
"""The signature of a basic block.
Stores the inout/output variables with their types.
"""

input_row: VarRow
output_rows: Sequence[VarRow] # One for each successor


@dataclass(frozen=True)
class CompiledBB:
"""The result of compiling a basic block.
Besides the corresponding node in the graph, we also store the signature of the
basic block with type information.
"""

node: CFNode
bb: "BB"
sig: Signature


class NestedFunctionDef(ast.FunctionDef):
cfg: "CFG"
ty: FunctionType

def __init__(self, cfg: "CFG", ty: FunctionType, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.cfg = cfg
self.ty = ty


BBStatement = Union[ast.Assign, ast.AugAssign, ast.Expr, ast.Return, NestedFunctionDef]
BBStatement = Union[
ast.Assign, ast.AugAssign, ast.AnnAssign, ast.Expr, ast.Return, NestedFunctionDef
]


@dataclass(eq=False) # Disable equality to recover hash from `object`
class BB:
class BB(ABC):
"""A basic block in a control flow graph."""

idx: int

# Pointer to the CFG that contains this node
cfg: "CFG"
containing_cfg: "BaseCFG[Self]"

# AST statements contained in this BB
statements: list[BBStatement] = field(default_factory=list)

# Predecessor and successor BBs
predecessors: list["BB"] = field(default_factory=list)
successors: list["BB"] = field(default_factory=list)
predecessors: list[Self] = field(default_factory=list)
successors: list[Self] = field(default_factory=list)

# If the BB has multiple successors, we need a predicate to decide to which one to
# jump to
Expand All @@ -107,40 +72,25 @@ def vars(self) -> VariableStats:
assert self._vars is not None
return self._vars

def compute_variable_stats(self, num_returns: int) -> None:
"""Determines which variables are assigned/used in this BB.
This also requires the expected number of returns of the whole CFG in order to
process `return` statements.
"""
visitor = VariableVisitor(self, num_returns)
def compute_variable_stats(self) -> None:
"""Determines which variables are assigned/used in this BB."""
visitor = VariableVisitor(self)
for s in self.statements:
visitor.visit(s)
self._vars = visitor.stats

if self.branch_pred is not None:
self._vars.update_used(self.branch_pred)

# In the `StatementCompiler`, we're going to turn return statements into
# assignments of dummy variables `%ret_xxx`. Thus, we have to register those
# variables as being used in the exit BB
if len(self.successors) == 0:
self._vars.used |= {
return_var(i): ast.Name(return_var(i), ast.Load)
for i in range(num_returns)
}


class VariableVisitor(ast.NodeVisitor):
"""Visitor that computes used and assigned variables in a BB."""

bb: BB
stats: VariableStats
num_returns: int

def __init__(self, bb: BB, num_returns: int):
def __init__(self, bb: BB):
self.bb = bb
self.num_returns = num_returns
self.stats = VariableStats()

def visit_Assign(self, node: ast.Assign) -> None:
Expand All @@ -155,22 +105,19 @@ def visit_AugAssign(self, node: ast.AugAssign) -> None:
for name in name_nodes_in_ast(node.target):
self.stats.assigned[name.id] = node

def visit_Return(self, node: ast.Return) -> None:
if node.value is not None:
def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
if node.value:
self.stats.update_used(node.value)

# In the `StatementCompiler`, we're going to turn return statements into
# assignments of dummy variables `%ret_xxx`. To make the liveness analysis work,
# we have to register those variables as being assigned here
self.stats.assigned |= {return_var(i): node for i in range(self.num_returns)}
for name in name_nodes_in_ast(node.target):
self.stats.assigned[name.id] = node

def visit_NestedFunctionDef(self, node: NestedFunctionDef) -> None:
# In order to compute the used external variables in a nested function
# definition, we have to run live variable analysis first
from guppy.cfg.analysis import LivenessAnalysis

for bb in node.cfg.bbs:
bb.compute_variable_stats(len(node.ty.returns))
bb.compute_variable_stats()
live = LivenessAnalysis().run(node.cfg.bbs)

# Only store used *external* variables: things defined in the current BB, as
Expand Down
33 changes: 18 additions & 15 deletions guppy/cfg/builder.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import ast
import itertools
from typing import Optional, Iterator, Union, NamedTuple
from typing import Optional, Iterator, NamedTuple

from guppy.ast_util import set_location_from, AstVisitor
from guppy.cfg.bb import BB, NestedFunctionDef
from guppy.cfg.bb import BB, BBStatement
from guppy.cfg.cfg import CFG
from guppy.compiler_base import Globals
from guppy.checker.core import Globals
from guppy.error import GuppyError, InternalGuppyError

from guppy.gtypes import NoneType
from guppy.nodes import NestedFunctionDef

# In order to build expressions, need an endless stream of unique temporary variables
# to store intermediate results
Expand All @@ -31,18 +32,16 @@ class CFGBuilder(AstVisitor[Optional[BB]]):
"""Constructs a CFG from ast nodes."""

cfg: CFG
num_returns: int
globals: Globals

def build(self, nodes: list[ast.stmt], num_returns: int, globals: Globals) -> CFG:
def build(self, nodes: list[ast.stmt], returns_none: bool, globals: Globals) -> CFG:
"""Builds a CFG from a list of ast nodes.
We also require the expected number of return ports for the whole CFG. This is
needed to translate return statements into assignments of dummy return
variables.
"""
self.cfg = CFG()
self.num_returns = num_returns
self.globals = globals

final_bb = self.visit_stmts(
Expand All @@ -52,7 +51,7 @@ def build(self, nodes: list[ast.stmt], num_returns: int, globals: Globals) -> CF
# If we're still in a basic block after compiling the whole body, we have to add
# an implicit void return
if final_bb is not None:
if num_returns > 0:
if not returns_none:
raise GuppyError("Expected return statement", nodes[-1])
self.cfg.link(final_bb, self.cfg.exit_bb)

Expand All @@ -76,15 +75,13 @@ def visit_stmts(self, nodes: list[ast.stmt], bb: BB, jumps: Jumps) -> Optional[B
bb_opt = self.visit(node, bb_opt, jumps)
return bb_opt

def _build_node_value(
self, node: Union[ast.Assign, ast.AugAssign, ast.Return, ast.Expr], bb: BB
) -> BB:
def _build_node_value(self, node: BBStatement, bb: BB) -> BB:
"""Utility method for building a node containing a `value` expression.
Builds the expression and mutates `node.value` to point to the built expression.
Returns the BB in which the expression is available and adds the node to it.
"""
if node.value is not None:
if not isinstance(node, NestedFunctionDef) and node.value is not None:
node.value, bb = ExprBuilder.build(node.value, self.cfg, bb)
bb.statements.append(node)
return bb
Expand All @@ -97,6 +94,11 @@ def visit_AugAssign(
) -> Optional[BB]:
return self._build_node_value(node, bb)

def visit_AnnAssign(
self, node: ast.AnnAssign, bb: BB, jumps: Jumps
) -> Optional[BB]:
return self._build_node_value(node, bb)

def visit_Expr(self, node: ast.Expr, bb: BB, jumps: Jumps) -> Optional[BB]:
# This is an expression statement where the value is discarded
node.value, bb = ExprBuilder.build(node.value, self.cfg, bb)
Expand Down Expand Up @@ -166,10 +168,11 @@ def visit_Pass(self, node: ast.Pass, bb: BB, jumps: Jumps) -> Optional[BB]:
def visit_FunctionDef(
self, node: ast.FunctionDef, bb: BB, jumps: Jumps
) -> Optional[BB]:
from guppy.function import FunctionDefCompiler
from guppy.checker.func_checker import check_signature

func_ty = FunctionDefCompiler.validate_signature(node, self.globals)
cfg = CFGBuilder().build(node.body, len(func_ty.returns), self.globals)
func_ty = check_signature(node, self.globals)
returns_none = isinstance(func_ty.returns, NoneType)
cfg = CFGBuilder().build(node.body, returns_none, self.globals)

new_node = NestedFunctionDef(
cfg,
Expand Down
Loading

0 comments on commit b701c06

Please sign in to comment.