From 9ba0721a943c190450f1e521069da46253daa951 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 11 Jul 2024 01:56:26 -0300 Subject: [PATCH] refactor: return binary instead of EXLA.MLIR.Module struct --- exla/c_src/exla/exla.cc | 37 +++++++++++++++++++------------------ exla/lib/exla.ex | 8 ++++---- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index f1cf82a29d..2bf3fadd26 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -1,18 +1,16 @@ #include -#include "exla_mlir.h" #include "exla_client.h" #include "exla_cuda.h" #include "exla_log_sink.h" +#include "exla_mlir.h" #include "exla_nif_util.h" - -#include "xla/pjrt/pjrt_api.h" -#include "xla/service/platform_util.h" - #include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" +#include "xla/pjrt/pjrt_api.h" +#include "xla/service/platform_util.h" // All of these are created with calls to `new` and subsequently // passed to the VM as pointers-to-pointers so we balance it out @@ -202,9 +200,9 @@ ERL_NIF_TERM mlir_create_function(ErlNifEnv* env, int argc, const ERL_NIF_TERM a auto arg_types = std::vector{}; - for (auto const & type_string : arg_type_strings) { + for (auto const& type_string : arg_type_strings) { auto type = (*module)->ParseType(type_string); - if(type == nullptr) { + if (type == nullptr) { return type_parsing_error(env, type_string); } arg_types.push_back(type); @@ -212,9 +210,9 @@ ERL_NIF_TERM mlir_create_function(ErlNifEnv* env, int argc, const ERL_NIF_TERM a auto ret_types = std::vector{}; - for (auto const & type_string : ret_type_strings) { + for (auto const& type_string : ret_type_strings) { auto type = (*module)->ParseType(type_string); - if(type == nullptr) { + if (type == nullptr) { return type_parsing_error(env, type_string); } ret_types.push_back(type); @@ -281,9 +279,9 @@ ERL_NIF_TERM mlir_op(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { auto result_types = std::vector{}; - for (auto const & type_string : result_type_strings) { + for (auto const& type_string : result_type_strings) { auto type = (*function)->module()->ParseType(type_string); - if(type == nullptr) { + if (type == nullptr) { return type_parsing_error(env, type_string); } result_types.push_back(type); @@ -291,9 +289,9 @@ ERL_NIF_TERM mlir_op(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { auto attributes = std::vector>{}; - for (auto const & pair : attributes_kwlist) { + for (auto const& pair : attributes_kwlist) { auto attribute_value = (*function)->module()->ParseAttribute(pair.second); - if(attribute_value == nullptr) { + if (attribute_value == nullptr) { return attribute_parsing_error(env, pair.second); } attributes.push_back(std::pair{pair.first, attribute_value}); @@ -304,7 +302,6 @@ ERL_NIF_TERM mlir_op(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { return exla::nif::ok(env, exla::nif::make_list(env, results)); } - ERL_NIF_TERM mlir_push_region(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { if (argc != 2) { return exla::nif::error(env, "Bad argument count."); @@ -322,9 +319,9 @@ ERL_NIF_TERM mlir_push_region(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[ auto types = std::vector{}; - for (auto const & type_string : arg_types) { + for (auto const& type_string : arg_types) { auto type = (*function)->module()->ParseType(type_string); - if(type == nullptr) { + if (type == nullptr) { return type_parsing_error(env, type_string); } types.push_back(type); @@ -379,9 +376,13 @@ ERL_NIF_TERM mlir_module_to_string(ErlNifEnv* env, int argc, const ERL_NIF_TERM return exla::nif::error(env, "Unable to get builder."); } - auto string = (*module)->ToString(); + std::string string = (*module)->ToString(); + + ErlNifBinary bin; + enif_alloc_binary(string.size(), &bin); + memcpy(bin.data, string.c_str(), string.size()); - return exla::nif::ok(env, exla::nif::make(env, string)); + return exla::nif::ok(env, exla::nif::make(env, bin)); } // ExlaBuffer Functions diff --git a/exla/lib/exla.ex b/exla/lib/exla.ex index 5f167d7a75..55da871384 100644 --- a/exla/lib/exla.ex +++ b/exla/lib/exla.ex @@ -363,9 +363,8 @@ defmodule EXLA do iex> fun = fn x, y -> Nx.add(Nx.sin(x), Nx.cos(y)) end iex> args = [1.0, 2.0] - iex> module = EXLA.to_mlir_module(fun, args) - iex> EXLA.MLIR.Module.as_string(module) - ~c\"\"\" + iex> EXLA.to_mlir_module(fun, args) + \"\"\" module { func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { %0 = stablehlo.sine %arg0 : tensor @@ -389,7 +388,8 @@ defmodule EXLA do ]) |> apply(args) catch - {:mlir_module, ref} -> %EXLA.MLIR.Module{ref: ref} + {:mlir_module, ref} -> + EXLA.MLIR.Module.as_string(%EXLA.MLIR.Module{ref: ref}) end @doc """