From b117dec8855371220883e3461ce1f96b2e30c962 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 12 Dec 2024 09:03:32 -0800 Subject: [PATCH] rebase --- python/src/export.cpp | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/python/src/export.cpp b/python/src/export.cpp index 5c71567e5..c09e71f60 100644 --- a/python/src/export.cpp +++ b/python/src/export.cpp @@ -10,14 +10,14 @@ #include "mlx/graph_utils.h" #include "python/src/trees.h" +namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; -using namespace mlx::core; template bool check_arrs(const T& iterable) { for (auto it = iterable.begin(); it != iterable.end(); ++it) { - if (!nb::isinstance(*it)) { + if (!nb::isinstance(*it)) { return false; } } @@ -25,7 +25,7 @@ bool check_arrs(const T& iterable) { }; bool valid_inputs(const nb::args& inputs) { - if (inputs.size() > 0 && nb::isinstance(inputs[0])) { + if (inputs.size() > 0 && nb::isinstance(inputs[0])) { return check_arrs(inputs); } else if (inputs.size() > 1) { return false; @@ -42,7 +42,7 @@ bool valid_inputs(const nb::args& inputs) { } bool valid_outputs(const nb::object& outputs) { - if (nb::isinstance(outputs)) { + if (nb::isinstance(outputs)) { return true; } else if (nb::isinstance(outputs)) { return check_arrs(nb::cast(outputs)); @@ -65,10 +65,9 @@ void init_export(nb::module_& m) { "of arrays or a single tuple or list of arrays."); } - std::vector inputs = tree_flatten(arrays, true); - auto wrapped_fun = - [&fun, - &arrays](const std::vector& inputs) -> std::vector { + std::vector inputs = tree_flatten(arrays, true); + auto wrapped_fun = [&fun, &arrays](const std::vector& inputs) + -> std::vector { auto outputs = fun(*tree_unflatten(arrays, inputs)); if (!valid_outputs(outputs)) { throw std::invalid_argument( @@ -77,7 +76,7 @@ void init_export(nb::module_& m) { } return tree_flatten(outputs, true); }; - export_function(path, wrapped_fun, inputs, shapeless); + mx::export_function(path, wrapped_fun, inputs, shapeless); }, "path"_a, "fun"_a, @@ -101,7 +100,7 @@ void init_export(nb::module_& m) { "import_function", [](const std::string& path) { return nb::cpp_function( - [fn = import_function(path)](const nb::args& arrays) { + [fn = mx::import_function(path)](const nb::args& arrays) { if (!valid_inputs(arrays)) { throw std::invalid_argument( "[import_function::call] Inputs can be either a variable " @@ -124,13 +123,13 @@ void init_export(nb::module_& m) { m.def( "export_to_dot", [](nb::object file, const nb::args& args) { - std::vector arrays = tree_flatten(args); + std::vector arrays = tree_flatten(args); if (nb::isinstance(file)) { std::ofstream out(nb::cast(file)); - export_to_dot(out, arrays); + mx::export_to_dot(out, arrays); } else if (nb::hasattr(file, "write")) { std::ostringstream out; - export_to_dot(out, arrays); + mx::export_to_dot(out, arrays); auto write = file.attr("write"); write(out.str()); } else {