Skip to content

Commit

Permalink
Capture dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed Aug 4, 2024
1 parent 5c438dc commit 6661862
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 51 deletions.
67 changes: 50 additions & 17 deletions exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ typedef struct {
void * handle;
} ExlaPlugin;

typedef void (*ExlaCustomCallFunction)(void *out[], const void *in[]);
typedef void (*ExlaCustomCallFunction)(void *out[], const void *in[], int **dims);

typedef struct {
const char* name;
Expand Down Expand Up @@ -962,21 +962,6 @@ ERL_NIF_TERM load_custom_call_plugin_library(ErlNifEnv* env, int argc, const ERL
return exla::nif::error(env, "Unable to open library.");
}

const ExlaPluginCustomCall* custom_calls = (ExlaPluginCustomCall*) dlsym(handle, "exla_custom_calls");

if(!custom_calls) {
dlclose(handle);
return exla::nif::error(env, "Unable to find exla_custom_calls");
}

int i = 0;
ExlaPluginCustomCall func = custom_calls[i];
while (func.name != NULL) {
// TODO: GPU flags
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(func.name, func.func);
func = custom_calls[++i];
}

ExlaPlugin* plugin = (ExlaPlugin*) enif_alloc_resource(exla_plugin_resource_type, sizeof(ExlaPlugin));
plugin->handle = handle;

Expand All @@ -986,6 +971,53 @@ ERL_NIF_TERM load_custom_call_plugin_library(ErlNifEnv* env, int argc, const ERL
return exla::nif::ok(env, result);
}

ERL_NIF_TERM register_custom_call_symbol(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 3) {
return exla::nif::error(env, "Bad argument count.");
}

ExlaPlugin* plugin;
std::string symbol;
std::vector<std::vector<exla::int64>> dimensions;

if (!enif_get_resource(env, argv[0], exla_plugin_resource_type, (void **) &plugin)) {
return exla::nif::error(env, "Unable to get plugin.");
}
if (!exla::nif::get(env, argv[1], symbol)) {
return exla::nif::error(env, "Unable to get symbol.");
}
if (!exla::nif::get_list(env, argv[2], dimensions)) {
return exla::nif::error(env, "Unable to get dimensions.");
}

ExlaCustomCallFunction function = (ExlaCustomCallFunction) dlsym(plugin->handle, symbol.c_str());

if (!function) {
return exla::nif::error(env, "Could not find symbol.");
}

auto lambda = [&dimensions, function](void *in[], const void *out[]) {
std::vector<std::vector<int>> int_dims(dimensions.size());
for (size_t i = 0; i < dimensions.size(); ++i) {
int_dims[i].resize(dimensions[i].size());
std::transform(dimensions[i].begin(), dimensions[i].end(), int_dims[i].begin(),
[](exla::int64 x) { return static_cast<int>(x); });
}

std::vector<int*> dims_ptrs;
for (auto& d : int_dims) {
dims_ptrs.push_back(d.data());
}

function(in, out, dims_ptrs.data());
};

// TODO: GPU/Client flag
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol.c_str(), function);

return exla::nif::ok(env);
}

static ErlNifFunc exla_funcs[] = {
// MLIR Builder
{"mlir_new_context", 0, mlir_new_context},
Expand Down Expand Up @@ -1024,7 +1056,8 @@ static ErlNifFunc exla_funcs[] = {
{"serialize_executable", 1, serialize_executable},
{"deserialize_executable", 2, deserialize_executable},
// Plugins
{"load_custom_call_plugin_library", 1, load_custom_call_plugin_library}
{"load_custom_call_plugin_library", 1, load_custom_call_plugin_library},
{"register_custom_call_symbol", 3, register_custom_call_symbol}
};

ERL_NIF_INIT(Elixir.EXLA.NIF, exla_funcs, &load, NULL, &upgrade, NULL);
19 changes: 19 additions & 0 deletions exla/c_src/exla/exla_nif_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,25 @@ int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector<std::string>& var) {
return 1;
}

int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector<std::vector<int64>>& var) {
unsigned int length;
if (!enif_get_list_length(env, list, &length)) {
return 0;
}
var.reserve(length);
ERL_NIF_TERM head, tail;

while (enif_get_list_cell(env, list, &head, &tail)) {
std::vector<int64> elem;
if (!get_list(env, head, elem)) {
return 0;
}
var.push_back(elem);
list = tail;
}
return 1;
}

int get_binary(ErlNifEnv* env, ERL_NIF_TERM term, ErlNifBinary* var) {
return enif_inspect_binary(env, term, var);
}
Expand Down
2 changes: 2 additions & 0 deletions exla/c_src/exla/exla_nif_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector<std::string>& var);

int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector<xla::Shape>& var);

