Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix array execution bugs #731

Merged
merged 18 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions execute_llvm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ crate-type = ["cdylib"]

[dependencies]
hugr = {workspace = true, features = ["llvm"]}
hugr-passes = "0.14.0"
inkwell.workspace = true
pyo3.workspace = true
serde_json.workspace = true
8 changes: 7 additions & 1 deletion execute_llvm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use hugr::llvm::utils::fat::FatExt;
use hugr::Hugr;
use hugr::{self, ops, std_extensions, HugrView};
use hugr_passes;
use inkwell::{context::Context, module::Module, values::GenericValue};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
Expand Down Expand Up @@ -38,6 +39,10 @@ fn find_funcdef_node(hugr: impl HugrView, fn_name: &str) -> PyResult<hugr::Node>
}
}

fn guppy_pass(hugr: Hugr) -> Hugr {
hugr_passes::monomorphize(hugr)
mark-koch marked this conversation as resolved.
Show resolved Hide resolved
}

fn compile_module<'a>(
hugr: &'a hugr::Hugr,
ctx: &'a Context,
Expand Down Expand Up @@ -77,7 +82,8 @@ fn run_function<T>(
fn_name: &str,
parse_result: impl FnOnce(&Context, GenericValue) -> PyResult<T>,
) -> PyResult<T> {
let hugr = parse_hugr(hugr_json)?;
let mut hugr = parse_hugr(hugr_json)?;
hugr = guppy_pass(hugr);
let ctx = Context::create();

let namer = hugr::llvm::emit::Namer::default();
Expand Down
20 changes: 13 additions & 7 deletions guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from guppylang.checker.errors.generic import UnsupportedError
from guppylang.checker.linearity_checker import contains_subscript
from guppylang.compiler.core import CompilerBase, DFContainer
from guppylang.compiler.hugr_extension import PartialOp, UnsupportedOp
from guppylang.compiler.hugr_extension import PartialOp
from guppylang.definition.custom import CustomFunctionDef
from guppylang.definition.value import (
CallReturnWires,
Expand All @@ -46,6 +46,7 @@
TensorCall,
TypeApply,
)
from guppylang.std._internal.compiler.arithmetic import convert_ifromusize
from guppylang.std._internal.compiler.array import array_repeat
from guppylang.std._internal.compiler.list import (
list_new,
Expand Down Expand Up @@ -206,11 +207,16 @@ def visit_GlobalName(self, node: GlobalName) -> Wire:
return defn.load(self.dfg, self.globals, node)

def visit_GenericParamValue(self, node: GenericParamValue) -> Wire:
# TODO: We need a way to look up the concrete value of a generic type arg in
# Hugr. For example, a new op that captures the value during monomorphisation
return self.builder.add_op(
UnsupportedOp("load_type_param", [], [node.param.ty.to_hugr()]).ext_op
)
match node.param.ty:
case NumericType(NumericType.Kind.Nat):
arg = node.param.to_bound().to_hugr()
load_nat = hugr.std.PRELUDE.get_op("load_nat").instantiate(
[arg], ht.FunctionType([], [ht.USize()])
)
usize = self.builder.add_op(load_nat)
return self.builder.add_op(convert_ifromusize(), usize)
case _:
raise NotImplementedError

def visit_Name(self, node: ast.Name) -> Wire:
raise InternalGuppyError("Node should have been removed during type checking.")
Expand Down Expand Up @@ -606,7 +612,7 @@ def python_value_to_hugr(v: Any, exp_ty: Type) -> hv.Value | None:
assert is_array_type(exp_ty)
vs = [python_value_to_hugr(elt, get_element_type(exp_ty)) for elt in elts]
if doesnt_contain_none(vs):
# TODO: Use proper array value: https://github.com/CQCL/hugr/issues/1497
# TODO: Use proper array value: https://github.com/CQCL/hugr/issues/1771
return hv.Extension(
name="ArrayValue",
typ=exp_ty.to_hugr(),
Expand Down
43 changes: 30 additions & 13 deletions guppylang/std/_internal/compiler/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from hugr import tys as ht
from hugr.std.collections.array import EXTENSION

from guppylang.compiler.hugr_extension import UnsupportedOp
from guppylang.definition.custom import CustomCallCompiler
from guppylang.definition.value import CallReturnWires
from guppylang.error import InternalGuppyError
Expand Down Expand Up @@ -72,24 +71,42 @@ def array_set(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp:
)


def array_scan(
elem_ty: ht.Type,
length: ht.TypeArg,
new_elem_ty: ht.Type,
accumulators: list[ht.Type],
) -> ops.ExtOp:
"""Returns an operation that maps and folds a function across an array."""
ty_args = [
length,
ht.TypeTypeArg(elem_ty),
ht.TypeTypeArg(new_elem_ty),
ht.SequenceArg([ht.TypeTypeArg(acc) for acc in accumulators]),
ht.ExtensionsArg([]),
]
ins = [
array_type(elem_ty, length),
ht.FunctionType([elem_ty, *accumulators], [new_elem_ty, *accumulators]),
*accumulators,
]
outs = [array_type(new_elem_ty, length), *accumulators]
return EXTENSION.get_op("scan").instantiate(ty_args, ht.FunctionType(ins, outs))


def array_map(elem_ty: ht.Type, length: ht.TypeArg, new_elem_ty: ht.Type) -> ops.ExtOp:
"""Returns an operation that maps a function across an array."""
# TODO
return UnsupportedOp(
op_name="array_map",
inputs=[array_type(elem_ty, length), ht.FunctionType([elem_ty], [new_elem_ty])],
outputs=[array_type(new_elem_ty, length)],
).ext_op
return array_scan(elem_ty, length, new_elem_ty, accumulators=[])


def array_repeat(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp:
"""Returns an array `repeat` operation."""
# TODO
return UnsupportedOp(
op_name="array.repeat",
inputs=[ht.FunctionType([], [elem_ty])],
outputs=[array_type(elem_ty, length)],
).ext_op
return EXTENSION.get_op("repeat").instantiate(
[length, ht.TypeTypeArg(elem_ty), ht.ExtensionsArg([])],
ht.FunctionType(
[ht.FunctionType([], [elem_ty])], [array_type(elem_ty, length)]
),
)


# ------------------------------------------------------
Expand Down
3 changes: 1 addition & 2 deletions tests/integration/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,7 @@ def main() -> int:
package = module.compile()
validate(package)

# TODO: Enable execution once lowering for missing ops is implemented
# run_int_fn(package, expected=9)
run_int_fn(package, expected=9)


def test_mem_swap(validate):
Expand Down
Loading