-
Notifications
You must be signed in to change notification settings - Fork 200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: add support for defn compilation in EXLA.to_mlir_module #1530
Conversation
exla/test/exla_test.exs
Outdated
|
||
assert Nx.compatible?(container, Nx.template({}, :s32)) | ||
|
||
assert used_inputs == %{0 => nil, 1 => nil} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should returns something more meaningful, like a list or mapset?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be MapSet of Map.keys(used_inputs). WDYT?
exla/lib/exla.ex
Outdated
@@ -360,11 +360,18 @@ defmodule EXLA do | |||
Takes in a function, the argument templates and the compilation | |||
options and returns the textual representation of the MLIR module. | |||
|
|||
## Options | |||
|
|||
* `:nested_defn_compilation` - a boolean that indicates whether |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we check Nx.Defn.Compiler.defn?()
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried, but that only gets set inside the runtime fun call, and that's after what I need
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about Nx.Defn.Compiler.current()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried both of those I believe
EXLA.MLIR.Module.as_string(%EXLA.MLIR.Module{ref: ref}) | ||
{:mlir_module, ref, used_inputs, output_container} -> | ||
%{ | ||
used_inputs: used_inputs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely a MapSet. :)
exla/lib/exla/defn.ex
Outdated
{:ok, {xla_time, executable, inputs_and_typespecs, %{outfeed | infeeds: []}}} | ||
{:ok, | ||
{xla_time, executable, {inputs_and_typespecs, used_inputs, outputs}, | ||
%{outfeed | infeeds: []}}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding outputs
and used_inputs
here means those values will be cached twice. I think for to_mlir_module
, we don't want to use the cache funs anymore. I would instead check inside __compile__
if module_compilation
is set to :mlir
and have it throw there the exact data that you need.
Alternatively you can create a new EXLA.Defn.MLIRCompiler, that calls EXLA.Defn
. We would need to make some APIs public.
EXLA.to_mlir_module
wasn't working when being called from within a Defn compiler.The new option adds the control needed for this to work.