Skip to content

Commit

Permalink
[new] Wire up the remaining integration tests for export
Browse files Browse the repository at this point in the history
We also fix the ycombinator test case
  • Loading branch information
doug-q committed Sep 27, 2023
1 parent e1c4cd5 commit aec1216
Show file tree
Hide file tree
Showing 11 changed files with 100 additions and 105 deletions.
5 changes: 3 additions & 2 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ def export_test_cases_dir(request):

@pytest.fixture
def validate(request, export_test_cases_dir):
def validate_impl(hugr):
def validate_impl(hugr,name=None):
bs = hugr.serialize()
util.validate_bytes(bs)
if export_test_cases_dir:
export_file = export_test_cases_dir / f"{request.node.name}.msgpack"
file_name = f"{request.node.name}{f'_{name}' if name else ''}.msgpack"
export_file = export_test_cases_dir / file_name
export_file.write_bytes(bs)
return validate_impl
19 changes: 9 additions & 10 deletions tests/integration/test_arithmetic.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
from guppy.compiler import guppy
from tests.integration.util import validate


def test_arith_basic():
def test_arith_basic(validate):
@guppy
def add(x: int, y: int) -> int:
return x + y

validate(add)


def test_constant():
def test_constant(validate):
@guppy
def const() -> float:
return 42.0

validate(const)


def test_ann_assign():
def test_ann_assign(validate):
@guppy
def add(x: int) -> int:
x += 1
Expand All @@ -27,15 +26,15 @@ def add(x: int) -> int:
validate(add)


def test_float_coercion():
def test_float_coercion(validate):
@guppy
def coerce(x: int, y: float) -> float:
return x * y

validate(coerce)


def test_arith_big():
def test_arith_big(validate):
@guppy
def arith(x: int, y: float, z: int) -> bool:
a = x // y + 3 * z
Expand All @@ -45,7 +44,7 @@ def arith(x: int, y: float, z: int) -> bool:
validate(arith)


def test_shortcircuit_assign1():
def test_shortcircuit_assign1(validate):
@guppy
def foo(x: bool, y: int) -> bool:
if (z := x) and y > 0:
Expand All @@ -55,7 +54,7 @@ def foo(x: bool, y: int) -> bool:
validate(foo)


def test_shortcircuit_assign2():
def test_shortcircuit_assign2(validate):
@guppy
def foo(x: bool, y: int) -> bool:
if y > 0 and (z := x):
Expand All @@ -65,7 +64,7 @@ def foo(x: bool, y: int) -> bool:
validate(foo)


def test_shortcircuit_assign3():
def test_shortcircuit_assign3(validate):
@guppy
def foo(x: bool, y: int) -> bool:
if (z := x) or y > 0:
Expand All @@ -75,7 +74,7 @@ def foo(x: bool, y: int) -> bool:
validate(foo)


def test_shortcircuit_assign4():
def test_shortcircuit_assign4(validate):
@guppy
def foo(x: bool, y: int) -> bool:
if y > 0 or (z := x):
Expand Down
9 changes: 4 additions & 5 deletions tests/integration/test_call.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from guppy.compiler import guppy, GuppyModule
from tests.integration.util import validate


def test_call():
def test_call(validate):
module = GuppyModule("module")

@module
Expand All @@ -16,7 +15,7 @@ def bar() -> int:
validate(module.compile(exit_on_error=True))


def test_call_back(tmp_path):
def test_call_back(validate):
module = GuppyModule("module")

@module
Expand All @@ -30,15 +29,15 @@ def bar(x: int) -> int:
validate(module.compile(exit_on_error=True))


def test_recursion():
def test_recursion(validate):
@guppy
def main(x: int) -> int:
return main(x)

validate(main)


def test_mutual_recursion():
def test_mutual_recursion(validate):
module = GuppyModule("module")

@module
Expand Down
16 changes: 8 additions & 8 deletions tests/integration/test_functional.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import pytest

from guppy.compiler import guppy
from tests.integration.util import validate, functional, _
from tests.integration.util import functional, _


@pytest.mark.skip()
def test_if_no_else():
def test_if_no_else(validate):
@guppy
def foo(x: bool, y: int) -> int:
_@functional
Expand All @@ -17,7 +17,7 @@ def foo(x: bool, y: int) -> int:


@pytest.mark.skip()
def test_if_else():
def test_if_else(validate):
@guppy
def foo(x: bool, y: int) -> int:
_@functional
Expand All @@ -31,7 +31,7 @@ def foo(x: bool, y: int) -> int:


@pytest.mark.skip()
def test_if_elif():
def test_if_elif(validate):
@guppy
def foo(x: bool, y: int) -> int:
_@functional
Expand All @@ -45,7 +45,7 @@ def foo(x: bool, y: int) -> int:


@pytest.mark.skip()
def test_if_elif_else():
def test_if_elif_else(validate):
@guppy
def foo(x: bool, y: int) -> int:
_@functional
Expand All @@ -61,7 +61,7 @@ def foo(x: bool, y: int) -> int:


@pytest.mark.skip()
def test_infinite_loop():
def test_infinite_loop(validate):
@guppy
def foo() -> int:
while True:
Expand All @@ -72,7 +72,7 @@ def foo() -> int:


@pytest.mark.skip()
def test_counting_loop():
def test_counting_loop(validate):
@guppy
def foo(i: int) -> int:
while i > 0:
Expand All @@ -83,7 +83,7 @@ def foo(i: int) -> int:


@pytest.mark.skip()
def test_nested_loop():
def test_nested_loop(validate):
@guppy
def foo(x: int, y: int) -> int:
p = 0
Expand Down
16 changes: 8 additions & 8 deletions tests/integration/test_higher_order.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Callable

from guppy.compiler import guppy, GuppyModule
from tests.integration.util import validate


def test_basic():
def test_basic(validate):
module = GuppyModule("test")

@module
Expand All @@ -18,7 +17,7 @@ def foo() -> Callable[[int], bool]:
validate(module.compile())


def test_call_1():
def test_call_1(validate):
module = GuppyModule("test")

@module
Expand All @@ -36,7 +35,7 @@ def baz() -> bool:
validate(module.compile())


def test_call_2():
def test_call_2(validate):
module = GuppyModule("test")

@module
Expand All @@ -54,7 +53,7 @@ def baz(y: int) -> None:
validate(module.compile())


def test_nested():
def test_nested(validate):
@guppy
def foo(x: int) -> Callable[[int], bool]:
def bar(y: int) -> bool:
Expand All @@ -65,7 +64,7 @@ def bar(y: int) -> bool:
validate(foo)


def test_curry():
def test_curry(validate):
module = GuppyModule("curry")

@module
Expand All @@ -78,7 +77,7 @@ def h(y: int) -> bool:

@module
def uncurry(f: Callable[[int], Callable[[int], bool]]) -> Callable[[int, int], bool]:
def g(x: int, y: int):
def g(x: int, y: int) -> bool:
return f(x)(y)
return g

Expand All @@ -94,8 +93,9 @@ def main(x: int, y: int) -> None:
uncurried(x, y)
curry(uncurry(curry(gt)))(y)(x)

validate(module.compile())

def test_y_combinator():
def test_y_combinator(validate):
module = GuppyModule("fib")

@module
Expand Down
Loading

0 comments on commit aec1216

Please sign in to comment.