Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add support for defn compilation in EXLA.to_mlir_module #1530

Merged
merged 5 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions exla/lib/exla.ex
Original file line number Diff line number Diff line change
Expand Up @@ -360,11 +360,18 @@ defmodule EXLA do
Takes in a function, the argument templates and the compilation
options and returns the textual representation of the MLIR module.

## Options

* `:within_defn_compiler` - a boolean that indicates whether
this function is being called from within a `defn` compiler.
Defaults to `false`.

## Examples

iex> fun = fn x, y -> Nx.add(Nx.sin(x), Nx.cos(y)) end
iex> args = [1.0, 2.0]
iex> EXLA.to_mlir_module(fun, args)
iex> %{mlir_module: mlir_module} = EXLA.to_mlir_module(fun, args)
iex> mlir_module
"""
module {
func.func public @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
Expand All @@ -377,20 +384,26 @@ defmodule EXLA do
"""
'''
def to_mlir_module(function, args, options \\ []) do
comp_fun = fn _key, callback ->
{:ok, {_xla_time, executable, _extra, _outfeed}} = callback.()
throw({:mlir_module, executable.ref})
end
{nested_compilation?, options} = Keyword.pop(options, :within_defn_compiler, false)

opts = [
{EXLA, {&EXLA.Defn.LockedCache.run/2, comp_fun}},
{:module_compilation, :to_mlir} | options
]
opts =
Keyword.merge(options,
module_compilation: :to_mlir,
compiler: EXLA
)

jit_apply(function, args, opts)
if nested_compilation? do
EXLA.Defn.__compile__(function, args, function, opts)
else
Nx.Defn.compile(function, args, opts)
end
catch
{:mlir_module, ref} ->
EXLA.MLIR.Module.as_string(%EXLA.MLIR.Module{ref: ref})
{:mlir_module, ref, used_inputs, output_container} ->
%{
used_inputs: used_inputs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely a MapSet. :)

output_container: output_container,
mlir_module: EXLA.MLIR.Module.as_string(%EXLA.MLIR.Module{ref: ref})
}
end

@doc """
Expand Down
4 changes: 4 additions & 0 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,10 @@ defmodule EXLA.Defn do
{executable, {used_inputs, outputs, outfeed, _input_typespecs?}} =
compile(client, key, vars, fun, compile_options, 0, [], _stream = false, debug?, callback)

if compile_options[:module_compilation] == :to_mlir do
throw({:mlir_module, executable.ref, MapSet.new(Map.keys(used_inputs)), outputs})
end

fn [args] ->
{time, lock} =
:timer.tc(fn ->
Expand Down
53 changes: 53 additions & 0 deletions exla/test/exla_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,57 @@ defmodule EXLATest do
end
end
end

defmodule ValidCompiler do
def __jit__(key, vars, fun, args_list, opts) do
__compile__(key, vars, fun, opts).(args_list)
end

def __compile__(_key, vars, fun, opts) do
result = EXLA.to_mlir_module(fun, vars, Keyword.put(opts, :within_defn_compiler, true))
throw({__MODULE__, result})
end
end

defmodule InvalidCompiler do
def __jit__(key, vars, fun, args_list, opts) do
__compile__(key, vars, fun, opts).(args_list)
end

def __compile__(_key, vars, fun, opts) do
# Keyword.delete to ensure default is false
EXLA.to_mlir_module(fun, vars, Keyword.delete(opts, :within_defn_compiler))
end
end

describe "to_mlir_module/3" do
test "fails if the compiler doesn't set the nested compilation flag" do
assert_raise BadArityError, fn ->
Nx.Defn.jit_apply(&Nx.add/2, [1, 2], compiler: __MODULE__.InvalidCompiler)
end
end

test "works if the compiler sets the nested compilation flag" do
try do
Nx.Defn.jit_apply(&Nx.add/2, [1, 2], compiler: __MODULE__.ValidCompiler)
catch
{__MODULE__.ValidCompiler, result} ->
assert %{mlir_module: module, output_container: container, used_inputs: used_inputs} =
result

assert module == """
module {
func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<i32>
return %0 : tensor<i32>
}
}
"""

assert Nx.compatible?(container, Nx.template({}, :s32))

assert MapSet.equal?(used_inputs, MapSet.new([0, 1]))
end
end
end
end
Loading