Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add to/from pointer #1473

Merged
merged 12 commits into from
Apr 25, 2024
23 changes: 21 additions & 2 deletions exla/Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Environment variables passed via elixir_make
# Environment variables passed via elixir_make
# ERTS_INCLUDE_DIR
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Environment variables passed via elixir_make
# Environment variables passed via elixir_make

# MIX_APP_PATH

Expand Down Expand Up @@ -29,8 +29,11 @@ CFLAGS = -fPIC -I$(ERTS_INCLUDE_DIR) -I$(XLA_INCLUDE_PATH) -Wall -Wno-sign-compa
-Wno-unused-parameter -Wno-missing-field-initializers -Wno-comment \
-std=c++17 -w -DLLVM_VERSION_STRING=

NVCCFLAGS = -shared -Xcompiler -fPIC

ifdef DEBUG
CFLAGS += -g
NVCCFLAGS += -g
else
CFLAGS += -O3
endif
Expand Down Expand Up @@ -60,7 +63,23 @@ $(EXLA_SO): $(EXLA_CACHE_SO)

SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/mlir/ops.cc $(EXLA_DIR)/mlir/builder.cc $(EXLA_DIR)/mlir/custom_calls.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_nif_util.cc
HEADERS = $(EXLA_DIR)/mlir/ops.h $(EXLA_DIR)/mlir/builder.h $(EXLA_DIR)/mlir/custom_calls.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h
OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(SOURCES))
OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(SOURCES)) $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o


NVCC_RESULT := $(shell which nvcc 2> NULL)
NVCC_TEST := $(notdir $(NVCC_RESULT))

ifeq ($(NVCC_TEST),nvcc)
NVCC := nvcc
NVCCFLAGS += -DCUDA_ENABLED
else
NVCC := g++
NVCCFLAGS = $(CFLAGS)
endif

$(EXLA_CACHE_OBJ_DIR)/exla_cuda.o: $(EXLA_DIR)/exla_cuda.cc $(EXLA_DIR)/exla_cuda.h
@ mkdir -p $(EXLA_CACHE_OBJ_DIR)
$(NVCC) $(NVCCFLAGS) -c $< -o $@

$(EXLA_CACHE_OBJ_DIR)/%.o: $(EXLA_DIR)/%.cc $(HEADERS)
@ mkdir -p $(EXLA_CACHE_OBJ_DIR)
Expand Down
101 changes: 101 additions & 0 deletions exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <string>

#include "exla_client.h"
#include "exla_cuda.h"
#include "exla_log_sink.h"
#include "exla_nif_util.h"
#include "mlir/ops.h"
Expand Down Expand Up @@ -134,6 +135,104 @@ ERL_NIF_TERM create_sub_builder(ErlNifEnv* env, int argc, const ERL_NIF_TERM arg

// ExlaBuffer Functions

ERL_NIF_TERM get_buffer_device_pointer(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 3) {
return exla::nif::error(env, "Bad argument count.");
}

exla::ExlaClient** client;
exla::ExlaBuffer** buffer;
std::string pointer_kind;

if (!exla::nif::get<exla::ExlaClient*>(env, argv[0], client)) {
return exla::nif::error(env, "Unable to get client.");
}
if (!exla::nif::get<exla::ExlaBuffer*>(env, argv[1], buffer)) {
return exla::nif::error(env, "Unable to get buffer.");
}
if (!exla::nif::get_atom(env, argv[2], pointer_kind)) {
return exla::nif::error(env, "Unable to get device pointer kind.");
}

EXLA_ASSIGN_OR_RETURN_NIF(std::uintptr_t ptr,
(*buffer)->GetDevicePointer((*client)->client()), env);

std::vector<unsigned char> pointer_vec;
if (pointer_kind == "local") {
unsigned char* bytePtr = reinterpret_cast<unsigned char*>(&ptr);
for (size_t i = 0; i < sizeof(void*); i++) {
pointer_vec.push_back(bytePtr[i]);
}
} else if (pointer_kind == "cuda_ipc") {
auto result = get_cuda_ipc_handle(ptr);
if (result.second) {
return exla::nif::error(env, "Unable to get cuda IPC handle");
}
pointer_vec = result.first;
}

EXLA_ASSIGN_OR_RETURN_NIF(unsigned long device_size, (*buffer)->GetOnDeviceSizeInBytes(), env);

ERL_NIF_TERM handle_list[pointer_vec.size()];
for (int i = 0; i < pointer_vec.size(); i++) {
handle_list[i] = enif_make_uint(env, pointer_vec[i]);
}

ERL_NIF_TERM handle_list_term = enif_make_list_from_array(env, handle_list, pointer_vec.size());
ERL_NIF_TERM device_size_term = enif_make_uint64(env, device_size);

return exla::nif::ok(env, enif_make_tuple2(env, handle_list_term, device_size_term));
}

