Skip to content

Commit

Permalink
feat: Add Option type to standard library
Browse files Browse the repository at this point in the history
  • Loading branch information
mark-koch committed Dec 9, 2024
1 parent 7f24264 commit 10eac8a
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 5 deletions.
13 changes: 9 additions & 4 deletions guppylang/decorator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import ast
import inspect
from collections.abc import Callable, KeysView
from collections.abc import Callable, KeysView, Sequence
from dataclasses import dataclass, field
from pathlib import Path
from types import ModuleType
Expand Down Expand Up @@ -48,6 +48,8 @@
sphinx_running,
)
from guppylang.span import SourceMap
from guppylang.tys.arg import Argument
from guppylang.tys.param import Parameter
from guppylang.tys.subst import Inst
from guppylang.tys.ty import NumericType

Expand Down Expand Up @@ -197,10 +199,11 @@ def dec(c: type) -> type:
@pretty_errors
def type(
self,
hugr_ty: ht.Type,
hugr_ty: ht.Type | Callable[[Sequence[Argument]], ht.Type],
name: str = "",
linear: bool = False,
bound: ht.TypeBound | None = None,
params: Sequence[Parameter] | None = None,
module: GuppyModule | None = None,
) -> OpaqueTypeDecorator:
"""Decorator to annotate a class definitions as Guppy types.
Expand All @@ -212,14 +215,16 @@ def type(
mod = module or self.get_module()
mod._instance_func_buffer = {}

mk_hugr_ty = (lambda _: hugr_ty) if isinstance(hugr_ty, ht.Type) else hugr_ty

def dec(c: type) -> OpaqueTypeDef:
defn = OpaqueTypeDef(
DefId.fresh(mod),
name or c.__name__,
None,
[],
params or [],
linear,
lambda _: hugr_ty,
mk_hugr_ty,
bound,
)
mod.register_def(defn)
Expand Down
8 changes: 7 additions & 1 deletion guppylang/definition/ty.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import abstractmethod
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from hugr import tys

Expand Down Expand Up @@ -41,6 +41,12 @@ class OpaqueTypeDef(TypeDef, CompiledDef):
to_hugr: Callable[[Sequence[Argument]], tys.Type]
bound: tys.TypeBound | None = None

def __getitem__(self, item: Any) -> "OpaqueTypeDef":
"""Dummy implementation to allow generic instantiations in type signatures that
are evaluated by the Python interpreter.
"""
return self

def check_instantiate(
self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None
) -> OpaqueType:
Expand Down
61 changes: 61 additions & 0 deletions guppylang/std/_internal/compiler/option.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from abc import ABC

from hugr import Wire, ops
from hugr import tys as ht
from hugr import val as hv

from guppylang.definition.custom import CustomCallCompiler, CustomInoutCallCompiler
from guppylang.definition.value import CallReturnWires
from guppylang.error import InternalGuppyError
from guppylang.std._internal.compiler.prelude import build_unwrap
from guppylang.tys.arg import TypeArg


class OptionCompiler(CustomInoutCallCompiler, ABC):
"""Abstract base class for compilers for `Option` methods."""

@property
def option_ty(self) -> ht.Option:
match self.type_args:
case [TypeArg(ty)]:
return ht.Option(ty.to_hugr())
case _:
raise InternalGuppyError("Invalid type args for Option op")


class OptionConstructor(OptionCompiler, CustomCallCompiler):
"""Compiler for the `Option` constructors `none` and `some`."""

def __init__(self, tag: int):
self.tag = tag

def compile(self, args: list[Wire]) -> list[Wire]:
return [self.builder.add_op(ops.Tag(self.tag, self.option_ty), *args)]


class OptionTestCompiler(OptionCompiler):
"""Compiler for the `Option.is_none` and `Option.is_none` methods."""

def __init__(self, tag: int):
self.tag = tag

def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
[opt] = args
cond = self.builder.add_conditional(opt)
for i in [0, 1]:
with cond.add_case(i) as case:
val = hv.TRUE if i == self.tag else hv.FALSE
opt = case.add_op(ops.Tag(i, self.option_ty), *case.inputs())
case.set_outputs(case.load(val), opt)
[res, opt] = cond.outputs()
return CallReturnWires(regular_returns=[res], inout_returns=[opt])


class OptionUnwrapCompiler(OptionCompiler, CustomCallCompiler):
"""Compiler for the `Option.unwrap` method."""

def compile(self, args: list[Wire]) -> list[Wire]:
[opt] = args
return list(
build_unwrap(self.builder, opt, "Option.unwrap: value is `None`").outputs()
)
61 changes: 61 additions & 0 deletions guppylang/std/option.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from collections.abc import Sequence
from typing import Generic, no_type_check

import hugr.tys as ht

from guppylang.decorator import guppy
from guppylang.error import InternalGuppyError
from guppylang.std._internal.compiler.option import (
OptionConstructor,
OptionTestCompiler,
OptionUnwrapCompiler,
)
from guppylang.std.builtins import owned
from guppylang.tys.arg import Argument, TypeArg
from guppylang.tys.param import TypeParam


def _option_to_hugr(args: Sequence[Argument]) -> ht.Type:
match args:
case [TypeArg(ty)]:
return ht.Option(ty.to_hugr())
case _:
raise InternalGuppyError("Invalid type args for Option")


T = guppy.type_var("T", linear=True)


@guppy.type(_option_to_hugr, params=[TypeParam(0, "T", can_be_linear=True)])
class Option(Generic[T]): # type: ignore[misc]
"""Represents an optional value."""

@guppy.custom(OptionTestCompiler(0))
@no_type_check
def is_none(self: "Option[T]") -> bool:
"""Returns `True` if the option is a `none` value."""

@guppy.custom(OptionTestCompiler(1))
@no_type_check
def is_some(self: "Option[T]") -> bool:
"""Returns `True` if the option is a `some` value."""

@guppy.custom(OptionUnwrapCompiler())
@no_type_check
def unwrap(self: "Option[T]" @ owned) -> T:
"""Returns the contained `some` value, consuming `self`.
Panics if the option is a `none` value.
"""


@guppy.custom(OptionConstructor(0))
@no_type_check
def none() -> Option[T]:
"""Constructs a `none` optional value."""


@guppy.custom(OptionConstructor(1))
@no_type_check
def some(value: T @ owned) -> Option[T]:
"""Constructs a `some` optional value."""
36 changes: 36 additions & 0 deletions tests/integration/test_option.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.option import Option, none, some


def test_none(validate, run_int_fn):
module = GuppyModule("test_range")
module.load(Option, none)

@guppy(module)
def main() -> int:
x: Option[int] = none()
is_none = 10 if x.is_none() else 0
is_some = 1 if x.is_some() else 0
return is_none + is_some

compiled = module.compile()
validate(compiled)
run_int_fn(compiled, expected=10)


def test_some_unwrap(validate, run_int_fn):
module = GuppyModule("test_range")
module.load(Option, some)

@guppy(module)
def main() -> int:
x: Option[int] = some(42)
is_none = 1 if x.is_none() else 0
is_some = x.unwrap() if x.is_some() else 0
return is_none + is_some

compiled = module.compile()
validate(compiled)
run_int_fn(compiled, expected=42)

0 comments on commit 10eac8a

Please sign in to comment.