From 981ca89387f7ffe821ed8c96b41b0f948a3b513a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sun, 12 Nov 2023 16:29:05 +0100 Subject: [PATCH] Move automatic transfer configuration to each client --- exla/lib/exla/client.ex | 13 +++++++------ exla/lib/exla/defn/buffers.ex | 5 ++--- exla/mix.exs | 3 +-- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/exla/lib/exla/client.ex b/exla/lib/exla/client.ex index 7c9e7f2c9a..9bb2084e41 100644 --- a/exla/lib/exla/client.ex +++ b/exla/lib/exla/client.ex @@ -11,8 +11,8 @@ defmodule EXLA.Client do @name __MODULE__ - @enforce_keys [:ref, :platform, :name, :device_count, :default_device_id] - defstruct [:ref, :platform, :name, :device_count, :default_device_id] + @enforce_keys [:ref, :platform, :name, :device_count, :default_device_id, :automatic_transfers] + defstruct [:ref, :platform, :name, :device_count, :default_device_id, :automatic_transfers] @doc """ Returns the name of the default client. @@ -140,12 +140,9 @@ defmodule EXLA.Client do defp build_client(name, options) do platform = Keyword.get(options, :platform) - default_device_id = Keyword.get(options, :default_device_id, 0) memory_fraction = Keyword.get(options, :memory_fraction, 0.9) - preallocate = Keyword.get(options, :preallocate, true) preallocate_int = if preallocate, do: 1, else: 0 - platforms = Map.keys(EXLA.Client.get_supported_platforms()) ref = @@ -176,17 +173,21 @@ defmodule EXLA.Client do |> unwrap!() device_count = EXLA.NIF.get_device_count(ref) |> unwrap!() + default_device_id = Keyword.get(options, :default_device_id, 0) if default_device_id not in 0..(device_count - 1) do raise ArgumentError, ":default_device_id must be a number between 0 and #{device_count - 1}" end + automatic_transfers = Keyword.get(options, :automatic_transfers, platform == :host) + %EXLA.Client{ ref: ref, platform: platform, name: name, device_count: device_count, - default_device_id: default_device_id + default_device_id: default_device_id, + automatic_transfers: automatic_transfers } end diff --git a/exla/lib/exla/defn/buffers.ex b/exla/lib/exla/defn/buffers.ex index 813ff6f949..2394d59d0b 100644 --- a/exla/lib/exla/defn/buffers.ex +++ b/exla/lib/exla/defn/buffers.ex @@ -116,9 +116,8 @@ defmodule EXLA.Defn.Buffers do when transfer? and buffer.client_name != executable.client.name when transfer? and buffer.device_id != executable.device_id -> buffer_client = EXLA.Client.fetch!(buffer.client_name) - automatic = Application.fetch_env!(:exla, :automatic_device_transfer_platforms) - if buffer_client.platform in automatic do + if buffer_client.automatic_transfers do EXLA.DeviceBuffer.copy_to_device(buffer, executable.client, executable.device_id) else default = EXLA.Client.fetch!(EXLA.Client.default_name()) @@ -129,7 +128,7 @@ defmodule EXLA.Defn.Buffers do but one of the input tensors are allocated on #{buffer_client.name} \ ##{buffer.device_id} (#{buffer_client.platform}). - EXLA only transfers tensors allocated on host to other clients. \ + EXLA by default only transfers tensors allocated on host to other clients. \ You can force `:host` as your default backend with: # via config diff --git a/exla/mix.exs b/exla/mix.exs index 1fc88795d5..cfadb4c2d8 100644 --- a/exla/mix.exs +++ b/exla/mix.exs @@ -49,8 +49,7 @@ defmodule EXLA.MixProject do tpu: [platform: :tpu], host: [platform: :host] ], - preferred_clients: [:cuda, :rocm, :tpu, :host], - automatic_device_transfer_platforms: [:host] + preferred_clients: [:cuda, :rocm, :tpu, :host] ] ] end