Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add EXLA.to_mlir_module/2 #1497

Merged
merged 7 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 19 additions & 18 deletions exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
#include <string>

#include "exla_mlir.h"
#include "exla_client.h"
#include "exla_cuda.h"
#include "exla_log_sink.h"
#include "exla_mlir.h"
#include "exla_nif_util.h"

#include "xla/pjrt/pjrt_api.h"
#include "xla/service/platform_util.h"

#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/pjrt/pjrt_api.h"
#include "xla/service/platform_util.h"

// All of these are created with calls to `new` and subsequently
// passed to the VM as pointers-to-pointers so we balance it out
Expand Down Expand Up @@ -202,19 +200,19 @@ ERL_NIF_TERM mlir_create_function(ErlNifEnv* env, int argc, const ERL_NIF_TERM a

auto arg_types = std::vector<mlir::Type>{};

for (auto const & type_string : arg_type_strings) {
for (auto const& type_string : arg_type_strings) {
auto type = (*module)->ParseType(type_string);
if(type == nullptr) {
if (type == nullptr) {
return type_parsing_error(env, type_string);
}
arg_types.push_back(type);
}

auto ret_types = std::vector<mlir::Type>{};

for (auto const & type_string : ret_type_strings) {
for (auto const& type_string : ret_type_strings) {
auto type = (*module)->ParseType(type_string);
if(type == nullptr) {
if (type == nullptr) {
return type_parsing_error(env, type_string);
}
ret_types.push_back(type);
Expand Down Expand Up @@ -281,19 +279,19 @@ ERL_NIF_TERM mlir_op(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {

auto result_types = std::vector<mlir::Type>{};

for (auto const & type_string : result_type_strings) {
for (auto const& type_string : result_type_strings) {
auto type = (*function)->module()->ParseType(type_string);
if(type == nullptr) {
if (type == nullptr) {
return type_parsing_error(env, type_string);
}
result_types.push_back(type);
}

auto attributes = std::vector<std::pair<std::string, mlir::Attribute>>{};

for (auto const & pair : attributes_kwlist) {
for (auto const& pair : attributes_kwlist) {
auto attribute_value = (*function)->module()->ParseAttribute(pair.second);
if(attribute_value == nullptr) {
if (attribute_value == nullptr) {
return attribute_parsing_error(env, pair.second);
}
attributes.push_back(std::pair{pair.first, attribute_value});
Expand All @@ -304,7 +302,6 @@ ERL_NIF_TERM mlir_op(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
return exla::nif::ok(env, exla::nif::make_list<mlir::Value>(env, results));
}


ERL_NIF_TERM mlir_push_region(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 2) {
return exla::nif::error(env, "Bad argument count.");
Expand All @@ -322,9 +319,9 @@ ERL_NIF_TERM mlir_push_region(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[

auto types = std::vector<mlir::Type>{};

for (auto const & type_string : arg_types) {
for (auto const& type_string : arg_types) {
auto type = (*function)->module()->ParseType(type_string);
if(type == nullptr) {
if (type == nullptr) {
return type_parsing_error(env, type_string);
}
types.push_back(type);
Expand Down Expand Up @@ -379,9 +376,13 @@ ERL_NIF_TERM mlir_module_to_string(ErlNifEnv* env, int argc, const ERL_NIF_TERM
return exla::nif::error(env, "Unable to get builder.");
}

auto string = (*module)->ToString();
std::string string = (*module)->ToString();

ErlNifBinary bin;
enif_alloc_binary(string.size(), &bin);
memcpy(bin.data, string.c_str(), string.size());

return exla::nif::ok(env, exla::nif::make(env, string));
return exla::nif::ok(env, exla::nif::make(env, bin));
}

// ExlaBuffer Functions
Expand Down
37 changes: 37 additions & 0 deletions exla/lib/exla.ex
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,43 @@ defmodule EXLA do
Nx.Defn.stream(function, args, Keyword.put(options, :compiler, EXLA))
end

@doc """
Takes in a function, the templates variables and the compilation options
and returns the `EXLA.Executable` struct.

## Examples

iex> fun = fn x, y -> Nx.add(Nx.sin(x), Nx.cos(y)) end
iex> args = [1.0, 2.0]
iex> EXLA.to_mlir_module(fun, args)
\"\"\"
module {
func.func public @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
%0 = stablehlo.sine %arg0 : tensor<f32>
%1 = stablehlo.cosine %arg1 : tensor<f32>
%2 = stablehlo.add %0, %1 : tensor<f32>
return %2 : tensor<f32>
}
}
\"\"\"
"""
def to_mlir_module(function, args, options \\ []) do
comp_fun = fn _key, callback ->
{:ok, {_xla_time, executable, _extra, _outfeed}} = callback.()
throw({:mlir_module, executable.ref})
Comment on lines +380 to +381
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gah, this is relying on internals of another module but I can't think of anything better for now, so ship it.

end

function
|> jit([
{EXLA, {&EXLA.Defn.LockedCache.run/2, comp_fun}},
{:module_compilation, :to_mlir} | options
])
|> apply(args)
catch
{:mlir_module, ref} ->
EXLA.MLIR.Module.as_string(%EXLA.MLIR.Module{ref: ref})
end

@doc """
Checks if the compilation of function with args is cached.

Expand Down
31 changes: 21 additions & 10 deletions exla/lib/exla/mlir/module.ex
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ defmodule EXLA.MLIR.Module do
* `:use_spmd` - enables Single-Program Multiple-Data partioning.
This is set to true if `:num_partitions` is more than one, otherwise is `false`.

* `:module_compilation` - either `:to_mlir` or `:to_pjrt`. The default is `:to_pjrt`.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@josevalim not sure about the option naming here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is private, so it is fine!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we return %EXLA.MLIR.Module{} in the other function, shouldn't the whole module be public?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. We should return the string only, not the module struct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The public-facing function is now returning the binary directly


* `:to_pjrt` - the `EXLA.Executable` `:ref` field will hold the reference to a PjRt executable.
* `:to_mlir` - the `EXLA.Executable` `:ref` field will hold the reference to an MLIR module.

Currently those options do not have an effect as they related to running the
same compiled executable on multiple replicas.

Expand Down Expand Up @@ -102,16 +107,22 @@ defmodule EXLA.MLIR.Module do
# module |> as_string() |> IO.puts()

ref =
EXLA.NIF.mlir_compile(
client.ref,
module.ref,
Enum.map(argument_typespecs, &EXLA.Typespec.nif_encode/1),
num_replicas,
num_partitions,
use_spmd,
device_id
)
|> unwrap!()
case Keyword.get(options, :module_compilation, :to_pjrt) do
:to_mlir ->
module.ref

:to_pjrt ->
EXLA.NIF.mlir_compile(
client.ref,
module.ref,
Enum.map(argument_typespecs, &EXLA.Typespec.nif_encode/1),
num_replicas,
num_partitions,
use_spmd,
device_id
)
|> unwrap!()
end

%Executable{
client: client,
Expand Down
Loading