Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update XLA and fix infeed on the GPU #1487

Merged
merged 2 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion exla/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ endif
LDFLAGS = -L$(XLA_EXTENSION_LIB) -lxla_extension -shared

ifeq ($(shell uname -s), Darwin)
LDFLAGS += -flat_namespace -undefined suppress -rpath @loader_path/xla_extension/lib
LDFLAGS += -flat_namespace -undefined dynamic_lookup -rpath @loader_path/xla_extension/lib
else
# Use a relative RPATH, so at runtime libexla.so looks for libxla_extension.so
# in ./lib regardless of the absolute location. This way priv can be safely
Expand Down
29 changes: 20 additions & 9 deletions exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ ERL_NIF_TERM create_buffer_from_device_pointer(ErlNifEnv* env, int argc, const E
ptr = result.first;
}

EXLA_ASSIGN_OR_RETURN_NIF(xla::PjRtDevice * device, (*client)->client()->LookupDevice(device_id), env);
EXLA_ASSIGN_OR_RETURN_NIF(xla::PjRtDevice * device, (*client)->client()->LookupDevice(xla::PjRtGlobalDeviceId(device_id)), env);

std::function<void()> on_delete_callback = []() {};
EXLA_ASSIGN_OR_RETURN_NIF(std::unique_ptr<xla::PjRtBuffer> buffer, (*client)->client()->CreateViewOfDeviceBuffer(ptr, shape, device, on_delete_callback), env);
Expand Down Expand Up @@ -573,29 +573,40 @@ ERL_NIF_TERM transfer_to_infeed(ErlNifEnv* env, int argc, const ERL_NIF_TERM arg
return exla::nif::error(env, "Unable to get device ID.");
}

std::vector<ErlNifBinary> buffer_bins;
std::vector<xla::Shape> shapes;

ERL_NIF_TERM head, tail;
while (enif_get_list_cell(env, data, &head, &tail)) {
const ERL_NIF_TERM* terms;
int count;
xla::Shape shape;

if (!enif_get_tuple(env, head, &count, &terms) && count != 2) {
return exla::nif::error(env, "Unable to binary-shape tuple.");
return exla::nif::error(env, "Unable to {binary, shape} tuple.");
}

ErlNifBinary buffer_bin;
if (!exla::nif::get_binary(env, terms[0], &buffer_bin)) {
return exla::nif::error(env, "Unable to binary.");
}

xla::Shape shape;
if (!exla::nif::get_typespec_as_xla_shape(env, terms[1], &shape)) {
return exla::nif::error(env, "Unable to get shape.");
}

xla::Status transfer_status = (*client)->TransferToInfeed(env, terms[0], shape, device_id);

if (!transfer_status.ok()) {
return exla::nif::error(env, transfer_status.message().data());
}
buffer_bins.push_back(buffer_bin);
shapes.push_back(shape);

data = tail;
}

xla::Status transfer_status = (*client)->TransferToInfeed(env, buffer_bins, shapes, device_id);

if (!transfer_status.ok()) {
return exla::nif::error(env, transfer_status.message().data());
}

return exla::nif::ok(env);
}

Expand Down Expand Up @@ -668,7 +679,7 @@ ERL_NIF_TERM copy_buffer_to_device(ErlNifEnv* env, int argc, const ERL_NIF_TERM
}

EXLA_ASSIGN_OR_RETURN_NIF(xla::PjRtDevice * device,
(*client)->client()->LookupDevice(device_id), env);
(*client)->client()->LookupDevice(xla::PjRtGlobalDeviceId(device_id)), env);
EXLA_ASSIGN_OR_RETURN_NIF(exla::ExlaBuffer * buf,
(*buffer)->CopyToDevice(device), env);

Expand Down
65 changes: 44 additions & 21 deletions exla/c_src/exla/exla_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "xla/pjrt/pjrt_c_api_client.h"
#include "xla/pjrt/pjrt_compiler.h"
#include "xla/pjrt/tfrt_cpu_pjrt_client.h"
#include "xla/shape_util.h"

