From 12491abbc048eb403d857472d2de7b6a0728b898 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sat, 15 Jun 2024 10:13:19 +0200 Subject: [PATCH] Do not compile stack/concatenate expressions People may attempt to concatenate/stack thousands of entries, which ends up spending too much time in the compiler. Using the binary backend is simpler and faster. Here is a simple benchmark showcasing the issue: Nx.default_backend(EXLA.Backend) for(_ <- 1..10_000, do: Nx.broadcast(0, {1024})) |> Nx.stack(name: :articles) --- exla/lib/exla/backend.ex | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/exla/lib/exla/backend.ex b/exla/lib/exla/backend.ex index 21c2f8ccd1..0193e14109 100644 --- a/exla/lib/exla/backend.ex +++ b/exla/lib/exla/backend.ex @@ -234,24 +234,16 @@ defmodule EXLA.Backend do @impl true def concatenate(out, tensors, axis) do - out = Nx.to_template(out) - - expr_fun = fn tensors -> - Nx.Defn.Expr.concatenate(out, Tuple.to_list(tensors), axis) - end - - jit([], expr_fun, tensors, [List.to_tuple(tensors)]) + copied = Enum.map(tensors, &Nx.backend_copy(&1, Nx.BinaryBackend)) + result = Nx.BinaryBackend.concatenate(out, copied, axis) + Nx.backend_transfer(result, {EXLA.Backend, jit_opts([], tensors)}) end @impl true def stack(out, tensors, axis) do - out = Nx.to_template(out) - - expr_fun = fn tensors -> - Nx.Defn.Expr.stack(out, Tuple.to_list(tensors), axis) - end - - jit([], expr_fun, tensors, [List.to_tuple(tensors)]) + copied = Enum.map(tensors, &Nx.backend_copy(&1, Nx.BinaryBackend)) + result = Nx.BinaryBackend.stack(out, copied, axis) + Nx.backend_transfer(result, {EXLA.Backend, jit_opts([], tensors)}) end @impl true @@ -390,6 +382,10 @@ defmodule EXLA.Backend do defp jit(opts, fun, args), do: jit(opts, fun, args, args) defp jit(opts, fun, tensors, args) do + EXLA.jit_apply(fun, args, [on_conflict: :force] ++ jit_opts(tensors, opts)) + end + + defp jit_opts(opts, tensors) do {priority_client, priority_did, backup_client, backup_did} = for %T{data: %B{buffer: %EXLA.DeviceBuffer{client_name: client_name, device_id: device_id}}} <- tensors, @@ -418,6 +414,6 @@ defmodule EXLA.Backend do opts[:device_id] || priority_did || backup_did || EXLA.Client.fetch!(client).default_device_id - EXLA.jit_apply(fun, args, on_conflict: :force, client: client, device_id: device_id) + [client: client, device_id: device_id] end end