Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Option type to standard library #696

Merged
merged 3 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions guppylang/decorator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import ast
import inspect
from collections.abc import Callable, KeysView
from collections.abc import Callable, KeysView, Sequence
from dataclasses import dataclass, field
from pathlib import Path
from types import ModuleType
Expand Down Expand Up @@ -48,6 +48,8 @@
sphinx_running,
)
from guppylang.span import SourceMap
from guppylang.tys.arg import Argument
from guppylang.tys.param import Parameter
from guppylang.tys.subst import Inst
from guppylang.tys.ty import NumericType

Expand Down Expand Up @@ -197,29 +199,36 @@ def dec(c: type) -> type:
@pretty_errors
def type(
self,
hugr_ty: ht.Type,
hugr_ty: ht.Type | Callable[[Sequence[Argument]], ht.Type],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

worth adding to the docstring to explain what this callable is

name: str = "",
linear: bool = False,
bound: ht.TypeBound | None = None,
params: Sequence[Parameter] | None = None,
module: GuppyModule | None = None,
) -> OpaqueTypeDecorator:
"""Decorator to annotate a class definitions as Guppy types.

Requires the static Hugr translation of the type. Additionally, the type can be
marked as linear. All `@guppy` annotated functions on the class are turned into
instance functions.

For non-generic types, the Hugr representation can be passed as a static value.
For generic types, a callable may be passed that takes the type arguments of a
concrete instantiation.
"""
mod = module or self.get_module()
mod._instance_func_buffer = {}

mk_hugr_ty = (lambda _: hugr_ty) if isinstance(hugr_ty, ht.Type) else hugr_ty

def dec(c: type) -> OpaqueTypeDef:
defn = OpaqueTypeDef(
DefId.fresh(mod),
name or c.__name__,
None,
[],
params or [],
linear,
lambda _: hugr_ty,
mk_hugr_ty,
bound,
)
mod.register_def(defn)
Expand Down
8 changes: 7 additions & 1 deletion guppylang/definition/ty.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import abstractmethod
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from hugr import tys

Expand Down Expand Up @@ -41,6 +41,12 @@ class OpaqueTypeDef(TypeDef, CompiledDef):
to_hugr: Callable[[Sequence[Argument]], tys.Type]
bound: tys.TypeBound | None = None

def __getitem__(self, item: Any) -> "OpaqueTypeDef":
"""Dummy implementation to allow generic instantiations in type signatures that
are evaluated by the Python interpreter.
"""
return self

def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
) -> OpaqueType:
Expand Down
60 changes: 60 additions & 0 deletions guppylang/std/_internal/compiler/option.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from abc import ABC

from hugr import Wire, ops
from hugr import tys as ht
from hugr import val as hv

from guppylang.definition.custom import CustomCallCompiler, CustomInoutCallCompiler
from guppylang.definition.value import CallReturnWires
from guppylang.error import InternalGuppyError
from guppylang.std._internal.compiler.prelude import build_unwrap
from guppylang.tys.arg import TypeArg


class OptionCompiler(CustomInoutCallCompiler, ABC):
"""Abstract base class for compilers for `Option` methods."""

@property
def option_ty(self) -> ht.Option:
match self.type_args:
case [TypeArg(ty)]:
return ht.Option(ty.to_hugr())
case _:
raise InternalGuppyError("Invalid type args for Option op")


class OptionConstructor(OptionCompiler, CustomCallCompiler):
"""Compiler for the `Option` constructors `nothing` and `some`."""

def __init__(self, tag: int):
self.tag = tag

def compile(self, args: list[Wire]) -> list[Wire]:
return [self.builder.add_op(ops.Tag(self.tag, self.option_ty), *args)]


class OptionTestCompiler(OptionCompiler):
"""Compiler for the `Option.is_nothing` and `Option.is_some` methods."""

def __init__(self, tag: int):
self.tag = tag

def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
[opt] = args
cond = self.builder.add_conditional(opt)
for i in [0, 1]:
with cond.add_case(i) as case:
val = hv.TRUE if i == self.tag else hv.FALSE
opt = case.add_op(ops.Tag(i, self.option_ty), *case.inputs())
case.set_outputs(case.load(val), opt)
[res, opt] = cond.outputs()
return CallReturnWires(regular_returns=[res], inout_returns=[opt])


class OptionUnwrapCompiler(OptionCompiler, CustomCallCompiler):
"""Compiler for the `Option.unwrap` method."""

def compile(self, args: list[Wire]) -> list[Wire]:
[opt] = args
err = "Option.unwrap: value is `Nothing`"
return list(build_unwrap(self.builder, opt, err).outputs())
61 changes: 61 additions & 0 deletions guppylang/std/option.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from collections.abc import Sequence
from typing import Generic, no_type_check

import hugr.tys as ht

from guppylang.decorator import guppy
from guppylang.error import InternalGuppyError
from guppylang.std._internal.compiler.option import (
OptionConstructor,
OptionTestCompiler,
OptionUnwrapCompiler,
)
from guppylang.std.builtins import owned
from guppylang.tys.arg import Argument, TypeArg
from guppylang.tys.param import TypeParam


def _option_to_hugr(args: Sequence[Argument]) -> ht.Type:
match args:
case [TypeArg(ty)]:
return ht.Option(ty.to_hugr())
case _:
raise InternalGuppyError("Invalid type args for Option")


T = guppy.type_var("T", linear=True)


@guppy.type(_option_to_hugr, params=[TypeParam(0, "T", can_be_linear=True)])
class Option(Generic[T]): # type: ignore[misc]
"""Represents an optional value."""

@guppy.custom(OptionTestCompiler(0))
@no_type_check
def is_nothing(self: "Option[T]") -> bool:
"""Returns `True` if the option is a `nothing` value."""

@guppy.custom(OptionTestCompiler(1))
@no_type_check
def is_some(self: "Option[T]") -> bool:
"""Returns `True` if the option is a `some` value."""

@guppy.custom(OptionUnwrapCompiler())
@no_type_check
def unwrap(self: "Option[T]" @ owned) -> T:
"""Returns the contained `some` value, consuming `self`.

Panics if the option is a `nothing` value.
"""


@guppy.custom(OptionConstructor(0))
@no_type_check
def nothing() -> Option[T]:
"""Constructs a `nothing` optional value."""


@guppy.custom(OptionConstructor(1))
@no_type_check
def some(value: T @ owned) -> Option[T]:
"""Constructs a `some` optional value."""
36 changes: 36 additions & 0 deletions tests/integration/test_option.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.option import Option, nothing, some


def test_none(validate, run_int_fn):
module = GuppyModule("test_range")
module.load(Option, nothing)

@guppy(module)
def main() -> int:
x: Option[int] = nothing()
is_none = 10 if x.is_nothing() else 0
is_some = 1 if x.is_some() else 0
return is_none + is_some

compiled = module.compile()
validate(compiled)
run_int_fn(compiled, expected=10)


def test_some_unwrap(validate, run_int_fn):
module = GuppyModule("test_range")
module.load(Option, some)

@guppy(module)
def main() -> int:
x: Option[int] = some(42)
is_none = 1 if x.is_nothing() else 0
is_some = x.unwrap() if x.is_some() else 0
return is_none + is_some

compiled = module.compile()
validate(compiled)
run_int_fn(compiled, expected=42)

Loading