Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for reflective operators #39

Merged
merged 2 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ _This project uses semantic versioning. Before 1.0.0, this means that every brea

- Add support for functions which mutates their args, like `__setitem__` [#35](https://github.com/metadsl/egglog-python/pull/35)
- Makes conversions transitive [#38](https://github.com/metadsl/egglog-python/pull/38)
- Add support for reflective operators [#39](https://github.com/metadsl/egglog-python/pull/39)

## 0.5.1 (2023-07-18)

Expand Down
25 changes: 12 additions & 13 deletions docs/tutorials/array-api.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload complete"
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -271,27 +271,25 @@
"unique_inverse(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))))[Int(0)].shape[Int(0)]\n",
"unique_inverse(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY))))[Int(0)].shape[Int(0)]\n",
" -> unique_values(reshape(NDArray.var(\"y\"), TupleInt(Int(-1)))).shape[Int(0)]\n",
" -> Int(3)\n",
"unique_values(concat((TupleNDArray(unique_values(asarray(asarray(reshape(asarray(NDArray.var(\"y\")), (TupleInt(Int(-1)) + TupleInt.EMPTY)))))) + TupleNDArray.EMPTY))).size\n",
" -> unique_values(reshape(NDArray.var(\"y\"), TupleInt(Int(-1)))).size\n",
" -> Int(3)\n"
]
},
{
"ename": "TypeError",
"evalue": "NDArray has no method __iter__",
"ename": "AttributeError",
"evalue": "module 'egglog.exp.array_api' has no attribute 'sqrt'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
"File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:452\u001b[0m, in \u001b[0;36m_preserved_method\u001b[0;34m(self, __name)\u001b[0m\n\u001b[1;32m 451\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 452\u001b[0m method \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__egg_decls__\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_class_decl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m__egg_typed_expr__\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpreserved_methods\u001b[49m\u001b[43m[\u001b[49m\u001b[43m__name\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 453\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m:\n",
"\u001b[0;31mKeyError\u001b[0m: '__iter__'",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[4], line 21\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;66;03m# Add values for the constants\u001b[39;00m\n\u001b[1;32m 8\u001b[0m egraph\u001b[38;5;241m.\u001b[39mregister(\n\u001b[1;32m 9\u001b[0m rewrite(X_arr\u001b[38;5;241m.\u001b[39mdtype, runtime_ruleset)\u001b[38;5;241m.\u001b[39mto(convert(X\u001b[38;5;241m.\u001b[39mdtype, DType)),\n\u001b[1;32m 10\u001b[0m rewrite(y_arr\u001b[38;5;241m.\u001b[39mdtype, runtime_ruleset)\u001b[38;5;241m.\u001b[39mto(convert(y\u001b[38;5;241m.\u001b[39mdtype, DType)),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 17\u001b[0m rewrite(unique_values(y_arr)\u001b[38;5;241m.\u001b[39mshape)\u001b[38;5;241m.\u001b[39mto(TupleInt(Int(\u001b[38;5;241m3\u001b[39m))),\n\u001b[1;32m 18\u001b[0m )\n\u001b[0;32m---> 21\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_arr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_arr\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 23\u001b[0m \u001b[38;5;66;03m# X_obj, y_obj = egraph.save_object(X), egraph.save_object(y)\u001b[39;00m\n\u001b[1;32m 24\u001b[0m \n\u001b[1;32m 25\u001b[0m \u001b[38;5;66;03m# X_arr = NDArray(X_obj)\u001b[39;00m\n\u001b[1;32m 26\u001b[0m \u001b[38;5;66;03m# y_arr = NDArray(y_obj)\u001b[39;00m\n",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[6], line 22\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;66;03m# Add values for the constants\u001b[39;00m\n\u001b[1;32m 8\u001b[0m egraph\u001b[38;5;241m.\u001b[39mregister(\n\u001b[1;32m 9\u001b[0m rewrite(X_arr\u001b[38;5;241m.\u001b[39mdtype, runtime_ruleset)\u001b[38;5;241m.\u001b[39mto(convert(X\u001b[38;5;241m.\u001b[39mdtype, DType)),\n\u001b[1;32m 10\u001b[0m rewrite(y_arr\u001b[38;5;241m.\u001b[39mdtype, runtime_ruleset)\u001b[38;5;241m.\u001b[39mto(convert(y\u001b[38;5;241m.\u001b[39mdtype, DType)),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 18\u001b[0m rewrite(unique_values(y_arr)\u001b[38;5;241m.\u001b[39msize)\u001b[38;5;241m.\u001b[39mto(Int(\u001b[38;5;241m3\u001b[39m)),\n\u001b[1;32m 19\u001b[0m )\n\u001b[0;32m---> 22\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_arr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_arr\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;66;03m# X_obj, y_obj = egraph.save_object(X), egraph.save_object(y)\u001b[39;00m\n\u001b[1;32m 25\u001b[0m \n\u001b[1;32m 26\u001b[0m \u001b[38;5;66;03m# X_arr = NDArray(X_obj)\u001b[39;00m\n\u001b[1;32m 27\u001b[0m \u001b[38;5;66;03m# y_arr = NDArray(y_obj)\u001b[39;00m\n",
"Cell \u001b[0;32mIn[1], line 15\u001b[0m, in \u001b[0;36mfit\u001b[0;34m(X, y)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(array_api_dispatch\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m):\n\u001b[1;32m 14\u001b[0m lda \u001b[38;5;241m=\u001b[39m LinearDiscriminantAnalysis(n_components\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[0;32m---> 15\u001b[0m X_r2 \u001b[38;5;241m=\u001b[39m \u001b[43mlda\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mtransform(X)\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m X_r2\n\u001b[1;32m 18\u001b[0m target_names \u001b[38;5;241m=\u001b[39m iris\u001b[38;5;241m.\u001b[39mtarget_names\n",
"File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/base.py:1151\u001b[0m, in \u001b[0;36m_fit_context.<locals>.decorator.<locals>.wrapper\u001b[0;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1144\u001b[0m estimator\u001b[38;5;241m.\u001b[39m_validate_params()\n\u001b[1;32m 1146\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[1;32m 1147\u001b[0m skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[1;32m 1148\u001b[0m prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[1;32m 1149\u001b[0m )\n\u001b[1;32m 1150\u001b[0m ):\n\u001b[0;32m-> 1151\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfit_method\u001b[49m\u001b[43m(\u001b[49m\u001b[43mestimator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:629\u001b[0m, in \u001b[0;36mLinearDiscriminantAnalysis.fit\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 623\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcovariance_estimator \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 624\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 625\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcovariance estimator \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 626\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mis not supported \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 627\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwith svd solver. Try another solver\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 628\u001b[0m )\n\u001b[0;32m--> 629\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_solve_svd\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 630\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msolver \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlsqr\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 631\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_solve_lstsq(\n\u001b[1;32m 632\u001b[0m X,\n\u001b[1;32m 633\u001b[0m y,\n\u001b[1;32m 634\u001b[0m shrinkage\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mshrinkage,\n\u001b[1;32m 635\u001b[0m covariance_estimator\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcovariance_estimator,\n\u001b[1;32m 636\u001b[0m )\n",
"File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:506\u001b[0m, in \u001b[0;36mLinearDiscriminantAnalysis._solve_svd\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 503\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcovariance_ \u001b[38;5;241m=\u001b[39m _class_cov(X, y, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpriors_)\n\u001b[1;32m 505\u001b[0m Xc \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m--> 506\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m idx, group \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28;43menumerate\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclasses_\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 507\u001b[0m Xg \u001b[38;5;241m=\u001b[39m X[y \u001b[38;5;241m==\u001b[39m group]\n\u001b[1;32m 508\u001b[0m Xc\u001b[38;5;241m.\u001b[39mappend(Xg \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmeans_[idx, :])\n",
"File \u001b[0;32m~/p/egg-smol-python/python/egglog/runtime.py:454\u001b[0m, in \u001b[0;36m_preserved_method\u001b[0;34m(self, __name)\u001b[0m\n\u001b[1;32m 452\u001b[0m method \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__egg_decls__\u001b[38;5;241m.\u001b[39mget_class_decl(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__egg_typed_expr__\u001b[38;5;241m.\u001b[39mtp\u001b[38;5;241m.\u001b[39mname)\u001b[38;5;241m.\u001b[39mpreserved_methods[__name]\n\u001b[1;32m 453\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m:\n\u001b[0;32m--> 454\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__egg_typed_expr__\u001b[38;5;241m.\u001b[39mtp\u001b[38;5;241m.\u001b[39mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m has no method \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m__name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 455\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m method(\u001b[38;5;28mself\u001b[39m)\n",
"\u001b[0;31mTypeError\u001b[0m: NDArray has no method __iter__"
"File \u001b[0;32m/usr/local/Caskroom/miniconda/base/envs/egg-smol-python/lib/python3.10/site-packages/sklearn/discriminant_analysis.py:521\u001b[0m, in \u001b[0;36mLinearDiscriminantAnalysis._solve_svd\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 518\u001b[0m fac \u001b[38;5;241m=\u001b[39m xp\u001b[38;5;241m.\u001b[39masarray(\u001b[38;5;241m1.0\u001b[39m \u001b[38;5;241m/\u001b[39m (n_samples \u001b[38;5;241m-\u001b[39m n_classes))\n\u001b[1;32m 520\u001b[0m \u001b[38;5;66;03m# 2) Within variance scaling\u001b[39;00m\n\u001b[0;32m--> 521\u001b[0m X \u001b[38;5;241m=\u001b[39m \u001b[43mxp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msqrt\u001b[49m(fac) \u001b[38;5;241m*\u001b[39m (Xc \u001b[38;5;241m/\u001b[39m std)\n\u001b[1;32m 522\u001b[0m \u001b[38;5;66;03m# SVD of centered (within)scaled data\u001b[39;00m\n\u001b[1;32m 523\u001b[0m U, S, Vt \u001b[38;5;241m=\u001b[39m svd(X, full_matrices\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n",
"\u001b[0;31mAttributeError\u001b[0m: module 'egglog.exp.array_api' has no attribute 'sqrt'"
]
}
],
Expand All @@ -313,6 +311,7 @@
" rewrite(X_arr.size, runtime_ruleset).to(Int(X.size)),\n",
" rewrite(y_arr.size, runtime_ruleset).to(Int(y.size)),\n",
" rewrite(unique_values(y_arr).shape).to(TupleInt(Int(3))),\n",
" rewrite(unique_values(y_arr).size).to(Int(3)),\n",
")\n",
"\n",
"\n",
Expand Down
25 changes: 24 additions & 1 deletion python/egglog/declarations.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,31 @@
"__truediv__": "/",
"__floordiv__": "//",
"__mod__": "%",
"__divmod__": "divmod",
# TODO: Support divmod, with tuple return value
# "__divmod__": "divmod",
# TODO: Three arg power
"__pow__": "**",
"__lshift__": "<<",
"__rshift__": ">>",
"__and__": "&",
"__xor__": "^",
"__or__": "|",
}
REFECLTED_BINARY_METHODS = {
"__radd__": "+",
"__rsub__": "-",
"__rmul__": "*",
"__rmatmul__": "@",
"__rtruediv__": "/",
"__rfloordiv__": "//",
"__rmod__": "%",
"__rpow__": "**",
"__rlshift__": "<<",
"__rrshift__": ">>",
"__rand__": "&",
"__rxor__": "^",
"__ror__": "|",
}
UNARY_METHODS = {
"__pos__": "+",
"__neg__": "-",
Expand Down Expand Up @@ -629,6 +646,12 @@ def pretty(self, context: PrettyContext, parens=True, **kwargs) -> str:
assert len(args) == 1
expr = f"{slf.pretty(context )} {BINARY_METHODS[name]} {args[0].pretty(context, wrap_lit=False)}"
return expr if not parens else f"({expr})"
elif name in REFECLTED_BINARY_METHODS:
assert len(args) == 1
expr = (
f"{args[0].pretty(context, wrap_lit=False)} {REFECLTED_BINARY_METHODS[name]} {slf.pretty(context)}"
)
return expr if not parens else f"({expr})"
elif name == "__getitem__":
assert len(args) == 1
return f"{slf.pretty(context)}[{args[0].pretty(context, wrap_lit=False)}]"
Expand Down
25 changes: 23 additions & 2 deletions python/egglog/exp/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import itertools
import numbers
import sys
from typing import Any, ClassVar, TypeVar
from typing import Any, ClassVar, Iterator, TypeVar

import numpy as np
from egglog import *
Expand All @@ -23,6 +23,10 @@
runtime_ruleset = egraph.ruleset("runtime")


# For now, have this global e-graph for this module, a bit hacky, but works as a proof of concept.
# We need a global e-graph so that we have the preserved methods reference it to extract when they are called.


def extract_py(e: Expr) -> Any:
print(e)
egraph.register(e)
Expand Down Expand Up @@ -193,6 +197,9 @@ def __add__(self, other: Int) -> Int:
def __sub__(self, other: Int) -> Int:
...

def __rtruediv__(self, other: Int) -> Int:
...

@egraph.method(preserve=True)
def __int__(self) -> int:
return extract_py(self)
Expand Down Expand Up @@ -237,6 +244,7 @@ def _int(i: i64, j: i64, r: Bool, o: Int):


converter(int, Int, lambda x: Int(x))
converter(float, Int, lambda x: Int(int(x)))

assert expr_parts(egraph.simplify(Int(1) == Int(1), 10)) == expr_parts(TRUE)
assert expr_parts(egraph.simplify(Int(1) == Int(2), 10)) == expr_parts(FALSE)
Expand Down Expand Up @@ -440,6 +448,11 @@ def size(self) -> Int:
def __len__(self) -> int:
return int(self.size)

@egraph.method(preserve=True)
def __iter__(self) -> Iterator[NDArray]:
for i in range(len(self)):
yield self[IndexKey.int(Int(i))]

def __getitem__(self, key: IndexKey) -> NDArray:
...

Expand All @@ -449,6 +462,9 @@ def __setitem__(self, key: IndexKey, value: NDArray) -> None:
def __truediv__(self, other: NDArray) -> NDArray:
...

def __matmul__(self, other: NDArray) -> NDArray:
...

def __sub__(self, other: NDArray) -> NDArray:
...

Expand Down Expand Up @@ -669,7 +685,7 @@ def _unique_values(x: NDArray):


@egraph.function
def concat(arrays: TupleNDArray) -> NDArray:
def concat(arrays: TupleNDArray, axis: OptionalInt = OptionalInt.none) -> NDArray:
...


Expand Down Expand Up @@ -708,6 +724,11 @@ def _astype(x: NDArray, dtype: DType, i: i64):
]


