Skip to content

Commit

Permalink
Remove Nx.Defn.stream and Nx.Stream
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Oct 14, 2024
1 parent 3a92566 commit c82702b
Show file tree
Hide file tree
Showing 14 changed files with 9 additions and 1,106 deletions.
87 changes: 0 additions & 87 deletions exla/lib/exla.ex
Original file line number Diff line number Diff line change
Expand Up @@ -297,65 +297,6 @@ defmodule EXLA do
Nx.Defn.compile(function, args, Keyword.put(options, :compiler, EXLA))
end

@doc """
Starts streaming the given anonymous function with just-in-time
compilation.
At least two arguments are expected:
1. The first argument is a tensor template of the data to
be streamed in
2. The second argument is a tensor with the stream initial state
The streaming function must return a two element tuple, the
first element is the data to be sent and the second is the
accumulator.
For each streamed chunk, you must call `Nx.Stream.send/2` and
`Nx.Stream.recv/1`. You don't need to call `recv` immediately
after `send`, but doing so can be a useful mechanism to provide
backpressure. Once all chunks are sent, you must use `Nx.Stream.done/1`
to receive the accumulated result. Let's see an example:
defmodule Streamed do
import Nx.Defn
defn sum(tensor, acc) do
{acc, tensor + acc}
end
end
Now let's invoke it:
stream = EXLA.stream(&Streamed.sum/2, [Nx.template({}, {:s, 32}), 0])
for i <- 1..5 do
Nx.Stream.send(stream, i)
IO.inspect {:chunk, Nx.Stream.recv(stream)}
end
IO.inspect {:result, Nx.Stream.done(stream)}
It will print:
{:chunk, 0}
{:chunk, 1}
{:chunk, 2}
{:chunk, 3}
{:chunk, 4}
{:result, 5}
**Note:** While any process can call `Nx.Stream.send/2`, EXLA
expects the process that starts the streaming to be the one
calling `Nx.Stream.recv/1` and `Nx.Stream.done/1`.
See `jit/2` for supported options.
"""
def stream(function, args, options \\ []) do
Nx.Defn.stream(function, args, Keyword.put(options, :compiler, EXLA))
end

