Skip to content

Commit

Permalink
Move automatic transfer configuration to each client
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Nov 12, 2023
1 parent 1c4ac8e commit 981ca89
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 11 deletions.
13 changes: 7 additions & 6 deletions exla/lib/exla/client.ex
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ defmodule EXLA.Client do

@name __MODULE__

@enforce_keys [:ref, :platform, :name, :device_count, :default_device_id]
defstruct [:ref, :platform, :name, :device_count, :default_device_id]
@enforce_keys [:ref, :platform, :name, :device_count, :default_device_id, :automatic_transfers]
defstruct [:ref, :platform, :name, :device_count, :default_device_id, :automatic_transfers]

@doc """
Returns the name of the default client.
Expand Down Expand Up @@ -140,12 +140,9 @@ defmodule EXLA.Client do

defp build_client(name, options) do
platform = Keyword.get(options, :platform)
default_device_id = Keyword.get(options, :default_device_id, 0)
memory_fraction = Keyword.get(options, :memory_fraction, 0.9)

preallocate = Keyword.get(options, :preallocate, true)
preallocate_int = if preallocate, do: 1, else: 0

platforms = Map.keys(EXLA.Client.get_supported_platforms())

ref =
Expand Down Expand Up @@ -176,17 +173,21 @@ defmodule EXLA.Client do
|> unwrap!()

device_count = EXLA.NIF.get_device_count(ref) |> unwrap!()
default_device_id = Keyword.get(options, :default_device_id, 0)

if default_device_id not in 0..(device_count - 1) do
raise ArgumentError, ":default_device_id must be a number between 0 and #{device_count - 1}"
end

automatic_transfers = Keyword.get(options, :automatic_transfers, platform == :host)

%EXLA.Client{
ref: ref,
platform: platform,
name: name,
device_count: device_count,
default_device_id: default_device_id
default_device_id: default_device_id,
automatic_transfers: automatic_transfers
}
end

Expand Down
5 changes: 2 additions & 3 deletions exla/lib/exla/defn/buffers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,8 @@ defmodule EXLA.Defn.Buffers do
when transfer? and buffer.client_name != executable.client.name
when transfer? and buffer.device_id != executable.device_id ->
buffer_client = EXLA.Client.fetch!(buffer.client_name)
automatic = Application.fetch_env!(:exla, :automatic_device_transfer_platforms)

if buffer_client.platform in automatic do
if buffer_client.automatic_transfers do
EXLA.DeviceBuffer.copy_to_device(buffer, executable.client, executable.device_id)
else
default = EXLA.Client.fetch!(EXLA.Client.default_name())
Expand All @@ -129,7 +128,7 @@ defmodule EXLA.Defn.Buffers do
but one of the input tensors are allocated on #{buffer_client.name} \
##{buffer.device_id} (#{buffer_client.platform}).
EXLA only transfers tensors allocated on host to other clients. \
EXLA by default only transfers tensors allocated on host to other clients. \
You can force `:host` as your default backend with:
# via config
Expand Down
3 changes: 1 addition & 2 deletions exla/mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ defmodule EXLA.MixProject do
tpu: [platform: :tpu],
host: [platform: :host]
],
preferred_clients: [:cuda, :rocm, :tpu, :host],
automatic_device_transfer_platforms: [:host]
preferred_clients: [:cuda, :rocm, :tpu, :host]
]
]
end
Expand Down

0 comments on commit 981ca89

Please sign in to comment.