Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add support for defn compilation in EXLA.to_mlir_module #1530

Merged
merged 5 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions exla/lib/exla.ex
Original file line number Diff line number Diff line change
Expand Up @@ -360,11 +360,18 @@ defmodule EXLA do
Takes in a function, the argument templates and the compilation
options and returns the textual representation of the MLIR module.

## Options

* `:nested_defn_compilation` - a boolean that indicates whether
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we check Nx.Defn.Compiler.defn?() instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried, but that only gets set inside the runtime fun call, and that's after what I need

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about Nx.Defn.Compiler.current()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried both of those I believe

this function is being called from within a `defn` compiler.
Defaults to `false`.

## Examples

iex> fun = fn x, y -> Nx.add(Nx.sin(x), Nx.cos(y)) end
iex> args = [1.0, 2.0]
iex> EXLA.to_mlir_module(fun, args)
iex> %{mlir_module: mlir_module} = EXLA.to_mlir_module(fun, args)
iex> mlir_module
"""
module {
func.func public @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
Expand All @@ -378,19 +385,31 @@ defmodule EXLA do
'''
def to_mlir_module(function, args, options \\ []) do
comp_fun = fn _key, callback ->
{:ok, {_xla_time, executable, _extra, _outfeed}} = callback.()
throw({:mlir_module, executable.ref})
{:ok, {_xla_time, executable, {_, used_inputs, outputs}, _outfeed}} = callback.()
throw({:mlir_module, executable.ref, used_inputs, outputs})
end

opts = [
{EXLA, {&EXLA.Defn.LockedCache.run/2, comp_fun}},
{:module_compilation, :to_mlir} | options
]
{nested_compilation?, options} = Keyword.pop(options, :nested_defn_compilation, false)

jit_apply(function, args, opts)
opts =
Keyword.merge(options, [
{EXLA, {&EXLA.Defn.LockedCache.run/2, comp_fun}},
module_compilation: :to_mlir,
compiler: EXLA
])

if nested_compilation? do
EXLA.Defn.__compile__(function, args, function, opts)
else
Nx.Defn.compile(function, args, opts)
end
catch
{:mlir_module, ref} ->
EXLA.MLIR.Module.as_string(%EXLA.MLIR.Module{ref: ref})
{:mlir_module, ref, used_inputs, output_container} ->
%{
used_inputs: used_inputs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely a MapSet. :)

output_container: output_container,
mlir_module: EXLA.MLIR.Module.as_string(%EXLA.MLIR.Module{ref: ref})
}
end

@doc """
Expand Down
7 changes: 5 additions & 2 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,8 @@ defmodule EXLA.Defn do
outfeed = Outfeed.new(hooks, defined_hooks)
comp_key = {ref, client.name, outfeed.used_hooks, lazy_transfers, options}

{comp_time, {evaled, {xla_time, executable, inputs_and_typespecs, outfeed}}} =
{comp_time,
{evaled, {xla_time, executable, {inputs_and_typespecs, _used_inputs, _outputs}, outfeed}}} =
:timer.tc(fn ->
comp_cache_fun.(comp_key, fn ->
{reverse_inputs_and_typespecs, reverse_infeeds} =
Expand Down Expand Up @@ -466,7 +467,9 @@ defmodule EXLA.Defn do
)
end)

{:ok, {xla_time, executable, inputs_and_typespecs, %{outfeed | infeeds: []}}}
{:ok,
{xla_time, executable, {inputs_and_typespecs, used_inputs, outputs},
%{outfeed | infeeds: []}}}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding outputs and used_inputs here means those values will be cached twice. I think for to_mlir_module, we don't want to use the cache funs anymore. I would instead check inside __compile__ if module_compilation is set to :mlir and have it throw there the exact data that you need.

Alternatively you can create a new EXLA.Defn.MLIRCompiler, that calls EXLA.Defn. We would need to make some APIs public.

end)
end)
end)
Expand Down
52 changes: 52 additions & 0 deletions exla/test/exla_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,56 @@ defmodule EXLATest do
end
end
end

defmodule ValidCompiler do
def __jit__(key, vars, fun, args_list, opts) do
__compile__(key, vars, fun, opts).(args_list)
end

def __compile__(_key, vars, fun, opts) do
result = EXLA.to_mlir_module(fun, vars, Keyword.put(opts, :nested_defn_compilation, true))
throw({__MODULE__, result})
end
end

defmodule InvalidCompiler do
def __jit__(key, vars, fun, args_list, opts) do
__compile__(key, vars, fun, opts).(args_list)
end

def __compile__(_key, vars, fun, opts) do
EXLA.to_mlir_module(fun, vars, Keyword.put(opts, :nested_defn_compilation, false))
end
end

describe "to_mlir_module/3" do
test "fails if the compiler doesn't set the nested compilation flag" do
assert_raise BadArityError, fn ->
Nx.Defn.jit_apply(&Nx.add/2, [1, 2], compiler: __MODULE__.InvalidCompiler)
end
end

test "works if the compiler sets the nested compilation flag" do
try do
Nx.Defn.jit_apply(&Nx.add/2, [1, 2], compiler: __MODULE__.ValidCompiler)
catch
{__MODULE__.ValidCompiler, result} ->
assert %{mlir_module: module, output_container: container, used_inputs: used_inputs} =
result

assert module == """
module {
func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<i32>
return %0 : tensor<i32>
}
}
"""

assert Nx.compatible?(container, Nx.template({}, :s32))

assert used_inputs == %{0 => nil, 1 => nil}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should returns something more meaningful, like a list or mapset?

Copy link
Contributor Author

@polvalente polvalente Sep 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be MapSet of Map.keys(used_inputs). WDYT?

end
end
end
end
Loading