Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: export executable from EXLA #1496

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 60 additions & 9 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,8 @@ defmodule EXLA.Defn do
def __compile__(key, vars, fun, options) do
{run_options, compile_options} = Keyword.pop(options, :run_options, [])

{client_name, compile_options} =
Keyword.pop_lazy(compile_options, :client, &EXLA.Client.default_name/0)

client = EXLA.Client.fetch!(client_name)

callback = &to_root_computation(&1, &2, &3, &4, Keyword.put(compile_options, :client, client))

{executable, used_inputs, outputs, outfeed, :ok, debug?} =
compile(client, key, vars, fun, compile_options, 0, [], _stream = false, callback)
{:ok, {executable, {used_inputs, outputs, outfeed, debug?}}} =
compile_executable(key, vars, fun, compile_options)

fn [args] ->
{time, lock} =
Expand All @@ -284,6 +277,64 @@ defmodule EXLA.Defn do
end
end

@doc """
Takes in a function, the templates variables and the compilation options
and returns the `EXLA.Executable` struct.

## Examples

iex> fun = fn x, y -> Nx.add(Nx.sin(x), Nx.cos(y)) end
iex> args = [Nx.to_template(1.0), Nx.to_template(2.0)]
iex> {:ok, %EXLA.Executable{ref: ref, mlir_module: %EXLA.MLIR.Module{}}} = EXLA.Defn.export_executable(fun, args)
iex> is_reference(ref)
true
iex> {:ok, %EXLA.Executable{ref: nil, mlir_module: %EXLA.MLIR.Module{}}} = EXLA.Defn.export_executable(fun, args, compile_mlir: false)
"""
def export_executable(fun, vars, options \\ []) do
runtime_fun =
fn args ->
fun
|> apply(args)
|> Nx.Defn.Composite.traverse(&Nx.Defn.Expr.tensor/1)
end

{params, _} =
Enum.map_reduce(vars, 0, fn
arg, i
when is_list(arg)
when is_function(arg)
when is_tuple(arg) and is_function(elem(arg, 0)) ->
{arg, i}

container, i ->
Nx.Defn.Composite.traverse(container, i, fn
template, i ->
{Nx.Defn.Expr.parameter(template, :root, i), i + 1}
end)
end)

{:ok, {executable, _}} =
compile_executable(fun, params, runtime_fun, Keyword.delete(options, :run_options))

{:ok, executable}
end

defp compile_executable(key, vars, fun, compile_options) do
{client_name, compile_options} =
Keyword.pop_lazy(compile_options, :client, &EXLA.Client.default_name/0)

compile_options = Keyword.put_new(compile_options, :runtime, :xla)

client = EXLA.Client.fetch!(client_name)

callback = &to_root_computation(&1, &2, &3, &4, Keyword.put(compile_options, :client, client))

{executable, used_inputs, outputs, outfeed, :ok, debug?} =
compile(client, key, vars, fun, compile_options, 0, [], _stream = false, callback)

{:ok, {executable, {used_inputs, outputs, outfeed, debug?}}}
end

defp to_root_computation(%Function{} = function, expr, used_typespecs, outfeed, options) do
params =
Enum.zip_with(used_typespecs, Function.get_arguments(function), fn {pos, _typespec}, arg ->
Expand Down
20 changes: 18 additions & 2 deletions exla/lib/exla/executable.ex
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,24 @@ defmodule EXLA.Executable do
alias __MODULE__
alias EXLA.{BinaryBuffer, DeviceBuffer}

@enforce_keys [:client, :ref, :output_typespecs, :num_replicas, :num_partitions, :device_id]
defstruct [:client, :ref, :output_typespecs, :num_replicas, :num_partitions, :device_id]
@enforce_keys [
:client,
:ref,
:output_typespecs,
:num_replicas,
:num_partitions,
:device_id,
:mlir_module
]
defstruct [
:client,
:ref,
:output_typespecs,
:num_replicas,
:num_partitions,
:device_id,
:mlir_module
]

@doc """
Runs the given executable with a list of lists as inputs and the given options.
Expand Down
31 changes: 21 additions & 10 deletions exla/lib/exla/mlir/module.ex
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ defmodule EXLA.MLIR.Module do
* `:use_spmd` - enables Single-Program Multiple-Data partioning.
This is set to true if `:num_partitions` is more than one, otherwise is `false`.

* `:compile_mlir` - whether to compile the MLIR module. If set to `false`,
the `ref` field of the returned `Executable` struct will be `nil`. Useful
if you want to serialize the module and compile it with a different stack.

Currently those options do not have an effect as they related to running the
same compiled executable on multiple replicas.

Expand Down Expand Up @@ -101,21 +105,28 @@ defmodule EXLA.MLIR.Module do
# Uncomment to debug the module MLIR source
# module |> as_string() |> IO.puts()

compile_mlir = Keyword.get(options, :compile_mlir, true)

ref =
EXLA.NIF.mlir_compile(
client.ref,
module.ref,
Enum.map(argument_typespecs, &EXLA.Typespec.nif_encode/1),
num_replicas,
num_partitions,
use_spmd,
device_id
)
|> unwrap!()
if compile_mlir do
EXLA.NIF.mlir_compile(
client.ref,
module.ref,
Enum.map(argument_typespecs, &EXLA.Typespec.nif_encode/1),
num_replicas,
num_partitions,
use_spmd,
device_id
)
|> unwrap!()
else
nil
end

%Executable{
client: client,
ref: ref,
mlir_module: module,
output_typespecs: return_typespecs,
num_replicas: num_replicas,
num_partitions: num_partitions,
Expand Down
5 changes: 5 additions & 0 deletions exla/test/exla/defn_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
defmodule EXLA.DefnTest do
use EXLA.Case, async: true

doctest EXLA.Defn
end
Loading