From 765f657389ae2331f7d3437ebc30b531b6e3376e Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 23 May 2024 18:55:10 -0300 Subject: [PATCH 1/6] feat: add EXLA.to_mlir_module/3 --- exla/lib/exla.ex | 15 +++++++++++++++ exla/lib/exla/mlir/module.ex | 4 ++++ 2 files changed, 19 insertions(+) diff --git a/exla/lib/exla.ex b/exla/lib/exla.ex index 782c9b4305..6d8e5e5d8f 100644 --- a/exla/lib/exla.ex +++ b/exla/lib/exla.ex @@ -355,6 +355,21 @@ defmodule EXLA do Nx.Defn.stream(function, args, Keyword.put(options, :compiler, EXLA)) end + def to_mlir_module(function, args, options \\ []) do + {expr_fun, _} = cached_check() + + comp_fun = fn _key, callback -> + {:ok, {_xla_time, executable, _extra, _outfeed}} = callback.() + throw({:mlir_module, executable.ref}) + end + + function + |> jit([{EXLA, {expr_fun, comp_fun}}, {:module_compilation, :to_mlir} | options]) + |> apply(args) + catch + {:mlir_module, ref} -> %EXLA.MLIR.Module{ref: ref} + end + @doc """ Checks if the compilation of function with args is cached. diff --git a/exla/lib/exla/mlir/module.ex b/exla/lib/exla/mlir/module.ex index dff1b7ae80..d72e44681e 100644 --- a/exla/lib/exla/mlir/module.ex +++ b/exla/lib/exla/mlir/module.ex @@ -71,6 +71,10 @@ defmodule EXLA.MLIR.Module do * `:use_spmd` - enables Single-Program Multiple-Data partioning. This is set to true if `:num_partitions` is more than one, otherwise is `false`. + * `:module_compilation` - either `:to_mlir` or `:to_pjrt`. The default is `:to_pjrt`. + If `:to_pjrt`, the Executable `:ref` field will hold the reference to a PjRt executable. + If `:to_mlir`, the Executable `:ref` field will hold the reference to an MLIR executable. + Currently those options do not have an effect as they related to running the same compiled executable on multiple replicas. From de3e5a71f005508b09c5b071db16540f1a88c270 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 23 May 2024 19:02:49 -0300 Subject: [PATCH 2/6] doc: add doctest --- exla/lib/exla.ex | 25 ++++++++++++++++++++++--- exla/lib/exla/mlir/module.ex | 26 ++++++++++++++++---------- 2 files changed, 38 insertions(+), 13 deletions(-) diff --git a/exla/lib/exla.ex b/exla/lib/exla.ex index 6d8e5e5d8f..8b818a58ca 100644 --- a/exla/lib/exla.ex +++ b/exla/lib/exla.ex @@ -355,16 +355,35 @@ defmodule EXLA do Nx.Defn.stream(function, args, Keyword.put(options, :compiler, EXLA)) end - def to_mlir_module(function, args, options \\ []) do - {expr_fun, _} = cached_check() + @doc """ + Takes in a function, the templates variables and the compilation options + and returns the `EXLA.Executable` struct. + + ## Examples + iex> fun = fn x, y -> Nx.add(Nx.sin(x), Nx.cos(y)) end + iex> args = [1.0, 2.0] + iex> module = EXLA.to_mlir_module(fun, args) + iex> EXLA.MLIR.Module.as_string(module) + ~c\"\"\" + module { + func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = stablehlo.sine %arg0 : tensor + %1 = stablehlo.cosine %arg1 : tensor + %2 = stablehlo.add %0, %1 : tensor + stablehlo.return %2 : tensor + } + } + \"\"\" + """ + def to_mlir_module(function, args, options \\ []) do comp_fun = fn _key, callback -> {:ok, {_xla_time, executable, _extra, _outfeed}} = callback.() throw({:mlir_module, executable.ref}) end function - |> jit([{EXLA, {expr_fun, comp_fun}}, {:module_compilation, :to_mlir} | options]) + |> jit([{EXLA, {&EXLA.Defn.LockedCache.run/2, comp_fun}}, {:module_compilation, :to_mlir} | options]) |> apply(args) catch {:mlir_module, ref} -> %EXLA.MLIR.Module{ref: ref} diff --git a/exla/lib/exla/mlir/module.ex b/exla/lib/exla/mlir/module.ex index d72e44681e..c98bd77b1b 100644 --- a/exla/lib/exla/mlir/module.ex +++ b/exla/lib/exla/mlir/module.ex @@ -106,16 +106,22 @@ defmodule EXLA.MLIR.Module do # module |> as_string() |> IO.puts() ref = - EXLA.NIF.mlir_compile( - client.ref, - module.ref, - Enum.map(argument_typespecs, &EXLA.Typespec.nif_encode/1), - num_replicas, - num_partitions, - use_spmd, - device_id - ) - |> unwrap!() + case Keyword.get(options, :module_compilation, :to_pjrt) do + :to_mlir -> + module.ref + + :to_pjrt -> + EXLA.NIF.mlir_compile( + client.ref, + module.ref, + Enum.map(argument_typespecs, &EXLA.Typespec.nif_encode/1), + num_replicas, + num_partitions, + use_spmd, + device_id + ) + |> unwrap!() + end %Executable{ client: client, From a9eecd9e72007701a9fa9a1da77b9bdd3a413603 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 23 May 2024 19:05:38 -0300 Subject: [PATCH 3/6] docs: update docs --- exla/lib/exla/mlir/module.ex | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/exla/lib/exla/mlir/module.ex b/exla/lib/exla/mlir/module.ex index c98bd77b1b..074ad0f721 100644 --- a/exla/lib/exla/mlir/module.ex +++ b/exla/lib/exla/mlir/module.ex @@ -72,8 +72,9 @@ defmodule EXLA.MLIR.Module do This is set to true if `:num_partitions` is more than one, otherwise is `false`. * `:module_compilation` - either `:to_mlir` or `:to_pjrt`. The default is `:to_pjrt`. - If `:to_pjrt`, the Executable `:ref` field will hold the reference to a PjRt executable. - If `:to_mlir`, the Executable `:ref` field will hold the reference to an MLIR executable. + + * `:to_pjrt` - the `EXLA.Executable` `:ref` field will hold the reference to a PjRt executable. + * `:to_mlir` - the `EXLA.Executable` `:ref` field will hold the reference to an MLIR module. Currently those options do not have an effect as they related to running the same compiled executable on multiple replicas. From bbef0a21f322bf2a18404ac43900c5871059f027 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 23 May 2024 19:19:15 -0300 Subject: [PATCH 4/6] format --- exla/lib/exla.ex | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/exla/lib/exla.ex b/exla/lib/exla.ex index 8b818a58ca..5f167d7a75 100644 --- a/exla/lib/exla.ex +++ b/exla/lib/exla.ex @@ -383,7 +383,10 @@ defmodule EXLA do end function - |> jit([{EXLA, {&EXLA.Defn.LockedCache.run/2, comp_fun}}, {:module_compilation, :to_mlir} | options]) + |> jit([ + {EXLA, {&EXLA.Defn.LockedCache.run/2, comp_fun}}, + {:module_compilation, :to_mlir} | options + ]) |> apply(args) catch {:mlir_module, ref} -> %EXLA.MLIR.Module{ref: ref} From 9ba0721a943c190450f1e521069da46253daa951 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 11 Jul 2024 01:56:26 -0300 Subject: [PATCH 5/6] refactor: return binary instead of EXLA.MLIR.Module struct --- exla/c_src/exla/exla.cc | 37 +++++++++++++++++++------------------ exla/lib/exla.ex | 8 ++++---- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index f1cf82a29d..2bf3fadd26 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -1,18 +1,16 @@ #include -#include "exla_mlir.h" #include "exla_client.h" #include "exla_cuda.h" #include "exla_log_sink.h" +#include "exla_mlir.h" #include "exla_nif_util.h" - -#include "xla/pjrt/pjrt_api.h" -#include "xla/service/platform_util.h" - #include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" +#include "xla/pjrt/pjrt_api.h" +#include "xla/service/platform_util.h" // All of these are created with calls to `new` and subsequently // passed to the VM as pointers-to-pointers so we balance it out @@ -202,9 +200,9 @@ ERL_NIF_TERM mlir_create_function(ErlNifEnv* env, int argc, const ERL_NIF_TERM a auto arg_types = std::vector{}; - for (auto const & type_string : arg_type_strings) { + for (auto const& type_string : arg_type_strings) { auto type = (*module)->ParseType(type_string); - if(type == nullptr) { + if (type == nullptr) { return type_parsing_error(env, type_string); } arg_types.push_back(type); @@ -212,9 +210,9 @@ ERL_NIF_TERM mlir_create_function(ErlNifEnv* env, int argc, const ERL_NIF_TERM a auto ret_types = std::vector{}; - for (auto const & type_string : ret_type_strings) { + for (auto const& type_string : ret_type_strings) { auto type = (*module)->ParseType(type_string); - if(type == nullptr) { + if (type == nullptr) { return type_parsing_error(env, type_string); } ret_types.push_back(type); @@ -281,9 +279,9 @@ ERL_NIF_TERM mlir_op(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { auto result_types = std::vector{}; - for (auto const & type_string : result_type_strings) { + for (auto const& type_string : result_type_strings) { auto type = (*function)->module()->ParseType(type_string); - if(type == nullptr) { + if (type == nullptr) { return type_parsing_error(env, type_string); } result_types.push_back(type); @@ -291,9 +289,9 @@ ERL_NIF_TERM mlir_op(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { auto attributes = std::vector>{}; - for (auto const & pair : attributes_kwlist) { + for (auto const& pair : attributes_kwlist) { auto attribute_value = (*function)->module()->ParseAttribute(pair.second); - if(attribute_value == nullptr) { + if (attribute_value == nullptr) { return attribute_parsing_error(env, pair.second); } attributes.push_back(std::pair{pair.first, attribute_value}); @@ -304,7 +302,6 @@ ERL_NIF_TERM mlir_op(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { return exla::nif::ok(env, exla::nif::make_list(env, results)); } - ERL_NIF_TERM mlir_push_region(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { if (argc != 2) { return exla::nif::error(env, "Bad argument count."); @@ -322,9 +319,9 @@ ERL_NIF_TERM mlir_push_region(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[ auto types = std::vector{}; - for (auto const & type_string : arg_types) { + for (auto const& type_string : arg_types) { auto type = (*function)->module()->ParseType(type_string); - if(type == nullptr) { + if (type == nullptr) { return type_parsing_error(env, type_string); } types.push_back(type); @@ -379,9 +376,13 @@ ERL_NIF_TERM mlir_module_to_string(ErlNifEnv* env, int argc, const ERL_NIF_TERM return exla::nif::error(env, "Unable to get builder."); } - auto string = (*module)->ToString(); + std::string string = (*module)->ToString(); + + ErlNifBinary bin; + enif_alloc_binary(string.size(), &bin); + memcpy(bin.data, string.c_str(), string.size()); - return exla::nif::ok(env, exla::nif::make(env, string)); + return exla::nif::ok(env, exla::nif::make(env, bin)); } // ExlaBuffer Functions diff --git a/exla/lib/exla.ex b/exla/lib/exla.ex index 5f167d7a75..55da871384 100644 --- a/exla/lib/exla.ex +++ b/exla/lib/exla.ex @@ -363,9 +363,8 @@ defmodule EXLA do iex> fun = fn x, y -> Nx.add(Nx.sin(x), Nx.cos(y)) end iex> args = [1.0, 2.0] - iex> module = EXLA.to_mlir_module(fun, args) - iex> EXLA.MLIR.Module.as_string(module) - ~c\"\"\" + iex> EXLA.to_mlir_module(fun, args) + \"\"\" module { func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { %0 = stablehlo.sine %arg0 : tensor @@ -389,7 +388,8 @@ defmodule EXLA do ]) |> apply(args) catch - {:mlir_module, ref} -> %EXLA.MLIR.Module{ref: ref} + {:mlir_module, ref} -> + EXLA.MLIR.Module.as_string(%EXLA.MLIR.Module{ref: ref}) end @doc """ From 6e12151cdab4909fd147b1dcec0310732d0af7e3 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 11 Jul 2024 02:01:41 -0300 Subject: [PATCH 6/6] fix: return value --- exla/lib/exla.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exla/lib/exla.ex b/exla/lib/exla.ex index 55da871384..4dc7b85228 100644 --- a/exla/lib/exla.ex +++ b/exla/lib/exla.ex @@ -370,7 +370,7 @@ defmodule EXLA do %0 = stablehlo.sine %arg0 : tensor %1 = stablehlo.cosine %arg1 : tensor %2 = stablehlo.add %0, %1 : tensor - stablehlo.return %2 : tensor + return %2 : tensor } } \"\"\"