From 9a9a568ae0b6bda760a805cb255586d6597cd538 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sun, 1 Sep 2024 01:52:50 +0200 Subject: [PATCH] Simplify return of compile --- exla/lib/exla/defn.ex | 68 +++++++++++++---------------------- exla/lib/exla/defn/buffers.ex | 13 ++++--- exla/lib/exla/defn/outfeed.ex | 14 ++++---- exla/lib/exla/defn/stream.ex | 15 ++++---- 4 files changed, 48 insertions(+), 62 deletions(-) diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 9c9bbdbdbb5..660a521267a 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -32,6 +32,7 @@ defmodule EXLA.Defn do @doc false def __stream__(key, input, acc, vars, fun, [args], options) do + {debug?, options} = Keyword.pop(options, :debug, false) {run_options, compile_options} = Keyword.pop(options, :run_options, []) {client_name, compile_options} = @@ -51,24 +52,26 @@ defmodule EXLA.Defn do comp_fun = &to_stream_computation(client, input_length, acc_length, &1, &2, &3, &4, compile_options) - {executable, used_inputs, {output, acc_output}, outfeed, extra, debug?} = + {executable, used_inputs, {output, acc_output}, outfeed, input_typespecs} = compile( client, - {:stream, key}, + key, vars, fun, compile_options, used_buffers, used_inputs, _stream = true, + debug?, comp_fun ) - {input_typespecs, input_indexes} = extra + # Now discard the infeed from used inputs, similar to how it is done to buffers. + # Note we discard all lazy transfers too, as they are not possible with streams. + used_inputs = for {i, nil} <- used_inputs, i >= used_buffers, do: {i, nil}, into: %{} - # Also discard the stream inputs from used inputs, similar to how it is done to buffers - # Note we discard all lazy transfers too, as they are not possible with streams - used_inputs = Enum.sort(for {i, nil} <- used_inputs, i >= used_buffers, do: i) + # And capture the typespecs for the infeed. + input_typespecs = Enum.take_while(input_typespecs, fn {i, _} -> i < input_length end) # Execution of streams requires the coordination of # multiple processes which is outlined below. @@ -120,7 +123,6 @@ defmodule EXLA.Defn do outfeed_pid, input, input_typespecs, - input_indexes, output, output_typespecs, acc_output @@ -151,9 +153,6 @@ defmodule EXLA.Defn do {input_typespecs, used_typespecs} = Enum.split_while(used_typespecs, fn {i, _} -> i < input_length end) - # Get all input indexes and shape - input_indexes = Enum.map(input_typespecs, &elem(&1, 0)) - # Drop all accumulator entries from used_typespecs as we will handle it separately. {acc_typespecs, used_typespecs} = Enum.split(used_typespecs, acc_length) @@ -166,13 +165,10 @@ defmodule EXLA.Defn do # The input will be read as part of the infeed. acc_typespecs_l = Enum.map(acc_typespecs, &elem(&1, 1)) acc_typespec = List.to_tuple(acc_typespecs_l) - flag_typespec = Typespec.tensor({:pred, 8}, {}) args = EXLA.MLIR.Function.get_arguments(builder) - {token, [flag]} = Value.infeed(root_token, [flag_typespec]) - init = [flag, token | args] arg_typespecs = Enum.map(init, &Value.get_typespec/1) @@ -186,11 +182,9 @@ defmodule EXLA.Defn do {body_computation, [_flag, token | args]} = Function.push_region(builder, arg_typespecs) {acc, constant} = Enum.split(args, acc_length) - - {indices, input_typespecs} = Enum.unzip(input_typespecs) + {input_indices, input_typespecs} = Enum.unzip(input_typespecs) {token, input} = Value.infeed(token, input_typespecs) - - input_params = Enum.zip(indices, input) + input_params = Enum.zip(input_indices, input) {%Outfeed{token: token} = outfeed, acc} = case expr do @@ -226,9 +220,7 @@ defmodule EXLA.Defn do # Emit the stream hook to signal loop output {token, [flag]} = Value.infeed(token, [flag_typespec]) - Value.return(flag.function, [flag, token | acc] ++ List.flatten(constant)) - Function.pop_region(builder) [_flag, out_token | results] = Value.while(builder, pred_computation, body_computation, init) @@ -238,8 +230,7 @@ defmodule EXLA.Defn do outfeed = outfeed |> Outfeed.with_token(out_token) |> Outfeed.close(builder) Value.func_return(builder, output) - - {{input_typespecs, input_indexes}, outfeed} + outfeed end @doc false @@ -249,6 +240,7 @@ defmodule EXLA.Defn do @doc false def __compile__(key, vars, fun, options) do + {debug?, options} = Keyword.pop(options, :debug, false) {run_options, compile_options} = Keyword.pop(options, :run_options, []) {client_name, compile_options} = @@ -258,8 +250,8 @@ defmodule EXLA.Defn do callback = &to_root_computation(&1, &2, &3, &4, Keyword.put(compile_options, :client, client)) - {executable, used_inputs, outputs, outfeed, :ok, debug?} = - compile(client, key, vars, fun, compile_options, 0, [], _stream = false, callback) + {executable, used_inputs, outputs, outfeed, _input_typespecs?} = + compile(client, key, vars, fun, compile_options, 0, [], _stream = false, debug?, callback) fn [args] -> {time, lock} = @@ -306,10 +298,8 @@ defmodule EXLA.Defn do {res, cache} = recur_flatten(expr, state, new_cache(outfeed)) outfeed = cache |> get_outfeed() |> Outfeed.close(function) - Value.func_return(function, res) - - {:ok, outfeed} + outfeed end defp maybe_outfeed(lock, executable, args, used_inputs, outputs, outfeed, run_options) @@ -367,6 +357,7 @@ defmodule EXLA.Defn do used_buffers, used_inputs, stream?, + debug?, to_computation ) do {{expr_cache_fun, comp_cache_fun}, options} = @@ -379,8 +370,6 @@ defmodule EXLA.Defn do {{cache_fun, cache_fun}, options} end - {debug?, options} = Keyword.pop(options, :debug, false) - {args_key, reverse_args_identifiers} = Enum.map_reduce(vars, [], fn var, acc -> Nx.Defn.Composite.traverse(var, acc, fn @@ -396,7 +385,7 @@ defmodule EXLA.Defn do {eval_time, {expr, {ref, outputs, {used_inputs, defined_hooks}}}} = :timer.tc(fn -> - expr_cache_fun.({key, args_key, lazy_transfers}, fn -> + expr_cache_fun.({key, stream?, args_key, lazy_transfers}, fn -> expr = fun.(vars) inputs_and_hooks = Outfeed.used_inputs_and_hooks(expr, used_inputs, lazy_transfers) {expr, {make_ref(), Nx.to_template(expr), inputs_and_hooks}} @@ -412,12 +401,10 @@ defmodule EXLA.Defn do end {hooks, options} = Keyword.pop(options, :hooks, %{}) - outfeed = Outfeed.new(hooks, defined_hooks) - comp_key = {ref, client.name, outfeed.used_hooks, lazy_transfers, options} - {comp_time, {evaled, {xla_time, executable, extra, outfeed}}} = + {comp_time, {evaled, {xla_time, executable, inputs_and_typespecs, outfeed}}} = :timer.tc(fn -> comp_cache_fun.(comp_key, fn -> {reverse_inputs_and_typespecs, reverse_infeeds} = @@ -430,7 +417,7 @@ defmodule EXLA.Defn do inputs_and_typespecs = Enum.reverse(reverse_inputs_and_typespecs) - comp_arg_typespecs = + comp_typespecs = for {i, typespec} <- inputs_and_typespecs, i >= used_buffers, do: typespec outputs = @@ -451,7 +438,7 @@ defmodule EXLA.Defn do |> then(&Typespec.tensor(&1.type, &1.shape)) end) - EXLA.MLIR.Module.new(comp_arg_typespecs, out_typespecs, fn builder -> + EXLA.MLIR.Module.new(comp_typespecs, out_typespecs, fn builder -> # Only create the token when we know it will actually be # used, that is: streaming, lazy transfers or hooks outfeed = @@ -464,25 +451,20 @@ defmodule EXLA.Defn do end expr = Nx.Defn.Composite.traverse(expr || fun.(vars), &Nx.devectorize/1) - - {extra, outfeed} = - to_computation.(builder, expr, inputs_and_typespecs, outfeed) + outfeed = to_computation.(builder, expr, inputs_and_typespecs, outfeed) {xla_time, executable} = :timer.tc(fn -> - typespecs = - for {i, typespec} <- inputs_and_typespecs, i >= used_buffers, do: typespec - EXLA.MLIR.Module.compile( builder.module, client, - typespecs, + comp_typespecs, builder.return_typespecs, options ) end) - {:ok, {xla_time, executable, extra, %{outfeed | infeeds: []}}} + {:ok, {xla_time, executable, inputs_and_typespecs, %{outfeed | infeeds: []}}} end) end) end) @@ -511,7 +493,7 @@ defmodule EXLA.Defn do end outfeed = Outfeed.with_user_hooks(outfeed, hooks) - {executable, used_inputs, outputs, outfeed, extra, debug?} + {executable, used_inputs, outputs, outfeed, inputs_and_typespecs} end defp us_to_ms(time), do: Float.round(time / 1000, 1) diff --git a/exla/lib/exla/defn/buffers.ex b/exla/lib/exla/defn/buffers.ex index b73c54742dc..5c806755f12 100644 --- a/exla/lib/exla/defn/buffers.ex +++ b/exla/lib/exla/defn/buffers.ex @@ -32,16 +32,19 @@ defmodule EXLA.Defn.Buffers do @doc """ Splits the given args by value and returns them as is. - - Entries with a map entry are discarded. """ def split_by_value(args, %{} = map, callback) do {_i, left, right} = Enum.reduce(args, {0, [], []}, fn arg, {i, left, right} -> case map do - %{^i => nil} -> {i + 1, [callback.(arg, i, nil) | left], right} - %{^i => value} -> {i + 1, left, [callback.(arg, i, value) | right]} - %{} -> {i + 1, left, right} + %{^i => nil} -> + {i + 1, [callback.(arg, i, nil) | left], right} + + %{^i => value} -> + {i + 1, left, [callback.(arg, i, value) | right]} + + %{} -> + {i + 1, left, right} end end) diff --git a/exla/lib/exla/defn/outfeed.ex b/exla/lib/exla/defn/outfeed.ex index 101946800bd..f528f19efa8 100644 --- a/exla/lib/exla/defn/outfeed.ex +++ b/exla/lib/exla/defn/outfeed.ex @@ -120,15 +120,15 @@ defmodule EXLA.Defn.Outfeed do {infeeds, {compiled_hooks, token}} = entries |> List.keysort(1, :desc) - |> Enum.map_reduce({compiled_hooks, token}, fn {pos, _, typespec}, - {compiled_hooks, token} -> - next_flag = next_hook(compiled_hooks) - compiled_hooks = Map.put(compiled_hooks, next_flag, {:infeed, pos, typespec}) + |> Enum.map_reduce({compiled_hooks, token}, fn + {pos, _, typespec}, {compiled_hooks, token} -> + next_flag = next_hook(compiled_hooks) + compiled_hooks = Map.put(compiled_hooks, next_flag, {:infeed, pos, typespec}) - token = Value.outfeed(Value.constant(builder, [next_flag], flag_typespec()), token) - {token, [input]} = Value.infeed(token, [typespec]) + token = Value.outfeed(Value.constant(builder, [next_flag], flag_typespec()), token) + {token, [input]} = Value.infeed(token, [typespec]) - {{pos, input}, {compiled_hooks, token}} + {{pos, input}, {compiled_hooks, token}} end) %{outfeed | compiled_hooks: compiled_hooks, token: token, infeeds: infeeds} diff --git a/exla/lib/exla/defn/stream.ex b/exla/lib/exla/defn/stream.ex index cfb0215a70d..67f68c3c9be 100644 --- a/exla/lib/exla/defn/stream.ex +++ b/exla/lib/exla/defn/stream.ex @@ -2,7 +2,7 @@ defmodule EXLA.Defn.Stream do @moduledoc false keys = - [:lock, :outfeed, :pid, :runner, :send, :send_typespecs, :send_indexes] ++ + [:lock, :outfeed, :pid, :runner, :send, :send_typespecs] ++ [:recv, :recv_length, :done, :client, :device_id] @derive {Inspect, only: [:pid, :client, :device_id, :send, :recv]} @@ -16,7 +16,6 @@ defmodule EXLA.Defn.Stream do outfeed, send, send_typespecs, - send_indexes, recv, recv_typespecs, done @@ -40,7 +39,6 @@ defmodule EXLA.Defn.Stream do lock: lock, send: send, send_typespecs: send_typespecs, - send_indexes: send_indexes, recv: recv, recv_length: length(recv_typespecs), client: client, @@ -64,15 +62,14 @@ defmodule EXLA.Defn.Stream do client: client, device_id: device_id, send: send, - send_typespecs: send_typespecs, - send_indexes: send_indexes + send_typespecs: send_typespecs } = stream if pid != self() do raise "EXLA streams require recv to be called from the process that started the stream" end - {template, buffers} = nx_to_io(data, send_indexes) + {template, buffers} = nx_to_io(data, Enum.map(send_typespecs, &elem(&1, 0))) unless Nx.compatible?(send, template) do raise ArgumentError, """ @@ -87,7 +84,11 @@ defmodule EXLA.Defn.Stream do end pred = EXLA.Typespec.tensor({:pred, 8}, {}) - data_and_typespecs = Enum.zip(buffers, send_typespecs) + + data_and_typespecs = + Enum.zip_with(buffers, send_typespecs, fn buffer, {_index, typespec} -> + {buffer, typespec} + end) :ok = EXLA.Client.to_infeed(client, device_id, [{<<1::8-native>>, pred}]) :ok = EXLA.Client.to_infeed(client, device_id, data_and_typespecs)