Skip to content

Commit

Permalink
Simplify return of compile
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Sep 1, 2024
1 parent 44d6410 commit 9a9a568
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 62 deletions.
68 changes: 25 additions & 43 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ defmodule EXLA.Defn do

@doc false
def __stream__(key, input, acc, vars, fun, [args], options) do
{debug?, options} = Keyword.pop(options, :debug, false)
{run_options, compile_options} = Keyword.pop(options, :run_options, [])

{client_name, compile_options} =
Expand All @@ -51,24 +52,26 @@ defmodule EXLA.Defn do
comp_fun =
&to_stream_computation(client, input_length, acc_length, &1, &2, &3, &4, compile_options)

{executable, used_inputs, {output, acc_output}, outfeed, extra, debug?} =
{executable, used_inputs, {output, acc_output}, outfeed, input_typespecs} =
compile(
client,
{:stream, key},
key,
vars,
fun,
compile_options,
used_buffers,
used_inputs,
_stream = true,
debug?,
comp_fun
)

{input_typespecs, input_indexes} = extra
# Now discard the infeed from used inputs, similar to how it is done to buffers.
# Note we discard all lazy transfers too, as they are not possible with streams.
used_inputs = for {i, nil} <- used_inputs, i >= used_buffers, do: {i, nil}, into: %{}

# Also discard the stream inputs from used inputs, similar to how it is done to buffers
# Note we discard all lazy transfers too, as they are not possible with streams
used_inputs = Enum.sort(for {i, nil} <- used_inputs, i >= used_buffers, do: i)
# And capture the typespecs for the infeed.
input_typespecs = Enum.take_while(input_typespecs, fn {i, _} -> i < input_length end)

# Execution of streams requires the coordination of
# multiple processes which is outlined below.
Expand Down Expand Up @@ -120,7 +123,6 @@ defmodule EXLA.Defn do
outfeed_pid,
input,
input_typespecs,
input_indexes,
output,
output_typespecs,
acc_output
Expand Down Expand Up @@ -151,9 +153,6 @@ defmodule EXLA.Defn do
{input_typespecs, used_typespecs} =
Enum.split_while(used_typespecs, fn {i, _} -> i < input_length end)

# Get all input indexes and shape
input_indexes = Enum.map(input_typespecs, &elem(&1, 0))

# Drop all accumulator entries from used_typespecs as we will handle it separately.
{acc_typespecs, used_typespecs} = Enum.split(used_typespecs, acc_length)

Expand All @@ -166,13 +165,10 @@ defmodule EXLA.Defn do
# The input will be read as part of the infeed.
acc_typespecs_l = Enum.map(acc_typespecs, &elem(&1, 1))
acc_typespec = List.to_tuple(acc_typespecs_l)

flag_typespec = Typespec.tensor({:pred, 8}, {})

args = EXLA.MLIR.Function.get_arguments(builder)

{token, [flag]} = Value.infeed(root_token, [flag_typespec])

init = [flag, token | args]

arg_typespecs = Enum.map(init, &Value.get_typespec/1)
Expand All @@ -186,11 +182,9 @@ defmodule EXLA.Defn do
{body_computation, [_flag, token | args]} = Function.push_region(builder, arg_typespecs)

{acc, constant} = Enum.split(args, acc_length)

{indices, input_typespecs} = Enum.unzip(input_typespecs)
{input_indices, input_typespecs} = Enum.unzip(input_typespecs)
{token, input} = Value.infeed(token, input_typespecs)

input_params = Enum.zip(indices, input)
input_params = Enum.zip(input_indices, input)

{%Outfeed{token: token} = outfeed, acc} =
case expr do
Expand Down Expand Up @@ -226,9 +220,7 @@ defmodule EXLA.Defn do

# Emit the stream hook to signal loop output
{token, [flag]} = Value.infeed(token, [flag_typespec])

Value.return(flag.function, [flag, token | acc] ++ List.flatten(constant))

Function.pop_region(builder)

[_flag, out_token | results] = Value.while(builder, pred_computation, body_computation, init)
Expand All @@ -238,8 +230,7 @@ defmodule EXLA.Defn do

outfeed = outfeed |> Outfeed.with_token(out_token) |> Outfeed.close(builder)
Value.func_return(builder, output)

{{input_typespecs, input_indexes}, outfeed}
outfeed
end

@doc false
Expand All @@ -249,6 +240,7 @@ defmodule EXLA.Defn do

@doc false
def __compile__(key, vars, fun, options) do
{debug?, options} = Keyword.pop(options, :debug, false)
{run_options, compile_options} = Keyword.pop(options, :run_options, [])