ERL_NIF_TERM create_buffer_from_device_pointer(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 5) {
return exla::nif::error(env, "Bad argument count.");
}

exla::ExlaClient** client;
std::vector<int64_t> pointer_vec;
xla::Shape* shape;
int device_id;
std::string pointer_kind;

if (!exla::nif::get<exla::ExlaClient*>(env, argv[0], client)) {
return exla::nif::error(env, "Unable to get client.");
}
if (!exla::nif::get_list(env, argv[1], pointer_vec)) {
return exla::nif::error(env, "Unable to get device pointer.");
}
if (!exla::nif::get_atom(env, argv[2], pointer_kind)) {
return exla::nif::error(env, "Unable to get device pointer kind.");
}
if (!exla::nif::get<xla::Shape>(env, argv[3], shape)) {
return exla::nif::error(env, "Unable to get shape.");
}
if (!exla::nif::get(env, argv[4], &device_id)) {
return exla::nif::error(env, "Unable to get device ordinal.");
}

void* ptr;
if (pointer_kind == "local") {
unsigned char* bytePtr = reinterpret_cast<unsigned char*>(&ptr);
for (size_t i = 0; i < sizeof(void*); i++) {
bytePtr[i] = pointer_vec[i];
}
} else if (pointer_kind == "cuda_ipc") {
auto result = get_pointer_for_ipc_handle(pointer_vec);
if (result.second) {
return exla::nif::error(env, "Unable to get pointer for IPC handle.");
}
ptr = result.first;
}

EXLA_ASSIGN_OR_RETURN_NIF(xla::PjRtDevice * device, (*client)->client()->LookupDevice(device_id), env);

std::function<void()> on_delete_callback = []() {};
EXLA_ASSIGN_OR_RETURN_NIF(std::unique_ptr<xla::PjRtBuffer> buffer, (*client)->client()->CreateViewOfDeviceBuffer(ptr, *shape, device, on_delete_callback), env);
exla::ExlaBuffer* exla_buffer = new exla::ExlaBuffer(std::move(buffer));
return exla::nif::ok(env, exla::nif::make<exla::ExlaBuffer*>(env, exla_buffer));
}

