Skip to content

Commit

Permalink
Merge pull request #228 from egraphs-good/fix-loopnest
Browse files Browse the repository at this point in the history
Add ability to subsume default definitions
  • Loading branch information
saulshanabrook authored Oct 29, 2024
2 parents f79dee8 + 2e089f2 commit 863f99b
Show file tree
Hide file tree
Showing 12 changed files with 540 additions and 59 deletions.
3 changes: 3 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ _This project uses semantic versioning_

## UNRELEASED

- Fix pretty printing of lambda functions
- Add support for subsuming rewrite generated by default function and method definitions

## 8.0.1 (2024-10-24)

- Upgrade dependencies including [egglog](https://github.com/egraphs-good/egglog/compare/saulshanabrook:egg-smol:a555b2f5e82c684442775cc1a5da94b71930113c...b0db06832264c9b22694bd3de2bdacd55bbe9e32)
Expand Down
2 changes: 1 addition & 1 deletion docs/reference/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Then install the package in editable mode with the development dependencies:
uv sync --all-extras
```

Anytime you change the rust code, you can run `uv sync` to recompile the rust code.
Anytime you change the rust code, you can run `uv sync --reinstall-package egglog --all-extras` to force recompiling the rust code.

If you would like to download a new version of the visualizer source, run `make clean; make`. This will download
the most recent released version from the github actions artifact in the [egraph-visualizer](https://github.com/egraphs-good/egraph-visualizer) repo. It is checked in because it's a pain to get cargo to include only one git ignored file while ignoring the rest of the files that were ignored.
Expand Down
9 changes: 8 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,14 @@ array = [
"numba==0.59.1",
"llvmlite==0.42.0",
]
dev = ["ruff", "pre-commit", "mypy", "anywidget[dev]", "egglog[docs,test]"]
dev = [
"ruff",
"pre-commit",
"mypy",
"anywidget[dev]",
"egglog[docs,test]",
"jupyterlab",
]

test = [
"pytest",
Expand Down
1 change: 1 addition & 0 deletions python/egglog/declarations.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,7 @@ class RuleDecl:
class DefaultRewriteDecl:
ref: CallableRef
expr: ExprDecl
subsume: bool


RewriteOrRuleDecl: TypeAlias = RewriteDecl | BiRewriteDecl | RuleDecl | DefaultRewriteDecl
Expand Down
58 changes: 44 additions & 14 deletions python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def method(
unextractable: bool = False,
) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]:
return lambda fn: _WrappedMethod(
egg_fn, cost, default, merge, on_merge, fn, preserve, mutates_self, unextractable
egg_fn, cost, default, merge, on_merge, fn, preserve, mutates_self, unextractable, False
)

@overload
Expand Down Expand Up @@ -404,6 +404,7 @@ def method(
on_merge: Callable[[Any, Any], Iterable[ActionLike]] | None = None,
mutates_self: bool = False,
unextractable: bool = False,
subsume: bool = False,
) -> Callable[[CALLABLE], CALLABLE]: ...


Expand All @@ -417,6 +418,7 @@ def method(
on_merge: Callable[[EXPR, EXPR], Iterable[ActionLike]] | None = None,
mutates_self: bool = False,
unextractable: bool = False,
subsume: bool = False,
) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ...


Expand All @@ -430,11 +432,14 @@ def method(
preserve: bool = False,
mutates_self: bool = False,
unextractable: bool = False,
subsume: bool = False,
) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]:
"""
Any method can be decorated with this to customize it's behavior. This is only supported in classes which subclass :class:`Expr`.
"""
return lambda fn: _WrappedMethod(egg_fn, cost, default, merge, on_merge, fn, preserve, mutates_self, unextractable)
return lambda fn: _WrappedMethod(
egg_fn, cost, default, merge, on_merge, fn, preserve, mutates_self, unextractable, subsume
)


class _ExprMetaclass(type):
Expand Down Expand Up @@ -519,7 +524,9 @@ def _generate_class_decls( # noqa: C901,PLR0912
(inner_tp,) = v.__args__
type_ref = resolve_type_annotation(decls, inner_tp)
cls_decl.class_variables[k] = ConstantDecl(type_ref.to_just())
_add_default_rewrite(decls, ClassVariableRef(cls_name, k), type_ref, namespace.pop(k, None), ruleset)
_add_default_rewrite(
decls, ClassVariableRef(cls_name, k), type_ref, namespace.pop(k, None), ruleset, subsume=False
)
else:
msg = f"On class {cls_name}, for attribute '{k}', expected a ClassVar, but got {v}"
raise NotImplementedError(msg)
Expand All @@ -542,12 +549,12 @@ def _generate_class_decls( # noqa: C901,PLR0912
if is_init and cls_name in LIT_CLASS_NAMES:
continue
match method:
case _WrappedMethod(egg_fn, cost, default, merge, on_merge, fn, preserve, mutates, unextractable):
case _WrappedMethod(egg_fn, cost, default, merge, on_merge, fn, preserve, mutates, unextractable, subsume):
pass
case _:
egg_fn, cost, default, merge, on_merge = None, None, None, None, None
fn = method
unextractable, preserve = False, False
unextractable, preserve, subsume = False, False, False
mutates = method_name in ALWAYS_MUTATES_SELF
if preserve:
cls_decl.preserved_methods[method_name] = fn
Expand All @@ -572,7 +579,20 @@ def _generate_class_decls( # noqa: C901,PLR0912
continue

_, add_rewrite = _fn_decl(
decls, egg_fn, ref, fn, locals, default, cost, merge, on_merge, mutates, builtin, ruleset, unextractable
decls,
egg_fn,
ref,
fn,
locals,
default,
cost,
merge,
on_merge,
mutates,
builtin,
ruleset=ruleset,
unextractable=unextractable,
subsume=subsume,
)

if not builtin and not isinstance(ref, InitRef) and not mutates:
Expand Down Expand Up @@ -602,6 +622,7 @@ def function(
builtin: bool = False,
ruleset: Ruleset | None = None,
use_body_as_name: bool = False,
subsume: bool = False,
) -> Callable[[CALLABLE], CALLABLE]: ...


Expand All @@ -617,6 +638,7 @@ def function(
unextractable: bool = False,
ruleset: Ruleset | None = None,
use_body_as_name: bool = False,
subsume: bool = False,
) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ...


Expand Down Expand Up @@ -649,6 +671,7 @@ class _FunctionConstructor:
unextractable: bool = False
ruleset: Ruleset | None = None
use_body_as_name: bool = False
subsume: bool = False

def __call__(self, fn: Callable[..., RuntimeExpr]) -> RuntimeFunction:
return RuntimeFunction(*split_thunk(Thunk.fn(self.create_decls, fn)))
Expand All @@ -668,7 +691,8 @@ def create_decls(self, fn: Callable[..., RuntimeExpr]) -> tuple[Declarations, Ca
self.on_merge,
self.mutates_first_arg,
self.builtin,
self.ruleset,
ruleset=self.ruleset,
subsume=self.subsume,
unextractable=self.unextractable,
)
add_rewrite()
Expand All @@ -690,6 +714,7 @@ def _fn_decl(
on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None,
mutates_first_arg: bool,
is_builtin: bool,
subsume: bool,
ruleset: Ruleset | None = None,
unextractable: bool = False,
) -> tuple[CallableRef, Callable[[], None]]:
Expand Down Expand Up @@ -804,7 +829,7 @@ def _fn_decl(
res_ref = ref
decls.set_function_decl(ref, decl)
res_thunk = Thunk.fn(_create_default_value, decls, ref, fn, args, ruleset)
return res_ref, Thunk.fn(_add_default_rewrite_function, decls, res_ref, return_type, ruleset, res_thunk)
return res_ref, Thunk.fn(_add_default_rewrite_function, decls, res_ref, return_type, ruleset, res_thunk, subsume)


# Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value
Expand Down Expand Up @@ -871,7 +896,7 @@ def _constant_thunk(
type_ref = resolve_type_annotation(decls, tp)
callable_ref = ConstantRef(name)
decls._constants[name] = ConstantDecl(type_ref.to_just(), egg_name)
_add_default_rewrite(decls, callable_ref, type_ref, default_replacement, ruleset)
_add_default_rewrite(decls, callable_ref, type_ref, default_replacement, ruleset, subsume=False)
return decls, TypedExprDecl(type_ref.to_just(), CallDecl(callable_ref))


Expand All @@ -898,15 +923,21 @@ def _add_default_rewrite_function(
res_type: TypeOrVarRef,
ruleset: Ruleset | None,
value_thunk: Callable[[], object],
subsume: bool,
) -> None:
"""
Helper functions that resolves a value thunk to create the default value.
"""
_add_default_rewrite(decls, ref, res_type, value_thunk(), ruleset)
_add_default_rewrite(decls, ref, res_type, value_thunk(), ruleset, subsume)


def _add_default_rewrite(
decls: Declarations, ref: CallableRef, type_ref: TypeOrVarRef, default_rewrite: object, ruleset: Ruleset | None
decls: Declarations,
ref: CallableRef,
type_ref: TypeOrVarRef,
default_rewrite: object,
ruleset: Ruleset | None,
subsume: bool,
) -> None:
"""
Adds a default rewrite for the callable, if the default rewrite is not None
Expand All @@ -916,7 +947,7 @@ def _add_default_rewrite(
if default_rewrite is None:
return
resolved_value = resolve_literal(type_ref, default_rewrite, Thunk.value(decls))
rewrite_decl = DefaultRewriteDecl(ref, resolved_value.__egg_typed_expr__.expr)
rewrite_decl = DefaultRewriteDecl(ref, resolved_value.__egg_typed_expr__.expr, subsume)
if ruleset:
ruleset_decls = ruleset._current_egg_decls
ruleset_decl = ruleset.__egg_ruleset__
Expand Down Expand Up @@ -1341,8 +1372,6 @@ def saturate(
from .visualizer_widget import VisualizerWidget

def to_json() -> str:
if expr:
print(self.extract(expr))
return self._serialize(**kwargs).to_json()

egraphs = [to_json()]
Expand Down Expand Up @@ -1407,6 +1436,7 @@ class _WrappedMethod(Generic[P, EXPR]):
preserve: bool
mutates_self: bool
unextractable: bool
subsume: bool

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> EXPR:
msg = "We should never call a wrapped method. Did you forget to wrap the class?"
Expand Down
4 changes: 2 additions & 2 deletions python/egglog/egraph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: str) -> bindings._Command:
)
return bindings.RuleCommand(name or "", ruleset, rule)
# TODO: Replace with just constants value and looking at REF of function
case DefaultRewriteDecl(ref, expr):
case DefaultRewriteDecl(ref, expr, subsume):
decl = self.__egg_decls__.get_callable_decl(ref).to_function_decl()
sig = decl.signature
assert isinstance(sig, FunctionSignature)
Expand All @@ -144,7 +144,7 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: str) -> bindings._Command:
for name, tp in zip(sig.arg_names, sig.arg_types, strict=False)
)
rewrite_decl = RewriteDecl(
sig.semantic_return_type.to_just(), CallDecl(ref, arg_mapping), expr, (), False
sig.semantic_return_type.to_just(), CallDecl(ref, arg_mapping), expr, (), subsume
)
return self.command_to_egg(rewrite_decl, ruleset)
case _:
Expand Down
21 changes: 12 additions & 9 deletions python/egglog/exp/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import numpy as np

from egglog import *
from egglog.bindings import EggSmolError
from egglog.runtime import RuntimeExpr

from .program_gen import *
Expand Down Expand Up @@ -272,7 +271,6 @@ def var(cls, name: StringLike) -> TupleInt: ...

EMPTY: ClassVar[TupleInt]

@method(unextractable=True)
def __init__(self, length: IntLike, idx_fn: Callable[[Int], Int]) -> None: ...

@classmethod
Expand All @@ -287,6 +285,7 @@ def range(cls, stop: Int) -> TupleInt:
def from_vec(cls, vec: Vec[Int]) -> TupleInt:
return TupleInt(vec.length(), partial(index_vec_int, vec))

@method(subsume=True)
def __add__(self, other: TupleInt) -> TupleInt:
return TupleInt(
self.length() + other.length(),
Expand All @@ -308,13 +307,13 @@ def fold(self, init: Int, f: Callable[[Int, Int], Int]) -> Int: ...

def fold_boolean(self, init: Boolean, f: Callable[[Boolean, Int], Boolean]) -> Boolean: ...

@method(subsume=True)
def contains(self, i: Int) -> Boolean:
return self.fold_boolean(FALSE, lambda acc, j: acc | (i == j))

@method(cost=100)
def filter(self, f: Callable[[Int], Boolean]) -> TupleInt: ...

@method(cost=100)
@method(subsume=True)
def map(self, f: Callable[[Int], Int]) -> TupleInt:
return TupleInt(self.length(), lambda i: f(self[i]))

Expand Down Expand Up @@ -372,7 +371,7 @@ def _tuple_int(
ne(k).to(i64(0)),
),
# Empty
rewrite(TupleInt.EMPTY).to(TupleInt(0, bottom_indexing)),
rewrite(TupleInt.EMPTY, subsume=True).to(TupleInt(0, bottom_indexing)),
# if_
rewrite(TupleInt.if_(TRUE, ti, ti2)).to(ti),
rewrite(TupleInt.if_(FALSE, ti, ti2)).to(ti2),
Expand All @@ -388,13 +387,16 @@ def var(cls, name: StringLike) -> TupleTupleInt: ...
def __init__(self, length: IntLike, idx_fn: Callable[[Int], TupleInt]) -> None: ...

@classmethod
@method(subsume=True)
def single(cls, i: TupleInt) -> TupleTupleInt:
return TupleTupleInt(Int(1), lambda _: i)

@classmethod
@method(subsume=True)
def from_vec(cls, vec: Vec[Int]) -> TupleInt:
return TupleInt(vec.length(), partial(index_vec_int, vec))

@method(subsume=True)
def __add__(self, other: TupleTupleInt) -> TupleTupleInt:
return TupleTupleInt(
self.length() + other.length(),
Expand Down Expand Up @@ -732,7 +734,7 @@ def _tuple_value(
rewrite(TupleValue.EMPTY.includes(v)).to(FALSE),
rewrite(TupleValue(v).includes(v)).to(TRUE),
rewrite(TupleValue(v).includes(v2)).to(FALSE, ne(v).to(v2)),
rewrite((ti + ti2).includes(v)).to(ti.includes(v) | ti2.includes(v)),
rewrite((ti + ti2).includes(v), subsume=True).to(ti.includes(v) | ti2.includes(v)),
]


Expand Down Expand Up @@ -1539,13 +1541,14 @@ def try_evaling(expr: Expr, prim_expr: i64 | Bool) -> int | bool:
egraph.run(array_api_schedule)
try:
extracted = egraph.extract(prim_expr)
except EggSmolError as exc:
# Catch base exceptions so that we catch rust panics which happen when trying to extract subsumed nodes
except BaseException as exc:
egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
# Try giving some context, by showing the smallest version of the larger expression
try:
expr_extracted = egraph.extract(expr)
except EggSmolError as inner_exc:
except BaseException as inner_exc:
raise ValueError(f"Cannot simplify {expr}") from inner_exc
egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
msg = f"Cannot simplify to primitive {expr_extracted}"
raise ValueError(msg) from exc
return egraph.eval(extracted)
Expand Down
Loading

0 comments on commit 863f99b

Please sign in to comment.