@doc ~S'''
Takes in a function, the argument templates and the compilation
options and returns the textual representation of the MLIR module.
Expand Down Expand Up @@ -442,31 +383,6 @@ defmodule EXLA do
{:cached?, bool} -> bool
end

@doc """
Checks if the JIT compilation of stream with
args is cached.
Note that hooks are part of the cache, and
therefore they must be included in the options.
## Examples
iex> left = Nx.tensor(1, type: {:u, 8})
iex> right = Nx.tensor([1, 2, 3], type: {:u, 16})
iex> fun = fn x, acc -> {acc, Nx.add(x, acc)} end
iex> stream = EXLA.stream(fun, [left, right])
iex> Nx.Stream.done(stream)
iex> EXLA.stream_cached?(fun, [left, right])
true
iex> EXLA.stream_cached?(fun, [left, Nx.tensor([1, 2, 3, 4], type: {:u, 16})])
false
"""
def stream_cached?(function, args, options \\ []) do
stream(function, args, [{EXLA, cached_check()} | options])
catch
{:cached?, bool} -> bool
end

defp cached_check do
expr_cache_fun = fn key, _callback ->
if res = EXLA.Defn.LockedCache.get(key) do
Expand All @@ -489,9 +405,6 @@ defmodule EXLA do
@impl true
defdelegate __jit__(key, vars, fun, args, opts), to: EXLA.Defn

@impl true
defdelegate __stream__(key, input, acc, vars, fun, args, opts), to: EXLA.Defn

@impl true
defdelegate __partitions_options__(opts), to: EXLA.Defn

Expand Down
205 changes: 6 additions & 199 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -30,190 +30,6 @@ defmodule EXLA.Defn do
{EXLA.Backend, [client: client_name, device_id: device_id]}
end

@doc false
def __stream__(key, input, acc, vars, fun, [args], options) do
{run_options, compile_options} = Keyword.pop(options, :run_options, [])
debug? = Keyword.get(compile_options, :debug, false)
compile_options = Keyword.put(compile_options, :lazy_transfers, :never)

input_length = length(Nx.Defn.Composite.flatten_list([input]))
acc_length = length(Nx.Defn.Composite.flatten_list([acc]))

# The input vars should not be converted to buffers as they come from infeed
# Accs are always considered as used
used_buffers = input_length
used_inputs = Enum.to_list(input_length..(input_length + acc_length - 1)//1)

comp_fun =
&to_stream_computation(input_length, acc_length, &1, &2, &3, &4, &5, compile_options)

{executable, {used_inputs, {output, acc_output}, outfeed, input_typespecs}} =
compile(key, vars, fun, compile_options, used_buffers, used_inputs, true, comp_fun)

# Now discard the infeed from used inputs, similar to how it is done to buffers.
# Note we discard all lazy transfers too, as they are not possible with streams.
used_inputs = for {i, nil} <- used_inputs, i >= used_buffers, do: {i, nil}, into: %{}

# And capture the typespecs for the infeed.
input_typespecs = Enum.take_while(input_typespecs, fn {i, _} -> i < input_length end)

# Execution of streams requires the coordination of
# multiple processes which is outlined below.

# First, we get a lock on the executable, because we want
# to avoid transfer to the device unless we know we are
# ready to use the device.
{time, lock} =
:timer.tc(fn ->
EXLA.Defn.Lock.lock(run_key(executable))
end)

debug? && Logger.debug("EXLA device #{executable.device_id} lock in #{us_to_ms(time)}ms")

{time, streams} =
:timer.tc(fn ->
buffers =
EXLA.Defn.Buffers.filter_by_indexes(args, used_inputs, fn arg, _ ->
EXLA.Defn.Buffers.from_nx!(arg, executable)
end)

# Now that we have transferred to device, we spawn a runner process
# to execute the stream. We use a runner instead of a task to avoid
# leaking messages in the inbox. We also don't use a supervisor
# to keep them linked, which is safe because the agent is not used
# outside the scope of the current process.
#
# Finally, note the runner cannot start immediately, we need to
# setup the outfeed reader and register the on_unlock callback
# that cancels the stream atomically. This is done inside
# EXLA.Defn.Stream.run.
{:ok, runner} =
EXLA.Defn.Runner.start_link(lock, fn ->
EXLA.Executable.run(executable, [buffers], run_options)
end)

# The outfeed reader will redirect all outputs with flag 1 to the current
# process. Once flag 0 is emitted, we know the stream is done.
{output_typespecs, outfeed} = Outfeed.configure_stream_hook(outfeed, self(), lock)
{:ok, outfeed_pid} = Outfeed.start_child(executable, outfeed, Process.group_leader())

stream =
EXLA.Defn.Stream.run(
executable,
lock,
runner,
outfeed_pid,
input,
input_typespecs,
output,
output_typespecs,
acc_output
)

[stream]
end)

debug? &&
Logger.debug("EXLA stream start on device #{executable.device_id} in #{us_to_ms(time)}ms")

streams
end

defp to_stream_computation(
input_length,
acc_length,
%Function{} = builder,
expr,
used_typespecs,
outfeed,
client,
options
) do
%{token: root_token, infeeds: []} = outfeed

{input_typespecs, used_typespecs} =
Enum.split_while(used_typespecs, fn {i, _} -> i < input_length end)

# Drop all accumulator entries from used_typespecs as we will handle it separately.
{acc_typespecs, used_typespecs} = Enum.split(used_typespecs, acc_length)

# The stream loop will be a three element tuple:
#
# The result of calling infeed.
# The looping accumulator.
# The looping constants.
#
# The input will be read as part of the infeed.
acc_typespecs_l = Enum.map(acc_typespecs, &elem(&1, 1))
acc_typespec = List.to_tuple(acc_typespecs_l)
flag_typespec = Typespec.tensor({:pred, 8}, {})

args = EXLA.MLIR.Function.get_arguments(builder)
{token, [flag]} = Value.infeed(root_token, [flag_typespec])
init = [flag, token | args]

arg_typespecs = Enum.map(init, &Value.get_typespec/1)
{pred_computation, [flag | _]} = Function.push_region(builder, arg_typespecs)
typespec = Typespec.tensor({:pred, 8}, {})
r0 = Value.constant(builder, [1], typespec)
pred_op = Value.equal(flag, r0, typespec)
Value.return(builder, [pred_op])
Function.pop_region(builder)

{body_computation, [_flag, token | args]} = Function.push_region(builder, arg_typespecs)

{acc, constant} = Enum.split(args, acc_length)
{input_indices, input_typespecs} = Enum.unzip(input_typespecs)
{token, input} = Value.infeed(token, input_typespecs)
input_params = Enum.zip(input_indices, input)

{%Outfeed{token: token} = outfeed, acc} =
case expr do
{output_expr, acc_expr} ->
acc_params =
Enum.map(acc_typespecs, fn {pos, _typespec} ->
{pos, Enum.fetch!(acc, pos - input_length)}
end)

constant_params =
Enum.with_index(used_typespecs, fn {pos, _typespec}, index ->
{pos, Enum.fetch!(constant, index)}
end)

state = %{
client: client,
builder: builder,
precision: Keyword.get(options, :precision, :default),
params: Map.new(input_params ++ acc_params ++ constant_params),
scope_ids: Tree.scope_ids(expr)
}

outfeed = Outfeed.with_token(outfeed, token)
{output, cache} = recur_flatten(output_expr, state, new_cache(outfeed))
{acc, cache} = recur_flatten(acc_expr, state, cache)
outfeed = cache |> get_outfeed() |> Outfeed.add_stream_hook(builder, output)
{outfeed, acc}

_ ->
raise "expected the function given to Nx.stream/3 to return a two-element tuple, got: " <>
inspect(expr)
end

# Emit the stream hook to signal loop output
{token, [flag]} = Value.infeed(token, [flag_typespec])
Value.return(flag.function, [flag, token | acc] ++ List.flatten(constant))
Function.pop_region(builder)

[_flag, out_token | results] = Value.while(builder, pred_computation, body_computation, init)

acc = Enum.take(results, acc_length)
output = wrap_tuple_result(acc, acc_typespec)

outfeed = outfeed |> Outfeed.with_token(out_token) |> Outfeed.close(builder)
Value.func_return(builder, output)
outfeed
end

@doc false
def __jit__(key, vars, fun, args_list, options) do
__compile__(key, vars, fun, options).(args_list)
Expand All @@ -223,10 +39,10 @@ defmodule EXLA.Defn do
def __compile__(key, vars, fun, options) do
{run_options, compile_options} = Keyword.pop(options, :run_options, [])
debug? = Keyword.get(compile_options, :debug, false)
callback = &to_root_computation(&1, &2, &3, &4, &5, compile_options)
callback = &to_computation(&1, &2, &3, &4, &5, compile_options)

{executable, {used_inputs, outputs, outfeed, _input_typespecs?}} =
compile(key, vars, fun, compile_options, 0, [], _stream = false, callback)
compile(key, vars, fun, compile_options, 0, [], callback)

if compile_options[:module_compilation] == :to_mlir do
throw({:mlir_module, executable.ref, MapSet.new(Map.keys(used_inputs)), outputs})
Expand All @@ -252,7 +68,7 @@ defmodule EXLA.Defn do
end
end

defp to_root_computation(%Function{} = function, expr, used_typespecs, outfeed, client, options) do
defp to_computation(%Function{} = function, expr, used_typespecs, outfeed, client, options) do
params =
Enum.zip_with(used_typespecs, Function.get_arguments(function), fn {pos, _typespec}, arg ->
{pos, arg}
Expand Down Expand Up @@ -322,7 +138,7 @@ defmodule EXLA.Defn do

## Compile

defp compile(key, vars, fun, options, used_buffers, used_inputs, stream?, to_computation) do
defp compile(key, vars, fun, options, used_buffers, used_inputs, to_computation) do
{cache, options} = Keyword.pop(options, :cache, true)
{hooks, options} = Keyword.pop(options, :hooks, %{})
{debug?, options} = Keyword.pop(options, :debug, false)
Expand Down Expand Up @@ -361,7 +177,7 @@ defmodule EXLA.Defn do

{eval_time, {expr, {ref, outputs, {used_inputs, defined_hooks}}}} =
:timer.tc(fn ->
expr_cache_fun.({key, stream?, args_key, lazy_transfers}, fn ->
expr_cache_fun.({key, args_key, lazy_transfers}, fn ->
expr = fun.(vars)
inputs_and_hooks = Outfeed.used_inputs_and_hooks(expr, used_inputs, lazy_transfers)
{expr, {make_ref(), Nx.to_template(expr), inputs_and_hooks}}
Expand Down Expand Up @@ -395,15 +211,6 @@ defmodule EXLA.Defn do
comp_typespecs =
for {i, typespec} <- inputs_and_typespecs, i >= used_buffers, do: typespec

outputs =
if stream? do
# The computation returns the final accumulator value
{_chunk_result, acc} = outputs
acc
else
outputs
end

out_typespecs =
[outputs]
|> Nx.Defn.Composite.flatten_list()
Expand All @@ -417,7 +224,7 @@ defmodule EXLA.Defn do
# Only create the token when we know it will actually be
# used, that is: streaming, lazy transfers or hooks
outfeed =
if stream? or reverse_infeeds != [] or hooks != %{} or defined_hooks != %{} do
if reverse_infeeds != [] or hooks != %{} or defined_hooks != %{} do
outfeed
|> Outfeed.with_token(Value.create_token(builder))
|> Outfeed.add_infeeds(builder, reverse_infeeds)
Expand Down
21 changes: 0 additions & 21 deletions exla/lib/exla/defn/outfeed.ex
Original file line number Diff line number Diff line change
Expand Up @@ -153,23 +153,6 @@ defmodule EXLA.Defn.Outfeed do
end
end

@doc """
Adds a stream hook.
Used by streams. Only one is allowed. Requires configuration.
"""
def add_stream_hook(%Outfeed{} = outfeed, builder, tuple) do
{outfeed, flag, typespecs} = outfeed_flat_tuple(outfeed, builder, tuple)
# We don't know the pid+ref pair for the stream, so we store it
# under a special key called :stream and revert to the flag once configured
put_in(outfeed.compiled_hooks[:stream], {flag, typespecs})
end

def configure_stream_hook(%Outfeed{} = outfeed, pid, ref) when is_pid(pid) do
{{flag, typespecs}, outfeed} = pop_in(outfeed.compiled_hooks[:stream])
{typespecs, put_in(outfeed.compiled_hooks[flag], {:stream, typespecs, pid, ref})}
end

@doc """
Closes the outfeed at the end of a pipeline.
Expand Down Expand Up @@ -254,10 +237,6 @@ defmodule EXLA.Defn.Outfeed do
EXLA.Client.to_infeed(client, device_id, [{data, data_typespec}])
loop(client, device_id, ref, typespec, hooks, compiled_hooks, infeeds)

{:stream, typespecs, recv_pid, recv_ref} ->
:ok = EXLA.Client.from_outfeed(client, device_id, typespecs, recv_pid, recv_ref)
loop(client, device_id, ref, typespec, hooks, compiled_hooks, infeeds)

{:function, typespecs, name, template} ->
fun = Map.fetch!(hooks, name)
length = length(typespecs)
Expand Down
Loading

0 comments on commit c82702b

Please sign in to comment.