Skip to content

Commit

Permalink
Do not compile stack/concatenate expressions
Browse files Browse the repository at this point in the history
People may attempt to concatenate/stack thousands
of entries, which ends up spending too much time
in the compiler. Using the binary backend is simpler
and faster.

Here is a simple benchmark showcasing the issue:

    Nx.default_backend(EXLA.Backend)
    for(_ <- 1..10_000, do: Nx.broadcast(0, {1024}))
    |> Nx.stack(name: :articles)
  • Loading branch information
josevalim committed Jun 15, 2024
1 parent 791a63d commit 12491ab
Showing 1 changed file with 11 additions and 15 deletions.
26 changes: 11 additions & 15 deletions exla/lib/exla/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -234,24 +234,16 @@ defmodule EXLA.Backend do

@impl true
def concatenate(out, tensors, axis) do
out = Nx.to_template(out)

expr_fun = fn tensors ->
Nx.Defn.Expr.concatenate(out, Tuple.to_list(tensors), axis)
end

jit([], expr_fun, tensors, [List.to_tuple(tensors)])
copied = Enum.map(tensors, &Nx.backend_copy(&1, Nx.BinaryBackend))
result = Nx.BinaryBackend.concatenate(out, copied, axis)
Nx.backend_transfer(result, {EXLA.Backend, jit_opts([], tensors)})
end

@impl true
def stack(out, tensors, axis) do
out = Nx.to_template(out)

expr_fun = fn tensors ->
Nx.Defn.Expr.stack(out, Tuple.to_list(tensors), axis)
end

jit([], expr_fun, tensors, [List.to_tuple(tensors)])
copied = Enum.map(tensors, &Nx.backend_copy(&1, Nx.BinaryBackend))
result = Nx.BinaryBackend.stack(out, copied, axis)
Nx.backend_transfer(result, {EXLA.Backend, jit_opts([], tensors)})
end

@impl true
Expand Down Expand Up @@ -390,6 +382,10 @@ defmodule EXLA.Backend do
defp jit(opts, fun, args), do: jit(opts, fun, args, args)

defp jit(opts, fun, tensors, args) do
EXLA.jit_apply(fun, args, [on_conflict: :force] ++ jit_opts(tensors, opts))
end

defp jit_opts(opts, tensors) do
{priority_client, priority_did, backup_client, backup_did} =
for %T{data: %B{buffer: %EXLA.DeviceBuffer{client_name: client_name, device_id: device_id}}} <-
tensors,
Expand Down Expand Up @@ -418,6 +414,6 @@ defmodule EXLA.Backend do
opts[:device_id] || priority_did || backup_did ||
EXLA.Client.fetch!(client).default_device_id

EXLA.jit_apply(fun, args, on_conflict: :force, client: client, device_id: device_id)
[client: client, device_id: device_id]
end
end

0 comments on commit 12491ab

Please sign in to comment.