{client_name, compile_options} =
Expand All @@ -258,8 +250,8 @@ defmodule EXLA.Defn do

callback = &to_root_computation(&1, &2, &3, &4, Keyword.put(compile_options, :client, client))

{executable, used_inputs, outputs, outfeed, :ok, debug?} =
compile(client, key, vars, fun, compile_options, 0, [], _stream = false, callback)
{executable, used_inputs, outputs, outfeed, _input_typespecs?} =
compile(client, key, vars, fun, compile_options, 0, [], _stream = false, debug?, callback)

fn [args] ->
{time, lock} =
Expand Down Expand Up @@ -306,10 +298,8 @@ defmodule EXLA.Defn do

{res, cache} = recur_flatten(expr, state, new_cache(outfeed))
outfeed = cache |> get_outfeed() |> Outfeed.close(function)

Value.func_return(function, res)

{:ok, outfeed}
outfeed
end

defp maybe_outfeed(lock, executable, args, used_inputs, outputs, outfeed, run_options)
Expand Down Expand Up @@ -367,6 +357,7 @@ defmodule EXLA.Defn do
used_buffers,
used_inputs,
stream?,
debug?,
to_computation
) do
{{expr_cache_fun, comp_cache_fun}, options} =
Expand All @@ -379,8 +370,6 @@ defmodule EXLA.Defn do
{{cache_fun, cache_fun}, options}
end

{debug?, options} = Keyword.pop(options, :debug, false)

