diff --git a/exla/lib/exla.ex b/exla/lib/exla.ex index ebfbe35487..f4714c1f8d 100644 --- a/exla/lib/exla.ex +++ b/exla/lib/exla.ex @@ -360,11 +360,18 @@ defmodule EXLA do Takes in a function, the argument templates and the compilation options and returns the textual representation of the MLIR module. + ## Options + + * `:within_defn_compiler` - a boolean that indicates whether + this function is being called from within a `defn` compiler. + Defaults to `false`. + ## Examples iex> fun = fn x, y -> Nx.add(Nx.sin(x), Nx.cos(y)) end iex> args = [1.0, 2.0] - iex> EXLA.to_mlir_module(fun, args) + iex> %{mlir_module: mlir_module} = EXLA.to_mlir_module(fun, args) + iex> mlir_module """ module { func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { @@ -377,20 +384,26 @@ defmodule EXLA do """ ''' def to_mlir_module(function, args, options \\ []) do - comp_fun = fn _key, callback -> - {:ok, {_xla_time, executable, _extra, _outfeed}} = callback.() - throw({:mlir_module, executable.ref}) - end + {nested_compilation?, options} = Keyword.pop(options, :within_defn_compiler, false) - opts = [ - {EXLA, {&EXLA.Defn.LockedCache.run/2, comp_fun}}, - {:module_compilation, :to_mlir} | options - ] + opts = + Keyword.merge(options, + module_compilation: :to_mlir, + compiler: EXLA + ) - jit_apply(function, args, opts) + if nested_compilation? do + EXLA.Defn.__compile__(function, args, function, opts) + else + Nx.Defn.compile(function, args, opts) + end catch - {:mlir_module, ref} -> - EXLA.MLIR.Module.as_string(%EXLA.MLIR.Module{ref: ref}) + {:mlir_module, ref, used_inputs, output_container} -> + %{ + used_inputs: used_inputs, + output_container: output_container, + mlir_module: EXLA.MLIR.Module.as_string(%EXLA.MLIR.Module{ref: ref}) + } end @doc """ diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 652a5fefc7..d55250f371 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -250,6 +250,10 @@ defmodule EXLA.Defn do {executable, {used_inputs, outputs, outfeed, _input_typespecs?}} = compile(client, key, vars, fun, compile_options, 0, [], _stream = false, debug?, callback) + if compile_options[:module_compilation] == :to_mlir do + throw({:mlir_module, executable.ref, MapSet.new(Map.keys(used_inputs)), outputs}) + end + fn [args] -> {time, lock} = :timer.tc(fn -> diff --git a/exla/test/exla_test.exs b/exla/test/exla_test.exs index 5558c73af5..2092ca9e90 100644 --- a/exla/test/exla_test.exs +++ b/exla/test/exla_test.exs @@ -17,4 +17,57 @@ defmodule EXLATest do end end end + + defmodule ValidCompiler do + def __jit__(key, vars, fun, args_list, opts) do + __compile__(key, vars, fun, opts).(args_list) + end + + def __compile__(_key, vars, fun, opts) do + result = EXLA.to_mlir_module(fun, vars, Keyword.put(opts, :within_defn_compiler, true)) + throw({__MODULE__, result}) + end + end + + defmodule InvalidCompiler do + def __jit__(key, vars, fun, args_list, opts) do + __compile__(key, vars, fun, opts).(args_list) + end + + def __compile__(_key, vars, fun, opts) do + # Keyword.delete to ensure default is false + EXLA.to_mlir_module(fun, vars, Keyword.delete(opts, :within_defn_compiler)) + end + end + + describe "to_mlir_module/3" do + test "fails if the compiler doesn't set the nested compilation flag" do + assert_raise BadArityError, fn -> + Nx.Defn.jit_apply(&Nx.add/2, [1, 2], compiler: __MODULE__.InvalidCompiler) + end + end + + test "works if the compiler sets the nested compilation flag" do + try do + Nx.Defn.jit_apply(&Nx.add/2, [1, 2], compiler: __MODULE__.ValidCompiler) + catch + {__MODULE__.ValidCompiler, result} -> + assert %{mlir_module: module, output_container: container, used_inputs: used_inputs} = + result + + assert module == """ + module { + func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = stablehlo.add %arg0, %arg1 : tensor + return %0 : tensor + } + } + """ + + assert Nx.compatible?(container, Nx.template({}, :s32)) + + assert MapSet.equal?(used_inputs, MapSet.new([0, 1])) + end + end + end end