ERL_NIF_TERM binary_to_device_mem(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 4) {
return exla::nif::error(env, "Bad argument count.");
Expand Down Expand Up @@ -710,6 +809,8 @@ static ErlNifFunc exla_funcs[] = {
{"get_supported_platforms", 0, get_supported_platforms},
{"mlir_compile", 7, mlir_compile, ERL_NIF_DIRTY_JOB_CPU_BOUND},
// ExlaBuffer
{"get_buffer_device_pointer", 3, get_buffer_device_pointer},
{"create_buffer_from_device_pointer", 5, create_buffer_from_device_pointer},
{"binary_to_device_mem", 4, binary_to_device_mem, ERL_NIF_DIRTY_JOB_IO_BOUND},
{"read_device_mem", 2, read_device_mem, ERL_NIF_DIRTY_JOB_IO_BOUND},
{"deallocate_device_mem", 1, deallocate_device_mem, ERL_NIF_DIRTY_JOB_IO_BOUND},
Expand Down
8 changes: 8 additions & 0 deletions exla/c_src/exla/exla_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ class ExlaBuffer {
xla::StatusOr<ERL_NIF_TERM> ToBinary(ErlNifEnv* env, exla::int64 size);
xla::Status Deallocate();

xla::StatusOr<std::uintptr_t> GetDevicePointer(xla::PjRtClient* client) {
return client->UnsafeBufferPointer(buffer_.get());
}

xla::StatusOr<size_t> GetOnDeviceSizeInBytes() {
return buffer_.get()->GetOnDeviceSizeInBytes();
}

~ExlaBuffer() {
// Theoretically this may block if a computation is running
// but we always block the host until the computation is done.
Expand Down
55 changes: 55 additions & 0 deletions exla/c_src/exla/exla_cuda.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#include "exla_cuda.h"

#ifdef CUDA_ENABLED
#include <cuda_runtime.h>

#include <cstring>
#include <iostream>

std::pair<std::vector<unsigned char>, int> get_cuda_ipc_handle(std::uintptr_t ptr) {
cudaIpcMemHandle_t ipc_handle;
cudaError_t status = cudaIpcGetMemHandle(&ipc_handle, reinterpret_cast<void*>(ptr));

// Assuming sizeof(cudaIpcMemHandle_t) is constant
const size_t size = sizeof(cudaIpcMemHandle_t);

// Copy the memory handle to a byte array
std::vector<unsigned char> result(size);
memcpy(result.data(), &ipc_handle, size);

return std::make_pair(result, status != cudaSuccess);
}

std::pair<void*, int> get_pointer_for_ipc_handle(std::vector<int64_t> handle_list) {
unsigned char ipc_handle_data[sizeof(cudaIpcMemHandle_t)];
for (int i = 0; i < sizeof(cudaIpcMemHandle_t); i++) {
ipc_handle_data[i] = (uint8_t)handle_list[i];
}

cudaIpcMemHandle_t ipc_handle;
memcpy(&ipc_handle, ipc_handle_data, sizeof(cudaIpcMemHandle_t));

int* ptr;
cudaError_t cuda_status = cudaSetDevice(0); // Assuming device 0, change as needed
if (cuda_status != cudaSuccess) {
printf("Error setting CUDA device: %s\n", cudaGetErrorString(cuda_status));
return std::make_pair(nullptr, 1); // Return with error status
}

cuda_status = cudaIpcOpenMemHandle((void**)&ptr, ipc_handle, cudaIpcMemLazyEnablePeerAccess);
if (cuda_status != cudaSuccess) {
printf("Error opening CUDA IPC memory handle: %s\n", cudaGetErrorString(cuda_status));
return std::make_pair(nullptr, 1); // Return with error status
}

return std::make_pair(ptr, cuda_status != cudaSuccess);
}
#else
std::pair<std::vector<unsigned char>, int> get_cuda_ipc_handle(std::uintptr_t ptr) {
return std::make_pair(std::vector<unsigned char>(0), 1);
}

std::pair<void*, int> get_pointer_for_ipc_handle(std::vector<int64_t> handle_list) {
return std::make_pair(nullptr, 1);
}
#endif
7 changes: 7 additions & 0 deletions exla/c_src/exla/exla_cuda.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#pragma once

#include <cstdint>
#include <vector>

std::pair<std::vector<unsigned char>, int> get_cuda_ipc_handle(std::uintptr_t);
std::pair<void*, int> get_pointer_for_ipc_handle(std::vector<int64_t>);
2 changes: 1 addition & 1 deletion exla/c_src/exla/exla_nif_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ ERL_NIF_TERM make(ErlNifEnv* env, const char* string);
// their signatures are the same for retrieving/returning
// regular strings.

int get_atom(ErlNifEnv* env, ERL_NIF_TERM term, std::string* var);
int get_atom(ErlNifEnv* env, ERL_NIF_TERM term, std::string& var);

ERL_NIF_TERM atom(ErlNifEnv* env, const char* status);

Expand Down
64 changes: 64 additions & 0 deletions exla/lib/exla/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,70 @@ defmodule EXLA.Backend do
EXLA.DeviceBuffer.deallocate(buffer)
end

@impl true
def to_pointer(%T{data: %B{buffer: buffer}}, opts \\ []) do
opts = Keyword.validate!(opts, mode: :local)

mode =
case opts[:mode] do
mode when mode in [:local, :cuda_ipc] ->
mode

mode ->
raise ArgumentError, "expected one of :local, :cuda_ipc, got: #{inspect(mode)}"
end

case buffer do
%EXLA.DeviceBuffer{} ->
:ok

_ ->
raise ArgumentError, "tensor must be allocated via a #{DeviceBuffer}"
end

client = EXLA.Client.fetch!(buffer.client_name)

case EXLA.NIF.get_buffer_device_pointer(client.ref, buffer.ref, mode) do
{:ok, {pointer, _size}} ->
{:ok, pointer}

error ->
error
end
end

@impl true
def from_pointer(pointer, type, dims, opts) do
template = Nx.template(dims, type, names: opts[:names])

opts = Keyword.validate!(opts[:backend_opts] || [], [:client_name, :device_id, mode: :local])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As commented below, this is mixing backend options (such as client name) with pointer options. We probably should treat them separately.


client_name = opts[:client_name] || EXLA.Client.default_name()
client = EXLA.Client.fetch!(client_name)

device_id = opts[:device_id] || client.default_device_id

shape = EXLA.Shape.make_shape(type, dims)

result =
EXLA.NIF.create_buffer_from_device_pointer(
client.ref,
pointer,
opts[:mode],
shape.ref,
device_id
)

case result do
{:ok, ref} ->
buffer = EXLA.DeviceBuffer.from_ref(ref, client, device_id, shape)
{:ok, %{template | data: %EXLA.Backend{buffer: buffer}}}

error ->
error
end
end

@impl true
def to_batched(out, tensor, opts) do
leftover = opts[:leftover]
Expand Down
11 changes: 11 additions & 0 deletions exla/lib/exla/nif.ex
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,17 @@ defmodule EXLA.NIF do
),
do: :erlang.nif_error(:undef)

def get_buffer_device_pointer(_client, _buffer, _pointer_kind), do: :erlang.nif_error(:undef)

def create_buffer_from_device_pointer(
_client,
_opaque_pointer,
_pointer_kind,
_shape,
_device_id
),
do: :erlang.nif_error(:undef)

def binary_to_device_mem(_client, _binary, _shape, _device_ordinal),
do: :erlang.nif_error(:undef)

Expand Down
4 changes: 2 additions & 2 deletions exla/mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ defmodule EXLA.MixProject do

defp deps do
[
{:nx, "~> 0.7.1"},
# {:nx, path: "../nx"},
# {:nx, "~> 0.7.1"},
{:nx, path: "../nx"},
{:telemetry, "~> 0.4.0 or ~> 1.0"},
{:xla, "~> 0.6.0", runtime: false},
{:elixir_make, "~> 0.6", runtime: false},
Expand Down
21 changes: 21 additions & 0 deletions exla/test/exla/device_memory_sharing_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
defmodule EXLA.DeviceMemorySharingTest do
use EXLA.Case, async: false

@moduletag :cuda_required

test "buffer sharing works as expected" do
t1 = Nx.tensor([1, 2, 3], backend: {EXLA.Backend, client: :cuda})

assert inspect(t1) =~ "1, 2, 3"

assert {:ok, pointer} = Nx.to_pointer(t1, mode: :local)

assert {:ok, t2} =
Nx.from_pointer(EXLA.Backend, pointer, t1.type, t1.shape,
backend_opts: [client_name: :cuda]
)

assert t1.data.buffer.ref != t2.data.buffer.ref
assert Nx.to_binary(t1) == Nx.to_binary(t2)
end
end
9 changes: 8 additions & 1 deletion exla/test/test_helper.exs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,15 @@ if client.platform == :host and client.device_count == 1 and System.schedulers_o
)
end

cuda_required =
if Map.has_key?(EXLA.Client.get_supported_platforms(), :cuda) do
[]
else
[:cuda_required]
end

ExUnit.start(
exclude: [:platform, :integration] ++ exclude_multi_device ++ exclude,
exclude: [:platform, :integration] ++ exclude_multi_device ++ exclude ++ cuda_required,
include: [platform: String.to_atom(target)],
assert_receive_timeout: 1000
)
Loading
Loading