diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 86cceb4b5d..1619065812 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -251,15 +251,8 @@ defmodule EXLA.Defn do def __compile__(key, vars, fun, options) do {run_options, compile_options} = Keyword.pop(options, :run_options, []) - {client_name, compile_options} = - Keyword.pop_lazy(compile_options, :client, &EXLA.Client.default_name/0) - - client = EXLA.Client.fetch!(client_name) - - callback = &to_root_computation(&1, &2, &3, &4, Keyword.put(compile_options, :client, client)) - - {executable, used_inputs, outputs, outfeed, :ok, debug?} = - compile(client, key, vars, fun, compile_options, 0, [], _stream = false, callback) + {:ok, {executable, {used_inputs, outputs, outfeed, debug?}}} = + compile_executable(key, vars, fun, compile_options) fn [args] -> {time, lock} = @@ -284,6 +277,64 @@ defmodule EXLA.Defn do end end + @doc """ + Takes in a function, the templates variables and the compilation options + and returns the `EXLA.Executable` struct. + + ## Examples + + iex> fun = fn x, y -> Nx.add(Nx.sin(x), Nx.cos(y)) end + iex> args = [Nx.to_template(1.0), Nx.to_template(2.0)] + iex> {:ok, %EXLA.Executable{ref: ref, mlir_module: %EXLA.MLIR.Module{}}} = EXLA.Defn.export_executable(fun, args) + iex> is_reference(ref) + true + iex> {:ok, %EXLA.Executable{ref: nil, mlir_module: %EXLA.MLIR.Module{}}} = EXLA.Defn.export_executable(fun, args, compile_mlir: false) + """ + def export_executable(fun, vars, options \\ []) do + runtime_fun = + fn args -> + fun + |> apply(args) + |> Nx.Defn.Composite.traverse(&Nx.Defn.Expr.tensor/1) + end + + {params, _} = + Enum.map_reduce(vars, 0, fn + arg, i + when is_list(arg) + when is_function(arg) + when is_tuple(arg) and is_function(elem(arg, 0)) -> + {arg, i} + + container, i -> + Nx.Defn.Composite.traverse(container, i, fn + template, i -> + {Nx.Defn.Expr.parameter(template, :root, i), i + 1} + end) + end) + + {:ok, {executable, _}} = + compile_executable(fun, params, runtime_fun, Keyword.delete(options, :run_options)) + + {:ok, executable} + end + + defp compile_executable(key, vars, fun, compile_options) do + {client_name, compile_options} = + Keyword.pop_lazy(compile_options, :client, &EXLA.Client.default_name/0) + + compile_options = Keyword.put_new(compile_options, :runtime, :xla) + + client = EXLA.Client.fetch!(client_name) + + callback = &to_root_computation(&1, &2, &3, &4, Keyword.put(compile_options, :client, client)) + + {executable, used_inputs, outputs, outfeed, :ok, debug?} = + compile(client, key, vars, fun, compile_options, 0, [], _stream = false, callback) + + {:ok, {executable, {used_inputs, outputs, outfeed, debug?}}} + end + defp to_root_computation(%Function{} = function, expr, used_typespecs, outfeed, options) do params = Enum.zip_with(used_typespecs, Function.get_arguments(function), fn {pos, _typespec}, arg -> diff --git a/exla/lib/exla/executable.ex b/exla/lib/exla/executable.ex index a6a0c8cbdf..41e9e5db76 100644 --- a/exla/lib/exla/executable.ex +++ b/exla/lib/exla/executable.ex @@ -6,8 +6,24 @@ defmodule EXLA.Executable do alias __MODULE__ alias EXLA.{BinaryBuffer, DeviceBuffer} - @enforce_keys [:client, :ref, :output_typespecs, :num_replicas, :num_partitions, :device_id] - defstruct [:client, :ref, :output_typespecs, :num_replicas, :num_partitions, :device_id] + @enforce_keys [ + :client, + :ref, + :output_typespecs, + :num_replicas, + :num_partitions, + :device_id, + :mlir_module + ] + defstruct [ + :client, + :ref, + :output_typespecs, + :num_replicas, + :num_partitions, + :device_id, + :mlir_module + ] @doc """ Runs the given executable with a list of lists as inputs and the given options. diff --git a/exla/lib/exla/mlir/module.ex b/exla/lib/exla/mlir/module.ex index dff1b7ae80..278c814ec5 100644 --- a/exla/lib/exla/mlir/module.ex +++ b/exla/lib/exla/mlir/module.ex @@ -71,6 +71,10 @@ defmodule EXLA.MLIR.Module do * `:use_spmd` - enables Single-Program Multiple-Data partioning. This is set to true if `:num_partitions` is more than one, otherwise is `false`. + * `:compile_mlir` - whether to compile the MLIR module. If set to `false`, + the `ref` field of the returned `Executable` struct will be `nil`. Useful + if you want to serialize the module and compile it with a different stack. + Currently those options do not have an effect as they related to running the same compiled executable on multiple replicas. @@ -101,21 +105,28 @@ defmodule EXLA.MLIR.Module do # Uncomment to debug the module MLIR source # module |> as_string() |> IO.puts() + compile_mlir = Keyword.get(options, :compile_mlir, true) + ref = - EXLA.NIF.mlir_compile( - client.ref, - module.ref, - Enum.map(argument_typespecs, &EXLA.Typespec.nif_encode/1), - num_replicas, - num_partitions, - use_spmd, - device_id - ) - |> unwrap!() + if compile_mlir do + EXLA.NIF.mlir_compile( + client.ref, + module.ref, + Enum.map(argument_typespecs, &EXLA.Typespec.nif_encode/1), + num_replicas, + num_partitions, + use_spmd, + device_id + ) + |> unwrap!() + else + nil + end %Executable{ client: client, ref: ref, + mlir_module: module, output_typespecs: return_typespecs, num_replicas: num_replicas, num_partitions: num_partitions, diff --git a/exla/test/exla/defn_test.exs b/exla/test/exla/defn_test.exs new file mode 100644 index 0000000000..f3c7211b70 --- /dev/null +++ b/exla/test/exla/defn_test.exs @@ -0,0 +1,5 @@ +defmodule EXLA.DefnTest do + use EXLA.Case, async: true + + doctest EXLA.Defn +end