diff --git a/docs/changelog.md b/docs/changelog.md index 730051bc..d8b07ec1 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -6,6 +6,7 @@ _This project uses semantic versioning_ - Adds ability to use anonymous functions where callables are needed. These are automatically transformed to egglog functions with default rewrites. +- Add `sort` and `fn` mid level commands ## 7.2.0 (2024-05-23) diff --git a/docs/explanation/2024_06_18_midlevel.ipynb b/docs/explanation/2024_06_18_midlevel.ipynb new file mode 100644 index 00000000..b41b5203 --- /dev/null +++ b/docs/explanation/2024_06_18_midlevel.ipynb @@ -0,0 +1,615 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```{post} 2024-06-18\n", + ":author: Saul\n", + "```\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Mid Level IR\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Philip Zucker recently wrote a blog post about using egglog to do simplifications by converting from the Z3 API, [\"Conditional Simplification of Z3py Expressions with Egglog\"](https://www.philipzucker.com/egglog_z3_simp/).\n", + "\n", + "In it, he uses the egglog bindings, but only the low level bindings, as a way to execute egglog programs as strings. So he constructs expressions and rewrites using the Z3 API, manually converts them to egglog program strings, runs them in egglog, and then gets them back.\n", + "\n", + "He prefers to use the Z3 API over the high level egglog API, because it is a simpler functional interface, creating sorts and functions through functions, instead of by using Python classes:\n", + "\n", + "> Saul has been making the egglog python bindings https://egglog-python.readthedocs.io/latest/ taking a very meta highly integrated approach. I kind of just want it to look like z3 though. It’s very interesting and I’m haunted by the idea that I am a stodgy old man and they’re right. I think it’s biggest demerit is that it is very novel. I’ve never seen an interface like it. From a research perspective this is a plus. It is very cool that they are getting the python typechecker and the embedded dsl to play ball. I dunno https://egraphs.zulipchat.com/#narrow/stream/375765-egg.2Fegglog/topic/egglog.20python.20midlevel.20api/near/421919681\n", + ">\n", + "> It turns out, it is simple enough to have my cake too. The pyegglog supports the raw bindings and I’ve been spending a decent amount of time serializing Z3 ASTs to other smt or tptp solvers. Translating to egglog programs is easy.\n", + "\n", + "In order to continue this conversation, I added two functions `sort` and `fn`, to the egglog bindings, to see how those could be used to do something similar. The overall hope is that we could re-use some of the existing mechanisms in the high level API, even with a slightly different external interface. Here is an example of using them to re-create the example Phil wrote up:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
one\n",
+ "
TypedExprDecl(\n", + "│ tp=JustTypeRef(name='Math', args=()),\n", + "│ expr=CallDecl(\n", + "│ │ callable=FunctionRef(name='add'),\n", + "│ │ args=(\n", + "│ │ │ TypedExprDecl(\n", + "│ │ │ │ tp=JustTypeRef(name='Math', args=()),\n", + "│ │ │ │ expr=CallDecl(callable=ConstantRef(name='zero'), args=(), bound_tp_params=None)\n", + "│ │ │ ),\n", + "│ │ │ TypedExprDecl(\n", + "│ │ │ │ tp=JustTypeRef(name='Math', args=()),\n", + "│ │ │ │ expr=CallDecl(callable=ConstantRef(name='one'), args=(), bound_tp_params=None)\n", + "│ │ │ )\n", + "│ │ ),\n", + "│ │ bound_tp_params=None\n", + "│ )\n", + ")\n", + "\n" + ], + "text/plain": [ + "\u001b[1;35mTypedExprDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[33mtp\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ \u001b[0m\u001b[33mexpr\u001b[0m=\u001b[1;35mCallDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[33mcallable\u001b[0m=\u001b[1;35mFunctionRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'add'\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[1;35mTypedExprDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33mtp\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33mexpr\u001b[0m=\u001b[1;35mCallDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mcallable\u001b[0m=\u001b[1;35mConstantRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'zero'\u001b[0m\u001b[1m)\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m, \u001b[33mbound_tp_params\u001b[0m=\u001b[3;35mNone\u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[1;35mTypedExprDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33mtp\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33mexpr\u001b[0m=\u001b[1;35mCallDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mcallable\u001b[0m=\u001b[1;35mConstantRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'one'\u001b[0m\u001b[1m)\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m, \u001b[33mbound_tp_params\u001b[0m=\u001b[3;35mNone\u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[33mbound_tp_params\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Declarations(\n", + "│ _functions={\n", + "│ │ 'add': FunctionDecl(\n", + "│ │ │ signature=FunctionSignature(\n", + "│ │ │ │ arg_types=(TypeRefWithVars(name='Math', args=()), TypeRefWithVars(name='Math', args=())),\n", + "│ │ │ │ arg_names=('__0', '__1'),\n", + "│ │ │ │ arg_defaults=(None, None),\n", + "│ │ │ │ return_type=TypeRefWithVars(name='Math', args=()),\n", + "│ │ │ │ var_arg_type=None\n", + "│ │ │ ),\n", + "│ │ │ builtin=False,\n", + "│ │ │ egg_name=None,\n", + "│ │ │ cost=None,\n", + "│ │ │ default=None,\n", + "│ │ │ on_merge=(),\n", + "│ │ │ merge=None,\n", + "│ │ │ unextractable=False\n", + "│ │ )\n", + "│ },\n", + "│ _constants={\n", + "│ │ 'zero': ConstantDecl(type_ref=JustTypeRef(name='Math', args=()), egg_name=None),\n", + "│ │ 'one': ConstantDecl(type_ref=JustTypeRef(name='Math', args=()), egg_name=None)\n", + "│ },\n", + "│ _classes={\n", + "│ │ 'Math': ClassDecl(\n", + "│ │ │ egg_name=None,\n", + "│ │ │ type_vars=(),\n", + "│ │ │ builtin=False,\n", + "│ │ │ init=None,\n", + "│ │ │ class_methods={},\n", + "│ │ │ class_variables={},\n", + "│ │ │ methods={},\n", + "│ │ │ properties={},\n", + "│ │ │ preserved_methods={}\n", + "│ │ )\n", + "│ },\n", + "│ _rulesets={'': RulesetDecl(rules=[])}\n", + ")\n", + "\n" + ], + "text/plain": [ + "\u001b[1;35mDeclarations\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[33m_functions\u001b[0m=\u001b[1m{\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[32m'add'\u001b[0m: \u001b[1;35mFunctionDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33msignature\u001b[0m=\u001b[1;35mFunctionSignature\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33marg_types\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;35mTypeRefWithVars\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m, \u001b[1;35mTypeRefWithVars\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33marg_names\u001b[0m=\u001b[1m(\u001b[0m\u001b[32m'__0'\u001b[0m, \u001b[32m'__1'\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33marg_defaults\u001b[0m=\u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m, \u001b[3;35mNone\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33mreturn_type\u001b[0m=\u001b[1;35mTypeRefWithVars\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33mvar_arg_type\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mbuiltin\u001b[0m=\u001b[3;91mFalse\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33megg_name\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mcost\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mdefault\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mon_merge\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mmerge\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33munextractable\u001b[0m=\u001b[3;91mFalse\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ \u001b[0m\u001b[33m_constants\u001b[0m=\u001b[1m{\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[32m'zero'\u001b[0m: \u001b[1;35mConstantDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mtype_ref\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m, \u001b[33megg_name\u001b[0m=\u001b[3;35mNone\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[32m'one'\u001b[0m: \u001b[1;35mConstantDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mtype_ref\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m, \u001b[33megg_name\u001b[0m=\u001b[3;35mNone\u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ \u001b[0m\u001b[33m_classes\u001b[0m=\u001b[1m{\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[32m'Math'\u001b[0m: \u001b[1;35mClassDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33megg_name\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mtype_vars\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mbuiltin\u001b[0m=\u001b[3;91mFalse\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33minit\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mclass_methods\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mclass_variables\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mmethods\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mproperties\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mpreserved_methods\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ \u001b[0m\u001b[33m_rulesets\u001b[0m=\u001b[1m{\u001b[0m\u001b[32m''\u001b[0m: \u001b[1;35mRulesetDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mrules\u001b[0m=\u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\u001b[1m}\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from rich.pretty import pprint\n", + "\n", + "x = add(zero, one)\n", + "\n", + "pprint(x.__egg_typed_expr__)\n", + "pprint(x.__egg_decls__)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The same is true of the rulesets:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
'ruleset_4743265072'\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[32m'ruleset_4743265072'\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "Declarations(\n", + "│ _functions={\n", + "│ │ 'add': FunctionDecl(\n", + "│ │ │ signature=FunctionSignature(\n", + "│ │ │ │ arg_types=(TypeRefWithVars(name='Math', args=()), TypeRefWithVars(name='Math', args=())),\n", + "│ │ │ │ arg_names=('__0', '__1'),\n", + "│ │ │ │ arg_defaults=(None, None),\n", + "│ │ │ │ return_type=TypeRefWithVars(name='Math', args=()),\n", + "│ │ │ │ var_arg_type=None\n", + "│ │ │ ),\n", + "│ │ │ builtin=False,\n", + "│ │ │ egg_name=None,\n", + "│ │ │ cost=None,\n", + "│ │ │ default=None,\n", + "│ │ │ on_merge=(),\n", + "│ │ │ merge=None,\n", + "│ │ │ unextractable=False\n", + "│ │ )\n", + "│ },\n", + "│ _constants={'zero': ConstantDecl(type_ref=JustTypeRef(name='Math', args=()), egg_name=None)},\n", + "│ _classes={\n", + "│ │ 'Math': ClassDecl(\n", + "│ │ │ egg_name=None,\n", + "│ │ │ type_vars=(),\n", + "│ │ │ builtin=False,\n", + "│ │ │ init=None,\n", + "│ │ │ class_methods={},\n", + "│ │ │ class_variables={},\n", + "│ │ │ methods={},\n", + "│ │ │ properties={},\n", + "│ │ │ preserved_methods={}\n", + "│ │ )\n", + "│ },\n", + "│ _rulesets={\n", + "│ │ '': RulesetDecl(rules=[]),\n", + "│ │ 'ruleset_4743265072': RulesetDecl(\n", + "│ │ │ rules=[\n", + "│ │ │ │ RewriteDecl(\n", + "│ │ │ │ │ tp=JustTypeRef(name='Math', args=()),\n", + "│ │ │ │ │ lhs=CallDecl(\n", + "│ │ │ │ │ │ callable=FunctionRef(name='add'),\n", + "│ │ │ │ │ │ args=(\n", + "│ │ │ │ │ │ │ TypedExprDecl(tp=JustTypeRef(name='Math', args=()), expr=VarDecl(name='x')),\n", + "│ │ │ │ │ │ │ TypedExprDecl(\n", + "│ │ │ │ │ │ │ │ tp=JustTypeRef(name='Math', args=()),\n", + "│ │ │ │ │ │ │ │ expr=CallDecl(callable=ConstantRef(name='zero'), args=(), bound_tp_params=None)\n", + "│ │ │ │ │ │ │ )\n", + "│ │ │ │ │ │ ),\n", + "│ │ │ │ │ │ bound_tp_params=None\n", + "│ │ │ │ │ ),\n", + "│ │ │ │ │ rhs=VarDecl(name='x'),\n", + "│ │ │ │ │ conditions=(),\n", + "│ │ │ │ │ subsume=False\n", + "│ │ │ │ ),\n", + "│ │ │ │ RewriteDecl(\n", + "│ │ │ │ │ tp=JustTypeRef(name='Math', args=()),\n", + "│ │ │ │ │ lhs=CallDecl(\n", + "│ │ │ │ │ │ callable=FunctionRef(name='add'),\n", + "│ │ │ │ │ │ args=(\n", + "│ │ │ │ │ │ │ TypedExprDecl(tp=JustTypeRef(name='Math', args=()), expr=VarDecl(name='x')),\n", + "│ │ │ │ │ │ │ TypedExprDecl(tp=JustTypeRef(name='Math', args=()), expr=VarDecl(name='y'))\n", + "│ │ │ │ │ │ ),\n", + "│ │ │ │ │ │ bound_tp_params=None\n", + "│ │ │ │ │ ),\n", + "│ │ │ │ │ rhs=CallDecl(\n", + "│ │ │ │ │ │ callable=FunctionRef(name='add'),\n", + "│ │ │ │ │ │ args=(\n", + "│ │ │ │ │ │ │ TypedExprDecl(tp=JustTypeRef(name='Math', args=()), expr=VarDecl(name='y')),\n", + "│ │ │ │ │ │ │ TypedExprDecl(tp=JustTypeRef(name='Math', args=()), expr=VarDecl(name='x'))\n", + "│ │ │ │ │ │ ),\n", + "│ │ │ │ │ │ bound_tp_params=None\n", + "│ │ │ │ │ ),\n", + "│ │ │ │ │ conditions=(),\n", + "│ │ │ │ │ subsume=False\n", + "│ │ │ │ )\n", + "│ │ │ ]\n", + "│ │ )\n", + "│ }\n", + ")\n", + "\n" + ], + "text/plain": [ + "\u001b[1;35mDeclarations\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[33m_functions\u001b[0m=\u001b[1m{\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[32m'add'\u001b[0m: \u001b[1;35mFunctionDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33msignature\u001b[0m=\u001b[1;35mFunctionSignature\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33marg_types\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;35mTypeRefWithVars\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m, \u001b[1;35mTypeRefWithVars\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33marg_names\u001b[0m=\u001b[1m(\u001b[0m\u001b[32m'__0'\u001b[0m, \u001b[32m'__1'\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33marg_defaults\u001b[0m=\u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m, \u001b[3;35mNone\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33mreturn_type\u001b[0m=\u001b[1;35mTypeRefWithVars\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33mvar_arg_type\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mbuiltin\u001b[0m=\u001b[3;91mFalse\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33megg_name\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mcost\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mdefault\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mon_merge\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mmerge\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33munextractable\u001b[0m=\u001b[3;91mFalse\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ \u001b[0m\u001b[33m_constants\u001b[0m=\u001b[1m{\u001b[0m\u001b[32m'zero'\u001b[0m: \u001b[1;35mConstantDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mtype_ref\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m, \u001b[33megg_name\u001b[0m=\u001b[3;35mNone\u001b[0m\u001b[1m)\u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ \u001b[0m\u001b[33m_classes\u001b[0m=\u001b[1m{\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[32m'Math'\u001b[0m: \u001b[1;35mClassDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33megg_name\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mtype_vars\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mbuiltin\u001b[0m=\u001b[3;91mFalse\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33minit\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mclass_methods\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mclass_variables\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mmethods\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mproperties\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mpreserved_methods\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ \u001b[0m\u001b[33m_rulesets\u001b[0m=\u001b[1m{\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[32m''\u001b[0m: \u001b[1;35mRulesetDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mrules\u001b[0m=\u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[32m'ruleset_4743265072'\u001b[0m: \u001b[1;35mRulesetDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mrules\u001b[0m=\u001b[1m[\u001b[0m\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1;35mRewriteDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mtp\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mlhs\u001b[0m=\u001b[1;35mCallDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[33mcallable\u001b[0m=\u001b[1;35mFunctionRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'add'\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ │ │ \u001b[0m\u001b[1;35mTypedExprDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mtp\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m, \u001b[33mexpr\u001b[0m=\u001b[1;35mVarDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'x'\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ │ │ \u001b[0m\u001b[1;35mTypedExprDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ │ │ │ \u001b[0m\u001b[33mtp\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ │ │ │ \u001b[0m\u001b[33mexpr\u001b[0m=\u001b[1;35mCallDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mcallable\u001b[0m=\u001b[1;35mConstantRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'zero'\u001b[0m\u001b[1m)\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m, \u001b[33mbound_tp_params\u001b[0m=\u001b[3;35mNone\u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ │ │ \u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[33mbound_tp_params\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mrhs\u001b[0m=\u001b[1;35mVarDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'x'\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mconditions\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33msubsume\u001b[0m=\u001b[3;91mFalse\u001b[0m\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1;35mRewriteDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mtp\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mlhs\u001b[0m=\u001b[1;35mCallDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[33mcallable\u001b[0m=\u001b[1;35mFunctionRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'add'\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ │ │ \u001b[0m\u001b[1;35mTypedExprDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mtp\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m, \u001b[33mexpr\u001b[0m=\u001b[1;35mVarDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'x'\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ │ │ \u001b[0m\u001b[1;35mTypedExprDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mtp\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m, \u001b[33mexpr\u001b[0m=\u001b[1;35mVarDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'y'\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[33mbound_tp_params\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mrhs\u001b[0m=\u001b[1;35mCallDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[33mcallable\u001b[0m=\u001b[1;35mFunctionRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'add'\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ │ │ \u001b[0m\u001b[1;35mTypedExprDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mtp\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m, \u001b[33mexpr\u001b[0m=\u001b[1;35mVarDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'y'\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ │ │ \u001b[0m\u001b[1;35mTypedExprDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mtp\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m, \u001b[33mexpr\u001b[0m=\u001b[1;35mVarDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'x'\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[33mbound_tp_params\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mconditions\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33msubsume\u001b[0m=\u001b[3;91mFalse\u001b[0m\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[1m]\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[1m}\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "pprint(rules.__egg_name__)\n", + "pprint(rules.__egg_decls__)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When we then run a ruleset, or add an expression, the declerations are converted to egglog strings and run. If we want to manually parse the expression, we could match on the different types of declerations:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'cast(Math, add(cast(Math, zero()), cast(Math, one())))'" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from egglog.declarations import *\n", + "\n", + "\n", + "def expr_to_string(expr: TypedExprDecl) -> str:\n", + " \"\"\"\n", + " Recursively convert a typed expression into a string, to show how to traverse them\n", + " \"\"\"\n", + " tp = tp_to_string(expr.tp)\n", + " match expr.expr:\n", + " case CallDecl(f, args, _):\n", + " f_str = callable_ref_to_string(f)\n", + " args_str = \", \".join(expr_to_string(a) for a in args)\n", + " expr_str = f\"{f_str}({args_str})\"\n", + " case _:\n", + " raise NotImplementedError(f\"TypedExprDecl {type(expr)} not implemented\")\n", + "\n", + " return f\"cast({tp}, {expr_str})\"\n", + "\n", + "\n", + "def tp_to_string(tp: JustTypeRef) -> str:\n", + " name = tp.name\n", + " if tp.args:\n", + " args = \", \".join(tp_to_string(a) for a in tp.args)\n", + " return f\"{name}[{args}]\"\n", + " return name\n", + "\n", + "\n", + "def callable_ref_to_string(call: CallableRef) -> str:\n", + " match call:\n", + " case FunctionRef(name):\n", + " return name\n", + " case ConstantRef(name):\n", + " return name\n", + " case _:\n", + " raise NotImplementedError(f\"CallableRef {call} not implemented\")\n", + "\n", + "\n", + "expr_to_string(x.__egg_typed_expr__)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This gives a sense of how you can use metaprogramming to construct or deconstruct egglog objects. If there are other protocols like z3, we could add support for using metaprogramming to go back and forth to and from egglog and those as well.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "egg-smol-python", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.1" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml index 58cff32a..1dc2884e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ docs = [ "line-profiler", "sphinxcontrib-mermaid", "ablog", + "rich", ] [tool.ruff] diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index 55f4b9b5..a927a0f9 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -91,6 +91,8 @@ "Action", "Command", "check_eq", + "sort", + "fn", ] T = TypeVar("T") @@ -2002,3 +2004,42 @@ def set_current_ruleset(r: Ruleset | None) -> Generator[None, None, None]: yield finally: _CURRENT_RULESET.reset(token) + + +## +# Mid level IR +## + + +def sort(name: str) -> type[Expr]: + """ + Create a new sort with the given name. + + Similar to subclassing `Expr`, but doesn't work as well with static type checking. + """ + res = RuntimeClass(Thunk.value(Declarations(_classes={name: ClassDecl()})), TypeRefWithVars(name)) + return cast(type[Expr], res) + + +def fn( + name: str, *args_and_return_tp: type[Expr], cost: int | None = None, unextractable: bool = False +) -> Callable[..., EXPR]: + """ + Create a new function with the given name and argument types. + + Similar to using the `@function` decorator, but more dynamic and so doesn't play as well with static typing. + """ + *args, return_type = args_and_return_tp + decls = Declarations() + decls._functions[name] = FunctionDecl( + FunctionSignature( + arg_types=tuple(resolve_type_annotation(decls, a) for a in args), + arg_names=tuple(f"__{i}" for i in range(len(args))), + arg_defaults=tuple(None for _ in args), + return_type=resolve_type_annotation(decls, return_type), + ), + cost=cost, + unextractable=unextractable, + ) + res = RuntimeFunction(Thunk.value(decls), Thunk.value(FunctionRef(name))) + return cast(Callable[..., EXPR], res)