Skip to content

Commit

Permalink
rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
awni committed Dec 19, 2024
1 parent cc0b30f commit b117dec
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions python/src/export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,22 @@
#include "mlx/graph_utils.h"
#include "python/src/trees.h"

namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;

template <typename T>
bool check_arrs(const T& iterable) {
for (auto it = iterable.begin(); it != iterable.end(); ++it) {
if (!nb::isinstance<array>(*it)) {
if (!nb::isinstance<mx::array>(*it)) {
return false;
}
}
return true;
};

bool valid_inputs(const nb::args& inputs) {
if (inputs.size() > 0 && nb::isinstance<array>(inputs[0])) {
if (inputs.size() > 0 && nb::isinstance<mx::array>(inputs[0])) {
return check_arrs(inputs);
} else if (inputs.size() > 1) {
return false;
Expand All @@ -42,7 +42,7 @@ bool valid_inputs(const nb::args& inputs) {
}

bool valid_outputs(const nb::object& outputs) {
if (nb::isinstance<array>(outputs)) {
if (nb::isinstance<mx::array>(outputs)) {
return true;
} else if (nb::isinstance<nb::list>(outputs)) {
return check_arrs(nb::cast<nb::list>(outputs));
Expand All @@ -65,10 +65,9 @@ void init_export(nb::module_& m) {
"of arrays or a single tuple or list of arrays.");
}

std::vector<array> inputs = tree_flatten(arrays, true);
auto wrapped_fun =
[&fun,
&arrays](const std::vector<array>& inputs) -> std::vector<array> {
std::vector<mx::array> inputs = tree_flatten(arrays, true);
auto wrapped_fun = [&fun, &arrays](const std::vector<mx::array>& inputs)
-> std::vector<mx::array> {
auto outputs = fun(*tree_unflatten(arrays, inputs));
if (!valid_outputs(outputs)) {
throw std::invalid_argument(
Expand All @@ -77,7 +76,7 @@ void init_export(nb::module_& m) {
}
return tree_flatten(outputs, true);
};
export_function(path, wrapped_fun, inputs, shapeless);
mx::export_function(path, wrapped_fun, inputs, shapeless);
},
"path"_a,
"fun"_a,
Expand All @@ -101,7 +100,7 @@ void init_export(nb::module_& m) {
"import_function",
[](const std::string& path) {
return nb::cpp_function(
[fn = import_function(path)](const nb::args& arrays) {
[fn = mx::import_function(path)](const nb::args& arrays) {
if (!valid_inputs(arrays)) {
throw std::invalid_argument(
"[import_function::call] Inputs can be either a variable "
Expand All @@ -124,13 +123,13 @@ void init_export(nb::module_& m) {
m.def(
"export_to_dot",
[](nb::object file, const nb::args& args) {
std::vector<array> arrays = tree_flatten(args);
std::vector<mx::array> arrays = tree_flatten(args);
if (nb::isinstance<nb::str>(file)) {
std::ofstream out(nb::cast<std::string>(file));
export_to_dot(out, arrays);
mx::export_to_dot(out, arrays);
} else if (nb::hasattr(file, "write")) {
std::ostringstream out;
export_to_dot(out, arrays);
mx::export_to_dot(out, arrays);
auto write = file.attr("write");
write(out.str());
} else {
Expand Down

0 comments on commit b117dec

Please sign in to comment.