@egraph.function
def std(x: NDArray, axis: OptionalTupleInt = OptionalTupleInt.none) -> NDArray:
...


@egraph.function
def any(x: NDArray) -> NDArray:
...
Expand Down
9 changes: 7 additions & 2 deletions python/egglog/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from . import bindings, config # noqa: F401
from .declarations import *
from .declarations import BINARY_METHODS, UNARY_METHODS
from .declarations import BINARY_METHODS, REFECLTED_BINARY_METHODS, UNARY_METHODS
from .type_constraint_solver import *

if TYPE_CHECKING:
Expand Down Expand Up @@ -432,7 +432,12 @@ def __setstate__(self, d):


# Define each of the special methods, since we have already declared them for pretty printing
for name in list(BINARY_METHODS) + list(UNARY_METHODS) + ["__getitem__", "__call__", "__setitem__", "__delitem__"]:
for name in (
list(BINARY_METHODS)
+ list(REFECLTED_BINARY_METHODS)
+ list(UNARY_METHODS)
+ ["__getitem__", "__call__", "__setitem__", "__delitem__"]
):

def _special_method(self: RuntimeExpr, *args: object, __name: str = name) -> Optional[RuntimeExpr]:
# First, try to resolve as preserved method
Expand Down
19 changes: 19 additions & 0 deletions python/tests/test_high_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,3 +393,22 @@ def incr(x: Math) -> None:
egraph.register(rewrite(incr_i).to(i + Math(1)), x)
egraph.run(10)
egraph.check(eq(x).to(Math(10) + Math(1)))


def test_reflected_binary_method():
egraph = EGraph()

@egraph.class_
class Math(Expr):
def __init__(self) -> None:
...

def __radd__(self, other: i64Like) -> Math: # type: ignore[empty-body]
...

expr = 10 + Math()
assert str(expr) == "10 + Math()"
assert expr_parts(expr) == TypedExprDecl(
JustTypeRef("Math"),
CallDecl(MethodRef("Math", "__radd__"), (expr_parts(Math()), expr_parts(i64(10)))),
)
Loading