namespace exla {

Expand Down Expand Up @@ -61,10 +62,10 @@ xla::StatusOr<std::unique_ptr<xla::PjRtBuffer>> PjRtBufferFromBinary(xla::PjRtCl
return xla::InvalidArgument("Expected buffer to be binary.");
}

xla::PjRtClient::HostBufferSemantics semantics = xla::PjRtClient::HostBufferSemantics::kZeroCopy;
xla::PjRtClient::HostBufferSemantics semantics = xla::PjRtClient::HostBufferSemantics::kImmutableZeroCopy;
std::function<void()> on_done_with_host_buffer = [copy_env]() { enif_free_env(copy_env); };

EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client->LookupDevice(device_id));
EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client->LookupDevice(xla::PjRtGlobalDeviceId(device_id)));
EXLA_ASSIGN_OR_RETURN(auto buffer, client->BufferFromHostBuffer(
binary.data, shape.element_type(), shape.dimensions(), std::nullopt, semantics, on_done_with_host_buffer, device));

Expand Down Expand Up @@ -292,7 +293,7 @@ xla::StatusOr<ERL_NIF_TERM> ExlaExecutable::Run(ErlNifEnv* env,
// executable, meaning we need to find the device corresponding to the specific device
// id and execute on that device, we've already guaranteed this executable only has 1
// replica
EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client_->client()->LookupDevice(device_id));
EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client_->client()->LookupDevice(xla::PjRtGlobalDeviceId(device_id)));
// because this is a portable executable, it only has 1 replica and so we only need
// to get the arguments at the first position of the input buffers
std::vector<xla::PjRtBuffer*> portable_args = input_buffers.at(0);
Expand Down Expand Up @@ -390,30 +391,49 @@ xla::StatusOr<ExlaExecutable*> ExlaClient::Compile(const mlir::OwningOpRef<mlir:
}

xla::Status ExlaClient::TransferToInfeed(ErlNifEnv* env,
ERL_NIF_TERM data,
const xla::Shape& shape,
std::vector<ErlNifBinary> buffer_bins,
std::vector<xla::Shape> shapes,
int device_id) {
// Fast path to avoid any traversal when not sending Tuples
ERL_NIF_TERM head, tail;
if (!enif_get_list_cell(env, data, &head, &tail)) {
return xla::InvalidArgument("infeed operation expects a list of binaries");
}
std::vector<const char*> buf_ptrs;
buf_ptrs.reserve(buffer_bins.size());

ErlNifBinary binary;
if (!nif::get_binary(env, head, &binary)) {
return xla::InvalidArgument("infeed operation expects a list of binaries");
for (const auto & buffer_bin : buffer_bins) {
const char* data_ptr = const_cast<char*>(reinterpret_cast<char*>(buffer_bin.data));
buf_ptrs.push_back(data_ptr);
}

const char* data_ptr = const_cast<char*>(reinterpret_cast<char*>(binary.data));
xla::BorrowingLiteral literal(data_ptr, shape);

EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client_->LookupDevice(device_id));

return device->TransferToInfeed(literal);
auto shape = xla::ShapeUtil::MakeTupleShape(shapes);

// Instead of pushing each buffer separately, we create a flat tuple
// literal and push the whole group of buffers.
//
// On the CPU, XLA infeed reads buffers from a queue one at a time [1][2]
// (or rather, the infeed operation is lovered to multiple queue reads),
jonatanklosko marked this conversation as resolved.
Show resolved Hide resolved
// hence pushing one at a time works fine. Pushing a flat tuple works
// effectively the same, since it basically adds each element to the
// queue [3].
//
// On the GPU, XLA infeed reads only a single "literal" from a queue [4]
// and expects it to carry all buffers for the given infeed operation.
// Consequently, we need to push all buffers as a single literal.
//
// Given that a flat tuple works in both cases, we just do that.
//
// [1]: https://github.com/openxla/xla/blob/fd58925adee147d38c25a085354e15427a12d00a/xla/service/cpu/ir_emitter.cc#L449-L450
// [2]: https://github.com/openxla/xla/blob/fd58925adee147d38c25a085354e15427a12d00a/xla/service/cpu/cpu_runtime.cc#L222
// [3]: https://github.com/openxla/xla/blob/fd58925adee147d38c25a085354e15427a12d00a/xla/service/cpu/cpu_xfeed.cc#L178
// [4]: https://github.com/openxla/xla/blob/fd58925adee147d38c25a085354e15427a12d00a/xla/service/gpu/runtime/infeed_thunk.cc#L40-L41
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful explanation 👏

xla::BorrowingLiteral literal(buf_ptrs, shape);

EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client_->LookupDevice(xla::PjRtGlobalDeviceId(device_id)));

