Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add support for defn compilation in EXLA.to_mlir_module #1530

Merged
merged 5 commits into from
Sep 5, 2024

Conversation

polvalente
Copy link
Contributor

EXLA.to_mlir_module wasn't working when being called from within a Defn compiler.
The new option adds the control needed for this to work.

@polvalente polvalente self-assigned this Sep 5, 2024

assert Nx.compatible?(container, Nx.template({}, :s32))

assert used_inputs == %{0 => nil, 1 => nil}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should returns something more meaningful, like a list or mapset?

Copy link
Contributor Author

@polvalente polvalente Sep 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be MapSet of Map.keys(used_inputs). WDYT?

exla/lib/exla.ex Outdated
@@ -360,11 +360,18 @@ defmodule EXLA do
Takes in a function, the argument templates and the compilation
options and returns the textual representation of the MLIR module.

## Options

* `:nested_defn_compilation` - a boolean that indicates whether
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we check Nx.Defn.Compiler.defn?() instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried, but that only gets set inside the runtime fun call, and that's after what I need

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about Nx.Defn.Compiler.current()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried both of those I believe

EXLA.MLIR.Module.as_string(%EXLA.MLIR.Module{ref: ref})
{:mlir_module, ref, used_inputs, output_container} ->
%{
used_inputs: used_inputs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely a MapSet. :)

{:ok, {xla_time, executable, inputs_and_typespecs, %{outfeed | infeeds: []}}}
{:ok,
{xla_time, executable, {inputs_and_typespecs, used_inputs, outputs},
%{outfeed | infeeds: []}}}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding outputs and used_inputs here means those values will be cached twice. I think for to_mlir_module, we don't want to use the cache funs anymore. I would instead check inside __compile__ if module_compilation is set to :mlir and have it throw there the exact data that you need.

Alternatively you can create a new EXLA.Defn.MLIRCompiler, that calls EXLA.Defn. We would need to make some APIs public.

@polvalente polvalente merged commit ad28ea7 into main Sep 5, 2024
8 checks passed
@polvalente polvalente deleted the pv-fix/to-mlir-moudle branch September 5, 2024 18:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants