Skip to content

Commit

Permalink
refactor: return binary instead of EXLA.MLIR.Module struct
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente committed Jul 11, 2024
1 parent bbef0a2 commit 9ba0721
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 22 deletions.
37 changes: 19 additions & 18 deletions exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
#include <string>

#include "exla_mlir.h"
#include "exla_client.h"
#include "exla_cuda.h"
#include "exla_log_sink.h"
#include "exla_mlir.h"
#include "exla_nif_util.h"

#include "xla/pjrt/pjrt_api.h"
#include "xla/service/platform_util.h"

#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/pjrt/pjrt_api.h"
#include "xla/service/platform_util.h"

// All of these are created with calls to `new` and subsequently
// passed to the VM as pointers-to-pointers so we balance it out
Expand Down Expand Up @@ -202,19 +200,19 @@ ERL_NIF_TERM mlir_create_function(ErlNifEnv* env, int argc, const ERL_NIF_TERM a

auto arg_types = std::vector<mlir::Type>{};

for (auto const & type_string : arg_type_strings) {
for (auto const& type_string : arg_type_strings) {
auto type = (*module)->ParseType(type_string);
if(type == nullptr) {
if (type == nullptr) {
return type_parsing_error(env, type_string);
}
arg_types.push_back(type);
}

auto ret_types = std::vector<mlir::Type>{};

for (auto const & type_string : ret_type_strings) {
for (auto const& type_string : ret_type_strings) {
auto type = (*module)->ParseType(type_string);
if(type == nullptr) {
if (type == nullptr) {
return type_parsing_error(env, type_string);
}
ret_types.push_back(type);
Expand Down Expand Up @@ -281,19 +279,19 @@ ERL_NIF_TERM mlir_op(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {

auto result_types = std::vector<mlir::Type>{};

for (auto const & type_string : result_type_strings) {
for (auto const& type_string : result_type_strings) {
auto type = (*function)->module()->ParseType(type_string);
if(type == nullptr) {
if (type == nullptr) {
return type_parsing_error(env, type_string);
}
result_types.push_back(type);
}

auto attributes = std::vector<std::pair<std::string, mlir::Attribute>>{};

for (auto const & pair : attributes_kwlist) {
for (auto const& pair : attributes_kwlist) {
auto attribute_value = (*function)->module()->ParseAttribute(pair.second);
if(attribute_value == nullptr) {
if (attribute_value == nullptr) {
return attribute_parsing_error(env, pair.second);
}
attributes.push_back(std::pair{pair.first, attribute_value});
Expand All @@ -304,7 +302,6 @@ ERL_NIF_TERM mlir_op(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
return exla::nif::ok(env, exla::nif::make_list<mlir::Value>(env, results));
}


ERL_NIF_TERM mlir_push_region(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 2) {
return exla::nif::error(env, "Bad argument count.");
Expand All @@ -322,9 +319,9 @@ ERL_NIF_TERM mlir_push_region(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[

auto types = std::vector<mlir::Type>{};

for (auto const & type_string : arg_types) {
for (auto const& type_string : arg_types) {
auto type = (*function)->module()->ParseType(type_string);
if(type == nullptr) {
if (type == nullptr) {
return type_parsing_error(env, type_string);
}
types.push_back(type);
Expand Down Expand Up @@ -379,9 +376,13 @@ ERL_NIF_TERM mlir_module_to_string(ErlNifEnv* env, int argc, const ERL_NIF_TERM
return exla::nif::error(env, "Unable to get builder.");
}

auto string = (*module)->ToString();
std::string string = (*module)->ToString();

ErlNifBinary bin;
enif_alloc_binary(string.size(), &bin);
memcpy(bin.data, string.c_str(), string.size());

return exla::nif::ok(env, exla::nif::make(env, string));
return exla::nif::ok(env, exla::nif::make(env, bin));
}

// ExlaBuffer Functions
Expand Down
8 changes: 4 additions & 4 deletions exla/lib/exla.ex
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,8 @@ defmodule EXLA do
iex> fun = fn x, y -> Nx.add(Nx.sin(x), Nx.cos(y)) end
iex> args = [1.0, 2.0]
iex> module = EXLA.to_mlir_module(fun, args)
iex> EXLA.MLIR.Module.as_string(module)
~c\"\"\"
iex> EXLA.to_mlir_module(fun, args)
\"\"\"
module {
func.func public @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
%0 = stablehlo.sine %arg0 : tensor<f32>
Expand All @@ -389,7 +388,8 @@ defmodule EXLA do
])
|> apply(args)
catch
{:mlir_module, ref} -> %EXLA.MLIR.Module{ref: ref}
{:mlir_module, ref} ->
EXLA.MLIR.Module.as_string(%EXLA.MLIR.Module{ref: ref})
end

@doc """
Expand Down

0 comments on commit 9ba0721

Please sign in to comment.