xla::Status status = device->TransferToInfeed(literal);

return status;
}

xla::StatusOr<ERL_NIF_TERM> ExlaClient::TransferFromOutfeed(ErlNifEnv* env, int device_id, xla::Shape& shape) {
EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client_->LookupDevice(device_id));
EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client_->LookupDevice(xla::PjRtGlobalDeviceId(device_id)));

auto literal = std::make_shared<xla::Literal>(shape);

Expand Down Expand Up @@ -445,8 +465,11 @@ xla::StatusOr<ExlaClient*> GetGpuClient(double memory_fraction,
.memory_fraction = memory_fraction,
.preallocate = preallocate};

xla::GpuClientOptions client_options = {
.allocator_config = allocator_config};

EXLA_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtClient> client,
xla::GetStreamExecutorGpuClient(false, allocator_config, 0));
xla::GetStreamExecutorGpuClient(client_options));

return new ExlaClient(std::move(client));
}
Expand Down
4 changes: 2 additions & 2 deletions exla/c_src/exla/exla_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ class ExlaClient {

// TODO(seanmor5): This is device logic and should be refactored
xla::Status TransferToInfeed(ErlNifEnv* env,
ERL_NIF_TERM data,
const xla::Shape& shape,
std::vector<ErlNifBinary> buffer_bins,
std::vector<xla::Shape> shapes,
int device_id);

xla::StatusOr<ERL_NIF_TERM> TransferFromOutfeed(ErlNifEnv* env, int device_id, xla::Shape& shape);
Expand Down
14 changes: 5 additions & 9 deletions exla/lib/exla/client.ex
Original file line number Diff line number Diff line change
Expand Up @@ -88,19 +88,15 @@ defmodule EXLA.Client do
@doc """
Sends `data_and_typespecs` to device infeed.

`data_and_typespecs` must be a list of two element tuples where the
first element is a binary or a flat list of binaries and the second
element is a `EXLA.Typespec`.
`data_and_typespecs` is a list of values corresponding to a single
infeed operation. It must be a list of two element tuples where the
first element is a binary and the second element is a `EXLA.Typespec`.
"""
def to_infeed(%EXLA.Client{ref: client}, device_id, data_and_typespecs)
when is_list(data_and_typespecs) do
data_and_typespecs =
Enum.map(data_and_typespecs, fn
{binary, typespec} when is_binary(binary) ->
{[binary], EXLA.Typespec.nif_encode(typespec)}

{[binary | _] = data, typespec} when is_binary(binary) ->
{data, EXLA.Typespec.nif_encode(typespec)}
Enum.map(data_and_typespecs, fn {binary, typespec} when is_binary(binary) ->
{binary, EXLA.Typespec.nif_encode(typespec)}
end)

EXLA.NIF.transfer_to_infeed(client, device_id, data_and_typespecs) |> unwrap!()
Expand Down
20 changes: 7 additions & 13 deletions exla/lib/exla/defn/stream.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ defmodule EXLA.Defn.Stream do
@moduledoc false

keys =
[:lock, :outfeed, :pid, :runner, :send, :send_typespec, :send_indexes] ++
[:lock, :outfeed, :pid, :runner, :send, :send_typespecs, :send_indexes] ++
[:recv, :recv_length, :done, :client, :device_id]

@derive {Inspect, only: [:pid, :client, :device_id, :send, :recv]}
Expand All @@ -15,7 +15,7 @@ defmodule EXLA.Defn.Stream do
runner,
outfeed,
send,
send_typespec,
send_typespecs,
send_indexes,
recv,
recv_typespecs,
Expand All @@ -39,7 +39,7 @@ defmodule EXLA.Defn.Stream do
outfeed: outfeed,
lock: lock,
send: send,
send_typespec: send_typespec,
send_typespecs: send_typespecs,
send_indexes: send_indexes,
recv: recv,
recv_length: length(recv_typespecs),
Expand All @@ -64,7 +64,7 @@ defmodule EXLA.Defn.Stream do
client: client,
device_id: device_id,
send: send,
send_typespec: send_typespec,
send_typespecs: send_typespecs,
send_indexes: send_indexes
} = stream

Expand All @@ -86,17 +86,11 @@ defmodule EXLA.Defn.Stream do
"""
end

data_and_typespecs =
if client.platform == :host do
Enum.zip(buffers, send_typespec)
else
[{buffers, send_typespec}]
end

pred = EXLA.Typespec.tensor({:pred, 8}, {})
data_and_typespecs = Enum.zip(buffers, send_typespecs)

:ok =
EXLA.Client.to_infeed(client, device_id, [{<<1::8-native>>, pred} | data_and_typespecs])
:ok = EXLA.Client.to_infeed(client, device_id, [{<<1::8-native>>, pred}])
:ok = EXLA.Client.to_infeed(client, device_id, data_and_typespecs)
end

defp nx_to_io(container, indexes) do
Expand Down
2 changes: 1 addition & 1 deletion exla/mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ defmodule EXLA.MixProject do
# {:nx, "~> 0.7.1"},
{:nx, path: "../nx"},
{:telemetry, "~> 0.4.0 or ~> 1.0"},
{:xla, "~> 0.6.0", runtime: false},
{:xla, "~> 0.7.0", runtime: false},
{:elixir_make, "~> 0.6", runtime: false},
{:benchee, "~> 1.0", only: :dev},
{:ex_doc, "~> 0.29", only: :docs},
Expand Down
4 changes: 2 additions & 2 deletions exla/mix.lock
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"},
"deep_merge": {:hex, :deep_merge, "1.0.0", "b4aa1a0d1acac393bdf38b2291af38cb1d4a52806cf7a4906f718e1feb5ee961", [:mix], [], "hexpm", "ce708e5f094b9cd4e8f2be4f00d2f4250c4095be93f8cd6d018c753894885430"},
"earmark_parser": {:hex, :earmark_parser, "1.4.39", "424642f8335b05bb9eb611aa1564c148a8ee35c9c8a8bba6e129d51a3e3c6769", [:mix], [], "hexpm", "06553a88d1f1846da9ef066b87b57c6f605552cfbe40d20bd8d59cc6bde41944"},
"elixir_make": {:hex, :elixir_make, "0.7.7", "7128c60c2476019ed978210c245badf08b03dbec4f24d05790ef791da11aa17c", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "5bc19fff950fad52bbe5f211b12db9ec82c6b34a9647da0c2224b8b8464c7e6c"},
"elixir_make": {:hex, :elixir_make, "0.8.3", "d38d7ee1578d722d89b4d452a3e36bcfdc644c618f0d063b874661876e708683", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "5c99a18571a756d4af7a4d89ca75c28ac899e6103af6f223982f09ce44942cc9"},
"ex_doc": {:hex, :ex_doc, "0.31.1", "8a2355ac42b1cc7b2379da9e40243f2670143721dd50748bf6c3b1184dae2089", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.1", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "3178c3a407c557d8343479e1ff117a96fd31bafe52a039079593fb0524ef61b0"},
"makeup": {:hex, :makeup, "1.1.1", "fa0bc768698053b2b3869fa8a62616501ff9d11a562f3ce39580d60860c3a55e", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "5dc62fbdd0de44de194898b6710692490be74baa02d9d108bc29f007783b0b48"},
"makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"},
Expand All @@ -13,5 +13,5 @@
"nx": {:hex, :nx, "0.7.1", "5f6376e3d18408116e8a84b8f4ac851fb07dfe61764a5410ebf0b5dcb69c1b7e", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "e3ddd6a3f2a9bac79c67b3933368c25bb5ec814a883fc68aba8fd8a236751777"},
"statistex": {:hex, :statistex, "1.0.0", "f3dc93f3c0c6c92e5f291704cf62b99b553253d7969e9a5fa713e5481cd858a5", [:mix], [], "hexpm", "ff9d8bee7035028ab4742ff52fc80a2aa35cece833cf5319009b52f1b5a86c27"},
"telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"},
"xla": {:hex, :xla, "0.6.0", "67bb7695efa4a23b06211dc212de6a72af1ad5a9e17325e05e0a87e4c241feb8", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "dd074daf942312c6da87c7ed61b62fb1a075bced157f1cc4d47af2d7c9f44fb7"},
"xla": {:hex, :xla, "0.7.0", "413880fb8f665d93636908092a409e549545e190b38b91107832e78379190d93", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "8eb5c5510e6737fd9e4860bfb0d8cafb13ab94b1b4123edd347562a71e19ec27"},
}
Loading