diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index 924f948210..eb48c6ce03 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -26,7 +26,7 @@ typedef struct { void * handle; } ExlaPlugin; -typedef void (*ExlaCustomCallFunction)(void *out[], const void *in[]); +typedef void (*ExlaCustomCallFunction)(void *out[], const void *in[], int **dims); typedef struct { const char* name; @@ -962,21 +962,6 @@ ERL_NIF_TERM load_custom_call_plugin_library(ErlNifEnv* env, int argc, const ERL return exla::nif::error(env, "Unable to open library."); } - const ExlaPluginCustomCall* custom_calls = (ExlaPluginCustomCall*) dlsym(handle, "exla_custom_calls"); - - if(!custom_calls) { - dlclose(handle); - return exla::nif::error(env, "Unable to find exla_custom_calls"); - } - - int i = 0; - ExlaPluginCustomCall func = custom_calls[i]; - while (func.name != NULL) { - // TODO: GPU flags - XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(func.name, func.func); - func = custom_calls[++i]; - } - ExlaPlugin* plugin = (ExlaPlugin*) enif_alloc_resource(exla_plugin_resource_type, sizeof(ExlaPlugin)); plugin->handle = handle; @@ -986,6 +971,53 @@ ERL_NIF_TERM load_custom_call_plugin_library(ErlNifEnv* env, int argc, const ERL return exla::nif::ok(env, result); } +ERL_NIF_TERM register_custom_call_symbol(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { + if (argc != 3) { + return exla::nif::error(env, "Bad argument count."); + } + + ExlaPlugin* plugin; + std::string symbol; + std::vector> dimensions; + + if (!enif_get_resource(env, argv[0], exla_plugin_resource_type, (void **) &plugin)) { + return exla::nif::error(env, "Unable to get plugin."); + } + if (!exla::nif::get(env, argv[1], symbol)) { + return exla::nif::error(env, "Unable to get symbol."); + } + if (!exla::nif::get_list(env, argv[2], dimensions)) { + return exla::nif::error(env, "Unable to get dimensions."); + } + + ExlaCustomCallFunction function = (ExlaCustomCallFunction) dlsym(plugin->handle, symbol.c_str()); + + if (!function) { + return exla::nif::error(env, "Could not find symbol."); + } + + auto lambda = [&dimensions, function](void *in[], const void *out[]) { + std::vector> int_dims(dimensions.size()); + for (size_t i = 0; i < dimensions.size(); ++i) { + int_dims[i].resize(dimensions[i].size()); + std::transform(dimensions[i].begin(), dimensions[i].end(), int_dims[i].begin(), + [](exla::int64 x) { return static_cast(x); }); + } + + std::vector dims_ptrs; + for (auto& d : int_dims) { + dims_ptrs.push_back(d.data()); + } + + function(in, out, dims_ptrs.data()); + }; + + // TODO: GPU/Client flag + XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol.c_str(), function); + + return exla::nif::ok(env); +} + static ErlNifFunc exla_funcs[] = { // MLIR Builder {"mlir_new_context", 0, mlir_new_context}, @@ -1024,7 +1056,8 @@ static ErlNifFunc exla_funcs[] = { {"serialize_executable", 1, serialize_executable}, {"deserialize_executable", 2, deserialize_executable}, // Plugins - {"load_custom_call_plugin_library", 1, load_custom_call_plugin_library} + {"load_custom_call_plugin_library", 1, load_custom_call_plugin_library}, + {"register_custom_call_symbol", 3, register_custom_call_symbol} }; ERL_NIF_INIT(Elixir.EXLA.NIF, exla_funcs, &load, NULL, &upgrade, NULL); diff --git a/exla/c_src/exla/exla_nif_util.cc b/exla/c_src/exla/exla_nif_util.cc index d38785f6ed..d802f2a55d 100644 --- a/exla/c_src/exla/exla_nif_util.cc +++ b/exla/c_src/exla/exla_nif_util.cc @@ -248,6 +248,25 @@ int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var) { return 1; } +int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector>& var) { + unsigned int length; + if (!enif_get_list_length(env, list, &length)) { + return 0; + } + var.reserve(length); + ERL_NIF_TERM head, tail; + + while (enif_get_list_cell(env, list, &head, &tail)) { + std::vector elem; + if (!get_list(env, head, elem)) { + return 0; + } + var.push_back(elem); + list = tail; + } + return 1; +} + int get_binary(ErlNifEnv* env, ERL_NIF_TERM term, ErlNifBinary* var) { return enif_inspect_binary(env, term, var); } diff --git a/exla/c_src/exla/exla_nif_util.h b/exla/c_src/exla/exla_nif_util.h index 5abf7e3cda..8244511174 100644 --- a/exla/c_src/exla/exla_nif_util.h +++ b/exla/c_src/exla/exla_nif_util.h @@ -247,6 +247,8 @@ int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var); int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var); +int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector>& var); + template int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var) { unsigned int length; diff --git a/exla/lib/exla/application.ex b/exla/lib/exla/application.ex index 3bdfa30d0c..03dc4b7e4c 100644 --- a/exla/lib/exla/application.ex +++ b/exla/lib/exla/application.ex @@ -18,6 +18,7 @@ defmodule EXLA.Application do name: EXLA.MLIR.ContextPool, lazy: true}, EXLA.Client, + EXLA.Plugin, EXLA.Defn.Lock, EXLA.Defn.LockedCache, {Task.Supervisor, name: EXLA.Defn.TaskSupervisor} diff --git a/exla/lib/exla/nif.ex b/exla/lib/exla/nif.ex index 907398d699..dd90ced016 100644 --- a/exla/lib/exla/nif.ex +++ b/exla/lib/exla/nif.ex @@ -114,4 +114,6 @@ defmodule EXLA.NIF do def load_pjrt_plugin(_device_type, _library_path), do: :erlang.nif_error(:undef) def load_custom_call_plugin_library(_library_path), do: :erlang.nif_error(:undef) + + def register_custom_call_symbol(_plugin, _symbol, _dimensions), do: :erlang.nif_error(:undef) end diff --git a/exla/lib/exla/plugin.ex b/exla/lib/exla/plugin.ex index 0a16f1d662..b7f682867e 100644 --- a/exla/lib/exla/plugin.ex +++ b/exla/lib/exla/plugin.ex @@ -2,28 +2,55 @@ defmodule EXLA.Plugin do @moduledoc """ Plugin system for registering custom calls. """ + use GenServer - def register(library_path) do - unless File.exists?(library_path) do - raise ArgumentError, "#{library_path} does not exist" + # TODO: Register and lookup per client + + def start_link(_opts) do + GenServer.start_link(__MODULE__, %{}, name: __MODULE__) + end + + def register(key, library_path) do + GenServer.cast(__MODULE__, {:register, key, library_path}) + end + + def lookup(key) do + GenServer.call(__MODULE__, {:lookup, key}) + end + + def register_symbol(key, symbol, dimensions) do + if ref = lookup(key) do + EXLA.NIF.register_custom_call_symbol(ref, symbol, dimensions) end + end + + @impl true + def init(_opts) do + {:ok, %{}} + end + + @impl true + def handle_cast({:register, key, library_path}, state) do + case state do + %{^key => _ref} -> + {:noreply, state} - case :persistent_term.get({__MODULE__, library_path}, nil) do - nil -> + %{} -> ref = library_path |> EXLA.NIF.load_custom_call_plugin_library() |> unwrap!() - # we need to keep the ref from getting garbage collected so - # we can use the symbols within it at anytime - :persistent_term.put({__MODULE__, library_path}, ref) - - _ref -> - :ok + {:noreply, Map.put(state, key, ref)} end end + @impl true + def handle_call({:lookup, key}, _from, state) do + value = Map.get(state, key) + {:reply, value, state} + end + defp unwrap!({:ok, ref}), do: ref defp unwrap!({:error, reason}), do: raise("#{reason}") end diff --git a/exla/test/exla/plugin_test.exs b/exla/test/exla/plugin_test.exs index e7f7d0d7f3..ca9f9dfce1 100644 --- a/exla/test/exla/plugin_test.exs +++ b/exla/test/exla/plugin_test.exs @@ -2,20 +2,8 @@ defmodule EXLA.PluginTest do use ExUnit.Case describe "register/1" do - test "raises if file does not exist" do - assert_raise ArgumentError, ~r/does not exist/, fn -> - EXLA.Plugin.register("test/support/c/doesnotexist.so") - end - end - - test "does not crash on invalid files" do - assert_raise RuntimeError, ~r/Unable to open/, fn -> - EXLA.Plugin.register(__ENV__.file) - end - end - test "registers a plugin" do - assert :ok = EXLA.Plugin.register("test/support/c/libcustom_plugin.so") + assert :ok = EXLA.Plugin.register(:custom_plugin, "test/support/c/libcustom_plugin.so") end end -end \ No newline at end of file +end diff --git a/exla/test/support/c/custom_plugin.c b/exla/test/support/c/custom_plugin.c index 085b64f6be..b3c70f0950 100644 --- a/exla/test/support/c/custom_plugin.c +++ b/exla/test/support/c/custom_plugin.c @@ -1,16 +1,16 @@ #include #include -typedef void (*ExlaCustomCallFunction)(void *out[], const void *in[]); +typedef void (*ExlaCustomCallFunction)(void *out[], const void *in[], int **dims); typedef struct { const char* name; ExlaCustomCallFunction func; } ExlaPluginCustomCall; -void custom_increment(void *out[], const void *in[]) { +extern "C" void custom_increment(void *out[], const void *in[], int **dims) { int64_t *operand = (int64_t *)in[0]; - int64_t *dim_sizes = (int64_t *)in[1]; + int64_t *dim_sizes = (int64_t *)dims[0]; int64_t *out_buffer = (int64_t *)out[0]; @@ -19,9 +19,4 @@ void custom_increment(void *out[], const void *in[]) { for (int64_t i = 0; i < n; i++) { out_buffer[i] = operand[i] + 1; } -} - -extern "C" ExlaPluginCustomCall exla_custom_calls[] = { - {"custom_increment", custom_increment}, - {NULL, NULL} -}; \ No newline at end of file +} \ No newline at end of file diff --git a/exla/test/support/c/libcustom_plugin.so b/exla/test/support/c/libcustom_plugin.so index cfc8eb7b93..90e7640043 100755 Binary files a/exla/test/support/c/libcustom_plugin.so and b/exla/test/support/c/libcustom_plugin.so differ