From 2b40a65276c372e5830bb85555d49e952083cc6b Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 17 Dec 2024 15:06:17 +0000 Subject: [PATCH 01/17] fix: Use proper ops from the Hugr 0.14 release --- Cargo.lock | 1 + execute_llvm/Cargo.toml | 1 + execute_llvm/src/lib.rs | 8 ++++- guppylang/compiler/expr_compiler.py | 20 +++++++---- guppylang/std/_internal/compiler/array.py | 43 ++++++++++++++++------- tests/integration/test_array.py | 3 +- 6 files changed, 53 insertions(+), 23 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ddb2de16..4fc890da 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -452,6 +452,7 @@ name = "execute_llvm" version = "0.2.0" dependencies = [ "hugr", + "hugr-passes", "inkwell", "pyo3", "serde_json", diff --git a/execute_llvm/Cargo.toml b/execute_llvm/Cargo.toml index e67ea79a..039984d2 100644 --- a/execute_llvm/Cargo.toml +++ b/execute_llvm/Cargo.toml @@ -16,6 +16,7 @@ crate-type = ["cdylib"] [dependencies] hugr = {workspace = true, features = ["llvm"]} +hugr-passes = "0.14.0" inkwell.workspace = true pyo3.workspace = true serde_json.workspace = true diff --git a/execute_llvm/src/lib.rs b/execute_llvm/src/lib.rs index 5eb06261..541b90a0 100644 --- a/execute_llvm/src/lib.rs +++ b/execute_llvm/src/lib.rs @@ -2,6 +2,7 @@ use hugr::llvm::utils::fat::FatExt; use hugr::Hugr; use hugr::{self, ops, std_extensions, HugrView}; +use hugr_passes; use inkwell::{context::Context, module::Module, values::GenericValue}; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; @@ -38,6 +39,10 @@ fn find_funcdef_node(hugr: impl HugrView, fn_name: &str) -> PyResult } } +fn guppy_pass(hugr: Hugr) -> Hugr { + hugr_passes::monomorphize(hugr) +} + fn compile_module<'a>( hugr: &'a hugr::Hugr, ctx: &'a Context, @@ -77,7 +82,8 @@ fn run_function( fn_name: &str, parse_result: impl FnOnce(&Context, GenericValue) -> PyResult, ) -> PyResult { - let hugr = parse_hugr(hugr_json)?; + let mut hugr = parse_hugr(hugr_json)?; + hugr = guppy_pass(hugr); let ctx = Context::create(); let namer = hugr::llvm::emit::Namer::default(); diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index b3d8184f..cbdbf584 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -21,7 +21,7 @@ from guppylang.checker.errors.generic import UnsupportedError from guppylang.checker.linearity_checker import contains_subscript from guppylang.compiler.core import CompilerBase, DFContainer -from guppylang.compiler.hugr_extension import PartialOp, UnsupportedOp +from guppylang.compiler.hugr_extension import PartialOp from guppylang.definition.custom import CustomFunctionDef from guppylang.definition.value import ( CallReturnWires, @@ -46,6 +46,7 @@ TensorCall, TypeApply, ) +from guppylang.std._internal.compiler.arithmetic import convert_ifromusize from guppylang.std._internal.compiler.array import array_repeat from guppylang.std._internal.compiler.list import ( list_new, @@ -206,11 +207,16 @@ def visit_GlobalName(self, node: GlobalName) -> Wire: return defn.load(self.dfg, self.globals, node) def visit_GenericParamValue(self, node: GenericParamValue) -> Wire: - # TODO: We need a way to look up the concrete value of a generic type arg in - # Hugr. For example, a new op that captures the value during monomorphisation - return self.builder.add_op( - UnsupportedOp("load_type_param", [], [node.param.ty.to_hugr()]).ext_op - ) + match node.param.ty: + case NumericType(NumericType.Kind.Nat): + arg = node.param.to_bound().to_hugr() + load_nat = hugr.std.PRELUDE.get_op("load_nat").instantiate( + [arg], ht.FunctionType([], [ht.USize()]) + ) + usize = self.builder.add_op(load_nat) + return self.builder.add_op(convert_ifromusize(), usize) + case _: + raise NotImplementedError def visit_Name(self, node: ast.Name) -> Wire: raise InternalGuppyError("Node should have been removed during type checking.") @@ -606,7 +612,7 @@ def python_value_to_hugr(v: Any, exp_ty: Type) -> hv.Value | None: assert is_array_type(exp_ty) vs = [python_value_to_hugr(elt, get_element_type(exp_ty)) for elt in elts] if doesnt_contain_none(vs): - # TODO: Use proper array value: https://github.com/CQCL/hugr/issues/1497 + # TODO: Use proper array value: https://github.com/CQCL/hugr/issues/1771 return hv.Extension( name="ArrayValue", typ=exp_ty.to_hugr(), diff --git a/guppylang/std/_internal/compiler/array.py b/guppylang/std/_internal/compiler/array.py index 66e602cd..21a9e22e 100644 --- a/guppylang/std/_internal/compiler/array.py +++ b/guppylang/std/_internal/compiler/array.py @@ -6,7 +6,6 @@ from hugr import tys as ht from hugr.std.collections.array import EXTENSION -from guppylang.compiler.hugr_extension import UnsupportedOp from guppylang.definition.custom import CustomCallCompiler from guppylang.definition.value import CallReturnWires from guppylang.error import InternalGuppyError @@ -72,24 +71,42 @@ def array_set(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp: ) +def array_scan( + elem_ty: ht.Type, + length: ht.TypeArg, + new_elem_ty: ht.Type, + accumulators: list[ht.Type], +) -> ops.ExtOp: + """Returns an operation that maps and folds a function across an array.""" + ty_args = [ + length, + ht.TypeTypeArg(elem_ty), + ht.TypeTypeArg(new_elem_ty), + ht.SequenceArg([ht.TypeTypeArg(acc) for acc in accumulators]), + ht.ExtensionsArg([]), + ] + ins = [ + array_type(elem_ty, length), + ht.FunctionType([elem_ty, *accumulators], [new_elem_ty, *accumulators]), + *accumulators, + ] + outs = [array_type(new_elem_ty, length), *accumulators] + return EXTENSION.get_op("scan").instantiate(ty_args, ht.FunctionType(ins, outs)) + + def array_map(elem_ty: ht.Type, length: ht.TypeArg, new_elem_ty: ht.Type) -> ops.ExtOp: """Returns an operation that maps a function across an array.""" - # TODO - return UnsupportedOp( - op_name="array_map", - inputs=[array_type(elem_ty, length), ht.FunctionType([elem_ty], [new_elem_ty])], - outputs=[array_type(new_elem_ty, length)], - ).ext_op + return array_scan(elem_ty, length, new_elem_ty, accumulators=[]) def array_repeat(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp: """Returns an array `repeat` operation.""" - # TODO - return UnsupportedOp( - op_name="array.repeat", - inputs=[ht.FunctionType([], [elem_ty])], - outputs=[array_type(elem_ty, length)], - ).ext_op + return EXTENSION.get_op("repeat").instantiate( + [length, ht.TypeTypeArg(elem_ty), ht.ExtensionsArg([])], + ht.FunctionType( + [ht.FunctionType([], [elem_ty])], [array_type(elem_ty, length)] + ), + ) # ------------------------------------------------------ diff --git a/tests/integration/test_array.py b/tests/integration/test_array.py index ab462cff..9973539a 100644 --- a/tests/integration/test_array.py +++ b/tests/integration/test_array.py @@ -291,8 +291,7 @@ def main() -> int: package = module.compile() validate(package) - # TODO: Enable execution once lowering for missing ops is implemented - # run_int_fn(package, expected=9) + run_int_fn(package, expected=9) def test_mem_swap(validate): From 27b22b863da4c7e4ba98c753080c888f519a52d3 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 17 Dec 2024 17:07:34 +0000 Subject: [PATCH 02/17] Use exported hugr_passes --- Cargo.lock | 1 - execute_llvm/Cargo.toml | 1 - execute_llvm/src/lib.rs | 3 +-- 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4fc890da..ddb2de16 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -452,7 +452,6 @@ name = "execute_llvm" version = "0.2.0" dependencies = [ "hugr", - "hugr-passes", "inkwell", "pyo3", "serde_json", diff --git a/execute_llvm/Cargo.toml b/execute_llvm/Cargo.toml index 039984d2..e67ea79a 100644 --- a/execute_llvm/Cargo.toml +++ b/execute_llvm/Cargo.toml @@ -16,7 +16,6 @@ crate-type = ["cdylib"] [dependencies] hugr = {workspace = true, features = ["llvm"]} -hugr-passes = "0.14.0" inkwell.workspace = true pyo3.workspace = true serde_json.workspace = true diff --git a/execute_llvm/src/lib.rs b/execute_llvm/src/lib.rs index 541b90a0..a0694415 100644 --- a/execute_llvm/src/lib.rs +++ b/execute_llvm/src/lib.rs @@ -2,7 +2,6 @@ use hugr::llvm::utils::fat::FatExt; use hugr::Hugr; use hugr::{self, ops, std_extensions, HugrView}; -use hugr_passes; use inkwell::{context::Context, module::Module, values::GenericValue}; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; @@ -40,7 +39,7 @@ fn find_funcdef_node(hugr: impl HugrView, fn_name: &str) -> PyResult } fn guppy_pass(hugr: Hugr) -> Hugr { - hugr_passes::monomorphize(hugr) + hugr::algorithms::monomorphize(hugr) } fn compile_module<'a>( From c0c2516660b239212b3f7169e65fedb6db3bf71e Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 17 Dec 2024 17:19:53 +0000 Subject: [PATCH 03/17] Switch to git dependency --- Cargo.lock | 18 +++++------------- Cargo.toml | 6 +++--- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ddb2de16..ee6ed534 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -525,8 +525,7 @@ dependencies = [ [[package]] name = "hugr" version = "0.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f209c7cd671de29be8bdf0725e09b2e9d386387f439b13975e158f095e5a0fe" +source = "git+https://github.com/CQCL/hugr?rev=e40b6c7#e40b6c7057a15ead78bb18aa837e5b84e12a3722" dependencies = [ "hugr-core", "hugr-llvm", @@ -536,24 +535,20 @@ dependencies = [ [[package]] name = "hugr-cli" version = "0.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab6a94a980d47788908d7f93165846164f8b623b7f382cd3813bd0c0d1188e65" +source = "git+https://github.com/CQCL/hugr?rev=e40b6c7#e40b6c7057a15ead78bb18aa837e5b84e12a3722" dependencies = [ "clap", "clap-verbosity-flag", "clio", "derive_more", "hugr", - "serde", "serde_json", - "thiserror 2.0.7", ] [[package]] name = "hugr-core" version = "0.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60c3d5422f76dbec1d6948e68544b134562ec9ec087e8e6a599555b716f555dc" +source = "git+https://github.com/CQCL/hugr?rev=e40b6c7#e40b6c7057a15ead78bb18aa837e5b84e12a3722" dependencies = [ "bitvec", "bumpalo", @@ -585,12 +580,10 @@ dependencies = [ [[package]] name = "hugr-llvm" version = "0.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce4117f4f934b1033b82d8cb672b3c33c3a7f8f541c50f7cc7ff53cebb5816d1" +source = "git+https://github.com/CQCL/hugr?rev=e40b6c7#e40b6c7057a15ead78bb18aa837e5b84e12a3722" dependencies = [ "anyhow", "delegate", - "downcast-rs", "hugr-core", "inkwell", "itertools 0.13.0", @@ -602,8 +595,7 @@ dependencies = [ [[package]] name = "hugr-passes" version = "0.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec2591767b6fe03074d38de7c4e61d52b37cb2e73b7340bf4ff957ad4554022a" +source = "git+https://github.com/CQCL/hugr?rev=e40b6c7#e40b6c7057a15ead78bb18aa837e5b84e12a3722" dependencies = [ "ascent", "hugr-core", diff --git a/Cargo.toml b/Cargo.toml index 21c65419..8b1094a6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,6 @@ inkwell = "0.5.0" [patch.crates-io] # Uncomment these to test the latest dependency version during development -# hugr = { git = "https://github.com/CQCL/hugr", rev = "861183e" } -# hugr-cli = { git = "https://github.com/CQCL/hugr", rev = "861183e" } -# hugr-llvm = { git = "https://github.com/CQCL/hugr", rev = "1091755" } + hugr = { git = "https://github.com/CQCL/hugr", rev = "e40b6c7" } + hugr-cli = { git = "https://github.com/CQCL/hugr", rev = "e40b6c7" } + hugr-llvm = { git = "https://github.com/CQCL/hugr", rev = "e40b6c7" } From 76a110dfee329eb0ea5148afae6df7f1590b0d2f Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 17 Dec 2024 17:20:06 +0000 Subject: [PATCH 04/17] Run remove_polyfuncs pass --- execute_llvm/src/lib.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/execute_llvm/src/lib.rs b/execute_llvm/src/lib.rs index a0694415..1b3d2111 100644 --- a/execute_llvm/src/lib.rs +++ b/execute_llvm/src/lib.rs @@ -39,7 +39,8 @@ fn find_funcdef_node(hugr: impl HugrView, fn_name: &str) -> PyResult } fn guppy_pass(hugr: Hugr) -> Hugr { - hugr::algorithms::monomorphize(hugr) + let hugr = hugr::algorithms::monomorphize(hugr); + hugr::algorithms::remove_polyfuncs(hugr) } fn compile_module<'a>( From 96caecc4f3cff7ba2291f4a5fa4243ed91c3a367 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 17 Dec 2024 17:20:20 +0000 Subject: [PATCH 05/17] Replace iu_to_s with noop --- guppylang/std/builtins.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/guppylang/std/builtins.py b/guppylang/std/builtins.py index d9843bc2..85a9be7d 100644 --- a/guppylang/std/builtins.py +++ b/guppylang/std/builtins.py @@ -161,7 +161,9 @@ def __ge__(self: nat, other: nat) -> bool: ... @guppy.hugr_op(int_op("igt_u")) def __gt__(self: nat, other: nat) -> bool: ... - @guppy.hugr_op(int_op("iu_to_s")) + # TODO: Use "iu_to_s" once we have lowering: + # https://github.com/CQCL/hugr/issues/1806 + @guppy.custom(NoopCompiler()) def __int__(self: nat) -> int: ... @guppy.hugr_op(int_op("inot")) From 4a8eb32fd99deae84ecbe5a808c0704a1a7b9db8 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 17 Dec 2024 17:29:01 +0000 Subject: [PATCH 06/17] Use git dependency in uv --- pyproject.toml | 2 +- uv.lock | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 473993fb..29ceb926 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,7 +83,7 @@ members = ["execute_llvm"] execute-llvm = { workspace = true } # Uncomment these to test the latest dependency version during development -# hugr = { git = "https://github.com/CQCL/hugr", subdirectory = "hugr-py", rev = "861183e" } + hugr = { git = "https://github.com/CQCL/hugr", subdirectory = "hugr-py", rev = "e40b6c7" } # tket2-exts = { git = "https://github.com/CQCL/tket2", subdirectory = "tket2-exts", rev = "eb7cc63"} # tket2 = { git = "https://github.com/CQCL/tket2", subdirectory = "tket2-py", rev = "eb7cc63"} diff --git a/uv.lock b/uv.lock index 8b5ef284..c61fbdd1 100644 --- a/uv.lock +++ b/uv.lock @@ -614,7 +614,7 @@ test = [ [package.metadata] requires-dist = [ { name = "graphviz", specifier = ">=0.20.1,<0.21" }, - { name = "hugr", specifier = ">=0.10.0,<0.11" }, + { name = "hugr", git = "https://github.com/CQCL/hugr?subdirectory=hugr-py&rev=e40b6c7" }, { name = "networkx", specifier = ">=3.2.1,<4" }, { name = "pydantic", specifier = ">=2.7.0b1,<3" }, { name = "pytket", marker = "extra == 'pytket'", specifier = ">=1.34" }, @@ -679,17 +679,13 @@ test = [ [[package]] name = "hugr" version = "0.10.0" -source = { registry = "https://pypi.org/simple" } +source = { git = "https://github.com/CQCL/hugr?subdirectory=hugr-py&rev=e40b6c7#e40b6c7057a15ead78bb18aa837e5b84e12a3722" } dependencies = [ { name = "graphviz" }, { name = "pydantic" }, { name = "pydantic-extra-types" }, { name = "semver" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a6/ef/9cd410ff0e3a92c5e88da2ef3c0e051dd971f4f6c5577873c7901ed31dd5/hugr-0.10.0.tar.gz", hash = "sha256:11e5a80ebd4e31cad0cb04d408cdd93a094e6fb817dd81481eedac5a58f86ff7", size = 129441 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f2/71/83556457cfe27f4a1613cd49041cfe4c6e9e087a53b5beec48a8d709c36d/hugr-0.10.0-py3-none-any.whl", hash = "sha256:591e252ef3e4182fd0de99274ebb4999ddd9572a0ec823519de154e4bd9f14ec", size = 83000 }, -] [[package]] name = "identify" From ed898c7ae0c906940c56f06510c14214867f6d7f Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 17 Dec 2024 17:29:22 +0000 Subject: [PATCH 07/17] Use ArrayVal --- guppylang/compiler/expr_compiler.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index cbdbf584..dd8398a0 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -13,6 +13,7 @@ from hugr import val as hv from hugr.build.cond_loop import Conditional from hugr.build.dfg import DP, DfBase +import hugr.std.collections.array from typing_extensions import assert_never from guppylang.ast_util import AstNode, AstVisitor, get_type @@ -610,17 +611,12 @@ def python_value_to_hugr(v: Any, exp_ty: Type) -> hv.Value | None: return hv.Tuple(*vs) case list(elts): assert is_array_type(exp_ty) - vs = [python_value_to_hugr(elt, get_element_type(exp_ty)) for elt in elts] + elem_ty = get_element_type(exp_ty) + vs = [python_value_to_hugr(elt, elem_ty) for elt in elts] if doesnt_contain_none(vs): - # TODO: Use proper array value: https://github.com/CQCL/hugr/issues/1771 - return hv.Extension( - name="ArrayValue", - typ=exp_ty.to_hugr(), - # The value list must be serialized at this point, otherwise the - # `Extension` value would not be serializable. - val=[v._to_serial_root() for v in vs], - extensions=["unsupported"], - ) + opt_ty = ht.Option(elem_ty.to_hugr()) + opt_vs = [hv.Sum(1, opt_ty, [v]) for v in vs] + return hugr.std.collections.array.ArrayVal(opt_vs, opt_ty) case _: # TODO replace with hugr protocol handling: https://github.com/CQCL/guppylang/issues/563 # Pytket conversion is an experimental feature From 7b3d7d298e183870b4b2b31b4c46df9f0fb595b8 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 17 Dec 2024 17:34:26 +0000 Subject: [PATCH 08/17] Fix mypy --- guppylang/compiler/expr_compiler.py | 2 +- guppylang/tys/builtin.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index 12b6759a..94a74ca1 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -615,7 +615,7 @@ def python_value_to_hugr(v: Any, exp_ty: Type) -> hv.Value | None: vs = [python_value_to_hugr(elt, elem_ty) for elt in elts] if doesnt_contain_none(vs): opt_ty = ht.Option(elem_ty.to_hugr()) - opt_vs = [hv.Sum(1, opt_ty, [v]) for v in vs] + opt_vs: list[hv.Value] = [hv.Sum(1, opt_ty, [v]) for v in vs] return hugr.std.collections.array.ArrayVal(opt_vs, opt_ty) case _: # TODO replace with hugr protocol handling: https://github.com/CQCL/guppylang/issues/563 diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index f9785b21..740869dc 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -138,8 +138,7 @@ def _array_to_hugr(args: Sequence[Argument]) -> ht.Type: elem_ty = ht.Option(ty_arg.ty.to_hugr()) hugr_arg = len_arg.to_hugr() - # TODO remove type ignore after Array type annotation fixed to include VariableArg - return hugr.std.collections.array.Array(elem_ty, hugr_arg) # type:ignore[arg-type] + return hugr.std.collections.array.Array(elem_ty, hugr_arg) def _sized_iter_to_hugr(args: Sequence[Argument]) -> ht.Type: From e0d0e748a061fe5f48ae2768c4fdc4ba2381b93e Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 17 Dec 2024 17:57:37 +0000 Subject: [PATCH 09/17] Add more execution tests --- tests/integration/test_array_comprehension.py | 68 +++++++++++++++---- 1 file changed, 55 insertions(+), 13 deletions(-) diff --git a/tests/integration/test_array_comprehension.py b/tests/integration/test_array_comprehension.py index e012b3a5..d2407f91 100644 --- a/tests/integration/test_array_comprehension.py +++ b/tests/integration/test_array_comprehension.py @@ -9,12 +9,23 @@ from tests.util import compile_guppy -def test_basic(validate): - @compile_guppy +def test_basic_exec(validate, run_int_fn): + module = GuppyModule("test") + + @guppy(module) def test() -> array[int, 10]: return array(i + 1 for i in range(10)) - validate(test) + @guppy(module) + def main() -> int: + s = 0 + for x in test(): + s += x + return s + + package = module.compile() + validate(package) + run_int_fn(package, expected=sum(i + 1 for i in range(10))) def test_basic_linear(validate): @@ -29,23 +40,42 @@ def test() -> array[qubit, 42]: validate(module.compile()) -def test_zero_length(validate): - @compile_guppy +def test_zero_length(validate, run_int_fn): + module = GuppyModule("test") + + @guppy(module) def test() -> array[float, 0]: return array(i / 0 for i in range(0)) - validate(test) + @guppy(module) + def main() -> int: + test() + return 0 + package = module.compile() + validate(package) + run_int_fn(package, expected=0) -def test_capture(validate): - @compile_guppy + +def test_capture(validate, run_int_fn): + module = GuppyModule("test") + + @guppy(module) def test(x: int) -> array[int, 42]: return array(i + x for i in range(42)) - validate(test) + @guppy(module) + def main() -> int: + s = 0 + for x in test(3): + s += x + return s + + package = module.compile() + validate(package) + run_int_fn(package, expected=0) -@pytest.mark.skip("See https://github.com/CQCL/hugr/issues/1625") def test_capture_struct(validate): module = GuppyModule("test") @@ -71,12 +101,24 @@ def test() -> float: validate(test) -def test_nested_left(validate): - @compile_guppy +def test_nested_left(validate, run_int_fn): + module = GuppyModule("test") + + @guppy(module) def test() -> array[array[int, 10], 20]: return array(array(x + y for y in range(10)) for x in range(20)) - validate(test) + @guppy(module) + def main() -> int: + s = 0 + for xs in test(): + for x in xs: + s += x + return s + + package = module.compile() + validate(package) + run_int_fn(package, expected=sum(x + y for y in range(10) for x in range(20))) def test_generic(validate): From 7ff6e35da5591f0e691c55cd38e3e196db0f9aa1 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 18 Dec 2024 13:33:05 +0000 Subject: [PATCH 10/17] Bump hugr --- Cargo.lock | 10 +++++----- Cargo.toml | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ee6ed534..0cad9946 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -525,7 +525,7 @@ dependencies = [ [[package]] name = "hugr" version = "0.14.0" -source = "git+https://github.com/CQCL/hugr?rev=e40b6c7#e40b6c7057a15ead78bb18aa837e5b84e12a3722" +source = "git+https://github.com/CQCL/hugr?rev=ab94518#ab94518ed2812abca615bfbfb5a822f67c115be8" dependencies = [ "hugr-core", "hugr-llvm", @@ -535,7 +535,7 @@ dependencies = [ [[package]] name = "hugr-cli" version = "0.14.0" -source = "git+https://github.com/CQCL/hugr?rev=e40b6c7#e40b6c7057a15ead78bb18aa837e5b84e12a3722" +source = "git+https://github.com/CQCL/hugr?rev=ab94518#ab94518ed2812abca615bfbfb5a822f67c115be8" dependencies = [ "clap", "clap-verbosity-flag", @@ -548,7 +548,7 @@ dependencies = [ [[package]] name = "hugr-core" version = "0.14.0" -source = "git+https://github.com/CQCL/hugr?rev=e40b6c7#e40b6c7057a15ead78bb18aa837e5b84e12a3722" +source = "git+https://github.com/CQCL/hugr?rev=ab94518#ab94518ed2812abca615bfbfb5a822f67c115be8" dependencies = [ "bitvec", "bumpalo", @@ -580,7 +580,7 @@ dependencies = [ [[package]] name = "hugr-llvm" version = "0.14.0" -source = "git+https://github.com/CQCL/hugr?rev=e40b6c7#e40b6c7057a15ead78bb18aa837e5b84e12a3722" +source = "git+https://github.com/CQCL/hugr?rev=ab94518#ab94518ed2812abca615bfbfb5a822f67c115be8" dependencies = [ "anyhow", "delegate", @@ -595,7 +595,7 @@ dependencies = [ [[package]] name = "hugr-passes" version = "0.14.0" -source = "git+https://github.com/CQCL/hugr?rev=e40b6c7#e40b6c7057a15ead78bb18aa837e5b84e12a3722" +source = "git+https://github.com/CQCL/hugr?rev=ab94518#ab94518ed2812abca615bfbfb5a822f67c115be8" dependencies = [ "ascent", "hugr-core", diff --git a/Cargo.toml b/Cargo.toml index 8b1094a6..6d600d42 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,6 @@ inkwell = "0.5.0" [patch.crates-io] # Uncomment these to test the latest dependency version during development - hugr = { git = "https://github.com/CQCL/hugr", rev = "e40b6c7" } - hugr-cli = { git = "https://github.com/CQCL/hugr", rev = "e40b6c7" } - hugr-llvm = { git = "https://github.com/CQCL/hugr", rev = "e40b6c7" } + hugr = { git = "https://github.com/CQCL/hugr", rev = "ab94518" } + hugr-cli = { git = "https://github.com/CQCL/hugr", rev = "ab94518" } + hugr-llvm = { git = "https://github.com/CQCL/hugr", rev = "ab94518" } From 25e3c4e7302b7085c266bb98d41106a8daa68d7d Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 18 Dec 2024 13:33:41 +0000 Subject: [PATCH 11/17] Run guppy_pass in compile_module_to_string --- execute_llvm/src/lib.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/execute_llvm/src/lib.rs b/execute_llvm/src/lib.rs index 1b3d2111..8737a0e6 100644 --- a/execute_llvm/src/lib.rs +++ b/execute_llvm/src/lib.rs @@ -69,9 +69,10 @@ fn compile_module<'a>( #[pyfunction] fn compile_module_to_string(hugr_json: &str) -> PyResult { - let hugr = parse_hugr(hugr_json)?; + let mut hugr = parse_hugr(hugr_json)?; let ctx = Context::create(); + hugr = guppy_pass(hugr); let module = compile_module(&hugr, &ctx, Default::default())?; Ok(module.print_to_string().to_str().unwrap().to_string()) From 6c7b4397e095101467fa6376387892a986a4a02e Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 18 Dec 2024 13:38:18 +0000 Subject: [PATCH 12/17] Add logic extension to execute_llvm --- execute_llvm/src/lib.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/execute_llvm/src/lib.rs b/execute_llvm/src/lib.rs index 8737a0e6..3ff6d2d7 100644 --- a/execute_llvm/src/lib.rs +++ b/execute_llvm/src/lib.rs @@ -52,6 +52,7 @@ fn compile_module<'a>( // TODO: Handle tket2 codegen extension let extensions = hugr::llvm::custom::CodegenExtsBuilder::default() .add_int_extensions() + .add_logic_extensions() .add_default_prelude_extensions() .add_default_array_extensions() .add_float_extensions() From 8cf96921f53285e3e831e90a959a54196ead532c Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 18 Dec 2024 13:39:01 +0000 Subject: [PATCH 13/17] Fix bool tags --- guppylang/compiler/expr_compiler.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index 94a74ca1..5dbae38d 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -125,7 +125,7 @@ def _new_dfcontainer( def _new_loop( self, loop_vars: list[PlaceNode], - branch: PlaceNode, + continue_predicate: PlaceNode, ) -> Iterator[None]: """Context manager to build a graph inside a new `TailLoop` node. @@ -136,13 +136,12 @@ def _new_loop( loop = self.builder.add_tail_loop([], loop_inputs) with self._new_dfcontainer(loop_vars, loop): yield - # Output the branch predicate and the inputs for the next iteration - loop.set_loop_outputs( - # Note that we have to do fresh calls to `self.visit` here since we're - # in a new context - self.visit(branch), - *(self.visit(name) for name in loop_vars), - ) + # Output the branch predicate and the inputs for the next iteration. Note + # that we have to do fresh calls to `self.visit` here since we're in a new + # context + do_continue = self.visit(continue_predicate) + do_break = loop.add_op(hugr.std.logic.Not, do_continue) + loop.set_loop_outputs(do_break, *(self.visit(name) for name in loop_vars)) # Update the DFG with the outputs from the loop for node, wire in zip(loop_vars, loop, strict=True): self.dfg[node.place] = wire @@ -174,12 +173,12 @@ def _if_true(self, cond: ast.expr, inputs: list[PlaceNode]) -> Iterator[None]: conditional = self.builder.add_conditional( self.visit(cond), *(self.visit(inp) for inp in inputs) ) - # If the condition is true, we enter the `with` block - with self._new_case(inputs, inputs, conditional, 0): - yield # If the condition is false, output the inputs as is - with self._new_case(inputs, inputs, conditional, 1): + with self._new_case(inputs, inputs, conditional, 0): pass + # If the condition is true, we enter the `with` block + with self._new_case(inputs, inputs, conditional, 1): + yield # Update the DFG with the outputs from the Conditional node for node, wire in zip(inputs, conditional, strict=True): self.dfg[node.place] = wire From 51c262c105bd6ce01c6e4b070a0e3ec3e98f7b35 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 18 Dec 2024 13:39:15 +0000 Subject: [PATCH 14/17] Fix test --- tests/integration/test_array_comprehension.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/test_array_comprehension.py b/tests/integration/test_array_comprehension.py index d2407f91..d967c705 100644 --- a/tests/integration/test_array_comprehension.py +++ b/tests/integration/test_array_comprehension.py @@ -73,7 +73,7 @@ def main() -> int: package = module.compile() validate(package) - run_int_fn(package, expected=0) + run_int_fn(package, expected=sum(i + 3 for i in range(42))) def test_capture_struct(validate): From 768b2e440d3b1bb3a72d514f17563dc9332fb406 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 18 Dec 2024 13:40:00 +0000 Subject: [PATCH 15/17] Enable execution test --- tests/integration/test_unpack.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/integration/test_unpack.py b/tests/integration/test_unpack.py index f399f22d..91bd6da6 100644 --- a/tests/integration/test_unpack.py +++ b/tests/integration/test_unpack.py @@ -69,8 +69,7 @@ def main() -> int: compiled = module.compile() validate(compiled) - # TODO: Enable execution test once array lowering is fully supported - # run_int_fn(compiled, expected=9) + run_int_fn(compiled, expected=10) def test_unpack_tuple_starred(validate, run_int_fn): From 80e9a5736e54d724df2bd7bb7c5194f4d7a823aa Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 18 Dec 2024 14:19:22 +0000 Subject: [PATCH 16/17] Fix unpacking --- guppylang/compiler/stmt_compiler.py | 19 +++++++++++++++---- tests/integration/test_unpack.py | 21 ++++++++++++++++++++- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/guppylang/compiler/stmt_compiler.py b/guppylang/compiler/stmt_compiler.py index 6f92854e..9132268c 100644 --- a/guppylang/compiler/stmt_compiler.py +++ b/guppylang/compiler/stmt_compiler.py @@ -115,15 +115,26 @@ def pop( array: Wire, length: int, pats: list[ast.expr], from_left: bool ) -> tuple[Wire, int]: err = "Internal error: unpacking of iterable failed" - for pat in pats: + num_pats = len(pats) + # Pop the number of requested elements from the array + elts = [] + for i in range(num_pats): res = self.builder.add_op( - array_pop(opt_elt_ty, length, from_left), array + array_pop(opt_elt_ty, length - i, from_left), array ) [elt_opt, array] = build_unwrap(self.builder, res, err) [elt] = build_unwrap(self.builder, elt_opt, err) + elts.append(elt) + # Assign elements to the given patterns + for pat, elt in zip( + pats, + # Assignments are evaluated from left to right, so we need to assign in + # reverse order if we popped from the right + elts if from_left else reversed(elts), + strict=True, + ): self._assign(pat, elt) - length -= 1 - return array, length + return array, length - num_pats self.dfg[lhs.rhs_var.place] = port array = self.expr_compiler.visit_DesugaredArrayComp(lhs.compr) diff --git a/tests/integration/test_unpack.py b/tests/integration/test_unpack.py index 91bd6da6..69f019f7 100644 --- a/tests/integration/test_unpack.py +++ b/tests/integration/test_unpack.py @@ -69,7 +69,7 @@ def main() -> int: compiled = module.compile() validate(compiled) - run_int_fn(compiled, expected=10) + run_int_fn(compiled, expected=9) def test_unpack_tuple_starred(validate, run_int_fn): @@ -101,3 +101,22 @@ def main( return x, y, z, a, b, c validate(module.compile()) + + +def test_left_to_right(validate, run_int_fn): + module = GuppyModule("test") + + @guppy(module) + def left() -> int: + [x, x, *_] = range(10) + return x + + @guppy(module) + def right() -> int: + [*_, x, x] = range(10) + return x + + compiled = module.compile() + validate(compiled) + run_int_fn(compiled, fn_name="left", expected=1) + run_int_fn(compiled, fn_name="right", expected=9) From 752f922cfacf6fa6606473557c65f298d6cec870 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 18 Dec 2024 16:08:58 +0000 Subject: [PATCH 17/17] Use Some value --- guppylang/compiler/expr_compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index 5dbae38d..e724cdec 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -614,7 +614,7 @@ def python_value_to_hugr(v: Any, exp_ty: Type) -> hv.Value | None: vs = [python_value_to_hugr(elt, elem_ty) for elt in elts] if doesnt_contain_none(vs): opt_ty = ht.Option(elem_ty.to_hugr()) - opt_vs: list[hv.Value] = [hv.Sum(1, opt_ty, [v]) for v in vs] + opt_vs: list[hv.Value] = [hv.Some(v) for v in vs] return hugr.std.collections.array.ArrayVal(opt_vs, opt_ty) case _: # TODO replace with hugr protocol handling: https://github.com/CQCL/guppylang/issues/563