diff --git a/python/src/ops.cpp b/python/src/ops.cpp index bd64bf687..abfbbbc7c 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -105,7 +105,7 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "unflatten", - &unflatten, + &mx::unflatten, nb::arg(), "axis"_a, "shape"_a,