diff --git a/docs/changelog.md b/docs/changelog.md index 730051bc..d8b07ec1 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -6,6 +6,7 @@ _This project uses semantic versioning_ - Adds ability to use anonymous functions where callables are needed. These are automatically transformed to egglog functions with default rewrites. +- Add `sort` and `fn` mid level commands ## 7.2.0 (2024-05-23) diff --git a/docs/explanation/2024_06_18_midlevel.ipynb b/docs/explanation/2024_06_18_midlevel.ipynb new file mode 100644 index 00000000..b41b5203 --- /dev/null +++ b/docs/explanation/2024_06_18_midlevel.ipynb @@ -0,0 +1,615 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```{post} 2024-06-18\n", + ":author: Saul\n", + "```\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Mid Level IR\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Philip Zucker recently wrote a blog post about using egglog to do simplifications by converting from the Z3 API, [\"Conditional Simplification of Z3py Expressions with Egglog\"](https://www.philipzucker.com/egglog_z3_simp/).\n", + "\n", + "In it, he uses the egglog bindings, but only the low level bindings, as a way to execute egglog programs as strings. So he constructs expressions and rewrites using the Z3 API, manually converts them to egglog program strings, runs them in egglog, and then gets them back.\n", + "\n", + "He prefers to use the Z3 API over the high level egglog API, because it is a simpler functional interface, creating sorts and functions through functions, instead of by using Python classes:\n", + "\n", + "> Saul has been making the egglog python bindings https://egglog-python.readthedocs.io/latest/ taking a very meta highly integrated approach. I kind of just want it to look like z3 though. It’s very interesting and I’m haunted by the idea that I am a stodgy old man and they’re right. I think it’s biggest demerit is that it is very novel. I’ve never seen an interface like it. From a research perspective this is a plus. It is very cool that they are getting the python typechecker and the embedded dsl to play ball. I dunno https://egraphs.zulipchat.com/#narrow/stream/375765-egg.2Fegglog/topic/egglog.20python.20midlevel.20api/near/421919681\n", + ">\n", + "> It turns out, it is simple enough to have my cake too. The pyegglog supports the raw bindings and I’ve been spending a decent amount of time serializing Z3 ASTs to other smt or tptp solvers. Translating to egglog programs is easy.\n", + "\n", + "In order to continue this conversation, I added two functions `sort` and `fn`, to the egglog bindings, to see how those could be used to do something similar. The overall hope is that we could re-use some of the existing mechanisms in the high level API, even with a slightly different external interface. Here is an example of using them to re-create the example Phil wrote up:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
one\n",
+       "
\n" + ], + "text/latex": [ + "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", + "\\PY{n}{one}\n", + "\\end{Verbatim}\n" + ], + "text/plain": [ + "one" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from egglog import *\n", + "\n", + "Math = sort(\"Math\")\n", + "add = fn(\"add\", Math, Math, Math)\n", + "zero = constant(\"zero\", Math)\n", + "one = constant(\"one\", Math)\n", + "\n", + "sig = [Math, add, zero, one]\n", + "\n", + "x, y = vars_(\"x y\", Math)\n", + "\n", + "rules = ruleset(\n", + " rewrite(add(x, zero)).to(x),\n", + " rewrite(add(x, y)).to(add(y, x)),\n", + ")\n", + "\n", + "res = simplify(add(zero, one), rules.saturate())\n", + "res" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Metaprogramming in egglog\n", + "\n", + "What about if you want to do more generic metaprogramming in egglog? So far, I have tried to keep the internal details of how egglog objects are stored hidden from external users, so that they can be changed without breaking the API. However, if you do want to try doing some metaprogramming, the API has been written in such a way that each egglog object will be first converted into data only \"declerations\" before then being converted into the egglog string.\n", + "\n", + "These \"declerations\" are what are stored at runtime. So it's actually quite easy to manually create whatever objects (expressions, sorts, functions, rewrite, etc) that you want just using the normal data classes, and then convert them to egglog strings. Or to go the other way, and match on them to convert them to another syntax. Again, this intentionally has not been part of the public API yet, because things are still unstable, but was designed in such a way to make metaprogramming easier.\n", + "\n", + "For example, we can start by looking at the expression object. At runtime, it's simply a dataclass with two fields, the `Declerations`, like the total state of the e-graph we need to know, and the `expr`, a pointer into that state of the expressions:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
TypedExprDecl(\n",
+       "tp=JustTypeRef(name='Math', args=()),\n",
+       "expr=CallDecl(\n",
+       "│   │   callable=FunctionRef(name='add'),\n",
+       "│   │   args=(\n",
+       "│   │   │   TypedExprDecl(\n",
+       "│   │   │   │   tp=JustTypeRef(name='Math', args=()),\n",
+       "│   │   │   │   expr=CallDecl(callable=ConstantRef(name='zero'), args=(), bound_tp_params=None)\n",
+       "│   │   │   ),\n",
+       "│   │   │   TypedExprDecl(\n",
+       "│   │   │   │   tp=JustTypeRef(name='Math', args=()),\n",
+       "│   │   │   │   expr=CallDecl(callable=ConstantRef(name='one'), args=(), bound_tp_params=None)\n",
+       "│   │   │   )\n",
+       "│   │   ),\n",
+       "│   │   bound_tp_params=None\n",
+       ")\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mTypedExprDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[33mtp\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ \u001b[0m\u001b[33mexpr\u001b[0m=\u001b[1;35mCallDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[33mcallable\u001b[0m=\u001b[1;35mFunctionRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'add'\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[1;35mTypedExprDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33mtp\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33mexpr\u001b[0m=\u001b[1;35mCallDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mcallable\u001b[0m=\u001b[1;35mConstantRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'zero'\u001b[0m\u001b[1m)\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m, \u001b[33mbound_tp_params\u001b[0m=\u001b[3;35mNone\u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[1;35mTypedExprDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33mtp\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33mexpr\u001b[0m=\u001b[1;35mCallDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mcallable\u001b[0m=\u001b[1;35mConstantRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'one'\u001b[0m\u001b[1m)\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m, \u001b[33mbound_tp_params\u001b[0m=\u001b[3;35mNone\u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[33mbound_tp_params\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Declarations(\n",
+       "_functions={\n",
+       "│   │   'add': FunctionDecl(\n",
+       "│   │   │   signature=FunctionSignature(\n",
+       "│   │   │   │   arg_types=(TypeRefWithVars(name='Math', args=()), TypeRefWithVars(name='Math', args=())),\n",
+       "│   │   │   │   arg_names=('__0', '__1'),\n",
+       "│   │   │   │   arg_defaults=(None, None),\n",
+       "│   │   │   │   return_type=TypeRefWithVars(name='Math', args=()),\n",
+       "│   │   │   │   var_arg_type=None\n",
+       "│   │   │   ),\n",
+       "│   │   │   builtin=False,\n",
+       "│   │   │   egg_name=None,\n",
+       "│   │   │   cost=None,\n",
+       "│   │   │   default=None,\n",
+       "│   │   │   on_merge=(),\n",
+       "│   │   │   merge=None,\n",
+       "│   │   │   unextractable=False\n",
+       "│   │   )\n",
+       "},\n",
+       "_constants={\n",
+       "│   │   'zero': ConstantDecl(type_ref=JustTypeRef(name='Math', args=()), egg_name=None),\n",
+       "│   │   'one': ConstantDecl(type_ref=JustTypeRef(name='Math', args=()), egg_name=None)\n",
+       "},\n",
+       "_classes={\n",
+       "│   │   'Math': ClassDecl(\n",
+       "│   │   │   egg_name=None,\n",
+       "│   │   │   type_vars=(),\n",
+       "│   │   │   builtin=False,\n",
+       "│   │   │   init=None,\n",
+       "│   │   │   class_methods={},\n",
+       "│   │   │   class_variables={},\n",
+       "│   │   │   methods={},\n",
+       "│   │   │   properties={},\n",
+       "│   │   │   preserved_methods={}\n",
+       "│   │   )\n",
+       "},\n",
+       "_rulesets={'': RulesetDecl(rules=[])}\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mDeclarations\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[33m_functions\u001b[0m=\u001b[1m{\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[32m'add'\u001b[0m: \u001b[1;35mFunctionDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33msignature\u001b[0m=\u001b[1;35mFunctionSignature\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33marg_types\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;35mTypeRefWithVars\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m, \u001b[1;35mTypeRefWithVars\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33marg_names\u001b[0m=\u001b[1m(\u001b[0m\u001b[32m'__0'\u001b[0m, \u001b[32m'__1'\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33marg_defaults\u001b[0m=\u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m, \u001b[3;35mNone\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33mreturn_type\u001b[0m=\u001b[1;35mTypeRefWithVars\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33mvar_arg_type\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mbuiltin\u001b[0m=\u001b[3;91mFalse\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33megg_name\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mcost\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mdefault\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mon_merge\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mmerge\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33munextractable\u001b[0m=\u001b[3;91mFalse\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ \u001b[0m\u001b[33m_constants\u001b[0m=\u001b[1m{\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[32m'zero'\u001b[0m: \u001b[1;35mConstantDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mtype_ref\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m, \u001b[33megg_name\u001b[0m=\u001b[3;35mNone\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[32m'one'\u001b[0m: \u001b[1;35mConstantDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mtype_ref\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m, \u001b[33megg_name\u001b[0m=\u001b[3;35mNone\u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ \u001b[0m\u001b[33m_classes\u001b[0m=\u001b[1m{\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[32m'Math'\u001b[0m: \u001b[1;35mClassDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33megg_name\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mtype_vars\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mbuiltin\u001b[0m=\u001b[3;91mFalse\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33minit\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mclass_methods\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mclass_variables\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mmethods\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mproperties\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mpreserved_methods\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ \u001b[0m\u001b[33m_rulesets\u001b[0m=\u001b[1m{\u001b[0m\u001b[32m''\u001b[0m: \u001b[1;35mRulesetDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mrules\u001b[0m=\u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\u001b[1m}\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from rich.pretty import pprint\n", + "\n", + "x = add(zero, one)\n", + "\n", + "pprint(x.__egg_typed_expr__)\n", + "pprint(x.__egg_decls__)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The same is true of the rulesets:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
'ruleset_4743265072'\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[32m'ruleset_4743265072'\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Declarations(\n",
+       "_functions={\n",
+       "│   │   'add': FunctionDecl(\n",
+       "│   │   │   signature=FunctionSignature(\n",
+       "│   │   │   │   arg_types=(TypeRefWithVars(name='Math', args=()), TypeRefWithVars(name='Math', args=())),\n",
+       "│   │   │   │   arg_names=('__0', '__1'),\n",
+       "│   │   │   │   arg_defaults=(None, None),\n",
+       "│   │   │   │   return_type=TypeRefWithVars(name='Math', args=()),\n",
+       "│   │   │   │   var_arg_type=None\n",
+       "│   │   │   ),\n",
+       "│   │   │   builtin=False,\n",
+       "│   │   │   egg_name=None,\n",
+       "│   │   │   cost=None,\n",
+       "│   │   │   default=None,\n",
+       "│   │   │   on_merge=(),\n",
+       "│   │   │   merge=None,\n",
+       "│   │   │   unextractable=False\n",
+       "│   │   )\n",
+       "},\n",
+       "_constants={'zero': ConstantDecl(type_ref=JustTypeRef(name='Math', args=()), egg_name=None)},\n",
+       "_classes={\n",
+       "│   │   'Math': ClassDecl(\n",
+       "│   │   │   egg_name=None,\n",
+       "│   │   │   type_vars=(),\n",
+       "│   │   │   builtin=False,\n",
+       "│   │   │   init=None,\n",
+       "│   │   │   class_methods={},\n",
+       "│   │   │   class_variables={},\n",
+       "│   │   │   methods={},\n",
+       "│   │   │   properties={},\n",
+       "│   │   │   preserved_methods={}\n",
+       "│   │   )\n",
+       "},\n",
+       "_rulesets={\n",
+       "│   │   '': RulesetDecl(rules=[]),\n",
+       "│   │   'ruleset_4743265072': RulesetDecl(\n",
+       "│   │   │   rules=[\n",
+       "│   │   │   │   RewriteDecl(\n",
+       "│   │   │   │   │   tp=JustTypeRef(name='Math', args=()),\n",
+       "│   │   │   │   │   lhs=CallDecl(\n",
+       "│   │   │   │   │   │   callable=FunctionRef(name='add'),\n",
+       "│   │   │   │   │   │   args=(\n",
+       "│   │   │   │   │   │   │   TypedExprDecl(tp=JustTypeRef(name='Math', args=()), expr=VarDecl(name='x')),\n",
+       "│   │   │   │   │   │   │   TypedExprDecl(\n",
+       "│   │   │   │   │   │   │   │   tp=JustTypeRef(name='Math', args=()),\n",
+       "│   │   │   │   │   │   │   │   expr=CallDecl(callable=ConstantRef(name='zero'), args=(), bound_tp_params=None)\n",
+       "│   │   │   │   │   │   │   )\n",
+       "│   │   │   │   │   │   ),\n",
+       "│   │   │   │   │   │   bound_tp_params=None\n",
+       "│   │   │   │   │   ),\n",
+       "│   │   │   │   │   rhs=VarDecl(name='x'),\n",
+       "│   │   │   │   │   conditions=(),\n",
+       "│   │   │   │   │   subsume=False\n",
+       "│   │   │   │   ),\n",
+       "│   │   │   │   RewriteDecl(\n",
+       "│   │   │   │   │   tp=JustTypeRef(name='Math', args=()),\n",
+       "│   │   │   │   │   lhs=CallDecl(\n",
+       "│   │   │   │   │   │   callable=FunctionRef(name='add'),\n",
+       "│   │   │   │   │   │   args=(\n",
+       "│   │   │   │   │   │   │   TypedExprDecl(tp=JustTypeRef(name='Math', args=()), expr=VarDecl(name='x')),\n",
+       "│   │   │   │   │   │   │   TypedExprDecl(tp=JustTypeRef(name='Math', args=()), expr=VarDecl(name='y'))\n",
+       "│   │   │   │   │   │   ),\n",
+       "│   │   │   │   │   │   bound_tp_params=None\n",
+       "│   │   │   │   │   ),\n",
+       "│   │   │   │   │   rhs=CallDecl(\n",
+       "│   │   │   │   │   │   callable=FunctionRef(name='add'),\n",
+       "│   │   │   │   │   │   args=(\n",
+       "│   │   │   │   │   │   │   TypedExprDecl(tp=JustTypeRef(name='Math', args=()), expr=VarDecl(name='y')),\n",
+       "│   │   │   │   │   │   │   TypedExprDecl(tp=JustTypeRef(name='Math', args=()), expr=VarDecl(name='x'))\n",
+       "│   │   │   │   │   │   ),\n",
+       "│   │   │   │   │   │   bound_tp_params=None\n",
+       "│   │   │   │   │   ),\n",
+       "│   │   │   │   │   conditions=(),\n",
+       "│   │   │   │   │   subsume=False\n",
+       "│   │   │   │   )\n",
+       "│   │   │   ]\n",
+       "│   │   )\n",
+       "}\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mDeclarations\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[33m_functions\u001b[0m=\u001b[1m{\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[32m'add'\u001b[0m: \u001b[1;35mFunctionDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33msignature\u001b[0m=\u001b[1;35mFunctionSignature\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33marg_types\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;35mTypeRefWithVars\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m, \u001b[1;35mTypeRefWithVars\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33marg_names\u001b[0m=\u001b[1m(\u001b[0m\u001b[32m'__0'\u001b[0m, \u001b[32m'__1'\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33marg_defaults\u001b[0m=\u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m, \u001b[3;35mNone\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33mreturn_type\u001b[0m=\u001b[1;35mTypeRefWithVars\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[33mvar_arg_type\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mbuiltin\u001b[0m=\u001b[3;91mFalse\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33megg_name\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mcost\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mdefault\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mon_merge\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mmerge\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33munextractable\u001b[0m=\u001b[3;91mFalse\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ \u001b[0m\u001b[33m_constants\u001b[0m=\u001b[1m{\u001b[0m\u001b[32m'zero'\u001b[0m: \u001b[1;35mConstantDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mtype_ref\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m, \u001b[33megg_name\u001b[0m=\u001b[3;35mNone\u001b[0m\u001b[1m)\u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ \u001b[0m\u001b[33m_classes\u001b[0m=\u001b[1m{\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[32m'Math'\u001b[0m: \u001b[1;35mClassDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33megg_name\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mtype_vars\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mbuiltin\u001b[0m=\u001b[3;91mFalse\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33minit\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mclass_methods\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mclass_variables\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mmethods\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mproperties\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mpreserved_methods\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ \u001b[0m\u001b[33m_rulesets\u001b[0m=\u001b[1m{\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[32m''\u001b[0m: \u001b[1;35mRulesetDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mrules\u001b[0m=\u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[32m'ruleset_4743265072'\u001b[0m: \u001b[1;35mRulesetDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mrules\u001b[0m=\u001b[1m[\u001b[0m\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1;35mRewriteDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mtp\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mlhs\u001b[0m=\u001b[1;35mCallDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[33mcallable\u001b[0m=\u001b[1;35mFunctionRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'add'\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ │ │ \u001b[0m\u001b[1;35mTypedExprDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mtp\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m, \u001b[33mexpr\u001b[0m=\u001b[1;35mVarDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'x'\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ │ │ \u001b[0m\u001b[1;35mTypedExprDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ │ │ │ \u001b[0m\u001b[33mtp\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ │ │ │ \u001b[0m\u001b[33mexpr\u001b[0m=\u001b[1;35mCallDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mcallable\u001b[0m=\u001b[1;35mConstantRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'zero'\u001b[0m\u001b[1m)\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m, \u001b[33mbound_tp_params\u001b[0m=\u001b[3;35mNone\u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ │ │ \u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[33mbound_tp_params\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mrhs\u001b[0m=\u001b[1;35mVarDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'x'\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mconditions\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33msubsume\u001b[0m=\u001b[3;91mFalse\u001b[0m\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1;35mRewriteDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mtp\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mlhs\u001b[0m=\u001b[1;35mCallDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[33mcallable\u001b[0m=\u001b[1;35mFunctionRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'add'\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ │ │ \u001b[0m\u001b[1;35mTypedExprDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mtp\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m, \u001b[33mexpr\u001b[0m=\u001b[1;35mVarDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'x'\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ │ │ \u001b[0m\u001b[1;35mTypedExprDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mtp\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m, \u001b[33mexpr\u001b[0m=\u001b[1;35mVarDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'y'\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[33mbound_tp_params\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mrhs\u001b[0m=\u001b[1;35mCallDecl\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[33mcallable\u001b[0m=\u001b[1;35mFunctionRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'add'\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ │ │ \u001b[0m\u001b[1;35mTypedExprDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mtp\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m, \u001b[33mexpr\u001b[0m=\u001b[1;35mVarDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'y'\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ │ │ \u001b[0m\u001b[1;35mTypedExprDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mtp\u001b[0m=\u001b[1;35mJustTypeRef\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Math'\u001b[0m, \u001b[33margs\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m, \u001b[33mexpr\u001b[0m=\u001b[1;35mVarDecl\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'x'\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ │ \u001b[0m\u001b[33mbound_tp_params\u001b[0m=\u001b[3;35mNone\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mconditions\u001b[0m=\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33msubsume\u001b[0m=\u001b[3;91mFalse\u001b[0m\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[1m]\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[1m}\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "pprint(rules.__egg_name__)\n", + "pprint(rules.__egg_decls__)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When we then run a ruleset, or add an expression, the declerations are converted to egglog strings and run. If we want to manually parse the expression, we could match on the different types of declerations:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'cast(Math, add(cast(Math, zero()), cast(Math, one())))'" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from egglog.declarations import *\n", + "\n", + "\n", + "def expr_to_string(expr: TypedExprDecl) -> str:\n", + " \"\"\"\n", + " Recursively convert a typed expression into a string, to show how to traverse them\n", + " \"\"\"\n", + " tp = tp_to_string(expr.tp)\n", + " match expr.expr:\n", + " case CallDecl(f, args, _):\n", + " f_str = callable_ref_to_string(f)\n", + " args_str = \", \".join(expr_to_string(a) for a in args)\n", + " expr_str = f\"{f_str}({args_str})\"\n", + " case _:\n", + " raise NotImplementedError(f\"TypedExprDecl {type(expr)} not implemented\")\n", + "\n", + " return f\"cast({tp}, {expr_str})\"\n", + "\n", + "\n", + "def tp_to_string(tp: JustTypeRef) -> str:\n", + " name = tp.name\n", + " if tp.args:\n", + " args = \", \".join(tp_to_string(a) for a in tp.args)\n", + " return f\"{name}[{args}]\"\n", + " return name\n", + "\n", + "\n", + "def callable_ref_to_string(call: CallableRef) -> str:\n", + " match call:\n", + " case FunctionRef(name):\n", + " return name\n", + " case ConstantRef(name):\n", + " return name\n", + " case _:\n", + " raise NotImplementedError(f\"CallableRef {call} not implemented\")\n", + "\n", + "\n", + "expr_to_string(x.__egg_typed_expr__)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This gives a sense of how you can use metaprogramming to construct or deconstruct egglog objects. If there are other protocols like z3, we could add support for using metaprogramming to go back and forth to and from egglog and those as well.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "egg-smol-python", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.1" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml index 58cff32a..1dc2884e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ docs = [ "line-profiler", "sphinxcontrib-mermaid", "ablog", + "rich", ] [tool.ruff] diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index 55f4b9b5..a927a0f9 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -91,6 +91,8 @@ "Action", "Command", "check_eq", + "sort", + "fn", ] T = TypeVar("T") @@ -2002,3 +2004,42 @@ def set_current_ruleset(r: Ruleset | None) -> Generator[None, None, None]: yield finally: _CURRENT_RULESET.reset(token) + + +## +# Mid level IR +## + + +def sort(name: str) -> type[Expr]: + """ + Create a new sort with the given name. + + Similar to subclassing `Expr`, but doesn't work as well with static type checking. + """ + res = RuntimeClass(Thunk.value(Declarations(_classes={name: ClassDecl()})), TypeRefWithVars(name)) + return cast(type[Expr], res) + + +def fn( + name: str, *args_and_return_tp: type[Expr], cost: int | None = None, unextractable: bool = False +) -> Callable[..., EXPR]: + """ + Create a new function with the given name and argument types. + + Similar to using the `@function` decorator, but more dynamic and so doesn't play as well with static typing. + """ + *args, return_type = args_and_return_tp + decls = Declarations() + decls._functions[name] = FunctionDecl( + FunctionSignature( + arg_types=tuple(resolve_type_annotation(decls, a) for a in args), + arg_names=tuple(f"__{i}" for i in range(len(args))), + arg_defaults=tuple(None for _ in args), + return_type=resolve_type_annotation(decls, return_type), + ), + cost=cost, + unextractable=unextractable, + ) + res = RuntimeFunction(Thunk.value(decls), Thunk.value(FunctionRef(name))) + return cast(Callable[..., EXPR], res)