diff --git a/exla/lib/exla.ex b/exla/lib/exla.ex index f4714c1f8d..f22d39f76a 100644 --- a/exla/lib/exla.ex +++ b/exla/lib/exla.ex @@ -297,65 +297,6 @@ defmodule EXLA do Nx.Defn.compile(function, args, Keyword.put(options, :compiler, EXLA)) end - @doc """ - Starts streaming the given anonymous function with just-in-time - compilation. - - At least two arguments are expected: - - 1. The first argument is a tensor template of the data to - be streamed in - - 2. The second argument is a tensor with the stream initial state - - The streaming function must return a two element tuple, the - first element is the data to be sent and the second is the - accumulator. - - For each streamed chunk, you must call `Nx.Stream.send/2` and - `Nx.Stream.recv/1`. You don't need to call `recv` immediately - after `send`, but doing so can be a useful mechanism to provide - backpressure. Once all chunks are sent, you must use `Nx.Stream.done/1` - to receive the accumulated result. Let's see an example: - - defmodule Streamed do - import Nx.Defn - - defn sum(tensor, acc) do - {acc, tensor + acc} - end - end - - Now let's invoke it: - - stream = EXLA.stream(&Streamed.sum/2, [Nx.template({}, {:s, 32}), 0]) - - for i <- 1..5 do - Nx.Stream.send(stream, i) - IO.inspect {:chunk, Nx.Stream.recv(stream)} - end - - IO.inspect {:result, Nx.Stream.done(stream)} - - It will print: - - {:chunk, 0} - {:chunk, 1} - {:chunk, 2} - {:chunk, 3} - {:chunk, 4} - {:result, 5} - - **Note:** While any process can call `Nx.Stream.send/2`, EXLA - expects the process that starts the streaming to be the one - calling `Nx.Stream.recv/1` and `Nx.Stream.done/1`. - - See `jit/2` for supported options. - """ - def stream(function, args, options \\ []) do - Nx.Defn.stream(function, args, Keyword.put(options, :compiler, EXLA)) - end - @doc ~S''' Takes in a function, the argument templates and the compilation options and returns the textual representation of the MLIR module. @@ -442,31 +383,6 @@ defmodule EXLA do {:cached?, bool} -> bool end - @doc """ - Checks if the JIT compilation of stream with - args is cached. - - Note that hooks are part of the cache, and - therefore they must be included in the options. - - ## Examples - - iex> left = Nx.tensor(1, type: {:u, 8}) - iex> right = Nx.tensor([1, 2, 3], type: {:u, 16}) - iex> fun = fn x, acc -> {acc, Nx.add(x, acc)} end - iex> stream = EXLA.stream(fun, [left, right]) - iex> Nx.Stream.done(stream) - iex> EXLA.stream_cached?(fun, [left, right]) - true - iex> EXLA.stream_cached?(fun, [left, Nx.tensor([1, 2, 3, 4], type: {:u, 16})]) - false - """ - def stream_cached?(function, args, options \\ []) do - stream(function, args, [{EXLA, cached_check()} | options]) - catch - {:cached?, bool} -> bool - end - defp cached_check do expr_cache_fun = fn key, _callback -> if res = EXLA.Defn.LockedCache.get(key) do @@ -489,9 +405,6 @@ defmodule EXLA do @impl true defdelegate __jit__(key, vars, fun, args, opts), to: EXLA.Defn - @impl true - defdelegate __stream__(key, input, acc, vars, fun, args, opts), to: EXLA.Defn - @impl true defdelegate __partitions_options__(opts), to: EXLA.Defn diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 2f73b25562..e7f7bfae47 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -30,190 +30,6 @@ defmodule EXLA.Defn do {EXLA.Backend, [client: client_name, device_id: device_id]} end - @doc false - def __stream__(key, input, acc, vars, fun, [args], options) do - {run_options, compile_options} = Keyword.pop(options, :run_options, []) - debug? = Keyword.get(compile_options, :debug, false) - compile_options = Keyword.put(compile_options, :lazy_transfers, :never) - - input_length = length(Nx.Defn.Composite.flatten_list([input])) - acc_length = length(Nx.Defn.Composite.flatten_list([acc])) - - # The input vars should not be converted to buffers as they come from infeed - # Accs are always considered as used - used_buffers = input_length - used_inputs = Enum.to_list(input_length..(input_length + acc_length - 1)//1) - - comp_fun = - &to_stream_computation(input_length, acc_length, &1, &2, &3, &4, &5, compile_options) - - {executable, {used_inputs, {output, acc_output}, outfeed, input_typespecs}} = - compile(key, vars, fun, compile_options, used_buffers, used_inputs, true, comp_fun) - - # Now discard the infeed from used inputs, similar to how it is done to buffers. - # Note we discard all lazy transfers too, as they are not possible with streams. - used_inputs = for {i, nil} <- used_inputs, i >= used_buffers, do: {i, nil}, into: %{} - - # And capture the typespecs for the infeed. - input_typespecs = Enum.take_while(input_typespecs, fn {i, _} -> i < input_length end) - - # Execution of streams requires the coordination of - # multiple processes which is outlined below. - - # First, we get a lock on the executable, because we want - # to avoid transfer to the device unless we know we are - # ready to use the device. - {time, lock} = - :timer.tc(fn -> - EXLA.Defn.Lock.lock(run_key(executable)) - end) - - debug? && Logger.debug("EXLA device #{executable.device_id} lock in #{us_to_ms(time)}ms") - - {time, streams} = - :timer.tc(fn -> - buffers = - EXLA.Defn.Buffers.filter_by_indexes(args, used_inputs, fn arg, _ -> - EXLA.Defn.Buffers.from_nx!(arg, executable) - end) - - # Now that we have transferred to device, we spawn a runner process - # to execute the stream. We use a runner instead of a task to avoid - # leaking messages in the inbox. We also don't use a supervisor - # to keep them linked, which is safe because the agent is not used - # outside the scope of the current process. - # - # Finally, note the runner cannot start immediately, we need to - # setup the outfeed reader and register the on_unlock callback - # that cancels the stream atomically. This is done inside - # EXLA.Defn.Stream.run. - {:ok, runner} = - EXLA.Defn.Runner.start_link(lock, fn -> - EXLA.Executable.run(executable, [buffers], run_options) - end) - - # The outfeed reader will redirect all outputs with flag 1 to the current - # process. Once flag 0 is emitted, we know the stream is done. - {output_typespecs, outfeed} = Outfeed.configure_stream_hook(outfeed, self(), lock) - {:ok, outfeed_pid} = Outfeed.start_child(executable, outfeed, Process.group_leader()) - - stream = - EXLA.Defn.Stream.run( - executable, - lock, - runner, - outfeed_pid, - input, - input_typespecs, - output, - output_typespecs, - acc_output - ) - - [stream] - end) - - debug? && - Logger.debug("EXLA stream start on device #{executable.device_id} in #{us_to_ms(time)}ms") - - streams - end - - defp to_stream_computation( - input_length, - acc_length, - %Function{} = builder, - expr, - used_typespecs, - outfeed, - client, - options - ) do - %{token: root_token, infeeds: []} = outfeed - - {input_typespecs, used_typespecs} = - Enum.split_while(used_typespecs, fn {i, _} -> i < input_length end) - - # Drop all accumulator entries from used_typespecs as we will handle it separately. - {acc_typespecs, used_typespecs} = Enum.split(used_typespecs, acc_length) - - # The stream loop will be a three element tuple: - # - # The result of calling infeed. - # The looping accumulator. - # The looping constants. - # - # The input will be read as part of the infeed. - acc_typespecs_l = Enum.map(acc_typespecs, &elem(&1, 1)) - acc_typespec = List.to_tuple(acc_typespecs_l) - flag_typespec = Typespec.tensor({:pred, 8}, {}) - - args = EXLA.MLIR.Function.get_arguments(builder) - {token, [flag]} = Value.infeed(root_token, [flag_typespec]) - init = [flag, token | args] - - arg_typespecs = Enum.map(init, &Value.get_typespec/1) - {pred_computation, [flag | _]} = Function.push_region(builder, arg_typespecs) - typespec = Typespec.tensor({:pred, 8}, {}) - r0 = Value.constant(builder, [1], typespec) - pred_op = Value.equal(flag, r0, typespec) - Value.return(builder, [pred_op]) - Function.pop_region(builder) - - {body_computation, [_flag, token | args]} = Function.push_region(builder, arg_typespecs) - - {acc, constant} = Enum.split(args, acc_length) - {input_indices, input_typespecs} = Enum.unzip(input_typespecs) - {token, input} = Value.infeed(token, input_typespecs) - input_params = Enum.zip(input_indices, input) - - {%Outfeed{token: token} = outfeed, acc} = - case expr do - {output_expr, acc_expr} -> - acc_params = - Enum.map(acc_typespecs, fn {pos, _typespec} -> - {pos, Enum.fetch!(acc, pos - input_length)} - end) - - constant_params = - Enum.with_index(used_typespecs, fn {pos, _typespec}, index -> - {pos, Enum.fetch!(constant, index)} - end) - - state = %{ - client: client, - builder: builder, - precision: Keyword.get(options, :precision, :default), - params: Map.new(input_params ++ acc_params ++ constant_params), - scope_ids: Tree.scope_ids(expr) - } - - outfeed = Outfeed.with_token(outfeed, token) - {output, cache} = recur_flatten(output_expr, state, new_cache(outfeed)) - {acc, cache} = recur_flatten(acc_expr, state, cache) - outfeed = cache |> get_outfeed() |> Outfeed.add_stream_hook(builder, output) - {outfeed, acc} - - _ -> - raise "expected the function given to Nx.stream/3 to return a two-element tuple, got: " <> - inspect(expr) - end - - # Emit the stream hook to signal loop output - {token, [flag]} = Value.infeed(token, [flag_typespec]) - Value.return(flag.function, [flag, token | acc] ++ List.flatten(constant)) - Function.pop_region(builder) - - [_flag, out_token | results] = Value.while(builder, pred_computation, body_computation, init) - - acc = Enum.take(results, acc_length) - output = wrap_tuple_result(acc, acc_typespec) - - outfeed = outfeed |> Outfeed.with_token(out_token) |> Outfeed.close(builder) - Value.func_return(builder, output) - outfeed - end - @doc false def __jit__(key, vars, fun, args_list, options) do __compile__(key, vars, fun, options).(args_list) @@ -223,10 +39,10 @@ defmodule EXLA.Defn do def __compile__(key, vars, fun, options) do {run_options, compile_options} = Keyword.pop(options, :run_options, []) debug? = Keyword.get(compile_options, :debug, false) - callback = &to_root_computation(&1, &2, &3, &4, &5, compile_options) + callback = &to_computation(&1, &2, &3, &4, &5, compile_options) {executable, {used_inputs, outputs, outfeed, _input_typespecs?}} = - compile(key, vars, fun, compile_options, 0, [], _stream = false, callback) + compile(key, vars, fun, compile_options, 0, [], callback) if compile_options[:module_compilation] == :to_mlir do throw({:mlir_module, executable.ref, MapSet.new(Map.keys(used_inputs)), outputs}) @@ -252,7 +68,7 @@ defmodule EXLA.Defn do end end - defp to_root_computation(%Function{} = function, expr, used_typespecs, outfeed, client, options) do + defp to_computation(%Function{} = function, expr, used_typespecs, outfeed, client, options) do params = Enum.zip_with(used_typespecs, Function.get_arguments(function), fn {pos, _typespec}, arg -> {pos, arg} @@ -322,7 +138,7 @@ defmodule EXLA.Defn do ## Compile - defp compile(key, vars, fun, options, used_buffers, used_inputs, stream?, to_computation) do + defp compile(key, vars, fun, options, used_buffers, used_inputs, to_computation) do {cache, options} = Keyword.pop(options, :cache, true) {hooks, options} = Keyword.pop(options, :hooks, %{}) {debug?, options} = Keyword.pop(options, :debug, false) @@ -361,7 +177,7 @@ defmodule EXLA.Defn do {eval_time, {expr, {ref, outputs, {used_inputs, defined_hooks}}}} = :timer.tc(fn -> - expr_cache_fun.({key, stream?, args_key, lazy_transfers}, fn -> + expr_cache_fun.({key, args_key, lazy_transfers}, fn -> expr = fun.(vars) inputs_and_hooks = Outfeed.used_inputs_and_hooks(expr, used_inputs, lazy_transfers) {expr, {make_ref(), Nx.to_template(expr), inputs_and_hooks}} @@ -395,15 +211,6 @@ defmodule EXLA.Defn do comp_typespecs = for {i, typespec} <- inputs_and_typespecs, i >= used_buffers, do: typespec - outputs = - if stream? do - # The computation returns the final accumulator value - {_chunk_result, acc} = outputs - acc - else - outputs - end - out_typespecs = [outputs] |> Nx.Defn.Composite.flatten_list() @@ -417,7 +224,7 @@ defmodule EXLA.Defn do # Only create the token when we know it will actually be # used, that is: streaming, lazy transfers or hooks outfeed = - if stream? or reverse_infeeds != [] or hooks != %{} or defined_hooks != %{} do + if reverse_infeeds != [] or hooks != %{} or defined_hooks != %{} do outfeed |> Outfeed.with_token(Value.create_token(builder)) |> Outfeed.add_infeeds(builder, reverse_infeeds) diff --git a/exla/lib/exla/defn/outfeed.ex b/exla/lib/exla/defn/outfeed.ex index f528f19efa..780a5fef79 100644 --- a/exla/lib/exla/defn/outfeed.ex +++ b/exla/lib/exla/defn/outfeed.ex @@ -153,23 +153,6 @@ defmodule EXLA.Defn.Outfeed do end end - @doc """ - Adds a stream hook. - - Used by streams. Only one is allowed. Requires configuration. - """ - def add_stream_hook(%Outfeed{} = outfeed, builder, tuple) do - {outfeed, flag, typespecs} = outfeed_flat_tuple(outfeed, builder, tuple) - # We don't know the pid+ref pair for the stream, so we store it - # under a special key called :stream and revert to the flag once configured - put_in(outfeed.compiled_hooks[:stream], {flag, typespecs}) - end - - def configure_stream_hook(%Outfeed{} = outfeed, pid, ref) when is_pid(pid) do - {{flag, typespecs}, outfeed} = pop_in(outfeed.compiled_hooks[:stream]) - {typespecs, put_in(outfeed.compiled_hooks[flag], {:stream, typespecs, pid, ref})} - end - @doc """ Closes the outfeed at the end of a pipeline. @@ -254,10 +237,6 @@ defmodule EXLA.Defn.Outfeed do EXLA.Client.to_infeed(client, device_id, [{data, data_typespec}]) loop(client, device_id, ref, typespec, hooks, compiled_hooks, infeeds) - {:stream, typespecs, recv_pid, recv_ref} -> - :ok = EXLA.Client.from_outfeed(client, device_id, typespecs, recv_pid, recv_ref) - loop(client, device_id, ref, typespec, hooks, compiled_hooks, infeeds) - {:function, typespecs, name, template} -> fun = Map.fetch!(hooks, name) length = length(typespecs) diff --git a/exla/lib/exla/defn/stream.ex b/exla/lib/exla/defn/stream.ex deleted file mode 100644 index 67f68c3c9b..0000000000 --- a/exla/lib/exla/defn/stream.ex +++ /dev/null @@ -1,162 +0,0 @@ -defmodule EXLA.Defn.Stream do - @moduledoc false - - keys = - [:lock, :outfeed, :pid, :runner, :send, :send_typespecs] ++ - [:recv, :recv_length, :done, :client, :device_id] - - @derive {Inspect, only: [:pid, :client, :device_id, :send, :recv]} - @enforce_keys keys - defstruct keys - - def run( - executable, - lock, - runner, - outfeed, - send, - send_typespecs, - recv, - recv_typespecs, - done - ) do - %{client: client, device_id: device_id} = executable - - # With the task and outfeed in place, we now register the unlock callback: - # if the current process shuts down, we send an infeed to stop the loop, - # and then we block until the outfeed completes. - ^lock = - EXLA.Defn.Lock.on_unlock( - lock, - fn -> send(runner, lock) end, - fn -> halt_stream(client, device_id, outfeed) end - ) - - %EXLA.Defn.Stream{ - pid: self(), - runner: runner, - outfeed: outfeed, - lock: lock, - send: send, - send_typespecs: send_typespecs, - recv: recv, - recv_length: length(recv_typespecs), - client: client, - device_id: device_id, - done: done - } - end - - # It is time to halt the stream, we do it by sending 0 for the loop infeed. - # Then we wait for the outfeed process to read all. - defp halt_stream(client, device_id, outfeed) do - pred = EXLA.Typespec.tensor({:pred, 8}, {}) - :ok = EXLA.Client.to_infeed(client, device_id, [{<<0::8-native>>, pred}]) - {:transfer, outfeed} - end - - defimpl Nx.Stream do - def send(stream, data) do - %{ - pid: pid, - client: client, - device_id: device_id, - send: send, - send_typespecs: send_typespecs - } = stream - - if pid != self() do - raise "EXLA streams require recv to be called from the process that started the stream" - end - - {template, buffers} = nx_to_io(data, Enum.map(send_typespecs, &elem(&1, 0))) - - unless Nx.compatible?(send, template) do - raise ArgumentError, """ - Nx stream expected a tensor of type, shape, and names on send: - - #{inspect(send)} - - But got tensor: - - #{inspect(template)} - """ - end - - pred = EXLA.Typespec.tensor({:pred, 8}, {}) - - data_and_typespecs = - Enum.zip_with(buffers, send_typespecs, fn buffer, {_index, typespec} -> - {buffer, typespec} - end) - - :ok = EXLA.Client.to_infeed(client, device_id, [{<<1::8-native>>, pred}]) - :ok = EXLA.Client.to_infeed(client, device_id, data_and_typespecs) - end - - defp nx_to_io(container, indexes) do - {template, buffers} = - Nx.LazyContainer.traverse(container, [], fn template, fun, acc -> - {template, [fun | acc]} - end) - - {template, - buffers - |> Enum.reverse() - |> EXLA.Defn.Buffers.filter_by_indexes(indexes) - |> Enum.map(fn fun -> Nx.to_binary(fun.()) end)} - end - - def recv(%{pid: pid, outfeed: outfeed, lock: lock, recv: recv, recv_length: length}) do - if pid != self() do - raise "EXLA streams require recv to be called from the process that started the stream" - end - - unless Process.alive?(outfeed) do - raise "cannot recv from stream because it has been terminated" - end - - buffers = - for _ <- 1..length//1 do - receive do - {^lock, binary} -> binary - end - end - - EXLA.Defn.Buffers.to_nx!(buffers, recv) - end - - def done(%{ - lock: lock, - outfeed: outfeed, - pid: pid, - runner: runner, - done: done - }) do - if pid != self() do - raise "EXLA streams require recv to be called from the process that started the stream" - end - - # This will write to infeed to stop the loop. We know unlocking - # is race free because we can only write to infeed from this process - # (or it is automatically written if this process is dead). - # - # Once we unlock, the lock process will now wait until the outfeed - # terminates. - EXLA.Defn.Lock.unlock(lock) - - # We also wait until the outfeed completes to ensure - # all output has been consumed before we return. - outfeed_ref = Process.monitor(outfeed) - - receive do - {^lock, _} -> - raise "cannot mark stream as done when there are recv messages pending" - - {:DOWN, ^outfeed_ref, _, _, _} -> - [result] = EXLA.Defn.Runner.read(runner) - EXLA.Defn.Buffers.to_nx!(result, done) - end - end - end -end diff --git a/exla/mix.exs b/exla/mix.exs index 4036616379..184a48cb94 100644 --- a/exla/mix.exs +++ b/exla/mix.exs @@ -63,8 +63,8 @@ defmodule EXLA.MixProject do defp deps do [ - {:nx, "~> 0.9.0"}, - # {:nx, path: "../nx"}, + # {:nx, "~> 0.9.0"}, + {:nx, path: "../nx"}, {:telemetry, "~> 0.4.0 or ~> 1.0"}, {:xla, "~> 0.8.0", runtime: false}, {:elixir_make, "~> 0.6", runtime: false}, diff --git a/exla/test/exla/defn/api_test.exs b/exla/test/exla/defn/api_test.exs index 1596cce7ee..319f30b619 100644 --- a/exla/test/exla/defn/api_test.exs +++ b/exla/test/exla/defn/api_test.exs @@ -167,165 +167,6 @@ defmodule EXLA.Defn.APITest do end end - describe "stream" do - defn defn_sum(entry, acc), do: {acc, entry + acc} - - test "immediately done" do - stream = EXLA.stream(&defn_sum/2, [0, 0]) - assert %Nx.Tensor{data: %EXLA.Backend{}} = done = Nx.Stream.done(stream) - assert_equal(Nx.backend_transfer(done), Nx.tensor(0)) - - stream = EXLA.stream(&defn_sum/2, [1, 2]) - assert %Nx.Tensor{data: %EXLA.Backend{}} = done = Nx.Stream.done(stream) - assert_equal(Nx.backend_transfer(done), Nx.tensor(2)) - end - - test "send/recv" do - %_{} = stream = EXLA.stream(&defn_sum/2, [0, 0]) - assert Nx.Stream.send(stream, 1) == :ok - assert_equal(Nx.Stream.recv(stream), Nx.tensor(0)) - - assert Nx.Stream.send(stream, 2) == :ok - assert_equal(Nx.Stream.recv(stream), Nx.tensor(1)) - - assert_equal(Nx.Stream.done(stream), Nx.tensor(3)) - end - - test "send x2/recv x2" do - %_{} = stream = EXLA.stream(&defn_sum/2, [0, 0]) - assert Nx.Stream.send(stream, 1) == :ok - assert Nx.Stream.send(stream, 2) == :ok - - assert_equal(Nx.Stream.recv(stream), Nx.tensor(0)) - assert_equal(Nx.Stream.recv(stream), Nx.tensor(1)) - - assert_equal(Nx.Stream.done(stream), Nx.tensor(3)) - end - - defn stream_composite(i, {a, {b, c}}) do - a = a + i - b = b * i - c = Nx.pow(c, i) - {{{a, b}, c}, {a, {b, c}}} - end - - test "send/recv with composite types" do - %_{} = stream = EXLA.stream(&stream_composite/2, [0, {0, {1, 2}}]) - assert Nx.Stream.send(stream, 1) == :ok - assert_equal(Nx.Stream.recv(stream), {{Nx.tensor(1), Nx.tensor(1)}, Nx.tensor(2)}) - - assert Nx.Stream.send(stream, 2) == :ok - assert_equal(Nx.Stream.recv(stream), {{Nx.tensor(3), Nx.tensor(2)}, Nx.tensor(4)}) - - assert_equal(Nx.Stream.done(stream), {Nx.tensor(3), {Nx.tensor(2), Nx.tensor(4)}}) - end - - defn stream_empty_outfeed(i, t), do: {{}, i + t} - - test "send/recv with empty outfeed" do - %_{} = stream = EXLA.stream(&stream_empty_outfeed/2, [0, 0.0]) - assert Nx.Stream.send(stream, 1) == :ok - assert Nx.Stream.recv(stream) == {} - - assert Nx.Stream.send(stream, 2) == :ok - assert Nx.Stream.recv(stream) == {} - - assert_equal(Nx.Stream.done(stream), Nx.tensor(3.0)) - end - - defn stream_empty_acc(i, {}), do: {i * i, {}} - - test "send/recv with empty acc" do - %_{} = stream = EXLA.stream(&stream_empty_acc/2, [0, {}]) - assert Nx.Stream.send(stream, 1) == :ok - assert_equal(Nx.Stream.recv(stream), Nx.tensor(1)) - - assert Nx.Stream.send(stream, 2) == :ok - assert_equal(Nx.Stream.recv(stream), Nx.tensor(4)) - - assert Nx.Stream.done(stream) == {} - end - - test "handles failure before writing" do - {_, ref} = spawn_monitor(fn -> EXLA.stream(&defn_sum/2, [0, 0]) end) - assert_receive {:DOWN, ^ref, _, _, _} - - %_{} = stream = EXLA.stream(&defn_sum/2, [0, 0]) - assert Nx.Stream.send(stream, 1) == :ok - assert_equal(Nx.Stream.recv(stream), Nx.tensor(0)) - assert_equal(Nx.Stream.done(stream), Nx.tensor(1)) - end - - test "handles failure after writing" do - {_, ref} = - spawn_monitor(fn -> - stream = EXLA.stream(&defn_sum/2, [0, 0]) - assert Nx.Stream.send(stream, 1) == :ok - end) - - assert_receive {:DOWN, ^ref, _, _, _} - - %_{} = stream = EXLA.stream(&defn_sum/2, [0, 0]) - assert Nx.Stream.send(stream, 1) == :ok - assert_equal(Nx.Stream.recv(stream), Nx.tensor(0)) - assert_equal(Nx.Stream.done(stream), Nx.tensor(1)) - end - - test "raises if recv is pending on done" do - %_{} = stream = EXLA.stream(&defn_sum/2, [0, 0]) - assert Nx.Stream.send(stream, 1) == :ok - - assert_raise RuntimeError, - "cannot mark stream as done when there are recv messages pending", - fn -> Nx.Stream.done(stream) end - end - - test "raises if stream is done when recving" do - %_{} = stream = EXLA.stream(&defn_sum/2, [0, 0]) - assert_equal(Nx.Stream.done(stream), Nx.tensor(0)) - - assert_raise RuntimeError, - "cannot recv from stream because it has been terminated", - fn -> Nx.Stream.recv(stream) end - end - - defn container_stream(%Container{a: a} = elem, %Container{b: b} = acc) do - {%{elem | a: a + b}, %{acc | b: a + b}} - end - - test "container in and out" do - args = [%Container{a: 0, b: 0, c: :reset, d: :elem}, %Container{a: 0, b: 0, d: :acc}] - %_{} = stream = EXLA.stream(&container_stream/2, args) - - assert Nx.Stream.send(stream, %Container{a: 1, b: -1}) == :ok - - assert_equal(Nx.Stream.recv(stream), %Container{a: Nx.tensor(1), b: Nx.tensor(-1), d: :elem}) - - assert Nx.Stream.send(stream, %Container{a: 2, b: -2}) == :ok - - assert_equal(Nx.Stream.recv(stream), %Container{a: Nx.tensor(3), b: Nx.tensor(-2), d: :elem}) - - assert_equal(Nx.Stream.done(stream), %Container{a: Nx.tensor(0), b: Nx.tensor(3), d: :acc}) - end - - defn lazy_container_stream(%LazyWrapped{a: a, c: c}, acc) do - {acc, acc + a - c} - end - - test "lazy container in" do - args = [%LazyOnly{a: 0, b: 0, c: 0}, 0] - %_{} = stream = EXLA.stream(&lazy_container_stream/2, args) - - assert Nx.Stream.send(stream, %LazyOnly{a: 3, b: 0, c: -1}) == :ok - assert_equal(Nx.Stream.recv(stream), Nx.tensor(0)) - - assert Nx.Stream.send(stream, %LazyOnly{a: 5, b: 0, c: 2}) == :ok - assert_equal(Nx.Stream.recv(stream), Nx.tensor(4)) - - assert_equal(Nx.Stream.done(stream), Nx.tensor(7)) - end - end - describe "hooks" do require Logger @@ -450,28 +291,6 @@ defmodule EXLA.Defn.APITest do assert_equal(a, Nx.tensor(1)) assert_equal(b, Nx.tensor(2)) end - - defn hook_stream(entry, acc), do: hook({acc, entry + acc}, :stream) - - test "executes hook with stream" do - %_{} = stream = EXLA.stream(&hook_stream/2, [0, 0], hooks: %{stream: send_to_self(:tag)}) - assert Nx.Stream.send(stream, 1) == :ok - assert_equal(Nx.Stream.recv(stream), Nx.tensor(0)) - assert_receive {:tag, {previous_acc, new_acc}} - assert_equal(previous_acc, Nx.tensor(0)) - assert_equal(new_acc, Nx.tensor(1)) - refute_received _ - - assert Nx.Stream.send(stream, 2) == :ok - assert_equal(Nx.Stream.recv(stream), Nx.tensor(1)) - assert_receive {:tag, {previous_acc, new_acc}} - assert_equal(previous_acc, Nx.tensor(1)) - assert_equal(new_acc, Nx.tensor(3)) - refute_received _ - - assert_equal(Nx.Stream.done(stream), Nx.tensor(3)) - refute_received _ - end end describe "telemetry" do diff --git a/nx/README.md b/nx/README.md index f4afac2c3d..8b422a9720 100644 --- a/nx/README.md +++ b/nx/README.md @@ -16,7 +16,7 @@ Nx is a multi-dimensional tensors library for Elixir with multi-staged compilati * Built-in distributed² serving: encapsulate complex numerical pipelines into `Nx.Serving`. Servings provide batching, streaming, and partitioning out of the box. You can distribute servings over multiple CPU cores and GPU devices, as well as over a cluster of machines; - * Support for data streaming and hooks, allowing developers to send and receive data from CPUs/GPUs/TPUs while computations are running; + * Support for hooks, allowing developers to send and receive data from CPUs/GPUs/TPUs while computations are running; * Support for linear algebra primitives via `Nx.LinAlg`; diff --git a/nx/lib/nx/defn.ex b/nx/lib/nx/defn.ex index 71c80a6731..e0d1cc1a30 100644 --- a/nx/lib/nx/defn.ex +++ b/nx/lib/nx/defn.ex @@ -485,100 +485,6 @@ defmodule Nx.Defn do res end - @doc """ - Starts streaming the given anonymous function with just-in-time - compilation. - - At least two arguments are expected: - - 1. The first argument is a tensor template of the data to - be streamed in - - 2. The second argument is a tensor with the stream initial state - - The streaming function must return a two element tuple, the - first element is the data to be sent and the second is the - accumulator. - - For each streamed chunk, you must call `Nx.Stream.send/2` and - `Nx.Stream.recv/1`. You don't need to call `recv` immediately - after `send`, but doing so can be a useful mechanism to provide - backpressure. Once all chunks are sent, you must use `Nx.Stream.done/1` - to receive the accumulated result. Let's see an example: - - defmodule Streamed do - import Nx.Defn - - defn sum(tensor, acc) do - {acc, tensor + acc} - end - end - - Now let's invoke it: - - stream = Nx.Defn.stream(&Streamed.sum/2, [Nx.template({}, {:s, 32}), 0]) - - for i <- 1..5 do - Nx.Stream.send(stream, i) - IO.inspect {:chunk, Nx.Stream.recv(stream)} - end - - IO.inspect {:result, Nx.Stream.done(stream)} - - It will print: - - {:chunk, 0} - {:chunk, 1} - {:chunk, 2} - {:chunk, 3} - {:chunk, 4} - {:result, 5} - - ## Options - - * `:hooks` - a map of hooks to execute. See `Nx.Defn.Kernel.hook/3` - - ## Beware: deadlocks - - Some backends (such as XLA) place locks around devices. For example, - if you start streaming on the GPU, you cannot perform any other - operation on the GPU until streaming is over. - - This means if we modify the loop above to the following: - - for i <- 1..5 do - Nx.Stream.send(stream, Nx.tensor(i) |> Nx.multiply(2)) - IO.inspect {:chunk, Nx.Stream.recv(stream)} - end - - The loop may deadlock at the time it performs the multiplication. - In practice, this means you should perform the streaming on the GPU - and the remaining operations on the CPU. If you only have a single - device (i.e. only a CPU), then it may not be possible to perform the - above and you will have to restructure your code to manipulate the - input before streaming starts. - """ - @deprecated "Move the streaming loop to Elixir instead" - def stream(fun, args, opts \\ []) - when is_function(fun) and is_list(args) and is_list(opts) do - if Nx.Defn.Compiler.current() do - raise "cannot call Nx.Defn.stream/3 when there is a JIT compilation happening" - end - - opts = prepare_options(opts) - {fun, params, _templates, flatten} = Nx.Defn.Compiler.to_lazy_params(fun, args) - - case args do - [_input, acc | _] -> - acc = Nx.Defn.Composite.traverse(acc, &Nx.to_tensor/1) - [stream] = Nx.Defn.Compiler.__stream__(fun, hd(params), acc, params, [flatten], opts) - stream - - _ -> - raise ArgumentError, "Nx.Defn.stream/3 expects at least two arguments" - end - end - defp prepare_options(opts) do opts = Keyword.merge(default_options(), opts) diff --git a/nx/lib/nx/defn/compiler.ex b/nx/lib/nx/defn/compiler.ex index ebb4b047c7..b529eae365 100644 --- a/nx/lib/nx/defn/compiler.ex +++ b/nx/lib/nx/defn/compiler.ex @@ -57,31 +57,6 @@ defmodule Nx.Defn.Compiler do ) :: ([[Nx.Tensor.t()]] -> [Nx.Container.t()]) when vars: [Nx.Container.t()] - @doc """ - Callback for streaming (on top of JIT compilation). - - It receives the same arguments as `c:__jit__/5` with the addition - of the streaming input and accumulator templates. If the input - and accumulator are containers, they are kept in their container - shapes. As in `c:__jit__/5`, both `vars` and `args_list` are flat - lists of tensors (without their container shape). - - It must return a struct that implements the `Nx.Stream` protocol. - """ - @callback __stream__( - key :: term, - input, - acc, - vars, - fun :: (vars -> {output, acc}), - args_list :: [[(-> Nx.t())]], - opts :: keyword - ) :: [Nx.Stream.t()] - when input: Nx.Container.t(), - output: Nx.Container.t(), - acc: Nx.Container.t(), - vars: [Nx.Container.t()] - @doc """ Receives a keyword list of compiler options and returns a list of compiler options, each to run @@ -153,12 +128,6 @@ defmodule Nx.Defn.Compiler do compiler.__jit__(fun, params, runtime_fun, args_list, opts) end - @doc false - def __stream__(fun, input, acc, params, args_list, opts) do - {compiler, runtime_fun, opts} = prepare_options(fun, opts) - compiler.__stream__(fun, input, acc, params, runtime_fun, args_list, opts) - end - defp prepare_options(fun, opts) do {compiler, opts} = Keyword.pop(opts, :compiler, Nx.Defn.Evaluator) {compiler, &runtime_fun(&1, fun, compiler), opts} diff --git a/nx/lib/nx/defn/debug.ex b/nx/lib/nx/defn/debug.ex index c314ae35d8..c59ef20aad 100644 --- a/nx/lib/nx/defn/debug.ex +++ b/nx/lib/nx/defn/debug.ex @@ -9,9 +9,6 @@ defmodule Nx.Defn.Debug do @impl true def __to_backend__(_), do: raise("not implemented") - @impl true - def __stream__(_, _, _, _, _, _, _), do: raise("not implemented") - @impl true def __compile__(_, _, _, _), do: raise("not implemented") diff --git a/nx/lib/nx/defn/evaluator.ex b/nx/lib/nx/defn/evaluator.ex index 382430694d..d0337f2c1c 100644 --- a/nx/lib/nx/defn/evaluator.ex +++ b/nx/lib/nx/defn/evaluator.ex @@ -33,26 +33,6 @@ defmodule Nx.Defn.Evaluator do Nx.default_backend() end - @impl true - def __stream__(_key, input, acc, vars, fun, [args], opts) do - count = Nx.Defn.Composite.count(input) + Nx.Defn.Composite.count(acc) - rest_params = Enum.drop(args, count) - hooks = Keyword.get(opts, :hooks, %{}) - gc? = Keyword.get(opts, :garbage_collect, false) - {expr, output, cache} = precompile(fun, vars, hooks) - - [ - Nx.Defn.Stream.start_link(input, acc, fn input_params, acc -> - acc_params = [acc] |> Nx.Defn.Composite.flatten_list() |> Enum.map(&fn -> &1 end) - params = input_params ++ acc_params ++ rest_params - - expr - |> composite_eval(%{params: params, gc: gc?}, [cache]) - |> apply_output(output) - end) - ] - end - @impl true def __jit__(key, vars, fun, args_list, opts) do __compile__(key, vars, fun, opts).(args_list) diff --git a/nx/lib/nx/defn/stream.ex b/nx/lib/nx/defn/stream.ex deleted file mode 100644 index 3b55a39ab2..0000000000 --- a/nx/lib/nx/defn/stream.ex +++ /dev/null @@ -1,115 +0,0 @@ -defmodule Nx.Defn.Stream do - # Default implementation for Nx.Stream - @moduledoc false - use GenServer - - @doc false - @enforce_keys [:pid, :input, :output] - defstruct [:pid, :input, :output] - - @doc false - def start_link(input, acc, fun) do - {backend, backend_options} = Nx.default_backend() - {:ok, pid} = GenServer.start_link(__MODULE__, {backend, backend_options, acc, fun}) - %Nx.Defn.Stream{input: input, output: Nx.to_template(acc), pid: pid} - end - - @impl true - def init({backend, backend_options, acc, fun}) do - Nx.default_backend({backend, backend_options}) - {:ok, {:queue.new(), :queue.new(), acc, fun}} - end - - @impl true - def handle_cast({:send, params}, {output, waiting, acc, fun}) do - {data, acc} = fun.(params, acc) - - case :queue.out(waiting) do - {:empty, waiting} -> - {:noreply, {:queue.in(data, output), waiting, acc, fun}} - - {{:value, from}, waiting} -> - GenServer.reply(from, {:ok, data}) - {:noreply, {output, waiting, acc, fun}} - end - end - - @impl true - def handle_call(:recv, from, {output, waiting, acc, fun}) do - case :queue.out(output) do - {:empty, output} -> - {:noreply, {output, :queue.in(from, waiting), acc, fun}} - - {{:value, data}, output} -> - {:reply, {:ok, data}, {output, waiting, acc, fun}} - end - end - - @impl true - def handle_call(:done, _from, {output, waiting, acc, fun}) do - if :queue.is_empty(output) do - for from <- :queue.to_list(waiting) do - GenServer.reply(from, :done) - end - - {:stop, :normal, {:ok, acc}, {output, waiting, acc, fun}} - else - {:reply, :recv_pending, {output, waiting, acc, fun}} - end - end - - defimpl Nx.Stream do - def send(%{pid: pid, input: input}, data) do - {template, funs} = - Nx.LazyContainer.traverse(data, [], fn template, fun, acc -> - {template, [fun | acc]} - end) - - unless Nx.compatible?(input, template) do - raise ArgumentError, """ - Nx stream expected a tensor of type, shape, and names on send: - - #{inspect(input)} - - But got tensor: - - #{inspect(template)} - """ - end - - GenServer.cast(pid, {:send, Enum.reverse(funs)}) - end - - def recv(%{pid: pid, output: output}) do - case GenServer.call(pid, :recv, :infinity) do - {:ok, data} -> - unless Nx.compatible?(output, data) do - raise ArgumentError, """ - Nx stream expected a tensor of type, shape, and names on recv: - - #{inspect(output)} - - But got tensor: - - #{inspect(data)} - """ - end - - data - - :done -> - raise "cannot recv from stream because it has been terminated" - end - end - - def done(%{pid: pid}) do - case GenServer.call(pid, :done, :infinity) do - {:ok, acc} -> - acc - - :recv_pending -> - raise "cannot mark stream as done when there are recv messages pending" - end - end - end -end diff --git a/nx/lib/nx/stream.ex b/nx/lib/nx/stream.ex deleted file mode 100644 index ae7a7b75f3..0000000000 --- a/nx/lib/nx/stream.ex +++ /dev/null @@ -1,26 +0,0 @@ -defprotocol Nx.Stream do - @moduledoc """ - The protocol for streaming data in and out of backends. - """ - - @doc """ - Sends a tensor. - - Returns the given tensor. - """ - def send(stream, tensor) - - @doc """ - Receives data from the stream. - - It may be a tensor, a tuple of tensors, or a map of tensors. - """ - def recv(stream) - - @doc """ - Returns the output of the stream. - - It may be a tensor, a tuple of tensors, or a map of tensors. - """ - def done(stream) -end diff --git a/nx/test/nx/defn/stream_test.exs b/nx/test/nx/defn/stream_test.exs deleted file mode 100644 index c1f0770ce9..0000000000 --- a/nx/test/nx/defn/stream_test.exs +++ /dev/null @@ -1,164 +0,0 @@ -defmodule Nx.Defn.StreamTest do - use ExUnit.Case, async: true - - import Nx.Defn - import ExUnit.CaptureLog - - defn defn_sum(entry, acc), do: {acc, entry + acc} - - def elixir_sum(entry, acc) do - true = Process.get(Nx.Defn.Compiler) in [Nx.Defn.Evaluator, Nx.Defn.Debug] - {acc, Nx.add(entry, acc)} - end - - test "runs defn stream" do - %_{} = stream = Nx.Defn.stream(&defn_sum/2, [0, 0]) - assert Nx.Stream.send(stream, 1) == :ok - assert Nx.Stream.recv(stream) == Nx.tensor(0) - - assert Nx.Stream.send(stream, 2) == :ok - assert Nx.Stream.recv(stream) == Nx.tensor(1) - - assert Nx.Stream.done(stream) == Nx.tensor(3) - end - - defn defn_sum_with_args(entry, acc, a, b), do: {acc, entry + acc + (a - b)} - - test "runs defn stream with args" do - %_{} = stream = Nx.Defn.stream(&defn_sum_with_args/4, [0, 0, 2, 1]) - assert Nx.Stream.send(stream, 1) == :ok - assert Nx.Stream.recv(stream) == Nx.tensor(0) - - assert Nx.Stream.send(stream, 2) == :ok - assert Nx.Stream.recv(stream) == Nx.tensor(2) - - assert Nx.Stream.done(stream) == Nx.tensor(5) - end - - test "runs elixir stream" do - %_{} = stream = Nx.Defn.stream(&elixir_sum/2, [0, 0]) - assert Nx.Stream.send(stream, 1) == :ok - assert Nx.Stream.recv(stream) == Nx.tensor(0) - - assert Nx.Stream.send(stream, 2) == :ok - assert Nx.Stream.recv(stream) == Nx.tensor(1) - - assert Nx.Stream.done(stream) == Nx.tensor(3) - end - - test "converts accumulator to tensors" do - assert %_{} = stream = Nx.Defn.stream(fn _, _ -> {0, 0} end, [1, {2, 3}]) - assert Nx.Stream.done(stream) == {Nx.tensor(2), Nx.tensor(3)} - end - - test "can recv before send" do - %_{} = stream = Nx.Defn.stream(&defn_sum/2, [0, 0]) - task = Task.async(fn -> Nx.Stream.recv(stream) end) - Process.sleep(100) - assert Nx.Stream.send(stream, 1) == :ok - assert Task.await(task) == Nx.tensor(0) - end - - @tag :capture_log - test "raises on errors" do - Process.flag(:trap_exit, true) - assert %_{} = stream = Nx.Defn.stream(fn _, _ -> 0 end, [1, 2]) - - assert Nx.Stream.send(stream, 1) == :ok - assert catch_exit(Nx.Stream.recv(stream)) - - ref = Process.monitor(stream.pid) - assert_receive {:DOWN, ^ref, _, _, _} - end - - test "raises if stream is not compatible on send" do - assert %_{} = stream = Nx.Defn.stream(fn _, _ -> {0, 0} end, [1, {2, 3}]) - - assert_raise ArgumentError, - ~r/Nx stream expected a tensor of type, shape, and names on send/, - fn -> Nx.Stream.send(stream, Nx.iota({3})) end - end - - test "raises if stream is not compatible on recv" do - assert %_{} = stream = Nx.Defn.stream(fn _a, {b, c} -> {b, c} end, [1, {2, 3}]) - - assert Nx.Stream.send(stream, Nx.iota({})) == :ok - - assert_raise ArgumentError, - ~r/Nx stream expected a tensor of type, shape, and names on recv/, - fn -> Nx.Stream.recv(stream) end - end - - test "raises if already done" do - assert %_{} = stream = Nx.Defn.stream(fn _, _ -> 0 end, [1, 2]) - assert Nx.Stream.done(stream) == Nx.tensor(2) - assert {:noproc, _} = catch_exit(Nx.Stream.done(stream)) - end - - test "raises if recv is pending on done" do - %_{} = stream = Nx.Defn.stream(&defn_sum/2, [0, 0]) - assert Nx.Stream.send(stream, 1) == :ok - - assert_raise RuntimeError, - "cannot mark stream as done when there are recv messages pending", - fn -> Nx.Stream.done(stream) end - end - - test "raises if stream is done when recving" do - Process.flag(:trap_exit, true) - assert %_{} = stream = Nx.Defn.stream(fn _, _ -> 0 end, [1, 2]) - - assert capture_log(fn -> - Task.start_link(fn -> Nx.Stream.recv(stream) end) - Process.sleep(100) - Nx.Stream.done(stream) - assert_receive {:EXIT, _, {%RuntimeError{}, _}} - end) =~ "cannot recv from stream because it has been terminated" - end - - defn stream_iota(_, _), do: {Nx.iota({}), Nx.iota({})} - - @tag :capture_log - test "uses the default backend on iota" do - Process.flag(:trap_exit, true) - args = [Nx.tensor(1), Nx.tensor(2)] - Nx.default_backend(ProcessBackend) - assert %_{} = stream = Nx.Defn.stream(&stream_iota/2, args) - assert Nx.Stream.send(stream, hd(args)) - assert_receive {:EXIT, _, {%RuntimeError{message: "not supported"}, _}}, 500 - end - - defn container_stream(%Container{a: a} = elem, %Container{b: b} = acc) do - {%{elem | a: a + b}, %{acc | b: a + b}} - end - - test "container in and out" do - args = [%Container{a: 0, b: 0, c: :reset, d: :elem}, %Container{a: 0, b: 0, d: :acc}] - %_{} = stream = Nx.Defn.stream(&container_stream/2, args) - - assert Nx.Stream.send(stream, %Container{a: 1, b: -1}) == :ok - assert Nx.Stream.recv(stream) == %Container{a: Nx.tensor(1), b: Nx.tensor(-1), d: :elem} - - assert Nx.Stream.send(stream, %Container{a: 2, b: -2}) == :ok - assert Nx.Stream.recv(stream) == %Container{a: Nx.tensor(3), b: Nx.tensor(-2), d: :elem} - - assert Nx.Stream.done(stream) == %Container{a: Nx.tensor(0), b: Nx.tensor(3), d: :acc} - end - - defn lazy_container_stream(%LazyWrapped{a: a, c: c}, acc) do - {acc, acc + a - c} - end - - test "lazy container in" do - args = [%LazyOnly{a: 0, b: 0, c: 0}, 0] - %_{} = stream = Nx.Defn.stream(&lazy_container_stream/2, args) - - assert Nx.Stream.send(stream, %LazyOnly{a: 3, b: 0, c: -1}) == :ok - assert Nx.Stream.recv(stream) == Nx.tensor(0) - - assert Nx.Stream.send(stream, %LazyOnly{a: 5, b: 0, c: 2}) == :ok - assert Nx.Stream.recv(stream) == Nx.tensor(4) - - assert Nx.Stream.done(stream) == Nx.tensor(7) - end -end