Skip to content

Commit

Permalink
Do not discard client/device in EXLA.Backend when it is host
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Nov 12, 2023
1 parent 981ca89 commit 9f2854d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
28 changes: 20 additions & 8 deletions exla/lib/exla/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -320,22 +320,34 @@ defmodule EXLA.Backend do
defp jit(opts, fun, args), do: jit(opts, fun, args, args)

defp jit(opts, fun, tensors, args) do
{client, device_id} =
{priority_client, priority_did, backup_client, backup_did} =
for %T{data: %B{buffer: %EXLA.DeviceBuffer{client_name: client_name, device_id: device_id}}} <-
tensors,
reduce: {nil, nil} do
{^client_name, ^device_id} = acc ->
reduce: {nil, nil, nil, nil} do
{^client_name, ^device_id, _, _} = acc ->
acc

acc ->
{priority_client, priority_did, backup_client, backup_did} ->
# If the client supports automatic transfers (typically host),
# it should not win over the cuda/rocm. At the same time,
# if it is the only device, we don't want to discard it.
case EXLA.Client.fetch!(client_name) do
%{platform: :host, default_device_id: ^device_id} -> acc
_ -> {client_name, device_id}
%{automatic_transfers: true, default_device_id: ^device_id} ->
{priority_client, priority_did, client_name, device_id}

_ ->
{client_name, device_id, backup_client, backup_did}
end
end

client = opts[:client] || client || EXLA.Client.default_name()
device_id = opts[:device_id] || device_id || EXLA.Client.fetch!(client).default_device_id
client =
opts[:client] || priority_client || backup_client ||
EXLA.Client.default_name()

device_id =
opts[:device_id] || priority_did || backup_did ||
EXLA.Client.fetch!(client).default_device_id

EXLA.jit_apply(fun, args, on_conflict: :force, client: client, device_id: device_id)
end
end
9 changes: 9 additions & 0 deletions exla/test/exla/backend_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,17 @@ defmodule EXLA.BackendTest do

@tag :multi_device
test "multi-device" do
a = Nx.tensor(1, backend: {EXLA.Backend, client: :other_host, device_id: 0})
assert_equal(Nx.add(a, 2), Nx.tensor(3))

a = Nx.tensor(1, backend: {EXLA.Backend, client: :other_host, device_id: 1})
assert_equal(Nx.add(a, 2), Nx.tensor(3))

a = Nx.tensor([[1]], backend: {EXLA.Backend, client: :other_host, device_id: 0})
assert Nx.reshape(a, {1}).data.buffer.client_name == :other_host

a = Nx.tensor([[1]], backend: {EXLA.Backend, client: :other_host, device_id: 1})
assert Nx.reshape(a, {1}).data.buffer.client_name == :other_host
end

test "Kernel.inspect/2" do
Expand Down

0 comments on commit 9f2854d

Please sign in to comment.