{args_key, reverse_args_identifiers} =
Enum.map_reduce(vars, [], fn var, acc ->
Nx.Defn.Composite.traverse(var, acc, fn
Expand All @@ -396,7 +385,7 @@ defmodule EXLA.Defn do

{eval_time, {expr, {ref, outputs, {used_inputs, defined_hooks}}}} =
:timer.tc(fn ->
expr_cache_fun.({key, args_key, lazy_transfers}, fn ->
expr_cache_fun.({key, stream?, args_key, lazy_transfers}, fn ->
expr = fun.(vars)
inputs_and_hooks = Outfeed.used_inputs_and_hooks(expr, used_inputs, lazy_transfers)
{expr, {make_ref(), Nx.to_template(expr), inputs_and_hooks}}
Expand All @@ -412,12 +401,10 @@ defmodule EXLA.Defn do
end

{hooks, options} = Keyword.pop(options, :hooks, %{})

outfeed = Outfeed.new(hooks, defined_hooks)

comp_key = {ref, client.name, outfeed.used_hooks, lazy_transfers, options}

{comp_time, {evaled, {xla_time, executable, extra, outfeed}}} =
{comp_time, {evaled, {xla_time, executable, inputs_and_typespecs, outfeed}}} =
:timer.tc(fn ->
comp_cache_fun.(comp_key, fn ->
{reverse_inputs_and_typespecs, reverse_infeeds} =
Expand All @@ -430,7 +417,7 @@ defmodule EXLA.Defn do

inputs_and_typespecs = Enum.reverse(reverse_inputs_and_typespecs)

comp_arg_typespecs =
comp_typespecs =
for {i, typespec} <- inputs_and_typespecs, i >= used_buffers, do: typespec

outputs =
Expand All @@ -451,7 +438,7 @@ defmodule EXLA.Defn do
|> then(&Typespec.tensor(&1.type, &1.shape))
end)

EXLA.MLIR.Module.new(comp_arg_typespecs, out_typespecs, fn builder ->
EXLA.MLIR.Module.new(comp_typespecs, out_typespecs, fn builder ->
# Only create the token when we know it will actually be
# used, that is: streaming, lazy transfers or hooks
outfeed =
Expand All @@ -464,25 +451,20 @@ defmodule EXLA.Defn do
end

expr = Nx.Defn.Composite.traverse(expr || fun.(vars), &Nx.devectorize/1)

{extra, outfeed} =
to_computation.(builder, expr, inputs_and_typespecs, outfeed)
outfeed = to_computation.(builder, expr, inputs_and_typespecs, outfeed)

{xla_time, executable} =
:timer.tc(fn ->
typespecs =
for {i, typespec} <- inputs_and_typespecs, i >= used_buffers, do: typespec

EXLA.MLIR.Module.compile(
builder.module,
client,
typespecs,
comp_typespecs,
builder.return_typespecs,
options
)
end)

{:ok, {xla_time, executable, extra, %{outfeed | infeeds: []}}}
{:ok, {xla_time, executable, inputs_and_typespecs, %{outfeed | infeeds: []}}}
end)
end)
end)
Expand Down Expand Up @@ -511,7 +493,7 @@ defmodule EXLA.Defn do
end

outfeed = Outfeed.with_user_hooks(outfeed, hooks)
{executable, used_inputs, outputs, outfeed, extra, debug?}
{executable, used_inputs, outputs, outfeed, inputs_and_typespecs}
end

defp us_to_ms(time), do: Float.round(time / 1000, 1)
Expand Down
13 changes: 8 additions & 5 deletions exla/lib/exla/defn/buffers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,19 @@ defmodule EXLA.Defn.Buffers do

@doc """
Splits the given args by value and returns them as is.
Entries with a map entry are discarded.
"""
def split_by_value(args, %{} = map, callback) do
{_i, left, right} =
Enum.reduce(args, {0, [], []}, fn arg, {i, left, right} ->
case map do
%{^i => nil} -> {i + 1, [callback.(arg, i, nil) | left], right}
%{^i => value} -> {i + 1, left, [callback.(arg, i, value) | right]}
%{} -> {i + 1, left, right}
%{^i => nil} ->
{i + 1, [callback.(arg, i, nil) | left], right}

%{^i => value} ->
{i + 1, left, [callback.(arg, i, value) | right]}

%{} ->
{i + 1, left, right}
end
end)

Expand Down
14 changes: 7 additions & 7 deletions exla/lib/exla/defn/outfeed.ex
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,15 @@ defmodule EXLA.Defn.Outfeed do
{infeeds, {compiled_hooks, token}} =
entries
|> List.keysort(1, :desc)
|> Enum.map_reduce({compiled_hooks, token}, fn {pos, _, typespec},
{compiled_hooks, token} ->
next_flag = next_hook(compiled_hooks)
compiled_hooks = Map.put(compiled_hooks, next_flag, {:infeed, pos, typespec})
|> Enum.map_reduce({compiled_hooks, token}, fn
{pos, _, typespec}, {compiled_hooks, token} ->
next_flag = next_hook(compiled_hooks)
compiled_hooks = Map.put(compiled_hooks, next_flag, {:infeed, pos, typespec})

token = Value.outfeed(Value.constant(builder, [next_flag], flag_typespec()), token)
{token, [input]} = Value.infeed(token, [typespec])
token = Value.outfeed(Value.constant(builder, [next_flag], flag_typespec()), token)
{token, [input]} = Value.infeed(token, [typespec])

{{pos, input}, {compiled_hooks, token}}
{{pos, input}, {compiled_hooks, token}}
end)

%{outfeed | compiled_hooks: compiled_hooks, token: token, infeeds: infeeds}
Expand Down
15 changes: 8 additions & 7 deletions exla/lib/exla/defn/stream.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ defmodule EXLA.Defn.Stream do
@moduledoc false

keys =
[:lock, :outfeed, :pid, :runner, :send, :send_typespecs, :send_indexes] ++
[:lock, :outfeed, :pid, :runner, :send, :send_typespecs] ++
[:recv, :recv_length, :done, :client, :device_id]

@derive {Inspect, only: [:pid, :client, :device_id, :send, :recv]}
Expand All @@ -16,7 +16,6 @@ defmodule EXLA.Defn.Stream do
outfeed,
send,
send_typespecs,
send_indexes,
recv,
recv_typespecs,
done
Expand All @@ -40,7 +39,6 @@ defmodule EXLA.Defn.Stream do
lock: lock,
send: send,
send_typespecs: send_typespecs,
send_indexes: send_indexes,
recv: recv,
recv_length: length(recv_typespecs),
client: client,
Expand All @@ -64,15 +62,14 @@ defmodule EXLA.Defn.Stream do
client: client,
device_id: device_id,
send: send,
send_typespecs: send_typespecs,
send_indexes: send_indexes
send_typespecs: send_typespecs
} = stream

if pid != self() do
raise "EXLA streams require recv to be called from the process that started the stream"
end

{template, buffers} = nx_to_io(data, send_indexes)
{template, buffers} = nx_to_io(data, Enum.map(send_typespecs, &elem(&1, 0)))

unless Nx.compatible?(send, template) do
raise ArgumentError, """
Expand All @@ -87,7 +84,11 @@ defmodule EXLA.Defn.Stream do
end

pred = EXLA.Typespec.tensor({:pred, 8}, {})
data_and_typespecs = Enum.zip(buffers, send_typespecs)

data_and_typespecs =
Enum.zip_with(buffers, send_typespecs, fn buffer, {_index, typespec} ->
{buffer, typespec}
end)

:ok = EXLA.Client.to_infeed(client, device_id, [{<<1::8-native>>, pred}])
:ok = EXLA.Client.to_infeed(client, device_id, data_and_typespecs)
Expand Down

0 comments on commit 9a9a568

Please sign in to comment.