int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector<std::vector<int64>>& var);

template <typename T>
int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector<T*>& var) {
unsigned int length;
Expand Down
1 change: 1 addition & 0 deletions exla/lib/exla/application.ex
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ defmodule EXLA.Application do
name: EXLA.MLIR.ContextPool,
lazy: true},
EXLA.Client,
EXLA.Plugin,
EXLA.Defn.Lock,
EXLA.Defn.LockedCache,
{Task.Supervisor, name: EXLA.Defn.TaskSupervisor}
Expand Down
2 changes: 2 additions & 0 deletions exla/lib/exla/nif.ex
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,6 @@ defmodule EXLA.NIF do
def load_pjrt_plugin(_device_type, _library_path), do: :erlang.nif_error(:undef)

def load_custom_call_plugin_library(_library_path), do: :erlang.nif_error(:undef)

def register_custom_call_symbol(_plugin, _symbol, _dimensions), do: :erlang.nif_error(:undef)
end
49 changes: 38 additions & 11 deletions exla/lib/exla/plugin.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,55 @@ defmodule EXLA.Plugin do
@moduledoc """
Plugin system for registering custom calls.
"""
use GenServer

def register(library_path) do
unless File.exists?(library_path) do
raise ArgumentError, "#{library_path} does not exist"
# TODO: Register and lookup per client

def start_link(_opts) do
GenServer.start_link(__MODULE__, %{}, name: __MODULE__)
end

def register(key, library_path) do
GenServer.cast(__MODULE__, {:register, key, library_path})
end

def lookup(key) do
GenServer.call(__MODULE__, {:lookup, key})
end

def register_symbol(key, symbol, dimensions) do
if ref = lookup(key) do
EXLA.NIF.register_custom_call_symbol(ref, symbol, dimensions)
end
end

@impl true
def init(_opts) do
{:ok, %{}}
end

@impl true
def handle_cast({:register, key, library_path}, state) do
case state do
%{^key => _ref} ->
{:noreply, state}

case :persistent_term.get({__MODULE__, library_path}, nil) do
nil ->
%{} ->
ref =
library_path
|> EXLA.NIF.load_custom_call_plugin_library()
|> unwrap!()

# we need to keep the ref from getting garbage collected so
# we can use the symbols within it at anytime
:persistent_term.put({__MODULE__, library_path}, ref)

_ref ->
:ok
{:noreply, Map.put(state, key, ref)}
end
end

@impl true
def handle_call({:lookup, key}, _from, state) do
value = Map.get(state, key)
{:reply, value, state}
end

defp unwrap!({:ok, ref}), do: ref
defp unwrap!({:error, reason}), do: raise("#{reason}")
end
16 changes: 2 additions & 14 deletions exla/test/exla/plugin_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,8 @@ defmodule EXLA.PluginTest do
use ExUnit.Case

describe "register/1" do
test "raises if file does not exist" do
assert_raise ArgumentError, ~r/does not exist/, fn ->
EXLA.Plugin.register("test/support/c/doesnotexist.so")
end
end

test "does not crash on invalid files" do
assert_raise RuntimeError, ~r/Unable to open/, fn ->
EXLA.Plugin.register(__ENV__.file)
end
end

test "registers a plugin" do
assert :ok = EXLA.Plugin.register("test/support/c/libcustom_plugin.so")
assert :ok = EXLA.Plugin.register(:custom_plugin, "test/support/c/libcustom_plugin.so")
end
end
end
end
13 changes: 4 additions & 9 deletions exla/test/support/c/custom_plugin.c
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
#include <cstdint>
#include <stddef.h>

typedef void (*ExlaCustomCallFunction)(void *out[], const void *in[]);
typedef void (*ExlaCustomCallFunction)(void *out[], const void *in[], int **dims);

typedef struct {
const char* name;
ExlaCustomCallFunction func;
} ExlaPluginCustomCall;

void custom_increment(void *out[], const void *in[]) {
extern "C" void custom_increment(void *out[], const void *in[], int **dims) {
int64_t *operand = (int64_t *)in[0];
int64_t *dim_sizes = (int64_t *)in[1];
int64_t *dim_sizes = (int64_t *)dims[0];

int64_t *out_buffer = (int64_t *)out[0];

Expand All @@ -19,9 +19,4 @@ void custom_increment(void *out[], const void *in[]) {
for (int64_t i = 0; i < n; i++) {
out_buffer[i] = operand[i] + 1;
}
}

extern "C" ExlaPluginCustomCall exla_custom_calls[] = {
{"custom_increment", custom_increment},
{NULL, NULL}
};
}
Binary file modified exla/test/support/c/libcustom_plugin.so
Binary file not shown.

0 comments on commit 6661862

Please sign in to comment.