From 6cb3fa11c93633c50418d9d53b4b48246350077e Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 7 May 2024 23:21:57 -0300 Subject: [PATCH 01/40] chore: pick makefile --- exla/Makefile | 48 +++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 41 insertions(+), 7 deletions(-) diff --git a/exla/Makefile b/exla/Makefile index bd1b1da6ae..f06d72eb4c 100644 --- a/exla/Makefile +++ b/exla/Makefile @@ -7,19 +7,30 @@ XLA_EXTENSION_DIR = cache/xla_extension XLA_EXTENSION_LIB = $(XLA_EXTENSION_DIR)/lib XLA_INCLUDE_PATH = $(XLA_EXTENSION_DIR)/include +IREE_COMPILER_DIR = iree/build/lib +IREE_COMPILER_LIB = cache/$(IREE_COMPILER_DIR) +IREE_COMPILER_INCLUDE_PATH = cache/iree/compiler/bindings/c + +LLVM_MLIR_INCLUDES = -Icache/iree/third_party/llvm-project/mlir/include + # Cache configuration EXLA_CACHE_SO = cache/libexla.so EXLA_CACHE_OBJ_DIR = cache/objs +EXLA_CACHE_IREE_COMPILER_SO = cache/libireecompiler.so # Private configuration EXLA_DIR = c_src/exla PRIV_DIR = $(MIX_APP_PATH)/priv EXLA_SO = $(PRIV_DIR)/libexla.so +EXLA_IREE_COMPILER_SO = $(PRIV_DIR)/libireecompiler.so EXLA_LIB_DIR = $(PRIV_DIR)/xla_extension/lib +EXLA_IREE_LIB_DIR = $(PRIV_DIR)/$(IREE_COMPILER_DIR) # Link paths -XLA_EXTENSION_LIB_LINK_PATH = ../$(CWD_RELATIVE_TO_PRIV_PATH)/$(XLA_EXTENSION_LIB) +XLA_EXTENSION_LIB_LINK_PATH = ../$(CWD_RELATIVE_TO_PRIV_PATH)/$(XLA_EXTENSION_DIR)/$(XLA_EXTENSION_LIB) +IREE_COMPILER_LIB_LINK_PATH = ../../$(CWD_RELATIVE_TO_PRIV_PATH)/$(IREE_COMPILER_LIB) EXLA_CACHE_SO_LINK_PATH = $(CWD_RELATIVE_TO_PRIV_PATH)/$(EXLA_CACHE_SO) +EXLA_CACHE_IREE_COMPILER_SO_LINK_PATH = $(CWD_RELATIVE_TO_PRIV_PATH)/$(EXLA_CACHE_IREE_COMPILER_SO) # Build flags # c++17 is needed, otherwise xla headers @@ -29,6 +40,8 @@ CFLAGS = -fPIC -I$(ERTS_INCLUDE_DIR) -I$(XLA_INCLUDE_PATH) -Wall -Wno-sign-compa -Wno-unused-parameter -Wno-missing-field-initializers -Wno-comment \ -std=c++17 -w -DLLVM_VERSION_STRING= +IREE_CFLAGS = $(CFLAGS) -I$(IREE_COMPILER_INCLUDE_PATH) + NVCCFLAGS = -shared -Xcompiler -fPIC ifdef DEBUG @@ -39,32 +52,41 @@ else endif LDFLAGS = -L$(XLA_EXTENSION_LIB) -lxla_extension -shared +IREE_LDFLAGS = $(LDFLAGS) -lIREECompiler ifeq ($(shell uname -s), Darwin) - LDFLAGS += -flat_namespace -undefined suppress -rpath @loader_path/xla_extension/lib + LDFLAGS += -flat_namespace -undefined suppress -rpath @loader_path/xla_extension/lib -rpath @loader_path/$(IREE_COMPILER_DIR) else # Use a relative RPATH, so at runtime libexla.so looks for libxla_extension.so # in ./lib regardless of the absolute location. This way priv can be safely # packed into an Elixir release. Also, we use $$ to escape Makefile variable # and single quotes to escape shell variable - LDFLAGS += -Wl,-rpath,'$$ORIGIN/xla_extension/lib' + LDFLAGS += -Wl,-rpath,'$$ORIGIN/xla_extension/lib' -Wl,-rpath,'$$ORIGIN/$(IREE_COMPILER_DIR)' endif -$(EXLA_SO): $(EXLA_CACHE_SO) +$(EXLA_SO): $(EXLA_CACHE_SO) $(EXLA_CACHE_IREE_COMPILER_SO) @ mkdir -p $(PRIV_DIR) @ mkdir -p $(PRIV_DIR)/xla_extension + @ mkdir -p $(PRIV_DIR)/iree/build @ if [ "${MIX_BUILD_EMBEDDED}" = "true" ]; then \ cp -a $(abspath $(XLA_EXTENSION_LIB)) $(EXLA_LIB_DIR) ; \ + cp -a $(abspath $(IREE_COMPILER_LIB)) $(EXLA_IREE_LIB_DIR) ; \ cp -a $(abspath $(EXLA_CACHE_SO)) $(EXLA_SO) ; \ + cp -a $(abspath $(EXLA_CACHE_IREE_COMPILER_SO)) $(EXLA_IREE_COMPILER_SO) ; \ else \ ln -sf $(XLA_EXTENSION_LIB_LINK_PATH) $(EXLA_LIB_DIR) ; \ + ln -sf $(IREE_COMPILER_LIB_LINK_PATH) $(EXLA_IREE_LIB_DIR) ; \ ln -sf $(EXLA_CACHE_SO_LINK_PATH) $(EXLA_SO) ; \ + ln -sf $(EXLA_CACHE_IREE_COMPILER_SO_LINK_PATH) $(EXLA_IREE_COMPILER_SO) ; \ fi SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/custom_calls.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_nif_util.cc HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(SOURCES)) $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o +IREE_SOURCES = $(EXLA_DIR)/iree/iree.cc $(EXLA_DIR)/iree/compiler.cc $(EXLA_DIR)/iree/runtime.cc +IREE_HEADERS = $(EXLA_DIR)/iree/compiler.h $(EXLA_DIR)/iree/runtime.h $(EXLA_DIR)/exla_nif_util.h +IREE_OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(IREE_SOURCES)) NVCC_RESULT := $(shell which nvcc 2> /dev/null) NVCC_TEST := $(notdir $(NVCC_RESULT)) @@ -77,17 +99,29 @@ else NVCCFLAGS = $(CFLAGS) endif +$(EXLA_CACHE_OBJ_DIR)/iree/%.o: $(EXLA_DIR)/iree/%.cc $(IREE_HEADERS) + @ mkdir -p $(EXLA_CACHE_OBJ_DIR)/iree + $(CXX) $(IREE_CFLAGS) $(LLVM_MLIR_INCLUDES) -c $< -o $@ + $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o: $(EXLA_DIR)/exla_cuda.cc $(EXLA_DIR)/exla_cuda.h @ mkdir -p $(EXLA_CACHE_OBJ_DIR) $(NVCC) $(NVCCFLAGS) -c $< -o $@ $(EXLA_CACHE_OBJ_DIR)/%.o: $(EXLA_DIR)/%.cc $(HEADERS) - @ mkdir -p $(EXLA_CACHE_OBJ_DIR) - @ mkdir -p $(EXLA_CACHE_OBJ_DIR)/mlir + @ mkdir -p $(EXLA_CACHE_OBJ_DIR)/iree $(CXX) $(CFLAGS) -c $< -o $@ -$(EXLA_CACHE_SO): $(XLA_EXTENSION_DIR) $(OBJECTS) +$(EXLA_CACHE_SO): $(XLA_EXTENSION_DIR) $(IREE_COMPILER_LIB) $(OBJECTS) $(CXX) $(OBJECTS) -o $(EXLA_CACHE_SO) $(LDFLAGS) +$(EXLA_CACHE_IREE_COMPILER_SO): $(EXLA_CACHE_OBJ_DIR)/iree/iree.o $(EXLA_CACHE_OBJ_DIR)/iree/compiler.o $(EXLA_CACHE_OBJ_DIR)/iree/runtime.o $(EXLA_CACHE_OBJ_DIR)/exla_nif_util.o + $(CXX) $^ -o $@ $(IREE_LDFLAGS) + +$(IREE_COMPILER_LIB): + # TO-DO: setup proper download and caching of the iree compiler + @ln -s $(HOME)/coding/iree cache/iree + cmake -G Ninja -B cache/iree/build -DIREE_BUILD_TESTS=OFF -DIREE_BUILD_SAMPLES=OFF -DIREE_ENABLE_ASSERTIONS=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_CXX_FLAGS="-fvisibility=hidden" cache/iree + cmake --build cache/iree/build + clean: rm -rf cache From f1830a0d1ce92434e0b2e2f4e13d2a7e9a92f343 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 7 May 2024 23:42:49 -0300 Subject: [PATCH 02/40] chore: import more changes --- exla/c_src/exla/exla.cc | 29 ++++--- exla/c_src/exla/iree/compiler.cc | 130 +++++++++++++++++++++++++++++++ exla/c_src/exla/iree/compiler.h | 4 + exla/c_src/exla/iree/iree.cc | 35 +++++++++ exla/c_src/exla/iree/runtime.cc | 9 +++ exla/c_src/exla/iree/runtime.h | 4 + exla/lib/exla/application.ex | 2 + exla/lib/exla/defn.ex | 109 +++++++++++++++++++------- exla/lib/exla/mlir/function.ex | 2 +- exla/lib/exla/mlir/iree.ex | 15 ++++ 10 files changed, 293 insertions(+), 46 deletions(-) create mode 100644 exla/c_src/exla/iree/compiler.cc create mode 100644 exla/c_src/exla/iree/compiler.h create mode 100644 exla/c_src/exla/iree/iree.cc create mode 100644 exla/c_src/exla/iree/runtime.cc create mode 100644 exla/c_src/exla/iree/runtime.h create mode 100644 exla/lib/exla/mlir/iree.ex diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index 3fc0d10fdb..1fb8d30dca 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -1,18 +1,16 @@ #include -#include "exla_mlir.h" #include "exla_client.h" #include "exla_cuda.h" #include "exla_log_sink.h" +#include "exla_mlir.h" #include "exla_nif_util.h" - -#include "xla/pjrt/pjrt_api.h" -#include "xla/service/platform_util.h" - #include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" +#include "xla/pjrt/pjrt_api.h" +#include "xla/service/platform_util.h" // All of these are created with calls to `new` and subsequently // passed to the VM as pointers-to-pointers so we balance it out @@ -202,9 +200,9 @@ ERL_NIF_TERM mlir_create_function(ErlNifEnv* env, int argc, const ERL_NIF_TERM a auto arg_types = std::vector{}; - for (auto const & type_string : arg_type_strings) { + for (auto const& type_string : arg_type_strings) { auto type = (*module)->ParseType(type_string); - if(type == nullptr) { + if (type == nullptr) { return type_parsing_error(env, type_string); } arg_types.push_back(type); @@ -212,9 +210,9 @@ ERL_NIF_TERM mlir_create_function(ErlNifEnv* env, int argc, const ERL_NIF_TERM a auto ret_types = std::vector{}; - for (auto const & type_string : ret_type_strings) { + for (auto const& type_string : ret_type_strings) { auto type = (*module)->ParseType(type_string); - if(type == nullptr) { + if (type == nullptr) { return type_parsing_error(env, type_string); } ret_types.push_back(type); @@ -281,9 +279,9 @@ ERL_NIF_TERM mlir_op(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { auto result_types = std::vector{}; - for (auto const & type_string : result_type_strings) { + for (auto const& type_string : result_type_strings) { auto type = (*function)->module()->ParseType(type_string); - if(type == nullptr) { + if (type == nullptr) { return type_parsing_error(env, type_string); } result_types.push_back(type); @@ -291,9 +289,9 @@ ERL_NIF_TERM mlir_op(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { auto attributes = std::vector>{}; - for (auto const & pair : attributes_kwlist) { + for (auto const& pair : attributes_kwlist) { auto attribute_value = (*function)->module()->ParseAttribute(pair.second); - if(attribute_value == nullptr) { + if (attribute_value == nullptr) { return attribute_parsing_error(env, pair.second); } attributes.push_back(std::pair{pair.first, attribute_value}); @@ -304,7 +302,6 @@ ERL_NIF_TERM mlir_op(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { return exla::nif::ok(env, exla::nif::make_list(env, results)); } - ERL_NIF_TERM mlir_push_region(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { if (argc != 2) { return exla::nif::error(env, "Bad argument count."); @@ -322,9 +319,9 @@ ERL_NIF_TERM mlir_push_region(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[ auto types = std::vector{}; - for (auto const & type_string : arg_types) { + for (auto const& type_string : arg_types) { auto type = (*function)->module()->ParseType(type_string); - if(type == nullptr) { + if (type == nullptr) { return type_parsing_error(env, type_string); } types.push_back(type); diff --git a/exla/c_src/exla/iree/compiler.cc b/exla/c_src/exla/iree/compiler.cc new file mode 100644 index 0000000000..225464542a --- /dev/null +++ b/exla/c_src/exla/iree/compiler.cc @@ -0,0 +1,130 @@ +#include "compiler.h" + +#include // For O_WRONLY, O_CREAT, O_TRUNC +#include +#include +#include +#include +#include +#include +#include // For mode constants +#include // For open, close + +#include "../exla_mlir.h" + +typedef struct compiler_state_t { + iree_compiler_session_t *session; + iree_compiler_source_t *source; + iree_compiler_output_t *output; + iree_compiler_invocation_t *invocation; + MlirContext context; +} compiler_state_t; + +void handle_compiler_error(iree_compiler_error_t *error) { + const char *msg = ireeCompilerErrorGetMessage(error); + fprintf(stderr, "Error from compiler API:\n%s\n", msg); + ireeCompilerErrorDestroy(error); +} + +void cleanup_compiler_state(compiler_state_t s) { + if (s.invocation) + ireeCompilerInvocationDestroy(s.invocation); + if (s.output) + ireeCompilerOutputDestroy(s.output); + if (s.source) + ireeCompilerSourceDestroy(s.source); + if (s.session) + ireeCompilerSessionDestroy(s.session); + // ireeCompilerGlobalShutdown(); +} + +static void initializeCompiler(struct compiler_state_t *state) { + // ireeCompilerGlobalInitialize(); + state->session = ireeCompilerSessionCreate(); + state->context = ireeCompilerSessionBorrowContext(state->session); +} + +static void shutdownCompiler(struct compiler_state_t *state) { + ireeCompilerSessionDestroy(state->session); + // ireeCompilerGlobalShutdown(); +} + +ERL_NIF_TERM compile(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { + if (argc != 2) { + return exla::nif::error(env, "Bad argument count."); + } + + exla::MLIRModule **module; + + if (!exla::nif::get(env, argv[0], module)) { + return exla::nif::error(env, "Unable to get module."); + } + + compiler_state_t state; + state.session = NULL; + state.source = NULL; + state.output = NULL; + state.invocation = NULL; + iree_compiler_error_t *error = NULL; + + initializeCompiler(&state); + + // std::string module_str = (*module)->toMLIRString(); + std::string module_str = ""; + MlirOperation module_op = mlirOperationCreateParse( + state.context, + mlirStringRefCreateFromCString(module_str.c_str()), + mlirStringRefCreateFromCString("source.stablehlo")); + if (mlirOperationIsNull(module_op)) { + return exla::nif::error(env, "Unable to create MlirOperation module."); + } + + // Set flags. + iree_compiler_error_t *err; + const char *flags[] = { + "--iree-hal-target-backends=metal-spirv", + "--iree-input-type=stablehlo_xla", + "--iree-execution-model=async-external"}; + err = ireeCompilerSessionSetFlags(state.session, 1, flags); + if (err) { + cleanup_compiler_state(state); + return exla::nif::error(env, "Unable to set flags."); + } + + state.invocation = ireeCompilerInvocationCreate(state.session); + ireeCompilerInvocationEnableConsoleDiagnostics(state.invocation); + + if (!ireeCompilerInvocationImportStealModule(state.invocation, module_op)) { + cleanup_compiler_state(state); + return exla::nif::error(env, "Unable to import module."); + } + + // Compile. + if (!ireeCompilerInvocationPipeline(state.invocation, iree_compiler_pipeline_t::IREE_COMPILER_PIPELINE_STD)) { + cleanup_compiler_state(state); + return exla::nif::error(env, "Unable to compile module."); + } + + fflush(stdout); + auto fd = open("/tmp/iree_output.metal", O_WRONLY | O_CREAT | O_TRUNC, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH); + error = ireeCompilerOutputOpenFD(fd, &state.output); + if (error) { + handle_compiler_error(error); + cleanup_compiler_state(state); + return exla::nif::error(env, "Error opening output file descriptor"); + } + + // Print IR to the output stream. + // When compiling to the 'end' phase, a compiler tool would typically use + // either |ireeCompilerInvocationOutputVMBytecode| or + // |ireeCompilerInvocationOutputVMCSource|. + error = ireeCompilerInvocationOutputVMBytecode(state.invocation, state.output); + if (error) { + handle_compiler_error(error); + cleanup_compiler_state(state); + return 1; + } + + cleanup_compiler_state(state); + return exla::nif::ok(env); +} \ No newline at end of file diff --git a/exla/c_src/exla/iree/compiler.h b/exla/c_src/exla/iree/compiler.h new file mode 100644 index 0000000000..6693ced742 --- /dev/null +++ b/exla/c_src/exla/iree/compiler.h @@ -0,0 +1,4 @@ +#pragma once +#include "../exla_nif_util.h" + +ERL_NIF_TERM compile(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); \ No newline at end of file diff --git a/exla/c_src/exla/iree/iree.cc b/exla/c_src/exla/iree/iree.cc new file mode 100644 index 0000000000..8faf0f8d7f --- /dev/null +++ b/exla/c_src/exla/iree/iree.cc @@ -0,0 +1,35 @@ +#include +#include + +#include "../exla_mlir.h" +#include "../exla_nif_util.h" +#include "compiler.h" +#include "runtime.h" + +ERL_NIF_TERM global_initialize(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { + ireeCompilerGlobalInitialize(); + return exla::nif::ok(env); +} + +static ErlNifFunc iree_funcs[] = { + // MLIR Builder + {"global_initialize", 0, global_initialize}, + {"compile", 2, compile, ERL_NIF_DIRTY_JOB_CPU_BOUND}, + {"run_module", 2, run_module, ERL_NIF_DIRTY_JOB_IO_BOUND}}; + +static int open_resources(ErlNifEnv *env) { + const char *mod = "EXLA"; + + if (!exla::nif::open_resource(env, mod, "ExlaMLIRModule")) { + return -1; + } + return 1; +} + +static int load(ErlNifEnv *env, void **priv, ERL_NIF_TERM load_info) { + if (open_resources(env) == -1) return -1; + + return 0; +} + +ERL_NIF_INIT(Elixir.EXLA.MLIR.IREE, iree_funcs, &load, NULL, NULL, NULL); \ No newline at end of file diff --git a/exla/c_src/exla/iree/runtime.cc b/exla/c_src/exla/iree/runtime.cc new file mode 100644 index 0000000000..f2c0b0f27c --- /dev/null +++ b/exla/c_src/exla/iree/runtime.cc @@ -0,0 +1,9 @@ +#include "runtime.h" + +ERL_NIF_TERM run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { + if (argc != 2) { + return enif_make_badarg(env); + } + + return enif_make_atom(env, "ok"); +} \ No newline at end of file diff --git a/exla/c_src/exla/iree/runtime.h b/exla/c_src/exla/iree/runtime.h new file mode 100644 index 0000000000..240f541b5a --- /dev/null +++ b/exla/c_src/exla/iree/runtime.h @@ -0,0 +1,4 @@ +#pragma once +#include "../exla_nif_util.h" + +ERL_NIF_TERM run_module(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); \ No newline at end of file diff --git a/exla/lib/exla/application.ex b/exla/lib/exla/application.ex index 3bdfa30d0c..82c31210fb 100644 --- a/exla/lib/exla/application.ex +++ b/exla/lib/exla/application.ex @@ -10,6 +10,8 @@ defmodule EXLA.Application do _ -> :os.set_signal(:sigchld, :default) end + EXLA.MLIR.IREE.global_initialize() + children = [ EXLA.Logger, {NimblePool, diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 72b75a371f..3d59a6376c 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -295,20 +295,41 @@ defmodule EXLA.Defn do raise ArgumentError, "missing client" end + compiler_mode = Keyword.fetch!(options, :compiler_mode) + + unless compiler_mode do + raise ArgumentError, "missing compiler_mode" + end + + state_params = + if compiler_mode == :iree do + Map.new(params) + else + Map.new(params ++ outfeed.infeeds) + end + state = %{ client: client, precision: Keyword.get(options, :precision, :default), builder: function, - params: Map.new(params ++ outfeed.infeeds), + params: state_params, scope_ids: Tree.scope_ids(expr) } - {res, cache} = recur_flatten(expr, state, new_cache(outfeed)) - outfeed = cache |> get_outfeed() |> Outfeed.close(function) + if compiler_mode == :iree do + {res, _cache} = recur_flatten(expr, state, no_token_cache()) + Value.return(function, res) + {:ok, nil} + else + {res, cache} = recur_flatten(expr, state, new_cache(outfeed)) + + outfeed = + cache |> get_outfeed() |> Outfeed.close(function) - Value.return(function, res) + Value.return(function, res) - {:ok, outfeed} + {:ok, outfeed} + end end defp maybe_outfeed(lock, executable, args, used_inputs, outputs, outfeed, run_options) @@ -369,6 +390,7 @@ defmodule EXLA.Defn do end {debug?, options} = Keyword.pop(options, :debug, false) + {compiler_mode, options} = Keyword.pop(options, :compiler_mode) {args_key, reverse_args_identifiers} = Enum.map_reduce(vars, [], fn var, acc -> @@ -402,7 +424,8 @@ defmodule EXLA.Defn do {hooks, options} = Keyword.pop(options, :hooks, %{}) - outfeed = Outfeed.new(hooks, defined_hooks) + outfeed = + Outfeed.new(hooks, defined_hooks) comp_key = {ref, client.name, outfeed.used_hooks, lazy_transfers, options} @@ -433,9 +456,11 @@ defmodule EXLA.Defn do EXLA.MLIR.Module.new(comp_arg_typespecs, out_typespecs, fn builder -> outfeed = - outfeed - |> Outfeed.with_token(Value.create_token(builder)) - |> Outfeed.add_infeeds(builder, reverse_infeeds) + if compiler_mode != :iree do + outfeed + |> Outfeed.with_token(Value.create_token(builder)) + |> Outfeed.add_infeeds(builder, reverse_infeeds) + end expr = Nx.Defn.Composite.traverse(expr || fun.(vars), &Nx.devectorize/1) @@ -447,13 +472,18 @@ defmodule EXLA.Defn do typespecs = for {i, typespec} <- inputs_and_typespecs, i >= used_buffers, do: typespec - EXLA.MLIR.Module.compile( - builder.module, - client, - typespecs, - builder.return_typespecs, - options - ) + if compiler_mode == :iree do + :ok = EXLA.MLIR.IREE.compile(builder.module.ref, "metal") + raise "compiler not returning computation yet" + else + EXLA.MLIR.Module.compile( + builder.module, + client, + typespecs, + builder.return_typespecs, + options + ) + end end) {:ok, {xla_time, executable, extra, %{outfeed | infeeds: []}}} @@ -649,9 +679,6 @@ defmodule EXLA.Defn do {computation, cache} %{} -> - {computation, cache} = token_computation("optional", call_args, expr, state, cache) - {computation, Map.put(cache, key, computation)} - end typespecs = [Typespec.token() | container_to_typespecs(expr)] @@ -1606,21 +1633,41 @@ defmodule EXLA.Defn do {region, merge_outfeed(cache, comp_cache)} end - defp token_computation(name, args, expr, %{builder: %Function{}} = state, cache) do + expr, + %{builder: %Function{compiler: compiler}} = state, + cache + ) do %Function{module: module, name: name} = subbuilder(state.builder, name) token_typespec = Typespec.token() arg_typespecs = Enum.map(args, &Value.get_typespec/1) out_typespecs = container_to_typespecs(expr) + in_types = + if compiler == :iree do + arg_typespecs + else + [token_typespec | arg_typespecs] + end + + out_types = + if compiler == :iree do + out_typespecs + else + [token_typespec | out_typespecs] + end + function = - EXLA.MLIR.Module.add_function(module, name, [token_typespec | arg_typespecs], [ - token_typespec | out_typespecs - ]) + EXLA.MLIR.Module.add_function(module, name, in_types, out_types) [arg_token | tail] = EXLA.MLIR.Function.get_arguments(function) - params = Enum.with_index(tail, fn param, i -> {i, param} end) + params = + if compiler == :iree do + Enum.with_index([arg_token | tail], fn param, i -> {i, param} end) + else + Enum.with_index(tail, fn param, i -> {i, param} end) + end state = %{ state @@ -1629,11 +1676,15 @@ defmodule EXLA.Defn do scope_ids: Tree.scope_ids(expr) } - {res, comp_cache} = recur_composite(expr, state, reset_token(cache, arg_token)) - - Value.return(function, [get_token(comp_cache) | List.flatten(res)]) - - {function, merge_outfeed(cache, comp_cache)} + if compiler_mode == :iree do + {res, comp_cache} = recur_composite(expr, state, cache) + Value.return(function, List.flatten(res)) + {function, merge_outfeed(cache, comp_cache)} + else + {res, comp_cache} = recur_composite(expr, state, reset_token(cache, arg_token)) + Value.return(function, [get_token(comp_cache) | List.flatten(res)]) + {function, merge_outfeed(cache, comp_cache)} + end end # The cache is built on top of call args because we need to handle pred/u8. diff --git a/exla/lib/exla/mlir/function.ex b/exla/lib/exla/mlir/function.ex index d9aaa4f7b1..bfd897c2c3 100644 --- a/exla/lib/exla/mlir/function.ex +++ b/exla/lib/exla/mlir/function.ex @@ -2,7 +2,7 @@ defmodule EXLA.MLIR.Function do @moduledoc false # Representation of an MLIR Function or `func.func` type. - defstruct [:module, :ref, :name, :return_typespecs] + defstruct [:module, :ref, :name, :return_typespecs, :compiler_mode] alias __MODULE__, as: Function alias EXLA.MLIR.Value diff --git a/exla/lib/exla/mlir/iree.ex b/exla/lib/exla/mlir/iree.ex new file mode 100644 index 0000000000..0d6c932e0a --- /dev/null +++ b/exla/lib/exla/mlir/iree.ex @@ -0,0 +1,15 @@ +defmodule EXLA.MLIR.IREE do + @moduledoc false + @on_load :__on_load__ + + def __on_load__ do + path = :filename.join(:code.priv_dir(:exla), ~c"libireecompiler") + :erlang.load_nif(path, 0) + end + + def compile(_module, _target), do: :erlang.nif_error(:undef) + + def global_initialize, do: :erlang.nif_error(:undef) + + def run_module(_module, _inputs), do: :erlang.nif_error(:undef) +end From 5a4522b5b64c5c92f40a272470f9152e4e97c2c2 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 7 May 2024 23:58:34 -0300 Subject: [PATCH 03/40] wip --- exla/Makefile | 4 +-- exla/c_src/exla/iree/compiler.cc | 3 +-- exla/lib/exla/defn.ex | 45 +++++++++++++++++++++++++++----- exla/lib/exla/mlir/function.ex | 2 +- 4 files changed, 42 insertions(+), 12 deletions(-) diff --git a/exla/Makefile b/exla/Makefile index f06d72eb4c..469e7c5bdc 100644 --- a/exla/Makefile +++ b/exla/Makefile @@ -85,7 +85,7 @@ HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls.h $(EXLA_DIR)/exla_cl OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(SOURCES)) $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o IREE_SOURCES = $(EXLA_DIR)/iree/iree.cc $(EXLA_DIR)/iree/compiler.cc $(EXLA_DIR)/iree/runtime.cc -IREE_HEADERS = $(EXLA_DIR)/iree/compiler.h $(EXLA_DIR)/iree/runtime.h $(EXLA_DIR)/exla_nif_util.h +IREE_HEADERS = $(EXLA_DIR)/iree/compiler.h $(EXLA_DIR)/iree/runtime.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_mlir.h IREE_OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(IREE_SOURCES)) NVCC_RESULT := $(shell which nvcc 2> /dev/null) @@ -114,7 +114,7 @@ $(EXLA_CACHE_OBJ_DIR)/%.o: $(EXLA_DIR)/%.cc $(HEADERS) $(EXLA_CACHE_SO): $(XLA_EXTENSION_DIR) $(IREE_COMPILER_LIB) $(OBJECTS) $(CXX) $(OBJECTS) -o $(EXLA_CACHE_SO) $(LDFLAGS) -$(EXLA_CACHE_IREE_COMPILER_SO): $(EXLA_CACHE_OBJ_DIR)/iree/iree.o $(EXLA_CACHE_OBJ_DIR)/iree/compiler.o $(EXLA_CACHE_OBJ_DIR)/iree/runtime.o $(EXLA_CACHE_OBJ_DIR)/exla_nif_util.o +$(EXLA_CACHE_IREE_COMPILER_SO): $(EXLA_CACHE_OBJ_DIR)/iree/iree.o $(EXLA_CACHE_OBJ_DIR)/iree/compiler.o $(EXLA_CACHE_OBJ_DIR)/iree/runtime.o $(EXLA_CACHE_OBJ_DIR)/exla_nif_util.o $(EXLA_CACHE_OBJ_DIR)/exla_mlir.o $(CXX) $^ -o $@ $(IREE_LDFLAGS) $(IREE_COMPILER_LIB): diff --git a/exla/c_src/exla/iree/compiler.cc b/exla/c_src/exla/iree/compiler.cc index 225464542a..564358d94a 100644 --- a/exla/c_src/exla/iree/compiler.cc +++ b/exla/c_src/exla/iree/compiler.cc @@ -69,8 +69,7 @@ ERL_NIF_TERM compile(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { initializeCompiler(&state); - // std::string module_str = (*module)->toMLIRString(); - std::string module_str = ""; + std::string module_str = (*module)->ToString(); MlirOperation module_op = mlirOperationCreateParse( state.context, mlirStringRefCreateFromCString(module_str.c_str()), diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 3d59a6376c..d30790ff7d 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -39,6 +39,7 @@ defmodule EXLA.Defn do client = EXLA.Client.fetch!(client_name) compile_options = Keyword.put(compile_options, :lazy_transfers, :never) + compile_options = Keyword.put_new(compile_options, :compiler_mode, :iree) input_length = length(Nx.Defn.Composite.flatten_list([input])) acc_length = length(Nx.Defn.Composite.flatten_list([acc])) @@ -145,6 +146,10 @@ defmodule EXLA.Defn do outfeed, options ) do + if builder.compiler == :iree do + raise ArgumentError, "streaming not supported when compiling with IREE" + end + %{token: root_token, infeeds: []} = outfeed {input_typespecs, used_typespecs} = @@ -253,6 +258,8 @@ defmodule EXLA.Defn do {client_name, compile_options} = Keyword.pop_lazy(compile_options, :client, &EXLA.Client.default_name/0) + compile_options = Keyword.put_new(compile_options, :compiler_mode, :iree) + client = EXLA.Client.fetch!(client_name) callback = &to_root_computation(&1, &2, &3, &4, Keyword.put(compile_options, :client, client)) @@ -455,6 +462,8 @@ defmodule EXLA.Defn do end) EXLA.MLIR.Module.new(comp_arg_typespecs, out_typespecs, fn builder -> + builder = %EXLA.MLIR.Function{builder | compiler: compiler_mode} + outfeed = if compiler_mode != :iree do outfeed @@ -607,10 +616,10 @@ defmodule EXLA.Defn do ] } }, - %{client: %EXLA.Client{platform: :host}, builder: %Function{}} = state, + %{client: %EXLA.Client{platform: :host}, builder: %Function{compiler: compiler}} = state, cache ) - when type_kind != :c do + when type_kind != :c and compiler != :iree do # We match only on platform: :host for MLIR, as we want to support # QR-on-cpu as a custom call only in this case {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() @@ -679,13 +688,23 @@ defmodule EXLA.Defn do {computation, cache} %{} -> + {computation, cache} = token_computation("optional", call_args, expr, state, cache) + {computation, Map.put(cache, key, computation)} + end + + if state.builder.compiler == :iree do + typespecs = container_to_typespecs(expr) - typespecs = [Typespec.token() | container_to_typespecs(expr)] + result = Value.call(state.builder, call_args, call_body, typespecs) + {wrap_tuple_result(result, expr), cache} + else + typespecs = [Typespec.token() | container_to_typespecs(expr)] - [token | result] = - Value.call(state.builder, [get_token(cache) | call_args], call_body, typespecs) + [token | result] = + Value.call(state.builder, [get_token(cache) | call_args], call_body, typespecs) - {wrap_tuple_result(result, expr), update_token(cache, token)} + {wrap_tuple_result(result, expr), update_token(cache, token)} + end end defp cached_recur_operator(:attach_token, %T{data: %Expr{args: [token, expr]}}, state, cache) do @@ -694,6 +713,15 @@ defmodule EXLA.Defn do {op, cache} end + defp cached_recur_operator( + :token, + %T{data: %Expr{args: [_token]}}, + %{builder: %Function{compiler: :iree}} = state, + cache + ) do + {[], cache} + end + defp cached_recur_operator(:token, %T{data: %Expr{args: [token]}}, state, cache) do builder = state.builder @@ -1633,6 +1661,9 @@ defmodule EXLA.Defn do {region, merge_outfeed(cache, comp_cache)} end + defp token_computation( + name, + args, expr, %{builder: %Function{compiler: compiler}} = state, cache @@ -1676,7 +1707,7 @@ defmodule EXLA.Defn do scope_ids: Tree.scope_ids(expr) } - if compiler_mode == :iree do + if compiler == :iree do {res, comp_cache} = recur_composite(expr, state, cache) Value.return(function, List.flatten(res)) {function, merge_outfeed(cache, comp_cache)} diff --git a/exla/lib/exla/mlir/function.ex b/exla/lib/exla/mlir/function.ex index bfd897c2c3..83e01f9bf3 100644 --- a/exla/lib/exla/mlir/function.ex +++ b/exla/lib/exla/mlir/function.ex @@ -2,7 +2,7 @@ defmodule EXLA.MLIR.Function do @moduledoc false # Representation of an MLIR Function or `func.func` type. - defstruct [:module, :ref, :name, :return_typespecs, :compiler_mode] + defstruct [:module, :ref, :name, :return_typespecs, :compiler] alias __MODULE__, as: Function alias EXLA.MLIR.Value From 5f1e2579271668abdc083d99dec2c30903fe282d Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 8 May 2024 00:50:41 -0300 Subject: [PATCH 04/40] feat: export module and pass it to the runtime --- exla/c_src/exla/iree/compiler.cc | 32 ++++++++++++++++++++-------- exla/c_src/exla/iree/runtime.cc | 2 +- exla/lib/exla/defn.ex | 24 ++++++++++++++++----- exla/lib/exla/executable.ex | 36 ++++++++++++++++++++++++++++---- 4 files changed, 75 insertions(+), 19 deletions(-) diff --git a/exla/c_src/exla/iree/compiler.cc b/exla/c_src/exla/iree/compiler.cc index 564358d94a..a82cee712d 100644 --- a/exla/c_src/exla/iree/compiler.cc +++ b/exla/c_src/exla/iree/compiler.cc @@ -10,6 +10,9 @@ #include // For mode constants #include // For open, close +#include +#include + #include "../exla_mlir.h" typedef struct compiler_state_t { @@ -105,25 +108,36 @@ ERL_NIF_TERM compile(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { } fflush(stdout); - auto fd = open("/tmp/iree_output.metal", O_WRONLY | O_CREAT | O_TRUNC, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH); - error = ireeCompilerOutputOpenFD(fd, &state.output); + + error = ireeCompilerOutputOpenMembuffer(&state.output); if (error) { handle_compiler_error(error); cleanup_compiler_state(state); - return exla::nif::error(env, "Error opening output file descriptor"); + return exla::nif::error(env, "Error opening output membuffer"); } - // Print IR to the output stream. - // When compiling to the 'end' phase, a compiler tool would typically use - // either |ireeCompilerInvocationOutputVMBytecode| or - // |ireeCompilerInvocationOutputVMCSource|. error = ireeCompilerInvocationOutputVMBytecode(state.invocation, state.output); if (error) { handle_compiler_error(error); cleanup_compiler_state(state); - return 1; + return exla::nif::error(env, "Failed to output VM Bytecode"); + } + + uint64_t size; + + ErlNifBinary binary; + + error = ireeCompilerOutputMapMemory(state.output, reinterpret_cast(&binary.data), &size); + + if (error) { + handle_compiler_error(error); + cleanup_compiler_state(state); + return exla::nif::error(env, "Failed to map output to output binary"); } + enif_alloc_binary(size, &binary); + cleanup_compiler_state(state); - return exla::nif::ok(env); + + return exla::nif::ok(env, enif_make_binary(env, &binary)); } \ No newline at end of file diff --git a/exla/c_src/exla/iree/runtime.cc b/exla/c_src/exla/iree/runtime.cc index f2c0b0f27c..fefc156791 100644 --- a/exla/c_src/exla/iree/runtime.cc +++ b/exla/c_src/exla/iree/runtime.cc @@ -5,5 +5,5 @@ ERL_NIF_TERM run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { return enif_make_badarg(env); } - return enif_make_atom(env, "ok"); + return exla::nif::error(env, "runtime not implemented yet"); } \ No newline at end of file diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index d30790ff7d..5d252f2989 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -482,8 +482,17 @@ defmodule EXLA.Defn do for {i, typespec} <- inputs_and_typespecs, i >= used_buffers, do: typespec if compiler_mode == :iree do - :ok = EXLA.MLIR.IREE.compile(builder.module.ref, "metal") - raise "compiler not returning computation yet" + {:ok, module_bytecode} = EXLA.MLIR.IREE.compile(builder.module.ref, "metal") + + %EXLA.Executable{ + client: client, + ref: module_bytecode, + output_typespecs: builder.return_typespecs, + num_replicas: 1, + num_partitions: 1, + device_id: -1, + runtime: :iree + } else EXLA.MLIR.Module.compile( builder.module, @@ -495,7 +504,12 @@ defmodule EXLA.Defn do end end) - {:ok, {xla_time, executable, extra, %{outfeed | infeeds: []}}} + outfeed = + if outfeed do + %{outfeed | infeeds: []} + end + + {:ok, {xla_time, executable, extra, outfeed}} end) end) end) @@ -523,7 +537,7 @@ defmodule EXLA.Defn do :telemetry.execute([:exla, :compilation], measurements, %{key: key}) end - outfeed = Outfeed.with_user_hooks(outfeed, hooks) + outfeed = if outfeed, do: Outfeed.with_user_hooks(outfeed, hooks) {executable, used_inputs, outputs, outfeed, extra, debug?} end @@ -716,7 +730,7 @@ defmodule EXLA.Defn do defp cached_recur_operator( :token, %T{data: %Expr{args: [_token]}}, - %{builder: %Function{compiler: :iree}} = state, + %{builder: %Function{compiler: :iree}}, cache ) do {[], cache} diff --git a/exla/lib/exla/executable.ex b/exla/lib/exla/executable.ex index a6a0c8cbdf..3c79d6c1c2 100644 --- a/exla/lib/exla/executable.ex +++ b/exla/lib/exla/executable.ex @@ -7,17 +7,31 @@ defmodule EXLA.Executable do alias EXLA.{BinaryBuffer, DeviceBuffer} @enforce_keys [:client, :ref, :output_typespecs, :num_replicas, :num_partitions, :device_id] - defstruct [:client, :ref, :output_typespecs, :num_replicas, :num_partitions, :device_id] + defstruct [ + :client, + :ref, + :output_typespecs, + :num_replicas, + :num_partitions, + :device_id, + runtime: :xla + ] @doc """ Runs the given executable with a list of lists as inputs and the given options. """ def run(%Executable{} = executable, [subinputs | _] = inputs, options \\ []) when is_list(subinputs) do - %{client: client, device_id: device_id, output_typespecs: output_typespecs, ref: ref} = + %{ + runtime: runtime, + client: client, + device_id: device_id, + output_typespecs: output_typespecs, + ref: ref + } = executable - for data_and_device_id <- run(client, ref, device_id, inputs, options) do + for data_and_device_id <- run(runtime, client, ref, device_id, inputs, options) do decompose_output(data_and_device_id, output_typespecs, client) end end @@ -63,7 +77,7 @@ defmodule EXLA.Executable do end end - defp run(client, ref, device_id, inputs, _options) do + defp run(:xla, client, ref, device_id, inputs, _options) do inputs = for subinputs <- inputs do Enum.map(subinputs, fn @@ -84,6 +98,20 @@ defmodule EXLA.Executable do unwrap!(data) end + defp run(:iree, _client, ref, _device_id, inputs, _options) do + inputs = + for subinputs <- inputs do + Enum.map(subinputs, fn + %BinaryBuffer{data: data, typespec: typespec} -> + {data, EXLA.Typespec.nif_encode(typespec)} + end) + end + + ref + |> EXLA.MLIR.IREE.run_module(List.flatten(inputs)) + |> unwrap!() + end + defp decompose_output({data, device_id}, output_typespecs, client) do Enum.zip_with(data, output_typespecs, fn buf, typespec when is_reference(buf) -> From 0c36ef51ae09643466a36e7a13e8f2e81bcb9ccc Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 8 May 2024 01:41:23 -0300 Subject: [PATCH 05/40] feat: lay scaffolding out for runtime --- exla/Makefile | 3 +- exla/c_src/exla/exla_nif_util.cc | 16 +++ exla/c_src/exla/exla_nif_util.h | 3 +- exla/c_src/exla/iree/runtime.cc | 182 ++++++++++++++++++++++++++++++- 4 files changed, 200 insertions(+), 4 deletions(-) diff --git a/exla/Makefile b/exla/Makefile index 469e7c5bdc..ff83395d41 100644 --- a/exla/Makefile +++ b/exla/Makefile @@ -10,6 +10,7 @@ XLA_INCLUDE_PATH = $(XLA_EXTENSION_DIR)/include IREE_COMPILER_DIR = iree/build/lib IREE_COMPILER_LIB = cache/$(IREE_COMPILER_DIR) IREE_COMPILER_INCLUDE_PATH = cache/iree/compiler/bindings/c +IREE_RUNTIME_INCLUDE_PATH = cache/iree/runtime/src LLVM_MLIR_INCLUDES = -Icache/iree/third_party/llvm-project/mlir/include @@ -40,7 +41,7 @@ CFLAGS = -fPIC -I$(ERTS_INCLUDE_DIR) -I$(XLA_INCLUDE_PATH) -Wall -Wno-sign-compa -Wno-unused-parameter -Wno-missing-field-initializers -Wno-comment \ -std=c++17 -w -DLLVM_VERSION_STRING= -IREE_CFLAGS = $(CFLAGS) -I$(IREE_COMPILER_INCLUDE_PATH) +IREE_CFLAGS = $(CFLAGS) -I$(IREE_COMPILER_INCLUDE_PATH) -I$(IREE_RUNTIME_INCLUDE_PATH) NVCCFLAGS = -shared -Xcompiler -fPIC diff --git a/exla/c_src/exla/exla_nif_util.cc b/exla/c_src/exla/exla_nif_util.cc index d38785f6ed..4f7092a54c 100644 --- a/exla/c_src/exla/exla_nif_util.cc +++ b/exla/c_src/exla/exla_nif_util.cc @@ -197,6 +197,22 @@ int get_tuple(ErlNifEnv* env, ERL_NIF_TERM tuple, std::vector& var) { return 1; } +int get_list(ErlNifEnv* env, + ERL_NIF_TERM list, + std::vector& var) { + unsigned int length; + if (!enif_get_list_length(env, list, &length)) return 0; + var.reserve(length); + ERL_NIF_TERM head, tail; + + while (enif_get_list_cell(env, list, &head, &tail)) { + var.push_back(head); + list = tail; + } + + return 1; +} + int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var) { diff --git a/exla/c_src/exla/exla_nif_util.h b/exla/c_src/exla/exla_nif_util.h index 5abf7e3cda..615ddd7b5d 100644 --- a/exla/c_src/exla/exla_nif_util.h +++ b/exla/c_src/exla/exla_nif_util.h @@ -9,10 +9,10 @@ #include #include "erl_nif.h" +#include "mlir/IR/Builders.h" #include "xla/shape.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "mlir/IR/Builders.h" #if !defined(__GNUC__) && (defined(__WIN32__) || defined(_WIN32) || defined(_WIN32_)) typedef unsigned __int64 nif_uint64_t; @@ -246,6 +246,7 @@ int get_list(ErlNifEnv* env, int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var); int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var); +int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var); template int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var) { diff --git a/exla/c_src/exla/iree/runtime.cc b/exla/c_src/exla/iree/runtime.cc index fefc156791..b3a3daca9a 100644 --- a/exla/c_src/exla/iree/runtime.cc +++ b/exla/c_src/exla/iree/runtime.cc @@ -1,9 +1,187 @@ #include "runtime.h" -ERL_NIF_TERM run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { +#include + +typedef struct { + void *data; + size_t size; + std::vector dims; + xla::PrimitiveType type; +} IREEInput; + +int load_inputs(ErlNifEnv *env, std::vector terms, std::vector &loaded) { + const ERL_NIF_TERM *tuple, *typespec; + int length; + ErlNifBinary bin; + xla::PrimitiveType type; + std::vector dims; + IREEInput item; + + for (ERL_NIF_TERM term : terms) { + if (!enif_get_tuple(env, term, &length, &tuple)) { + return 0; + } + + if (!enif_inspect_binary(env, tuple[0], &bin)) { + return 0; + } + + item.data = bin.data; + item.size = bin.size; + + if (!enif_get_tuple(env, tuple[1], &length, &typespec)) { + return 0; + } + + if (!exla::nif::get_primitive_type(env, typespec[0], &item.type)) { + return 0; + } + + if (!exla::nif::get_tuple(env, typespec[1], item.dims)) { + return 0; + } + + loaded.push_back(item); + } + + return 1; +} + +iree_status_t call_module(iree_runtime_session_t *session, std::vector inputs) { + iree_runtime_call_t call; + + IREE_RETURN_IF_ERROR(iree_runtime_call_initialize_by_name( + session, iree_make_cstring_view("module.main"), &call)); + + // Append the function inputs with the HAL device allocator in use by the + // session. The buffers will be usable within the session and _may_ be usable + // in other sessions depending on whether they share a compatible device. + iree_hal_device_t *device = iree_runtime_session_device(session); + iree_hal_allocator_t *device_allocator = + iree_runtime_session_device_allocator(session); + iree_allocator_t host_allocator = + iree_runtime_session_host_allocator(session); + + iree_status_t status = iree_ok_status(); + if (iree_status_is_ok(status)) { + // TO-DO: make this process inputs vector + // // %lhs: tensor<4xf32> + // iree_hal_buffer_view_t *lhs = NULL; + // if (iree_status_is_ok(status)) { + // static const iree_hal_dim_t lhs_shape[1] = {4}; + // static const float lhs_data[4] = {1.0f, 1.1f, 1.2f, 1.3f}; + // status = iree_hal_buffer_view_allocate_buffer_copy( + // device, device_allocator, + // // Shape rank and dimensions: + // IREE_ARRAYSIZE(lhs_shape), lhs_shape, + // // Element type: + // IREE_HAL_ELEMENT_TYPE_FLOAT_32, + // // Encoding type: + // IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + // (iree_hal_buffer_params_t){ + // // Where to allocate (host or device): + // .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, + // // Access to allow to this memory: + // .access = IREE_HAL_MEMORY_ACCESS_ALL, + // // Intended usage of the buffer (transfers, dispatches, etc): + // .usage = IREE_HAL_BUFFER_USAGE_DEFAULT, + // }, + // // The actual heap buffer to wrap or clone and its allocator: + // iree_make_const_byte_span(lhs_data, sizeof(lhs_data)), + // // Buffer view + storage are returned and owned by the caller: + // &lhs); + // if (iree_status_is_ok(status)) { + // IREE_IGNORE_ERROR(iree_hal_buffer_view_fprint( + // stdout, lhs, /*max_element_count=*/4096, host_allocator)); + // // Add to the call inputs list (which retains the buffer view). + // status = iree_runtime_call_inputs_push_back_buffer_view(&call, lhs); + // } + // // Since the call retains the buffer view we can release it here. + // iree_hal_buffer_view_release(lhs); + } + + return 0; +} + +ERL_NIF_TERM +run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { if (argc != 2) { return enif_make_badarg(env); } - return exla::nif::error(env, "runtime not implemented yet"); + ErlNifBinary bytecode_binary; + std::vector input_terms; + std::vector inputs; + + if (!enif_inspect_binary(env, argv[0], &bytecode_binary)) { + return exla::nif::error(env, "Unable to load bytecode binary"); + } + + if (!exla::nif::get_list(env, argv[1], input_terms)) { + return exla::nif::error(env, "Unable to load input terms"); + } + + if (!load_inputs(env, input_terms, inputs)) { + return exla::nif::error(env, "Unable to decode input terms"); + } + + iree_runtime_instance_options_t instance_options; + iree_runtime_instance_options_initialize(&instance_options); + iree_runtime_instance_options_use_all_available_drivers(&instance_options); + iree_runtime_instance_t *instance = NULL; + iree_status_t status = iree_runtime_instance_create(&instance_options, iree_allocator_system(), &instance); + + iree_hal_device_t *device = NULL; + char *device_uri = "metal"; // TO-DO: change this to an argument + if (iree_status_is_ok(status)) { + status = iree_hal_create_device( + iree_runtime_instance_driver_registry(instance), + iree_make_cstring_view(device_uri), + iree_runtime_instance_host_allocator(instance), &device); + } + + iree_runtime_session_t *session = NULL; + if (iree_status_is_ok(status)) { + iree_runtime_session_options_t session_options; + iree_runtime_session_options_initialize(&session_options); + status = iree_runtime_session_create_with_device( + instance, &session_options, device, + iree_runtime_instance_host_allocator(instance), &session); + } + + if (iree_status_is_ok(status)) { + status = iree_runtime_session_append_bytecode_module_from_memory(session, reinterpret_cast(bytecode_binary.data), iree_runtime_instance_host_allocator(instance)); + } + + if (iree_status_is_ok(status)) { + // this is where we actually call code + // status = iree_runtime_demo_perform_mul(session); + status = call_module(session, inputs) + } + + // Release the session and free all cached resources. + iree_runtime_session_release(session); + + // Release shared device once all sessions using it have been released. + iree_hal_device_release(device); + + // Release the shared instance - it will be deallocated when all sessions + // using it have been released (here it is deallocated immediately). + iree_runtime_instance_release(instance); + + int ret = (int)iree_status_code(status); + if (!iree_status_is_ok(status)) { + // Dump nice status messages to stderr on failure. + // An application can route these through its own logging infrastructure as + // needed. Note that the status is a handle and must be freed! + iree_status_fprint(stderr, status); + iree_status_ignore(status); + } + + if (!ret) { + exla::nif::error(env, "Fail to execute IREE runtime"); + } + + // TO-DO: we want to get output values too + return exla::nif::ok(env); } \ No newline at end of file From d261ff6cac761006fff013767d3b70e64b4ec0d0 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 8 May 2024 23:24:45 -0300 Subject: [PATCH 06/40] feat: jit workflow mostly working (sans outputs) --- exla/Makefile | 36 +++-- exla/c_src/exla/exla.cc | 1 + exla/c_src/exla/exla_mlir_nif_util.cc | 65 ++++++++ exla/c_src/exla/exla_mlir_nif_util.h | 10 ++ exla/c_src/exla/exla_nif_util.cc | 71 ++------- exla/c_src/exla/exla_nif_util.h | 6 +- exla/c_src/exla/iree/compiler.cc | 25 +-- exla/c_src/exla/iree/iree.cc | 1 + exla/c_src/exla/iree/runtime.cc | 206 +++++++++++++++++-------- exla/c_src/iree_runtime/CMakeLists.txt | 102 ++++++++++++ exla/lib/exla/defn.ex | 6 +- exla/lib/exla/defn/buffers.ex | 5 + exla/lib/exla/executable.ex | 4 +- 13 files changed, 390 insertions(+), 148 deletions(-) create mode 100644 exla/c_src/exla/exla_mlir_nif_util.cc create mode 100644 exla/c_src/exla/exla_mlir_nif_util.h create mode 100644 exla/c_src/iree_runtime/CMakeLists.txt diff --git a/exla/Makefile b/exla/Makefile index ff83395d41..c8029070f9 100644 --- a/exla/Makefile +++ b/exla/Makefile @@ -12,7 +12,7 @@ IREE_COMPILER_LIB = cache/$(IREE_COMPILER_DIR) IREE_COMPILER_INCLUDE_PATH = cache/iree/compiler/bindings/c IREE_RUNTIME_INCLUDE_PATH = cache/iree/runtime/src -LLVM_MLIR_INCLUDES = -Icache/iree/third_party/llvm-project/mlir/include +IREE_INSTALL_PREFIX = $(abspath cache/iree/build/) # Cache configuration EXLA_CACHE_SO = cache/libexla.so @@ -81,8 +81,8 @@ $(EXLA_SO): $(EXLA_CACHE_SO) $(EXLA_CACHE_IREE_COMPILER_SO) ln -sf $(EXLA_CACHE_IREE_COMPILER_SO_LINK_PATH) $(EXLA_IREE_COMPILER_SO) ; \ fi -SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/custom_calls.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_nif_util.cc -HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h +SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/custom_calls.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_nif_util.cc $(EXLA_DIR)/exla_mlir_nif_util.cc +HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h $(EXLA_DIR)/exla_mlir_nif_util.h OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(SOURCES)) $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o IREE_SOURCES = $(EXLA_DIR)/iree/iree.cc $(EXLA_DIR)/iree/compiler.cc $(EXLA_DIR)/iree/runtime.cc @@ -100,10 +100,6 @@ else NVCCFLAGS = $(CFLAGS) endif -$(EXLA_CACHE_OBJ_DIR)/iree/%.o: $(EXLA_DIR)/iree/%.cc $(IREE_HEADERS) - @ mkdir -p $(EXLA_CACHE_OBJ_DIR)/iree - $(CXX) $(IREE_CFLAGS) $(LLVM_MLIR_INCLUDES) -c $< -o $@ - $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o: $(EXLA_DIR)/exla_cuda.cc $(EXLA_DIR)/exla_cuda.h @ mkdir -p $(EXLA_CACHE_OBJ_DIR) $(NVCC) $(NVCCFLAGS) -c $< -o $@ @@ -115,8 +111,30 @@ $(EXLA_CACHE_OBJ_DIR)/%.o: $(EXLA_DIR)/%.cc $(HEADERS) $(EXLA_CACHE_SO): $(XLA_EXTENSION_DIR) $(IREE_COMPILER_LIB) $(OBJECTS) $(CXX) $(OBJECTS) -o $(EXLA_CACHE_SO) $(LDFLAGS) -$(EXLA_CACHE_IREE_COMPILER_SO): $(EXLA_CACHE_OBJ_DIR)/iree/iree.o $(EXLA_CACHE_OBJ_DIR)/iree/compiler.o $(EXLA_CACHE_OBJ_DIR)/iree/runtime.o $(EXLA_CACHE_OBJ_DIR)/exla_nif_util.o $(EXLA_CACHE_OBJ_DIR)/exla_mlir.o - $(CXX) $^ -o $@ $(IREE_LDFLAGS) + +IREE_CMAKE_BUILD_DIR = cache/objs/exla_iree_cmake + +ifdef DEBUG + IREE_CMAKE_CONFIG = RelWithDebInfo +else + IREE_CMAKE_CONFIG = Release +endif + +$(EXLA_CACHE_IREE_COMPILER_SO): + @mkdir -p $(IREE_CMAKE_BUILD_DIR) + @mkdir -p cache/objs/iree_cmake_out + @mkdir -p cache/objs/mlir_cmake_out + @mkdir -p cache/objs/llvm_cmake_out + cmake -S c_src/iree_runtime -B $(IREE_CMAKE_BUILD_DIR) \ + -DIREE_COMPILER_INCLUDE_PATH=$(IREE_COMPILER_INCLUDE_PATH) \ + -DIREE_COMPILER_DIR=$(IREE_COMPILER_DIR) \ + -DXLA_INCLUDE_PATH=$(abspath $(XLA_INCLUDE_PATH)) \ + -DIREE_INSTALL_PREFIX=$(IREE_INSTALL_PREFIX) \ + -DCACHE_DIR=cache\ + -DXLA_EXTENSION_LIB=$(abspath $(XLA_EXTENSION_LIB))\ + -DCMAKE_BUILD_TYPE=$(IREE_CMAKE_CONFIG) + cmake --build $(IREE_CMAKE_BUILD_DIR) --config $(IREE_CMAKE_CONFIG) --verbose + cmake --install $(IREE_CMAKE_BUILD_DIR) --config $(IREE_CMAKE_CONFIG) --prefix cache $(IREE_COMPILER_LIB): # TO-DO: setup proper download and caching of the iree compiler diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index 1fb8d30dca..293bc5b45c 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -4,6 +4,7 @@ #include "exla_cuda.h" #include "exla_log_sink.h" #include "exla_mlir.h" +#include "exla_mlir_nif_util.h" #include "exla_nif_util.h" #include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Func/IR/FuncOps.h" diff --git a/exla/c_src/exla/exla_mlir_nif_util.cc b/exla/c_src/exla/exla_mlir_nif_util.cc new file mode 100644 index 0000000000..417575673f --- /dev/null +++ b/exla/c_src/exla/exla_mlir_nif_util.cc @@ -0,0 +1,65 @@ +#include "exla_mlir_nif_util.h" + +#include "mlir/IR/Builders.h" +#include "stablehlo/dialect/StablehloOps.h" + +namespace exla { +namespace nif { + +std::string mlir_numeric_type_to_string(mlir::Type type) { + if (type.isSignlessInteger(1)) { + return "pred"; + } + if (auto integer_type = type.dyn_cast()) { + if (integer_type.isUnsigned()) { + return "u" + std::to_string(integer_type.getWidth()); + } else { + return "s" + std::to_string(integer_type.getWidth()); + } + } + if (type.isBF16()) { + return "bf16"; + } + if (auto float_type = type.dyn_cast()) { + return "f" + std::to_string(float_type.getWidth()); + } + if (auto complex_type = type.dyn_cast()) { + auto element_type = complex_type.getElementType(); + return "c" + std::to_string(element_type.cast().getWidth() * 2); + } + + std::cerr << "Unexpected mlir type" << std::endl; + exit(1); +} + +ERL_NIF_TERM make_typespec(ErlNifEnv* env, mlir::Type type) { + if (type.isa()) { + auto type_term = make(env, "token"); + auto shape_term = enif_make_tuple(env, 0); + + return enif_make_tuple(env, 2, type_term, shape_term); + } + + if (type.isa()) { + auto tensor_type = type.cast(); + auto dims = tensor_type.getShape(); + auto element_type = tensor_type.getElementType(); + + auto dims_array = std::vector{}; + dims_array.reserve(dims.size()); + + for (auto dim : dims) { + dims_array.push_back(enif_make_int(env, dim)); + } + + auto type_term = make(env, mlir_numeric_type_to_string(element_type)); + auto shape_term = enif_make_tuple_from_array(env, dims_array.data(), dims_array.size()); + + return enif_make_tuple(env, 2, type_term, shape_term); + } + + std::cerr << "Unexpected mlir type" << std::endl; + exit(1); +} +} // namespace nif +} // namespace exla \ No newline at end of file diff --git a/exla/c_src/exla/exla_mlir_nif_util.h b/exla/c_src/exla/exla_mlir_nif_util.h new file mode 100644 index 0000000000..7b630751a5 --- /dev/null +++ b/exla/c_src/exla/exla_mlir_nif_util.h @@ -0,0 +1,10 @@ +#pragma once +#include "exla_nif_util.h" +#include "mlir/IR/BuiltinTypes.h" + +namespace exla { +namespace nif { +// Extracts information from `GetShape` into a usable term. +ERL_NIF_TERM make_typespec(ErlNifEnv* env, mlir::Type type); +} // namespace nif +} // namespace exla \ No newline at end of file diff --git a/exla/c_src/exla/exla_nif_util.cc b/exla/c_src/exla/exla_nif_util.cc index 4f7092a54c..c898ddd464 100644 --- a/exla/c_src/exla/exla_nif_util.cc +++ b/exla/c_src/exla/exla_nif_util.cc @@ -112,6 +112,21 @@ ERL_NIF_TERM make(ErlNifEnv* env, int32 var) { return enif_make_int(env, var); } +ERL_NIF_TERM make_list(ErlNifEnv* env, std::vector result) { + size_t n = result.size(); + + std::vector nif_terms; + nif_terms.reserve(n); + + for (size_t i = 0; i < n; i++) { + nif_terms[i] = enif_make_binary(env, &result[i]); + } + + auto data = nif_terms.data(); + auto list = enif_make_list_from_array(env, &data[0], n); + return list; +} + // Standard types int get(ErlNifEnv* env, ERL_NIF_TERM term, std::string& var) { @@ -329,61 +344,5 @@ int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var) { return 1; } -std::string mlir_numeric_type_to_string(mlir::Type type) { - if (type.isSignlessInteger(1)) { - return "pred"; - } - if (auto integer_type = type.dyn_cast()) { - if (integer_type.isUnsigned()) { - return "u" + std::to_string(integer_type.getWidth()); - } else { - return "s" + std::to_string(integer_type.getWidth()); - } - } - if (type.isBF16()) { - return "bf16"; - } - if (auto float_type = type.dyn_cast()) { - return "f" + std::to_string(float_type.getWidth()); - } - if (auto complex_type = type.dyn_cast()) { - auto element_type = complex_type.getElementType(); - return "c" + std::to_string(element_type.cast().getWidth() * 2); - } - - std::cerr << "Unexpected mlir type" << std::endl; - exit(1); -} - -ERL_NIF_TERM make_typespec(ErlNifEnv* env, mlir::Type type) { - if (type.isa()) { - auto type_term = make(env, "token"); - auto shape_term = enif_make_tuple(env, 0); - - return enif_make_tuple(env, 2, type_term, shape_term); - } - - if (type.isa()) { - auto tensor_type = type.cast(); - auto dims = tensor_type.getShape(); - auto element_type = tensor_type.getElementType(); - - auto dims_array = std::vector{}; - dims_array.reserve(dims.size()); - - for (auto dim : dims) { - dims_array.push_back(enif_make_int(env, dim)); - } - - auto type_term = make(env, mlir_numeric_type_to_string(element_type)); - auto shape_term = enif_make_tuple_from_array(env, dims_array.data(), dims_array.size()); - - return enif_make_tuple(env, 2, type_term, shape_term); - } - - std::cerr << "Unexpected mlir type" << std::endl; - exit(1); -} - } // namespace nif } // namespace exla diff --git a/exla/c_src/exla/exla_nif_util.h b/exla/c_src/exla/exla_nif_util.h index 615ddd7b5d..481a87b26b 100644 --- a/exla/c_src/exla/exla_nif_util.h +++ b/exla/c_src/exla/exla_nif_util.h @@ -9,7 +9,6 @@ #include #include "erl_nif.h" -#include "mlir/IR/Builders.h" #include "xla/shape.h" #include "xla/types.h" #include "xla/xla_data.pb.h" @@ -189,6 +188,8 @@ ERL_NIF_TERM make(ErlNifEnv* env, T& var) { return ret; } +ERL_NIF_TERM make_list(ErlNifEnv* env, std::vector result); + template ERL_NIF_TERM make_list(ErlNifEnv* env, std::vector result) { size_t n = result.size(); @@ -333,9 +334,6 @@ T get_value(ErlNifEnv* env, ERL_NIF_TERM term) { return value; } -// Extracts information from `GetShape` into a usable term. -ERL_NIF_TERM make_typespec(ErlNifEnv* env, mlir::Type type); - } // namespace nif } // namespace exla diff --git a/exla/c_src/exla/iree/compiler.cc b/exla/c_src/exla/iree/compiler.cc index a82cee712d..dd34cec7fd 100644 --- a/exla/c_src/exla/iree/compiler.cc +++ b/exla/c_src/exla/iree/compiler.cc @@ -13,8 +13,6 @@ #include #include -#include "../exla_mlir.h" - typedef struct compiler_state_t { iree_compiler_session_t *session; iree_compiler_source_t *source; @@ -57,9 +55,9 @@ ERL_NIF_TERM compile(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { return exla::nif::error(env, "Bad argument count."); } - exla::MLIRModule **module; + std::string module_str; - if (!exla::nif::get(env, argv[0], module)) { + if (!exla::nif::get(env, argv[0], module_str)) { return exla::nif::error(env, "Unable to get module."); } @@ -72,10 +70,9 @@ ERL_NIF_TERM compile(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { initializeCompiler(&state); - std::string module_str = (*module)->ToString(); MlirOperation module_op = mlirOperationCreateParse( state.context, - mlirStringRefCreateFromCString(module_str.c_str()), + mlirStringRefCreate(module_str.c_str(), module_str.size()), mlirStringRefCreateFromCString("source.stablehlo")); if (mlirOperationIsNull(module_op)) { return exla::nif::error(env, "Unable to create MlirOperation module."); @@ -86,7 +83,8 @@ ERL_NIF_TERM compile(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { const char *flags[] = { "--iree-hal-target-backends=metal-spirv", "--iree-input-type=stablehlo_xla", - "--iree-execution-model=async-external"}; + "--iree-execution-model=async-internal", + "--output-format=vm-bytecode"}; err = ireeCompilerSessionSetFlags(state.session, 1, flags); if (err) { cleanup_compiler_state(state); @@ -123,11 +121,16 @@ ERL_NIF_TERM compile(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { return exla::nif::error(env, "Failed to output VM Bytecode"); } + uint8_t *contents; uint64_t size; - ErlNifBinary binary; + error = ireeCompilerOutputMapMemory(state.output, (void **)&contents, &size); - error = ireeCompilerOutputMapMemory(state.output, reinterpret_cast(&binary.data), &size); + std::vector bytes_term; + bytes_term.resize(size); + for (size_t i = 0; i < size; i++) { + bytes_term[i] = enif_make_uint(env, static_cast(contents[i])); + } if (error) { handle_compiler_error(error); @@ -135,9 +138,7 @@ ERL_NIF_TERM compile(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { return exla::nif::error(env, "Failed to map output to output binary"); } - enif_alloc_binary(size, &binary); - cleanup_compiler_state(state); - return exla::nif::ok(env, enif_make_binary(env, &binary)); + return exla::nif::ok(env, enif_make_list_from_array(env, bytes_term.data(), bytes_term.size())); } \ No newline at end of file diff --git a/exla/c_src/exla/iree/iree.cc b/exla/c_src/exla/iree/iree.cc index 8faf0f8d7f..7f7bfe7a06 100644 --- a/exla/c_src/exla/iree/iree.cc +++ b/exla/c_src/exla/iree/iree.cc @@ -7,6 +7,7 @@ #include "runtime.h" ERL_NIF_TERM global_initialize(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { + ireeCompilerLoadLibrary("libIREECompiler.dylib"); ireeCompilerGlobalInitialize(); return exla::nif::ok(env); } diff --git a/exla/c_src/exla/iree/runtime.cc b/exla/c_src/exla/iree/runtime.cc index b3a3daca9a..14d9905c8f 100644 --- a/exla/c_src/exla/iree/runtime.cc +++ b/exla/c_src/exla/iree/runtime.cc @@ -5,7 +5,7 @@ typedef struct { void *data; size_t size; - std::vector dims; + std::vector dims; xla::PrimitiveType type; } IREEInput; @@ -13,9 +13,10 @@ int load_inputs(ErlNifEnv *env, std::vector terms, std::vector dims; IREEInput item; + std::vector dims; + + loaded.reserve(terms.size()); for (ERL_NIF_TERM term : terms) { if (!enif_get_tuple(env, term, &length, &tuple)) { @@ -37,17 +38,88 @@ int load_inputs(ErlNifEnv *env, std::vector terms, std::vector inputs) { +std::pair primitive_type_to_iree_element_type(xla::PrimitiveType t) { + using xla::PrimitiveType; + + switch (t) { + case PrimitiveType::PRED: + return {IREE_HAL_ELEMENT_TYPE_BOOL_8, true}; + case PrimitiveType::S8: + return {IREE_HAL_ELEMENT_TYPE_INT_8, true}; + case PrimitiveType::S16: + return {IREE_HAL_ELEMENT_TYPE_INT_16, true}; + case PrimitiveType::S32: + return {IREE_HAL_ELEMENT_TYPE_INT_32, true}; + case PrimitiveType::S64: + return {IREE_HAL_ELEMENT_TYPE_INT_64, true}; + case PrimitiveType::U8: + return {IREE_HAL_ELEMENT_TYPE_UINT_8, true}; + case PrimitiveType::U16: + return {IREE_HAL_ELEMENT_TYPE_UINT_16, true}; + case PrimitiveType::U32: + return {IREE_HAL_ELEMENT_TYPE_UINT_32, true}; + case PrimitiveType::U64: + return {IREE_HAL_ELEMENT_TYPE_UINT_64, true}; + case PrimitiveType::BF16: + return {IREE_HAL_ELEMENT_TYPE_BFLOAT_16, true}; + case PrimitiveType::F16: + return {IREE_HAL_ELEMENT_TYPE_FLOAT_16, true}; + case PrimitiveType::F32: + return {IREE_HAL_ELEMENT_TYPE_FLOAT_32, true}; + case PrimitiveType::F64: + return {IREE_HAL_ELEMENT_TYPE_FLOAT_64, true}; + case PrimitiveType::C64: + return {IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64, true}; + case PrimitiveType::C128: + return {IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128, true}; + default: + return {IREE_HAL_ELEMENT_TYPE_BOOL_8, false}; + } +} + +iree_status_t iree_input_to_hal_arg(iree_hal_buffer_view_t **arg, IREEInput &input, iree_hal_device_t *device, iree_hal_allocator_t *device_allocator) { + auto result = primitive_type_to_iree_element_type(input.type); + if (!result.second) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT); + } + + auto type = result.first; + const iree_const_byte_span_t data_span = iree_make_const_byte_span(input.data, input.size); + + iree_hal_buffer_params_t buffer_params = { + .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, + .access = IREE_HAL_MEMORY_ACCESS_ALL, + .usage = IREE_HAL_BUFFER_USAGE_DEFAULT, + }; + + return iree_hal_buffer_view_allocate_buffer_copy( + device, + device_allocator, + input.dims.size(), + input.dims.data(), + type, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + buffer_params, + data_span, + arg); +} + +iree_status_t call_module(iree_runtime_session_t *session, std::vector inputs, std::vector *result) { iree_runtime_call_t call; IREE_RETURN_IF_ERROR(iree_runtime_call_initialize_by_name( @@ -59,48 +131,40 @@ iree_status_t call_module(iree_runtime_session_t *session, std::vector - // iree_hal_buffer_view_t *lhs = NULL; - // if (iree_status_is_ok(status)) { - // static const iree_hal_dim_t lhs_shape[1] = {4}; - // static const float lhs_data[4] = {1.0f, 1.1f, 1.2f, 1.3f}; - // status = iree_hal_buffer_view_allocate_buffer_copy( - // device, device_allocator, - // // Shape rank and dimensions: - // IREE_ARRAYSIZE(lhs_shape), lhs_shape, - // // Element type: - // IREE_HAL_ELEMENT_TYPE_FLOAT_32, - // // Encoding type: - // IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, - // (iree_hal_buffer_params_t){ - // // Where to allocate (host or device): - // .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, - // // Access to allow to this memory: - // .access = IREE_HAL_MEMORY_ACCESS_ALL, - // // Intended usage of the buffer (transfers, dispatches, etc): - // .usage = IREE_HAL_BUFFER_USAGE_DEFAULT, - // }, - // // The actual heap buffer to wrap or clone and its allocator: - // iree_make_const_byte_span(lhs_data, sizeof(lhs_data)), - // // Buffer view + storage are returned and owned by the caller: - // &lhs); - // if (iree_status_is_ok(status)) { - // IREE_IGNORE_ERROR(iree_hal_buffer_view_fprint( - // stdout, lhs, /*max_element_count=*/4096, host_allocator)); - // // Add to the call inputs list (which retains the buffer view). - // status = iree_runtime_call_inputs_push_back_buffer_view(&call, lhs); - // } - // // Since the call retains the buffer view we can release it here. - // iree_hal_buffer_view_release(lhs); - } - - return 0; + std::cout << "before call\n"; + IREE_RETURN_IF_ERROR(iree_runtime_call_invoke(&call, /*flags=*/0)); + std::cout << "after call\n"; + + iree_vm_list_t *outputs = iree_runtime_call_outputs(&call); + + std::cout << "size: " << iree_vm_list_size(outputs) << "\n"; + + ErlNifBinary binary; + size_t size = iree_vm_list_size(outputs); + + // for (iree_vm_size_t i = 0; i < size; i++) { + // iree_hal_buffer_view_t *out_buffer_view; + // iree_runtime_call_outputs_pop_front_buffer_view(&call, &out_buffer_view); + + // auto length = iree_hal_buffer_view_byte_length(out_buffer_view); + // iree_hal_buffer_t *buffer = iree_hal_buffer_view_buffer(out_buffer_view); + + // enif_alloc_binary(length, &binary); + + // IREE_RETURN_IF_ERROR(iree_hal_buffer_map_read( + // buffer, 0, &binary.data, length)); + + // result->push_back(binary); + // } + + return iree_make_status(IREE_STATUS_OK); } ERL_NIF_TERM @@ -109,14 +173,22 @@ run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { return enif_make_badarg(env); } - ErlNifBinary bytecode_binary; + std::vector bytecode_vec; std::vector input_terms; std::vector inputs; + std::vector bytecode; - if (!enif_inspect_binary(env, argv[0], &bytecode_binary)) { + if (!exla::nif::get_list(env, argv[0], bytecode_vec)) { return exla::nif::error(env, "Unable to load bytecode binary"); } + bytecode.resize(bytecode_vec.size()); + unsigned int byte; + for (int i = 0; i < bytecode_vec.size(); i++) { + enif_get_uint(env, bytecode_vec[i], &byte); + bytecode[i] = static_cast(byte); + } + if (!exla::nif::get_list(env, argv[1], input_terms)) { return exla::nif::error(env, "Unable to load input terms"); } @@ -132,7 +204,7 @@ run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { iree_status_t status = iree_runtime_instance_create(&instance_options, iree_allocator_system(), &instance); iree_hal_device_t *device = NULL; - char *device_uri = "metal"; // TO-DO: change this to an argument + char device_uri[] = "metal://0000000100000971"; // TO-DO: change this to an argument if (iree_status_is_ok(status)) { status = iree_hal_create_device( iree_runtime_instance_driver_registry(instance), @@ -149,39 +221,43 @@ run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { iree_runtime_instance_host_allocator(instance), &session); } + iree_const_byte_span_t span{.data = bytecode.data(), .data_length = bytecode.size()}; + if (iree_status_is_ok(status)) { - status = iree_runtime_session_append_bytecode_module_from_memory(session, reinterpret_cast(bytecode_binary.data), iree_runtime_instance_host_allocator(instance)); + status = iree_runtime_session_append_bytecode_module_from_memory(session, span, iree_runtime_instance_host_allocator(instance)); } + std::vector results; if (iree_status_is_ok(status)) { // this is where we actually call code // status = iree_runtime_demo_perform_mul(session); - status = call_module(session, inputs) + status = call_module(session, inputs, &results); } - // Release the session and free all cached resources. - iree_runtime_session_release(session); + if (session) { + // Release the session and free all cached resources. + iree_runtime_session_release(session); + } - // Release shared device once all sessions using it have been released. - iree_hal_device_release(device); + if (device) { + // Release shared device once all sessions using it have been released. + iree_hal_device_release(device); + } - // Release the shared instance - it will be deallocated when all sessions - // using it have been released (here it is deallocated immediately). - iree_runtime_instance_release(instance); + if (instance) { + // Release the shared instance - it will be deallocated when all sessions + // using it have been released (here it is deallocated immediately). + iree_runtime_instance_release(instance); + } - int ret = (int)iree_status_code(status); if (!iree_status_is_ok(status)) { // Dump nice status messages to stderr on failure. // An application can route these through its own logging infrastructure as // needed. Note that the status is a handle and must be freed! iree_status_fprint(stderr, status); iree_status_ignore(status); + return exla::nif::error(env, "Failed to execute IREE runtime"); } - if (!ret) { - exla::nif::error(env, "Fail to execute IREE runtime"); - } - - // TO-DO: we want to get output values too - return exla::nif::ok(env); + return exla::nif::ok(env, exla::nif::make_list(env, results)); } \ No newline at end of file diff --git a/exla/c_src/iree_runtime/CMakeLists.txt b/exla/c_src/iree_runtime/CMakeLists.txt new file mode 100644 index 0000000000..fe74466c6a --- /dev/null +++ b/exla/c_src/iree_runtime/CMakeLists.txt @@ -0,0 +1,102 @@ +cmake_minimum_required(VERSION 3.12 FATAL_ERROR) + +set(_NAME ireecompiler) + +project(${_NAME} VERSION 1.0 LANGUAGES CXX C) +set_property(GLOBAL PROPERTY USE_FOLDERS ON) +include(CheckCCompilerFlag) + +set(LLVM_DIR "${IREE_INSTALL_PREFIX}/llvm-project/lib/cmake/llvm") +set(MLIR_DIR "${IREE_INSTALL_PREFIX}/lib/cmake/mlir") +set(LLD_DIR "${IREE_INSTALL_PREFIX}/lib/cmake/lld") +set(Clang_DIR "${IREE_INSTALL_PREFIX}/lib/cmake/clang") + +set(LLVM_ABI_BREAKING_CHECKS FORCE_OFFrm ) + +set(IREE_BUILD_COMPILER ON) +set(IREE_INPUT_STABLEHLO ON) +set(IREE_BUILD_BUNDLED_LLVM OFF) +set(IREE_BUILD_TESTS OFF) +set(IREE_BUILD_SAMPLES OFF) + +set(IREE_HAL_DRIVER_DEFAULTS ON) +set(IREE_HAL_DRIVER_LOCAL_SYNC ON) +set(IREE_HAL_EXECUTABLE_LOADER_DEFAULTS OFF) +set(IREE_HAL_EXECUTABLE_LOADER_EMBEDDED_ELF ON) + +if(CMAKE_BUILD_TYPE MATCHES MinSizeRel) + set(IREE_SIZE_OPTIMIZED ON) +endif() + + +set(C_SRC "${CMAKE_CURRENT_SOURCE_DIR}/../exla/iree") +file(GLOB iree_runtime_sources CONFIGURE_DEPENDS "${C_SRC}/*.cc" "${C_SRC}/*.h" "${C_SRC}/../exla_nif_util.cc" "${C_SRC}/../exla_nif_util.h") + +add_library(${_NAME} SHARED ${iree_runtime_sources}) +set_property(TARGET ${_NAME} PROPERTY CXX_STANDARD 17) + +target_include_directories(${_NAME} PUBLIC $ENV{ERTS_INCLUDE_DIR}) +target_include_directories(${_NAME} PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/../../${XLA_INCLUDE_PATH}") +target_include_directories(${_NAME} SYSTEM + PRIVATE + "${CMAKE_CURRENT_SOURCE_DIR}/../../${IREE_COMPILER_INCLUDE_PATH}/iree/compiler" +) +target_include_directories(${_NAME} SYSTEM + PRIVATE + "${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree/third_party/llvm-project/mlir/include" +) + +add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree" "${CMAKE_CURRENT_SOURCE_DIR}/../../cache/objs/iree_cmake_out" EXCLUDE_FROM_ALL) + +install( + TARGETS ${_NAME} + DESTINATION "." +) + +set_target_properties(${_NAME} PROPERTIES SUFFIX ".so") + +set_target_properties(${_NAME} PROPERTIES + INSTALL_RPATH_USE_LINK_PATH TRUE + BUILD_WITH_INSTALL_RPATH TRUE +) + +if(NOT APPLE) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -shared") + set_target_properties(${_NAME} PROPERTIES INSTALL_RPATH "\$ORIGIN/${IREE_COMPILER_DIR}") +else() + # Although the compiler complains about not using these, + # things only work with them set + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -undefined dynamic_lookup") + check_c_compiler_flag("-arch arm64" ARM64_SUPPORTED) + if(ARM64_SUPPORTED) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMAC_ARM64") + endif() + # set(CMAKE_SHARED_LINKER_FLAGS "-bundle -flat_namespace -undefined suppress") + set_target_properties(${_NAME} PROPERTIES INSTALL_RPATH "@loader_path/${IREE_COMPILER_DIR}") +endif() + +target_compile_options(${_NAME} PRIVATE ${IREE_DEFAULT_COPTS}) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -Wall -Wextra -Wno-unused-function -Wno-sign-compare -Wno-comment ") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter -Wno-missing-field-initializers -DLLVM_VERSION_STRING= -std=c++17") + +if($ENV{DEBUG}) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g") +else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") +endif() + +add_definitions(-DLLVM_DISABLE_ABI_BREAKING_CHECKS_ENFORCING=1) + +set(XLA_EXTENSION_LIB_PATH ${XLA_EXTENSION_LIB}) +set(XLA_EXTENSION_INCLUDE_PATH ${XLA_INCLUDE_PATH}) +include_directories(${XLA_EXTENSION_INCLUDE_PATH}) +target_link_libraries(${_NAME} "${XLA_EXTENSION_LIB}/libxla_extension.so") + +target_link_libraries(${_NAME} iree_runtime_runtime) +target_link_libraries(${_NAME} iree_compiler_bindings_c_loader) + +if(NOT APPLE) + target_link_libraries(${_NAME} "${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree/build/lib/libIREECompiler.so") +else() + target_link_libraries(${_NAME} "${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree/build/lib/libIREECompiler.dylib") +endif() \ No newline at end of file diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 5d252f2989..a2322912a3 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -375,6 +375,7 @@ defmodule EXLA.Defn do EXLA.Executable.run(executable, [buffers], run_options) else [result] -> + dbg(result) [EXLA.Defn.Buffers.to_nx!(result, outputs)] after EXLA.Defn.Lock.unlock(lock) @@ -482,7 +483,10 @@ defmodule EXLA.Defn do for {i, typespec} <- inputs_and_typespecs, i >= used_buffers, do: typespec if compiler_mode == :iree do - {:ok, module_bytecode} = EXLA.MLIR.IREE.compile(builder.module.ref, "metal") + {:ok, module_charlist} = EXLA.NIF.mlir_module_to_string(builder.module.ref) + dbg(module_charlist) + {:ok, module_bytecode} = EXLA.MLIR.IREE.compile(module_charlist, "metal") + dbg(module_bytecode) %EXLA.Executable{ client: client, diff --git a/exla/lib/exla/defn/buffers.ex b/exla/lib/exla/defn/buffers.ex index d8bce0a4a5..e276b9da9c 100644 --- a/exla/lib/exla/defn/buffers.ex +++ b/exla/lib/exla/defn/buffers.ex @@ -107,6 +107,11 @@ defmodule EXLA.Defn.Buffers do %Nx.Tensor{data: data} = tensor = Nx.devectorize(fun.()) case data do + %EXLA.Backend{buffer: %EXLA.DeviceBuffer{ref: ref} = buffer} + when executable.runtime == :iree -> + binary = EXLA.DeviceBuffer.read(buffer) + EXLA.BinaryBuffer.from_binary(binary, to_typespec(tensor)) + %EXLA.Backend{buffer: %EXLA.DeviceBuffer{ref: ref} = buffer} when node(ref) != node() -> binary = :erpc.call(node(ref), EXLA.DeviceBuffer, :read, [buffer]) diff --git a/exla/lib/exla/executable.ex b/exla/lib/exla/executable.ex index 3c79d6c1c2..1068d190a3 100644 --- a/exla/lib/exla/executable.ex +++ b/exla/lib/exla/executable.ex @@ -98,7 +98,7 @@ defmodule EXLA.Executable do unwrap!(data) end - defp run(:iree, _client, ref, _device_id, inputs, _options) do + defp run(:iree, _client, ref, device_id, inputs, _options) do inputs = for subinputs <- inputs do Enum.map(subinputs, fn @@ -110,6 +110,8 @@ defmodule EXLA.Executable do ref |> EXLA.MLIR.IREE.run_module(List.flatten(inputs)) |> unwrap!() + |> Enum.map(&{&1, device_id}) + |> dbg() end defp decompose_output({data, device_id}, output_typespecs, client) do From 44a0e1874755cb8fd4edfce5669a2091016d6945 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 9 May 2024 14:00:41 -0300 Subject: [PATCH 07/40] wip --- exla/c_src/exla/iree/runtime.cc | 101 +++++++++++++++++++++++++++----- 1 file changed, 87 insertions(+), 14 deletions(-) diff --git a/exla/c_src/exla/iree/runtime.cc b/exla/c_src/exla/iree/runtime.cc index 14d9905c8f..ff3152cc47 100644 --- a/exla/c_src/exla/iree/runtime.cc +++ b/exla/c_src/exla/iree/runtime.cc @@ -1,5 +1,7 @@ #include "runtime.h" +#include +#include #include typedef struct { @@ -149,20 +151,91 @@ iree_status_t call_module(iree_runtime_session_t *session, std::vectorpush_back(binary); - // } + for (iree_vm_size_t i = 0; i < size; i++) { + iree_vm_ref_t ref; + iree_hal_buffer_view_t *buffer_view; + iree_hal_buffer_t *out_buffer; + iree_hal_buffer_t *host_buffer; + + IREE_RETURN_IF_ERROR(iree_vm_list_get_ref_assign(outputs, i, &ref)); + + IREE_RETURN_IF_ERROR(iree_hal_buffer_view_check_deref(ref, &buffer_view)); + + out_buffer = iree_hal_buffer_view_buffer(buffer_view); + + iree_hal_buffer_params_t buffer_params = { + .type = IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE, + .usage = IREE_HAL_BUFFER_USAGE_MAPPING_SCOPED | IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET}; + IREE_RETURN_IF_ERROR(iree_hal_allocator_allocate_buffer( + device_allocator, + buffer_params, + iree_hal_buffer_byte_length(out_buffer), + &host_buffer)); + + iree_hal_semaphore_t *semaphore; + IREE_RETURN_IF_ERROR(iree_hal_semaphore_create(device, 0, &semaphore)); + + // IREE_API_EXPORT iree_status_t iree_hal_device_queue_copy( + // iree_hal_device_t * device, iree_hal_queue_affinity_t queue_affinity, + // const iree_hal_semaphore_list_t wait_semaphore_list, + // const iree_hal_semaphore_list_t signal_semaphore_list, + // iree_hal_buffer_t *source_buffer, iree_device_size_t source_offset, + // iree_hal_buffer_t *target_buffer, iree_device_size_t target_offset, + // iree_device_size_t length); + + iree_hal_fence_t *wait_fence; + iree_hal_fence_t *signal_fence; + IREE_RETURN_IF_ERROR(iree_hal_fence_create_at(semaphore, 0, iree_allocator_system(), &wait_fence)); + IREE_RETURN_IF_ERROR(iree_hal_fence_create_at(semaphore, 1, iree_allocator_system(), &signal_fence)); + + iree_hal_semaphore_list_t wait_semaphore_list = iree_hal_fence_semaphore_list(wait_fence); + iree_hal_semaphore_list_t signal_semaphore_list = iree_hal_fence_semaphore_list(signal_fence); + + IREE_RETURN_IF_ERROR(iree_hal_device_queue_copy( + device, + IREE_HAL_QUEUE_AFFINITY_ANY, + wait_semaphore_list, + signal_semaphore_list, + out_buffer, + iree_hal_buffer_byte_offset(out_buffer), + host_buffer, + iree_hal_buffer_byte_offset(host_buffer), + iree_hal_buffer_byte_length(out_buffer))); + + iree_hal_fence_t *result_fence; + IREE_RETURN_IF_ERROR(iree_hal_fence_create_at(semaphore, 1, iree_allocator_system(), &result_fence)); + IREE_RETURN_IF_ERROR(iree_hal_fence_wait(result_fence, iree_infinite_timeout())); + + iree_hal_element_type_t raw_dtype = iree_hal_buffer_view_element_type(buffer_view); + + iree_hal_buffer_mapping_t mapped_memory; + size_t byte_size = iree_hal_buffer_byte_length(host_buffer); + IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( + host_buffer, + IREE_HAL_MAPPING_MODE_SCOPED, + IREE_HAL_MEMORY_ACCESS_READ, + 0, + byte_size, + &mapped_memory)); + + // iree_hal_buffer_map_range( + // iree_hal_buffer_t * buffer, iree_hal_mapping_mode_t mapping_mode, + // iree_hal_memory_access_t memory_access, iree_device_size_t byte_offset, + // iree_device_size_t byte_length, + // iree_hal_buffer_mapping_t * out_buffer_mapping) + + // HalElementType.map_to_dtype(self._buffer_view.element_type); + + // iree_hal_allocator_allocate_buffer( + // iree_hal_allocator_t * IREE_RESTRICT allocator, + // iree_hal_buffer_params_t params, iree_device_size_t allocation_size, + // iree_hal_buffer_t * *IREE_RESTRICT out_buffer) + + enif_alloc_binary(byte_size, &binary); + memcpy(binary.data, mapped_memory.contents.data, byte_size); + + result->push_back(binary); + } return iree_make_status(IREE_STATUS_OK); } From 32451133ae0d84f511622c02b279eb928e918d86 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 9 May 2024 14:27:30 -0300 Subject: [PATCH 08/40] feat: float32 working --- exla/c_src/exla/iree/runtime.cc | 92 +++++---------------------------- exla/lib/exla/executable.ex | 2 +- 2 files changed, 14 insertions(+), 80 deletions(-) diff --git a/exla/c_src/exla/iree/runtime.cc b/exla/c_src/exla/iree/runtime.cc index ff3152cc47..6ccc79b31c 100644 --- a/exla/c_src/exla/iree/runtime.cc +++ b/exla/c_src/exla/iree/runtime.cc @@ -152,87 +152,21 @@ iree_status_t call_module(iree_runtime_session_t *session, std::vectorpush_back(binary); } diff --git a/exla/lib/exla/executable.ex b/exla/lib/exla/executable.ex index 1068d190a3..0957286f7d 100644 --- a/exla/lib/exla/executable.ex +++ b/exla/lib/exla/executable.ex @@ -110,7 +110,7 @@ defmodule EXLA.Executable do ref |> EXLA.MLIR.IREE.run_module(List.flatten(inputs)) |> unwrap!() - |> Enum.map(&{&1, device_id}) + |> then(&[{&1, device_id}]) |> dbg() end From 0fca4375412a7a472cd49853eb573800fe5e5581 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Fri, 10 May 2024 14:35:36 -0300 Subject: [PATCH 09/40] feat: more support for other types --- exla/c_src/exla/iree/runtime.cc | 149 +++++++++++++++---------- exla/c_src/iree_runtime/CMakeLists.txt | 2 +- 2 files changed, 94 insertions(+), 57 deletions(-) diff --git a/exla/c_src/exla/iree/runtime.cc b/exla/c_src/exla/iree/runtime.cc index 6ccc79b31c..1fdf2eac16 100644 --- a/exla/c_src/exla/iree/runtime.cc +++ b/exla/c_src/exla/iree/runtime.cc @@ -8,15 +8,74 @@ typedef struct { void *data; size_t size; std::vector dims; - xla::PrimitiveType type; + iree_hal_element_type_t type; } IREEInput; +bool primitive_type_to_iree_element_type(xla::PrimitiveType t, iree_hal_element_type_t *type) { + using xla::PrimitiveType; + using type_enum = iree_hal_element_types_t; + + switch (t) { + case PrimitiveType::PRED: + *type = type_enum::IREE_HAL_ELEMENT_TYPE_BOOL_8; + return true; + case PrimitiveType::S8: + *type = type_enum::IREE_HAL_ELEMENT_TYPE_INT_8; + return true; + case PrimitiveType::S16: + *type = type_enum::IREE_HAL_ELEMENT_TYPE_INT_16; + return true; + case PrimitiveType::S32: + *type = type_enum::IREE_HAL_ELEMENT_TYPE_INT_32; + return true; + case PrimitiveType::S64: + // forced demotion from compiler + *type = type_enum::IREE_HAL_ELEMENT_TYPE_INT_32; + return true; + case PrimitiveType::U8: + *type = type_enum::IREE_HAL_ELEMENT_TYPE_UINT_8; + return true; + case PrimitiveType::U16: + *type = type_enum::IREE_HAL_ELEMENT_TYPE_UINT_16; + return true; + case PrimitiveType::U32: + *type = type_enum::IREE_HAL_ELEMENT_TYPE_UINT_32; + return true; + case PrimitiveType::U64: + // forced demotion from compiler + *type = type_enum::IREE_HAL_ELEMENT_TYPE_UINT_32; + return true; + case PrimitiveType::BF16: + *type = type_enum::IREE_HAL_ELEMENT_TYPE_BFLOAT_16; + return true; + case PrimitiveType::F16: + *type = type_enum::IREE_HAL_ELEMENT_TYPE_FLOAT_16; + return true; + case PrimitiveType::F32: + *type = type_enum::IREE_HAL_ELEMENT_TYPE_FLOAT_32; + return true; + case PrimitiveType::F64: + // forced demotion from compiler + *type = type_enum::IREE_HAL_ELEMENT_TYPE_FLOAT_32; + return true; + case PrimitiveType::C64: + *type = type_enum::IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64; + return true; + case PrimitiveType::C128: + *type = type_enum::IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128; + return true; + default: + return false; + } +} + int load_inputs(ErlNifEnv *env, std::vector terms, std::vector &loaded) { const ERL_NIF_TERM *tuple, *typespec; int length; ErlNifBinary bin; IREEInput item; std::vector dims; + xla::PrimitiveType primitive_type; loaded.reserve(terms.size()); @@ -36,7 +95,11 @@ int load_inputs(ErlNifEnv *env, std::vector terms, std::vector terms, std::vector primitive_type_to_iree_element_type(xla::PrimitiveType t) { - using xla::PrimitiveType; - - switch (t) { - case PrimitiveType::PRED: - return {IREE_HAL_ELEMENT_TYPE_BOOL_8, true}; - case PrimitiveType::S8: - return {IREE_HAL_ELEMENT_TYPE_INT_8, true}; - case PrimitiveType::S16: - return {IREE_HAL_ELEMENT_TYPE_INT_16, true}; - case PrimitiveType::S32: - return {IREE_HAL_ELEMENT_TYPE_INT_32, true}; - case PrimitiveType::S64: - return {IREE_HAL_ELEMENT_TYPE_INT_64, true}; - case PrimitiveType::U8: - return {IREE_HAL_ELEMENT_TYPE_UINT_8, true}; - case PrimitiveType::U16: - return {IREE_HAL_ELEMENT_TYPE_UINT_16, true}; - case PrimitiveType::U32: - return {IREE_HAL_ELEMENT_TYPE_UINT_32, true}; - case PrimitiveType::U64: - return {IREE_HAL_ELEMENT_TYPE_UINT_64, true}; - case PrimitiveType::BF16: - return {IREE_HAL_ELEMENT_TYPE_BFLOAT_16, true}; - case PrimitiveType::F16: - return {IREE_HAL_ELEMENT_TYPE_FLOAT_16, true}; - case PrimitiveType::F32: - return {IREE_HAL_ELEMENT_TYPE_FLOAT_32, true}; - case PrimitiveType::F64: - return {IREE_HAL_ELEMENT_TYPE_FLOAT_64, true}; - case PrimitiveType::C64: - return {IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64, true}; - case PrimitiveType::C128: - return {IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128, true}; - default: - return {IREE_HAL_ELEMENT_TYPE_BOOL_8, false}; - } -} - iree_status_t iree_input_to_hal_arg(iree_hal_buffer_view_t **arg, IREEInput &input, iree_hal_device_t *device, iree_hal_allocator_t *device_allocator) { - auto result = primitive_type_to_iree_element_type(input.type); - if (!result.second) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT); - } - - auto type = result.first; const iree_const_byte_span_t data_span = iree_make_const_byte_span(input.data, input.size); iree_hal_buffer_params_t buffer_params = { @@ -114,7 +132,7 @@ iree_status_t iree_input_to_hal_arg(iree_hal_buffer_view_t **arg, IREEInput &inp device_allocator, input.dims.size(), input.dims.data(), - type, + input.type, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, buffer_params, data_span, @@ -123,9 +141,20 @@ iree_status_t iree_input_to_hal_arg(iree_hal_buffer_view_t **arg, IREEInput &inp iree_status_t call_module(iree_runtime_session_t *session, std::vector inputs, std::vector *result) { iree_runtime_call_t call; + iree_vm_function_t function; - IREE_RETURN_IF_ERROR(iree_runtime_call_initialize_by_name( - session, iree_make_cstring_view("module.main"), &call)); + IREE_RETURN_IF_ERROR( + iree_runtime_session_lookup_function(session, iree_make_cstring_view("module.main"), &function)); + + IREE_RETURN_IF_ERROR(iree_runtime_call_initialize(session, function, &call)); + + iree_vm_function_signature_t signature = + iree_vm_function_signature(&function); + + iree_string_view_t arguments; + iree_string_view_t results; + IREE_RETURN_IF_ERROR(iree_vm_function_call_get_cconv_fragments( + &signature, &arguments, &results)); // Append the function inputs with the HAL device allocator in use by the // session. The buffers will be usable within the session and _may_ be usable @@ -133,8 +162,12 @@ iree_status_t call_module(iree_runtime_session_t *session, std::vector Date: Fri, 10 May 2024 20:44:55 -0300 Subject: [PATCH 10/40] feat: make all types work (with 64->32 demotion for ints and floats) --- exla/c_src/exla/exla_nif_util.cc | 15 ---- exla/c_src/exla/exla_nif_util.h | 2 - exla/c_src/exla/iree/compiler.cc | 4 +- exla/c_src/exla/iree/runtime.cc | 114 +++++++++++++++++++++++++------ exla/lib/exla/defn.ex | 3 - exla/lib/exla/defn/buffers.ex | 2 +- exla/lib/exla/executable.ex | 29 +++++++- 7 files changed, 125 insertions(+), 44 deletions(-) diff --git a/exla/c_src/exla/exla_nif_util.cc b/exla/c_src/exla/exla_nif_util.cc index c898ddd464..9a22dc106b 100644 --- a/exla/c_src/exla/exla_nif_util.cc +++ b/exla/c_src/exla/exla_nif_util.cc @@ -112,21 +112,6 @@ ERL_NIF_TERM make(ErlNifEnv* env, int32 var) { return enif_make_int(env, var); } -ERL_NIF_TERM make_list(ErlNifEnv* env, std::vector result) { - size_t n = result.size(); - - std::vector nif_terms; - nif_terms.reserve(n); - - for (size_t i = 0; i < n; i++) { - nif_terms[i] = enif_make_binary(env, &result[i]); - } - - auto data = nif_terms.data(); - auto list = enif_make_list_from_array(env, &data[0], n); - return list; -} - // Standard types int get(ErlNifEnv* env, ERL_NIF_TERM term, std::string& var) { diff --git a/exla/c_src/exla/exla_nif_util.h b/exla/c_src/exla/exla_nif_util.h index 481a87b26b..ba8be5cb80 100644 --- a/exla/c_src/exla/exla_nif_util.h +++ b/exla/c_src/exla/exla_nif_util.h @@ -188,8 +188,6 @@ ERL_NIF_TERM make(ErlNifEnv* env, T& var) { return ret; } -ERL_NIF_TERM make_list(ErlNifEnv* env, std::vector result); - template ERL_NIF_TERM make_list(ErlNifEnv* env, std::vector result) { size_t n = result.size(); diff --git a/exla/c_src/exla/iree/compiler.cc b/exla/c_src/exla/iree/compiler.cc index dd34cec7fd..4f275785a9 100644 --- a/exla/c_src/exla/iree/compiler.cc +++ b/exla/c_src/exla/iree/compiler.cc @@ -84,7 +84,9 @@ ERL_NIF_TERM compile(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { "--iree-hal-target-backends=metal-spirv", "--iree-input-type=stablehlo_xla", "--iree-execution-model=async-internal", - "--output-format=vm-bytecode"}; + "--output-format=vm-bytecode", + "--iree-opt-demote-f64-to-f32=true", + "--iree-opt-demote-i64-to-i32=true"}; err = ireeCompilerSessionSetFlags(state.session, 1, flags); if (err) { cleanup_compiler_state(state); diff --git a/exla/c_src/exla/iree/runtime.cc b/exla/c_src/exla/iree/runtime.cc index 1fdf2eac16..caa8c8094a 100644 --- a/exla/c_src/exla/iree/runtime.cc +++ b/exla/c_src/exla/iree/runtime.cc @@ -29,7 +29,6 @@ bool primitive_type_to_iree_element_type(xla::PrimitiveType t, iree_hal_element_ *type = type_enum::IREE_HAL_ELEMENT_TYPE_INT_32; return true; case PrimitiveType::S64: - // forced demotion from compiler *type = type_enum::IREE_HAL_ELEMENT_TYPE_INT_32; return true; case PrimitiveType::U8: @@ -42,7 +41,6 @@ bool primitive_type_to_iree_element_type(xla::PrimitiveType t, iree_hal_element_ *type = type_enum::IREE_HAL_ELEMENT_TYPE_UINT_32; return true; case PrimitiveType::U64: - // forced demotion from compiler *type = type_enum::IREE_HAL_ELEMENT_TYPE_UINT_32; return true; case PrimitiveType::BF16: @@ -55,8 +53,7 @@ bool primitive_type_to_iree_element_type(xla::PrimitiveType t, iree_hal_element_ *type = type_enum::IREE_HAL_ELEMENT_TYPE_FLOAT_32; return true; case PrimitiveType::F64: - // forced demotion from compiler - *type = type_enum::IREE_HAL_ELEMENT_TYPE_FLOAT_32; + *type = type_enum::IREE_HAL_ELEMENT_TYPE_FLOAT_64; return true; case PrimitiveType::C64: *type = type_enum::IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64; @@ -69,6 +66,60 @@ bool primitive_type_to_iree_element_type(xla::PrimitiveType t, iree_hal_element_ } } +bool iree_element_type_to_nx_type(iree_hal_element_type_t type, std::string &nx_type) { + using type_enum = iree_hal_element_types_t; + + switch (type) { + case type_enum::IREE_HAL_ELEMENT_TYPE_BOOL_8: + nx_type = "pred"; + return true; + case type_enum::IREE_HAL_ELEMENT_TYPE_INT_8: + nx_type = "s8"; + return true; + case type_enum::IREE_HAL_ELEMENT_TYPE_INT_16: + nx_type = "s16"; + return true; + case type_enum::IREE_HAL_ELEMENT_TYPE_INT_32: + nx_type = "s32"; + return true; + case type_enum::IREE_HAL_ELEMENT_TYPE_INT_64: + nx_type = "s64"; + return true; + case type_enum::IREE_HAL_ELEMENT_TYPE_UINT_8: + nx_type = "u8"; + return true; + case type_enum::IREE_HAL_ELEMENT_TYPE_UINT_16: + nx_type = "u16"; + return true; + case type_enum::IREE_HAL_ELEMENT_TYPE_UINT_32: + nx_type = "u32"; + return true; + case type_enum::IREE_HAL_ELEMENT_TYPE_UINT_64: + nx_type = "u64"; + return true; + case type_enum::IREE_HAL_ELEMENT_TYPE_BFLOAT_16: + nx_type = "bf16"; + return true; + case type_enum::IREE_HAL_ELEMENT_TYPE_FLOAT_16: + nx_type = "f16"; + return true; + case type_enum::IREE_HAL_ELEMENT_TYPE_FLOAT_32: + nx_type = "f32"; + return true; + case type_enum::IREE_HAL_ELEMENT_TYPE_FLOAT_64: + nx_type = "f32"; + return true; + case type_enum::IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64: + nx_type = "c64"; + return true; + case type_enum::IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128: + nx_type = "c64"; + return true; + default: + return false; + } +} + int load_inputs(ErlNifEnv *env, std::vector terms, std::vector &loaded) { const ERL_NIF_TERM *tuple, *typespec; int length; @@ -139,17 +190,21 @@ iree_status_t iree_input_to_hal_arg(iree_hal_buffer_view_t **arg, IREEInput &inp arg); } -iree_status_t call_module(iree_runtime_session_t *session, std::vector inputs, std::vector *result) { +iree_status_t call_module(iree_runtime_session_t *session, std::vector inputs, std::vector> *result) { iree_runtime_call_t call; iree_vm_function_t function; - IREE_RETURN_IF_ERROR( - iree_runtime_session_lookup_function(session, iree_make_cstring_view("module.main"), &function)); + IREE_RETURN_IF_ERROR(iree_runtime_session_lookup_function(session, iree_make_cstring_view("module.main"), &function)); IREE_RETURN_IF_ERROR(iree_runtime_call_initialize(session, function, &call)); - iree_vm_function_signature_t signature = - iree_vm_function_signature(&function); + iree_vm_function_t export_function; + iree_string_view_t export_function_name; + iree_vm_function_signature_t export_function_signature; + + IREE_RETURN_IF_ERROR(function.module->get_function(function.module->self, IREE_VM_FUNCTION_LINKAGE_EXPORT, function.ordinal, &export_function, &export_function_name, &export_function_signature)); + + iree_vm_function_signature_t signature = iree_vm_function_signature(&function); iree_string_view_t arguments; iree_string_view_t results; @@ -173,9 +228,7 @@ iree_status_t call_module(iree_runtime_session_t *session, std::vectorpush_back(binary); + result->push_back({element_type, binary}); } return iree_make_status(IREE_STATUS_OK); } +ERL_NIF_TERM return_results(ErlNifEnv *env, std::vector> results) { + size_t n = results.size(); + + std::vector nif_terms; + nif_terms.reserve(n); + + for (auto [iree_type, binary] : results) { + std::string nx_type; + if (!iree_element_type_to_nx_type(iree_type, nx_type)) { + return exla::nif::error(env, "Unable to convert IREE type to NX type"); + } + ERL_NIF_TERM type = exla::nif::make(env, nx_type); + ERL_NIF_TERM bin_term = enif_make_binary(env, &binary); + + nif_terms.push_back(enif_make_tuple2(env, type, bin_term)); + } + + auto data = nif_terms.data(); + auto list = enif_make_list_from_array(env, &data[0], n); + return exla::nif::ok(env, list); +} + ERL_NIF_TERM run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { if (argc != 2) { @@ -271,7 +345,7 @@ run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { status = iree_runtime_session_append_bytecode_module_from_memory(session, span, iree_runtime_instance_host_allocator(instance)); } - std::vector results; + std::vector> results; if (iree_status_is_ok(status)) { // this is where we actually call code // status = iree_runtime_demo_perform_mul(session); @@ -303,5 +377,5 @@ run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { return exla::nif::error(env, "Failed to execute IREE runtime"); } - return exla::nif::ok(env, exla::nif::make_list(env, results)); + return return_results(env, results); } \ No newline at end of file diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index a2322912a3..e18b453a14 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -375,7 +375,6 @@ defmodule EXLA.Defn do EXLA.Executable.run(executable, [buffers], run_options) else [result] -> - dbg(result) [EXLA.Defn.Buffers.to_nx!(result, outputs)] after EXLA.Defn.Lock.unlock(lock) @@ -484,9 +483,7 @@ defmodule EXLA.Defn do if compiler_mode == :iree do {:ok, module_charlist} = EXLA.NIF.mlir_module_to_string(builder.module.ref) - dbg(module_charlist) {:ok, module_bytecode} = EXLA.MLIR.IREE.compile(module_charlist, "metal") - dbg(module_bytecode) %EXLA.Executable{ client: client, diff --git a/exla/lib/exla/defn/buffers.ex b/exla/lib/exla/defn/buffers.ex index e276b9da9c..16f4b67284 100644 --- a/exla/lib/exla/defn/buffers.ex +++ b/exla/lib/exla/defn/buffers.ex @@ -107,7 +107,7 @@ defmodule EXLA.Defn.Buffers do %Nx.Tensor{data: data} = tensor = Nx.devectorize(fun.()) case data do - %EXLA.Backend{buffer: %EXLA.DeviceBuffer{ref: ref} = buffer} + %EXLA.Backend{buffer: %EXLA.DeviceBuffer{} = buffer} when executable.runtime == :iree -> binary = EXLA.DeviceBuffer.read(buffer) EXLA.BinaryBuffer.from_binary(binary, to_typespec(tensor)) diff --git a/exla/lib/exla/executable.ex b/exla/lib/exla/executable.ex index 0957286f7d..74d1d7572d 100644 --- a/exla/lib/exla/executable.ex +++ b/exla/lib/exla/executable.ex @@ -103,7 +103,20 @@ defmodule EXLA.Executable do for subinputs <- inputs do Enum.map(subinputs, fn %BinaryBuffer{data: data, typespec: typespec} -> - {data, EXLA.Typespec.nif_encode(typespec)} + if typespec.type in [f: 64, c: 128, s: 64, u: 64] do + {t, w} = typespec.type + w2 = div(w, 2) + target_type = {t, w2} + + data = + data |> Nx.from_binary(typespec.type) |> Nx.as_type(target_type) |> Nx.to_binary() + + data = <> + + {data, EXLA.Typespec.nif_encode(typespec)} + else + {data, EXLA.Typespec.nif_encode(typespec)} + end end) end @@ -111,11 +124,23 @@ defmodule EXLA.Executable do |> EXLA.MLIR.IREE.run_module(List.flatten(inputs)) |> unwrap!() |> then(&[{&1, device_id}]) - |> dbg() end defp decompose_output({data, device_id}, output_typespecs, client) do Enum.zip_with(data, output_typespecs, fn + {type, buf}, target_typespec when is_binary(buf) and is_list(type) -> + source_typespec = EXLA.Typespec.nif_decode({type, target_typespec.shape}) + + if source_typespec == target_typespec do + BinaryBuffer.from_binary(buf, target_typespec) + else + buf + |> Nx.from_binary(source_typespec.type) + |> Nx.as_type(target_typespec.type) + |> Nx.to_binary() + |> BinaryBuffer.from_binary(target_typespec) + end + buf, typespec when is_reference(buf) -> DeviceBuffer.from_ref(buf, client, device_id, typespec) From b223c59f2a80ad00e9edcabd97d67e3a88e702eb Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 11 May 2024 03:08:12 -0300 Subject: [PATCH 11/40] test: get tests to run without beam crashes --- exla/lib/exla/defn.ex | 26 +++++++++++++++++------ exla/test/exla/backend_test.exs | 2 ++ exla/test/exla/defn/api_test.exs | 4 ++++ exla/test/exla/defn/expr_test.exs | 23 ++++++++++++++++++++ exla/test/exla/defn/vectorize_test.exs | 4 +++- exla/test/exla/executable_test.exs | 1 + exla/test/exla/mlir/custom_call_test.exs | 1 + exla/test/exla/nx_linalg_doctest_test.exs | 3 +++ exla/test/exla/random_test.exs | 2 ++ exla/test/exla/serving_test.exs | 2 ++ exla/test/exla_test.exs | 1 + exla/test/test_helper.exs | 21 +++++++++++++++++- 12 files changed, 82 insertions(+), 8 deletions(-) diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index e18b453a14..b5b46cffdc 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -574,19 +574,31 @@ defmodule EXLA.Defn do cache ) do [initial_arg, _arg, pred, body] = args - initial_with_token = {get_token(cache), initial_arg} + + initial_with_token = + if state.builder.compiler == :iree do + initial_arg + else + [get_token(cache), initial_arg] + end {initial, cache} = recur_composite(initial_with_token, state, cache) {pred_computation, cache} = mlir_while_computation(pred, initial, {:pred, 8}, state, cache) {body_computation, cache} = mlir_while_computation(body, initial, :with_token, state, cache) - [token | results] = - Value.while(function, pred_computation, body_computation, List.flatten(initial)) + output = Value.while(function, pred_computation, body_computation, List.flatten(initial)) - result = wrap_tuple_result(results, initial_arg) + case state.builder.compiler do + :iree -> + result = wrap_tuple_result(output, initial_arg) + {result, cache} - {result, update_token(cache, token)} + _ -> + [token | results] = output + result = wrap_tuple_result(results, initial_arg) + {result, update_token(cache, token)} + end end defp cached_recur_operator(:cond, %T{data: %Expr{args: args}} = t, state, cache) do @@ -1706,11 +1718,13 @@ defmodule EXLA.Defn do function = EXLA.MLIR.Module.add_function(module, name, in_types, out_types) + function = %{function | compiler: compiler} + [arg_token | tail] = EXLA.MLIR.Function.get_arguments(function) params = if compiler == :iree do - Enum.with_index([arg_token | tail], fn param, i -> {i, param} end) + Enum.with_index(tail, fn param, i -> {i, param} end) else Enum.with_index(tail, fn param, i -> {i, param} end) end diff --git a/exla/test/exla/backend_test.exs b/exla/test/exla/backend_test.exs index a81c6d7857..54ea6bfac9 100644 --- a/exla/test/exla/backend_test.exs +++ b/exla/test/exla/backend_test.exs @@ -27,6 +27,8 @@ defmodule EXLA.BackendTest do @skip_mac_arm [] end + @moduletag :iree_hangup_error + doctest Nx, except: [:moduledoc] ++ @excluded_doctests ++ @skip_mac_arm diff --git a/exla/test/exla/defn/api_test.exs b/exla/test/exla/defn/api_test.exs index d8eb5a7c7c..62a3913039 100644 --- a/exla/test/exla/defn/api_test.exs +++ b/exla/test/exla/defn/api_test.exs @@ -98,6 +98,7 @@ defmodule EXLA.Defn.APITest do end describe "batch" do + @tag :iree_shape_mismatch_error test "when padded" do input = Nx.tensor([[1, 2, 3]], backend: EXLA.Backend) batch = [input] |> Nx.Batch.concatenate() |> Nx.Batch.pad(1) @@ -127,6 +128,7 @@ defmodule EXLA.Defn.APITest do end describe "stream" do + @describetag :token defn defn_sum(entry, acc), do: {acc, entry + acc} test "immediately done" do @@ -286,6 +288,7 @@ defmodule EXLA.Defn.APITest do end describe "hooks" do + @describetag :token require Logger defp send_to_self(tag) do @@ -440,6 +443,7 @@ defmodule EXLA.Defn.APITest do send(self(), {measurements, metadata}) end + @tag :iree_shape_mismatch_error test "executes event when function is compiled" do :ok = :telemetry.attach(__MODULE__, [:exla, :compilation], &__MODULE__.telemetry_handler/4, nil) diff --git a/exla/test/exla/defn/expr_test.exs b/exla/test/exla/defn/expr_test.exs index c540784c67..8ca1eb124d 100644 --- a/exla/test/exla/defn/expr_test.exs +++ b/exla/test/exla/defn/expr_test.exs @@ -392,6 +392,7 @@ defmodule EXLA.Defn.ExprTest do end describe "element-wise bitwise operators" do + @describetag :iree_shape_mismatch_error @left Nx.tensor([-2, -1, 0, 1, 2]) @right Nx.tensor([[-2], [-1], [0], [1], [2]]) @@ -497,6 +498,7 @@ defmodule EXLA.Defn.ExprTest do end describe "equal" do + @describetag :iree_shape_mismatch_error defn equal(a, b), do: Nx.equal(a, b) test "computes equality of scalars" do @@ -526,6 +528,7 @@ defmodule EXLA.Defn.ExprTest do end describe "not equal" do + @describetag :iree_shape_mismatch_error defn not_equal(a, b), do: Nx.not_equal(a, b) test "computes equality of scalars" do @@ -548,6 +551,7 @@ defmodule EXLA.Defn.ExprTest do end describe "less" do + @describetag :iree_shape_mismatch_error defn less(a, b), do: Nx.less(a, b) test "compares scalars" do @@ -567,6 +571,7 @@ defmodule EXLA.Defn.ExprTest do end describe "greater" do + @describetag :iree_shape_mismatch_error defn greater(a, b), do: Nx.greater(a, b) test "compares scalars" do @@ -589,6 +594,7 @@ defmodule EXLA.Defn.ExprTest do end describe "less equal" do + @describetag :iree_shape_mismatch_error defn less_equal(a, b), do: Nx.less_equal(a, b) test "compares scalars" do @@ -611,6 +617,7 @@ defmodule EXLA.Defn.ExprTest do end describe "greater equal" do + @describetag :iree_shape_mismatch_error defn greater_equal(a, b), do: Nx.greater_equal(a, b) test "compares scalars" do @@ -633,6 +640,7 @@ defmodule EXLA.Defn.ExprTest do end describe "logical" do + @describetag :iree_shape_mismatch_error defn logical_and(a, b), do: Nx.logical_and(a, b) test "and" do @@ -741,6 +749,7 @@ defmodule EXLA.Defn.ExprTest do end describe "select" do + @describetag :iree_shape_mismatch_error defn select(pred, x, y), do: Nx.select(pred, x, y) test "selects one or the other with a scalar" do @@ -809,6 +818,7 @@ defmodule EXLA.Defn.ExprTest do end describe "complex ops" do + @describetag :iree_unsupported_fft_error defn fft(t, opts \\ []), do: Nx.fft(t, opts) defn ifft(t, opts \\ []), do: Nx.ifft(t, opts) @@ -1366,6 +1376,7 @@ defmodule EXLA.Defn.ExprTest do end describe "while/3" do + @describetag :iree_key_not_found_error defn upto10(x) do while x, Nx.less(x, 10) do x + 1 @@ -1723,6 +1734,7 @@ defmodule EXLA.Defn.ExprTest do end describe "window_scatter_min/max" do + @describetag :iree_segfault_error defn window_scatter_max_no_padding(t) do Nx.window_scatter_max( t, @@ -2260,6 +2272,7 @@ defmodule EXLA.Defn.ExprTest do end describe "argmax/argmin" do + @describetag :iree_wrong_result_error defn argmax(t), do: Nx.argmax(t) defn argmin(t), do: Nx.argmin(t) defn argmax_axis(t), do: Nx.argmax(t, axis: 1) @@ -2330,6 +2343,7 @@ defmodule EXLA.Defn.ExprTest do end describe "window sum" do + @describetag :iree_segfault_error defn window_sum1(t), do: Nx.window_sum(t, {1, 2, 1}) defn window_sum2(t), @@ -2380,6 +2394,7 @@ defmodule EXLA.Defn.ExprTest do end describe "window mean" do + @describetag :iree_segfault_error defn window_mean1(t), do: Nx.window_mean(t, {1, 2, 1}) defn window_mean2(t), @@ -2435,6 +2450,7 @@ defmodule EXLA.Defn.ExprTest do end describe "window max" do + @describetag :iree_segfault_error defn window_max1(t), do: Nx.window_max(t, {1, 2, 1}) defn window_max2(t), @@ -2500,6 +2516,7 @@ defmodule EXLA.Defn.ExprTest do end describe "window min" do + @describetag :iree_wrong_result_error defn window_min0(t), do: Nx.window_min(t, {2}) defn window_min1(t), do: Nx.window_min(t, {1, 2, 1}) @@ -2570,6 +2587,7 @@ defmodule EXLA.Defn.ExprTest do end describe "window product" do + @describetag :iree_segfault_error defn window_product1(t), do: Nx.window_product(t, {1, 2, 1}) defn window_product2(t), @@ -2718,6 +2736,7 @@ defmodule EXLA.Defn.ExprTest do end describe "convolution" do + @describetag :iree_shape_mismatch_error defn conv_valid_no_stride(inp, kernel), do: Nx.conv(inp, kernel) defn conv_valid_stride(inp, kernel), @@ -3381,6 +3400,7 @@ defmodule EXLA.Defn.ExprTest do end describe "put slice" do + @describetag :iree_shape_mismatch_error defn put_slice1(t1, t2), do: Nx.put_slice(t1, [2], t2) defn put_slice2(t1, t2), do: Nx.put_slice(t1, [1, 2], t2) defn put_slice3(t1, t2), do: Nx.put_slice(t1, [2, 2], t2) @@ -3859,6 +3879,7 @@ defmodule EXLA.Defn.ExprTest do end describe "sort" do + @describetag :iree_segfault_error defn sort0(t), do: Nx.sort(t, axis: 0) defn sort1(t), do: Nx.sort(t, axis: 1) defn sort1_asc(t), do: Nx.sort(t, axis: 1, direction: :asc) @@ -3918,6 +3939,7 @@ defmodule EXLA.Defn.ExprTest do end describe "argsort" do + @describetag :iree_segfault_error defn argsort0(t), do: Nx.argsort(t, axis: 0) defn argsort1(t), do: Nx.argsort(t, axis: 1) defn argsort1_asc(t), do: Nx.argsort(t, axis: 1, direction: :asc) @@ -4171,6 +4193,7 @@ defmodule EXLA.Defn.ExprTest do end end + @tag :iree_key_not_found_error test "computes while inside cond" do assert {i} = while_in_cond(0) assert_equal(i, Nx.tensor(5)) diff --git a/exla/test/exla/defn/vectorize_test.exs b/exla/test/exla/defn/vectorize_test.exs index c81661c5f6..30a51fc548 100644 --- a/exla/test/exla/defn/vectorize_test.exs +++ b/exla/test/exla/defn/vectorize_test.exs @@ -4,6 +4,8 @@ defmodule EXLA.Defn.VectorizeTest do import Nx.Defn import Nx, only: :sigils + @moduletag :iree_shape_mismatch_error + setup do Nx.default_backend(EXLA.Backend) @@ -18,7 +20,7 @@ defmodule EXLA.Defn.VectorizeTest do %{base: base, vectorized: vectorized} end - defn add_n(x, y), do: Nx.add(x, y) + defn(add_n(x, y), do: Nx.add(x, y)) def add(x, y) do EXLA.jit_apply(&add_n/2, [x, y]) diff --git a/exla/test/exla/executable_test.exs b/exla/test/exla/executable_test.exs index 28e276edfc..4dcc67885b 100644 --- a/exla/test/exla/executable_test.exs +++ b/exla/test/exla/executable_test.exs @@ -155,6 +155,7 @@ defmodule EXLA.ExecutableFeedTest do import EXLAHelpers describe "infeed/outfeed" do + @describetag :token test "successfully sends to/from device asynchronously" do t = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {})) diff --git a/exla/test/exla/mlir/custom_call_test.exs b/exla/test/exla/mlir/custom_call_test.exs index c0b454cb94..50149ac9ab 100644 --- a/exla/test/exla/mlir/custom_call_test.exs +++ b/exla/test/exla/mlir/custom_call_test.exs @@ -2,6 +2,7 @@ defmodule EXLA.MLIR.CustomCallTest do use EXLA.Case, async: true describe "qr" do + @describetag :iree_key_not_found_error for type <- [bf: 16, f: 16, f: 32, f: 64] do tol_opts = case type do diff --git a/exla/test/exla/nx_linalg_doctest_test.exs b/exla/test/exla/nx_linalg_doctest_test.exs index 6df3aeec10..68e47993a4 100644 --- a/exla/test/exla/nx_linalg_doctest_test.exs +++ b/exla/test/exla/nx_linalg_doctest_test.exs @@ -12,9 +12,12 @@ defmodule EXLA.MLIR.NxLinAlgDoctestTest do ] @rounding_error_doctests [triangular_solve: 3, eigh: 2, cholesky: 1, least_squares: 2] + @iree_error_doctests [qr: 2] + @excluded_doctests @function_clause_error_doctests ++ @rounding_error_doctests ++ @invalid_type_error_doctests ++ + @iree_error_doctests ++ [:moduledoc] doctest Nx.LinAlg, except: @excluded_doctests end diff --git a/exla/test/exla/random_test.exs b/exla/test/exla/random_test.exs index 3b86fc64b9..db1bea2ec4 100644 --- a/exla/test/exla/random_test.exs +++ b/exla/test/exla/random_test.exs @@ -1,6 +1,8 @@ defmodule EXLA.NxRandomTest do use EXLA.Case, async: true + @moduletag :iree_hangup_error + setup do Nx.default_backend(EXLA.Backend) :ok diff --git a/exla/test/exla/serving_test.exs b/exla/test/exla/serving_test.exs index d45f17827e..a7df891e88 100644 --- a/exla/test/exla/serving_test.exs +++ b/exla/test/exla/serving_test.exs @@ -1,6 +1,8 @@ defmodule EXLA.ServingTest do use EXLA.Case, async: true + @moduletag :token + defmodule ExecuteSync do @behaviour Nx.Serving diff --git a/exla/test/exla_test.exs b/exla/test/exla_test.exs index 5558c73af5..45b0b5e144 100644 --- a/exla/test/exla_test.exs +++ b/exla/test/exla_test.exs @@ -1,6 +1,7 @@ defmodule EXLATest do use EXLA.Case, async: true + @moduletag :token doctest EXLA describe "integration" do diff --git a/exla/test/test_helper.exs b/exla/test/test_helper.exs index 4a00ad450c..8e60e6d395 100644 --- a/exla/test/test_helper.exs +++ b/exla/test/test_helper.exs @@ -38,8 +38,27 @@ cuda_required = [:cuda_required] end +compiler_mode = :iree + +iree_excludes = + if compiler_mode == :iree do + [ + :token, + :iree_hangup_error, + :iree_shape_mismatch_error, + :iree_key_not_found_error, + :iree_wrong_result_error, + :iree_unsupported_fft_error, + :iree_segfault_error, + :multi_device + ] + else + [] + end + ExUnit.start( - exclude: [:platform, :integration] ++ exclude_multi_device ++ exclude ++ cuda_required, + exclude: + [:platform, :integration] ++ exclude_multi_device ++ exclude ++ cuda_required ++ iree_excludes, include: [platform: String.to_atom(target)], assert_receive_timeout: 1000 ) From 23a468d89af00fa2ab0d99a39dcb426e3c1baa5d Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 11 May 2024 03:20:20 -0300 Subject: [PATCH 12/40] test: skip broken tests --- exla/test/exla/defn/api_test.exs | 4 +++- exla/test/exla/defn/expr_test.exs | 30 ++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/exla/test/exla/defn/api_test.exs b/exla/test/exla/defn/api_test.exs index 62a3913039..71be9aa3e0 100644 --- a/exla/test/exla/defn/api_test.exs +++ b/exla/test/exla/defn/api_test.exs @@ -7,6 +7,7 @@ defmodule EXLA.Defn.APITest do defn add_two(a, b), do: a + b describe "multi-client" do + @describetag :iree_key_not_found_error test "converts from host to separate client" do a = Nx.tensor(1, backend: {EXLA.Backend, client: :host}) b = Nx.tensor(2, backend: {EXLA.Backend, client: :host}) @@ -29,6 +30,7 @@ defmodule EXLA.Defn.APITest do end describe "options" do + @tag :iree_shape_mismatch_error test "logs when debugging" do logs = capture_log(fn -> @@ -37,7 +39,7 @@ defmodule EXLA.Defn.APITest do assert logs =~ ~r"EXLA defn evaluation #Function<[^>]+> cache (hit|miss) in \d+\.\dms" assert logs =~ ~r"EXLA compilation #Function<[^>]+> cache (hit|miss) in \d+\.\dms" - assert logs =~ ~r"EXLA device \d lock in \d+\.\dms" + assert logs =~ ~r"EXLA device -?\d lock in \d+\.\dms" assert logs =~ ~r"EXLA execution on device \d in \d+\.\dms" logs = diff --git a/exla/test/exla/defn/expr_test.exs b/exla/test/exla/defn/expr_test.exs index 8ca1eb124d..aa4294ef76 100644 --- a/exla/test/exla/defn/expr_test.exs +++ b/exla/test/exla/defn/expr_test.exs @@ -9,6 +9,7 @@ defmodule EXLA.Defn.ExprTest do end describe "tuples" do + @describetag :iree_shape_mismatch_error defn add_subtract_tuple(a, b), do: {a + b, a - b} test "on results" do @@ -148,6 +149,7 @@ defmodule EXLA.Defn.ExprTest do end describe "+/2" do + @describetag :iree_shape_mismatch_error defn add_two(a, b), do: a + b test "same shape and type" do @@ -237,6 +239,7 @@ defmodule EXLA.Defn.ExprTest do end describe "//2" do + @describetag :iree_shape_mismatch_error defn divide_two(a, b), do: a / b test "parameters" do @@ -277,6 +280,7 @@ defmodule EXLA.Defn.ExprTest do end describe "remainder" do + @describetag :iree_shape_mismatch_error defn remainder(a, b), do: Nx.remainder(a, b) test "integers" do @@ -286,6 +290,7 @@ defmodule EXLA.Defn.ExprTest do assert_all_close(remainder(left, right), Nx.remainder(left, right)) end + @tag :iree_shape_mismatch_error test "floats" do left = Nx.tensor([-8.3, -8.4, -8.5, 8.3, 8.4, 8.5]) right = Nx.tensor([[-4.2], [-4.1], [-4.0], [4.0], [4.1], [4.2]]) @@ -295,6 +300,7 @@ defmodule EXLA.Defn.ExprTest do end describe "element-wise arith operators" do + @describetag :iree_shape_mismatch_error @tensors [ {1, 2}, {1, Nx.tensor([1.0, 2.0, 3.0])}, @@ -793,6 +799,7 @@ defmodule EXLA.Defn.ExprTest do end describe "unary float ops" do + @describetag :iree_shape_mismatch_error @int_tensor Nx.tensor([1, 2, 3]) @float_tensor Nx.tensor([1.0, 2.0, 3.0]) @@ -1134,6 +1141,7 @@ defmodule EXLA.Defn.ExprTest do end describe "if" do + @describetag :iree_shape_mismatch_error defn if3(a, b, c), do: if(a, do: b, else: c) test "one param per branch" do @@ -1307,6 +1315,7 @@ defmodule EXLA.Defn.ExprTest do end describe "cond" do + @describetag :iree_shape_mismatch_error defn cond3(a, b, c) do d = Nx.sum(a) @@ -1511,6 +1520,7 @@ defmodule EXLA.Defn.ExprTest do end describe "map" do + @describetag :iree_shape_mismatch_error defn map_plus(t), do: Nx.map(t, fn x -> x + 1 end) defn map_equal(t), do: Nx.map(t, [type: {:f, 64}], fn x -> Nx.equal(x, 1) end) defn map_exp(t), do: Nx.map(t, [type: {:f, 64}], fn x -> Nx.exp(x) end) @@ -1809,6 +1819,7 @@ defmodule EXLA.Defn.ExprTest do end describe "indexed_add" do + @describetag :iree_shape_mismatch_error defn indexed_add(t, i, u) do Nx.indexed_add(t, i, u) end @@ -1883,6 +1894,7 @@ defmodule EXLA.Defn.ExprTest do end describe "indexed_put" do + @describetag :iree_shape_mismatch_error defn indexed_put(t, i, u) do Nx.indexed_put(t, i, u) end @@ -2011,6 +2023,7 @@ defmodule EXLA.Defn.ExprTest do end describe "sum" do + @describetag :iree_shape_mismatch_error defn sum(t), do: Nx.sum(t) test "computes the sum across types" do @@ -2059,6 +2072,7 @@ defmodule EXLA.Defn.ExprTest do end describe "product" do + @describetag :iree_shape_mismatch_error defn product(t), do: Nx.product(t) test "computes the product across types" do @@ -2107,6 +2121,7 @@ defmodule EXLA.Defn.ExprTest do end describe "mean" do + @describetag :iree_shape_mismatch_error defn mean(t), do: Nx.mean(t) test "computes mean without axis" do @@ -2170,6 +2185,7 @@ defmodule EXLA.Defn.ExprTest do end describe "reduce_max" do + @describetag :iree_shape_mismatch_error defn reduce_max(t), do: Nx.reduce_max(t) test "computes the maximum across types" do @@ -2219,6 +2235,7 @@ defmodule EXLA.Defn.ExprTest do end describe "reduce_min" do + @describetag :iree_shape_mismatch_error defn reduce_min(t), do: Nx.reduce_min(t) test "computes the minimum across types" do @@ -2642,6 +2659,7 @@ defmodule EXLA.Defn.ExprTest do end describe "dot product" do + @describetag :iree_shape_mismatch_error defn dot(a, b), do: Nx.dot(a, b) test "computes the dot product of scalars" do @@ -3439,6 +3457,7 @@ defmodule EXLA.Defn.ExprTest do end describe "take" do + @describetag :iree_shape_mismatch_error defn take_axis_0(t, idx), do: Nx.take(t, idx) defn take_axis_1(t, idx), do: Nx.take(t, idx, axis: 1) @@ -3501,6 +3520,7 @@ defmodule EXLA.Defn.ExprTest do end describe "gather" do + @describetag :iree_shape_mismatch_error defn gather(t, idx), do: Nx.gather(t, idx) test "1d result" do @@ -3623,6 +3643,7 @@ defmodule EXLA.Defn.ExprTest do end describe "concatenate" do + @describetag :iree_shape_mismatch_error defn concatenate0(t1, t2, t3), do: Nx.concatenate([t1, t2, t3], axis: 0) defn concatenate1(t1, t2, t3), do: Nx.concatenate([t1, t2, t3], axis: 1) defn concatenate2(t1, t2, t3), do: Nx.concatenate([t1, t2, t3], axis: 2) @@ -3784,6 +3805,7 @@ defmodule EXLA.Defn.ExprTest do end describe "decompositions" do + @describetag :iree_shape_mismatch_error defn ts(a, b, opts \\ []), do: Nx.LinAlg.triangular_solve(a, b, opts) test "triangular_solve" do @@ -3924,6 +3946,7 @@ defmodule EXLA.Defn.ExprTest do end describe "top_k" do + @describetag :iree_shape_mismatch_error defn top_1(t), do: Nx.top_k(t, k: 1) test "returns top 1 values and indices" do @@ -3995,6 +4018,7 @@ defmodule EXLA.Defn.ExprTest do describe "optional" do defn determinant(t), do: Nx.LinAlg.determinant(t) + @tag :iree_key_not_found_error test "determinant" do two_by_two = Nx.tensor([[1, 2], [3, 4]], names: [:x, :y]) assert_equal(determinant(two_by_two), Nx.tensor(-2.0)) @@ -4002,6 +4026,7 @@ defmodule EXLA.Defn.ExprTest do defn double_determinant(a, b), do: Nx.LinAlg.determinant(a) * Nx.LinAlg.determinant(b) + @tag :iree_key_not_found_error test "multiple determinant" do from_one = Nx.tensor([[1, 2], [3, 4]]) from_ten = Nx.tensor([[10, 20], [30, 40]]) @@ -4010,6 +4035,7 @@ defmodule EXLA.Defn.ExprTest do end describe "cholesky" do + @describetag :iree_shape_mismatch_error defn cholesky(t), do: Nx.LinAlg.cholesky(t) test "works on 2x2 matrix" do @@ -4022,6 +4048,7 @@ defmodule EXLA.Defn.ExprTest do assert_all_close(lhs, rhs) end + @tag :iree_key_not_found_error test "works on a 4x4 matrix" do lhs = cholesky( @@ -4061,6 +4088,7 @@ defmodule EXLA.Defn.ExprTest do end describe "bfloat16" do + @describetag :iree_shape_mismatch_error defn add(t1, t2), do: t1 + t2 test "accepts bfloat16 input" do @@ -4071,6 +4099,7 @@ defmodule EXLA.Defn.ExprTest do end describe "precision" do + @describetag :iree_shape_mismatch_error defn precision(t1, t2), do: Nx.dot(t1, t2) test "raises on bad precision" do @@ -4100,6 +4129,7 @@ defmodule EXLA.Defn.ExprTest do end describe "take_along_axis/3" do + @describetag :iree_shape_mismatch_error defn take_along_axis(t, idx, opts \\ [axis: 0]), do: Nx.take_along_axis(t, idx, opts) defn sort_with_take_along_axis(t, opts \\ []) do From 29604965018891e27dc2153f0326d2d0f806b7a6 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 11 May 2024 04:24:41 -0300 Subject: [PATCH 13/40] fix: shape allocation issues --- exla/c_src/exla/iree/runtime.cc | 89 ++++++++++++++++++++----------- exla/test/exla/defn/api_test.exs | 24 ++++----- exla/test/exla/defn/expr_test.exs | 43 +++++++++++---- exla/test/test_helper.exs | 2 + 4 files changed, 106 insertions(+), 52 deletions(-) diff --git a/exla/c_src/exla/iree/runtime.cc b/exla/c_src/exla/iree/runtime.cc index caa8c8094a..39ed8ddea4 100644 --- a/exla/c_src/exla/iree/runtime.cc +++ b/exla/c_src/exla/iree/runtime.cc @@ -4,12 +4,46 @@ #include #include -typedef struct { +class IREEInput { + public: void *data; size_t size; - std::vector dims; + iree_hal_dim_t *dims; + size_t rank; iree_hal_element_type_t type; -} IREEInput; + + // Default constructor + IREEInput(void *data, size_t size, std::vector in_dims, iree_hal_element_type_t type) : size(size), type(type) { + rank = in_dims.size(); + dims = reinterpret_cast(iree_alloca(rank * sizeof(iree_hal_dim_t))); + + for (size_t i = 0; i < rank; i++) { + dims[i] = in_dims[i]; + } + + this->data = std::malloc(size); // Allocate memory + std::memcpy(this->data, data, size); + } + + // Destructor + ~IREEInput() { + if (data) { + std::free(data); + data = nullptr; + } + + if (dims) { + std::free(dims); + dims = nullptr; + } + } + + // Disable copy and move semantics for simplicity + IREEInput(const IREEInput &) = delete; + IREEInput &operator=(const IREEInput &) = delete; + IREEInput(IREEInput &&) = delete; + IREEInput &operator=(IREEInput &&) = delete; +}; bool primitive_type_to_iree_element_type(xla::PrimitiveType t, iree_hal_element_type_t *type) { using xla::PrimitiveType; @@ -120,17 +154,20 @@ bool iree_element_type_to_nx_type(iree_hal_element_type_t type, std::string &nx_ } } -int load_inputs(ErlNifEnv *env, std::vector terms, std::vector &loaded) { +int load_inputs(ErlNifEnv *env, std::vector terms, std::vector &loaded) { const ERL_NIF_TERM *tuple, *typespec; int length; ErlNifBinary bin; - IREEInput item; - std::vector dims; - xla::PrimitiveType primitive_type; - loaded.reserve(terms.size()); + loaded.clear(); + loaded.resize(terms.size()); + + for (size_t i = 0; i < terms.size(); i++) { + ERL_NIF_TERM term = terms[i]; + std::vector dims; + xla::PrimitiveType primitive_type; + iree_hal_element_type_t type; - for (ERL_NIF_TERM term : terms) { if (!enif_get_tuple(env, term, &length, &tuple)) { return 0; } @@ -139,9 +176,6 @@ int load_inputs(ErlNifEnv *env, std::vector terms, std::vector terms, std::vector terms, std::vectordata, input->size); iree_hal_buffer_params_t buffer_params = { .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, @@ -181,16 +210,16 @@ iree_status_t iree_input_to_hal_arg(iree_hal_buffer_view_t **arg, IREEInput &inp return iree_hal_buffer_view_allocate_buffer_copy( device, device_allocator, - input.dims.size(), - input.dims.data(), - input.type, + input->rank, + input->dims, + input->type, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, buffer_params, data_span, arg); } -iree_status_t call_module(iree_runtime_session_t *session, std::vector inputs, std::vector> *result) { +iree_status_t call_module(iree_runtime_session_t *session, std::vector inputs, std::vector> *result) { iree_runtime_call_t call; iree_vm_function_t function; @@ -219,7 +248,7 @@ iree_status_t call_module(iree_runtime_session_t *session, std::vector bytecode_vec; - std::vector input_terms; - std::vector inputs; - std::vector bytecode; + std::vector bytecode_vec = {}; + std::vector input_terms = {}; + std::vector inputs = {}; + std::vector bytecode = {}; if (!exla::nif::get_list(env, argv[0], bytecode_vec)) { return exla::nif::error(env, "Unable to load bytecode binary"); diff --git a/exla/test/exla/defn/api_test.exs b/exla/test/exla/defn/api_test.exs index 71be9aa3e0..2db80ca51c 100644 --- a/exla/test/exla/defn/api_test.exs +++ b/exla/test/exla/defn/api_test.exs @@ -4,7 +4,7 @@ defmodule EXLA.Defn.APITest do import Nx.Defn import ExUnit.CaptureLog - defn add_two(a, b), do: a + b + defn(add_two(a, b), do: a + b) describe "multi-client" do @describetag :iree_key_not_found_error @@ -30,7 +30,6 @@ defmodule EXLA.Defn.APITest do end describe "options" do - @tag :iree_shape_mismatch_error test "logs when debugging" do logs = capture_log(fn -> @@ -40,7 +39,7 @@ defmodule EXLA.Defn.APITest do assert logs =~ ~r"EXLA defn evaluation #Function<[^>]+> cache (hit|miss) in \d+\.\dms" assert logs =~ ~r"EXLA compilation #Function<[^>]+> cache (hit|miss) in \d+\.\dms" assert logs =~ ~r"EXLA device -?\d lock in \d+\.\dms" - assert logs =~ ~r"EXLA execution on device \d in \d+\.\dms" + assert logs =~ ~r"EXLA execution on device -?\d in \d+\.\dms" logs = capture_log(fn -> @@ -49,8 +48,8 @@ defmodule EXLA.Defn.APITest do assert logs =~ ~r"EXLA defn evaluation #Function<[^>]+> cache hit in \d+\.\dms" assert logs =~ ~r"EXLA compilation #Function<[^>]+> cache hit in \d+\.\dms" - assert logs =~ ~r"EXLA device \d lock in \d+\.\d+ms" - assert logs =~ ~r"EXLA execution on device \d in \d+\.\dms" + assert logs =~ ~r"EXLA device -?\d lock in \d+\.\d+ms" + assert logs =~ ~r"EXLA execution on device -?\d in \d+\.\dms" end end @@ -100,7 +99,7 @@ defmodule EXLA.Defn.APITest do end describe "batch" do - @tag :iree_shape_mismatch_error + @tag :iree_resource_exhausted_error test "when padded" do input = Nx.tensor([[1, 2, 3]], backend: EXLA.Backend) batch = [input] |> Nx.Batch.concatenate() |> Nx.Batch.pad(1) @@ -121,7 +120,7 @@ defmodule EXLA.Defn.APITest do %{"x" => rand(), "y" => rand()} end - deftransformp rand, do: :rand.uniform() + deftransformp(rand, do: :rand.uniform()) test "considers map keys in cache keys" do assert_equal(merge(%{"x" => 10})["x"], Nx.tensor(10)) @@ -131,7 +130,7 @@ defmodule EXLA.Defn.APITest do describe "stream" do @describetag :token - defn defn_sum(entry, acc), do: {acc, entry + acc} + defn(defn_sum(entry, acc), do: {acc, entry + acc}) test "immediately done" do stream = EXLA.stream(&defn_sum/2, [0, 0]) @@ -183,7 +182,7 @@ defmodule EXLA.Defn.APITest do assert_equal(Nx.Stream.done(stream), {Nx.tensor(3), {Nx.tensor(2), Nx.tensor(4)}}) end - defn stream_empty_outfeed(i, t), do: {{}, i + t} + defn(stream_empty_outfeed(i, t), do: {{}, i + t}) test "send/recv with empty outfeed" do %_{} = stream = EXLA.stream(&stream_empty_outfeed/2, [0, 0.0]) @@ -196,7 +195,7 @@ defmodule EXLA.Defn.APITest do assert_equal(Nx.Stream.done(stream), Nx.tensor(3.0)) end - defn stream_empty_acc(i, {}), do: {i * i, {}} + defn(stream_empty_acc(i, {}), do: {i * i, {}}) test "send/recv with empty acc" do %_{} = stream = EXLA.stream(&stream_empty_acc/2, [0, {}]) @@ -415,7 +414,7 @@ defmodule EXLA.Defn.APITest do assert_equal(b, Nx.tensor(2)) end - defn hook_stream(entry, acc), do: hook({acc, entry + acc}, :stream) + defn(hook_stream(entry, acc), do: hook({acc, entry + acc}, :stream)) test "executes hook with stream" do %_{} = stream = EXLA.stream(&hook_stream/2, [0, 0], hooks: %{stream: send_to_self(:tag)}) @@ -439,13 +438,12 @@ defmodule EXLA.Defn.APITest do end describe "telemetry" do - defn telemetry_add_two(a, b), do: a + b + defn(telemetry_add_two(a, b), do: a + b) def telemetry_handler(_event_name, measurements, metadata, _config) do send(self(), {measurements, metadata}) end - @tag :iree_shape_mismatch_error test "executes event when function is compiled" do :ok = :telemetry.attach(__MODULE__, [:exla, :compilation], &__MODULE__.telemetry_handler/4, nil) diff --git a/exla/test/exla/defn/expr_test.exs b/exla/test/exla/defn/expr_test.exs index aa4294ef76..514fd8e0c9 100644 --- a/exla/test/exla/defn/expr_test.exs +++ b/exla/test/exla/defn/expr_test.exs @@ -149,7 +149,6 @@ defmodule EXLA.Defn.ExprTest do end describe "+/2" do - @describetag :iree_shape_mismatch_error defn add_two(a, b), do: a + b test "same shape and type" do @@ -300,7 +299,6 @@ defmodule EXLA.Defn.ExprTest do end describe "element-wise arith operators" do - @describetag :iree_shape_mismatch_error @tensors [ {1, 2}, {1, Nx.tensor([1.0, 2.0, 3.0])}, @@ -312,6 +310,7 @@ defmodule EXLA.Defn.ExprTest do defn subtract_two(a, b), do: a - b + @tag :iree_shape_mismatch_error test "-" do for {left, right} <- @tensors do assert_all_close(subtract_two(left, right), Nx.subtract(left, right)) @@ -321,6 +320,7 @@ defmodule EXLA.Defn.ExprTest do defn multiply_two(a, b), do: a * b + @tag :iree_shape_mismatch_error test "*" do for {left, right} <- @tensors do assert_all_close(multiply_two(left, right), Nx.multiply(left, right)) @@ -330,6 +330,7 @@ defmodule EXLA.Defn.ExprTest do defn unary_minus(a), do: -a + @tag :iree_shape_mismatch_error test "negate" do for t <- [ Nx.tensor([-1, 0, 1], type: {:u, 8}), @@ -342,6 +343,7 @@ defmodule EXLA.Defn.ExprTest do defn max_two(a, b), do: max(a, b) + @tag :iree_shape_mismatch_error test "max" do for {left, right} <- @tensors do assert_all_close(max_two(left, right), Nx.max(left, right)) @@ -351,6 +353,7 @@ defmodule EXLA.Defn.ExprTest do defn min_two(a, b), do: min(a, b) + @tag :iree_shape_mismatch_error test "min" do for {left, right} <- @tensors do assert_all_close(min_two(left, right), Nx.min(left, right)) @@ -360,6 +363,7 @@ defmodule EXLA.Defn.ExprTest do defn power_two(a, b), do: Nx.pow(a, b) + @tag :iree_shape_mismatch_error test "pow" do for {left, right} <- @tensors do assert_all_close(power_two(left, right), Nx.pow(left, right)) @@ -369,6 +373,7 @@ defmodule EXLA.Defn.ExprTest do defn atan2_two(a, b), do: Nx.atan2(a, b) + @tag :iree_resource_exhausted_error test "atan2" do <> = <<0x8000000000000000::64>> left = Nx.tensor([-1.0, neg_zero, 0.0, 1.0]) @@ -380,6 +385,7 @@ defmodule EXLA.Defn.ExprTest do defn quotient_two(a, b), do: Nx.quotient(a, b) + @tag :iree_shape_mismatch_error test "quotient" do int_tensors = [ {1, 2}, @@ -398,12 +404,12 @@ defmodule EXLA.Defn.ExprTest do end describe "element-wise bitwise operators" do - @describetag :iree_shape_mismatch_error @left Nx.tensor([-2, -1, 0, 1, 2]) @right Nx.tensor([[-2], [-1], [0], [1], [2]]) defn bitwise_and(a, b), do: a &&& b + @tag :iree_resource_exhausted_error test "bitwise_and" do assert Nx.shape(bitwise_and(@left, @right)) == {5, 5} assert_equal(bitwise_and(@left, @right), Nx.bitwise_and(@left, @right)) @@ -411,6 +417,7 @@ defmodule EXLA.Defn.ExprTest do defn bitwise_or(a, b), do: a ||| b + @tag :iree_resource_exhausted_error test "bitwise_or" do assert Nx.shape(bitwise_or(@left, @right)) == {5, 5} assert_equal(bitwise_or(@left, @right), Nx.bitwise_or(@left, @right)) @@ -418,6 +425,7 @@ defmodule EXLA.Defn.ExprTest do defn bitwise_not(a), do: ~~~a + @tag :iree_resource_exhausted_error test "bitwise_not" do assert Nx.shape(bitwise_not(@left)) == {5} assert_equal(bitwise_not(@left), Nx.bitwise_not(@left)) @@ -425,6 +433,7 @@ defmodule EXLA.Defn.ExprTest do defn bitwise_pc(a), do: Nx.population_count(a) + @tag :iree_illegal_op_error test "population_count" do assert Nx.shape(bitwise_pc(@left)) == {5} assert_equal(bitwise_pc(@left), Nx.population_count(@left)) @@ -432,6 +441,7 @@ defmodule EXLA.Defn.ExprTest do defn bitwise_clz(a), do: Nx.count_leading_zeros(a) + @tag :iree_illegal_op_error test "count_leading_zeros" do assert Nx.shape(bitwise_clz(@left)) == {5} assert_equal(bitwise_clz(@left), Nx.count_leading_zeros(@left)) @@ -442,6 +452,7 @@ defmodule EXLA.Defn.ExprTest do defn left_shift(a, b), do: a <<< b + @tag :iree_resource_exhausted_error test "left_shift" do assert Nx.shape(left_shift(@left, @right)) == {5, 5} assert_equal(left_shift(@left, @right), Nx.left_shift(@left, @right)) @@ -455,6 +466,7 @@ defmodule EXLA.Defn.ExprTest do defn right_shift(a, b), do: a >>> b + @tag :iree_resource_exhausted_error test "right_shift" do assert Nx.shape(right_shift(@left_signed, @right_signed)) == {9, 9} @@ -534,13 +546,14 @@ defmodule EXLA.Defn.ExprTest do end describe "not equal" do - @describetag :iree_shape_mismatch_error defn not_equal(a, b), do: Nx.not_equal(a, b) + @tag :iree_shape_mismatch_error test "computes equality of scalars" do assert_equal(not_equal(Nx.tensor(1), Nx.tensor(2)), Nx.tensor(1, type: {:u, 8})) end + @tag :iree_shape_mismatch_error test "computes equality with broadcasting" do assert_equal( not_equal(Nx.tensor(1), Nx.tensor([1, 2, 3])), @@ -548,6 +561,7 @@ defmodule EXLA.Defn.ExprTest do ) end + @tag :iree_shape_mismatch_error test "computes equality with mixed types" do assert_equal( not_equal(Nx.tensor([1, 2, 3]), Nx.tensor([1.0, 2.0, 3.0])), @@ -755,7 +769,6 @@ defmodule EXLA.Defn.ExprTest do end describe "select" do - @describetag :iree_shape_mismatch_error defn select(pred, x, y), do: Nx.select(pred, x, y) test "selects one or the other with a scalar" do @@ -785,6 +798,7 @@ defmodule EXLA.Defn.ExprTest do ) end + @tag :iree_shape_mismatch_error test "selects with broadcasting" do assert_equal( select(Nx.tensor([1, 0, 1, 0, 1]), Nx.tensor([10]), Nx.tensor([1, 2, 3, 4, 5])), @@ -1866,6 +1880,7 @@ defmodule EXLA.Defn.ExprTest do ) end + @tag :iree_resource_exhausted_error test "indexed_add handles different input types" do target = Nx.tensor([0]) indices = Nx.tensor([[0]]) @@ -1894,11 +1909,11 @@ defmodule EXLA.Defn.ExprTest do end describe "indexed_put" do - @describetag :iree_shape_mismatch_error defn indexed_put(t, i, u) do Nx.indexed_put(t, i, u) end + @tag :iree_resource_exhausted_error test "indexed_add works for multi-dim tensor" do target = Nx.broadcast(0, {2, 3, 4}) @@ -1939,6 +1954,7 @@ defmodule EXLA.Defn.ExprTest do ) end + @tag :iree_resource_exhausted_error test "indexed_put handles different input types" do target = Nx.tensor([0]) indices = Nx.tensor([[0]]) @@ -2223,6 +2239,7 @@ defmodule EXLA.Defn.ExprTest do defn reduce_max_keep(t), do: Nx.reduce_max(t, keep_axes: true) defn reduce_max_keep_2(t), do: Nx.reduce_max(t, axes: [0, 2], keep_axes: true) + @tag :iree_shape_mismatch_error test "keeps dimensions if keep_axes" do assert_equal(Nx.tensor([1, 2, 3]) |> reduce_max_keep(), Nx.tensor([3])) assert_equal(Nx.tensor([1.0, 2.0, 3.0]) |> reduce_max_keep(), Nx.tensor([3.0])) @@ -2235,9 +2252,9 @@ defmodule EXLA.Defn.ExprTest do end describe "reduce_min" do - @describetag :iree_shape_mismatch_error defn reduce_min(t), do: Nx.reduce_min(t) + @tag :iree_resource_exhausted_error test "computes the minimum across types" do assert_equal(Nx.tensor([1, 2, 3]) |> reduce_min(), Nx.tensor(1)) @@ -2259,6 +2276,7 @@ defmodule EXLA.Defn.ExprTest do ) end + @tag :iree_resource_exhausted_error test "computes the minimum across nan" do assert_equal(Nx.tensor([:nan, :nan, :nan]) |> reduce_min(), Nx.tensor(:nan)) end @@ -2267,6 +2285,7 @@ defmodule EXLA.Defn.ExprTest do defn reduce_min_neg_axis(t), do: Nx.reduce_min(t, axes: [-3]) defn reduce_min_pos_neg_axis(t), do: Nx.reduce_min(t, axes: [1, -3]) + @tag :iree_shape_mismatch_error test "computes the min on a given axis" do t = Nx.tensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]) assert_equal(reduce_min_pos_axis(t), Nx.reduce_min(t, axes: [1])) @@ -2277,6 +2296,7 @@ defmodule EXLA.Defn.ExprTest do defn reduce_min_keep(t), do: Nx.reduce_min(t, keep_axes: true) defn reduce_min_keep_2(t), do: Nx.reduce_min(t, axes: [0, 2], keep_axes: true) + @tag :iree_shape_mismatch_error test "keeps dimensions if keep_axes" do assert_equal(Nx.tensor([1, 2, 3]) |> reduce_min_keep(), Nx.tensor([1])) assert_equal(Nx.tensor([1.0, 2.0, 3.0]) |> reduce_min_keep(), Nx.tensor([1.0])) @@ -2935,6 +2955,7 @@ defmodule EXLA.Defn.ExprTest do end @tag :unsupported_64_bit_op + @tag :iree_resource_exhausted_error test "computes the convolution with general padding, stride" do img = Nx.iota({2, 1, 12, 24}, type: {:f, 64}) kernel = Nx.iota({2, 1, 6, 6}, type: {:f, 64}) @@ -3805,9 +3826,9 @@ defmodule EXLA.Defn.ExprTest do end describe "decompositions" do - @describetag :iree_shape_mismatch_error defn ts(a, b, opts \\ []), do: Nx.LinAlg.triangular_solve(a, b, opts) + @tag :iree_key_not_found_error test "triangular_solve" do a = Nx.tensor([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]]) b = Nx.tensor([4, 2, 4, 2]) @@ -3826,6 +3847,7 @@ defmodule EXLA.Defn.ExprTest do defn qr(t), do: Nx.LinAlg.qr(t) defn qr_complete(t), do: Nx.LinAlg.qr(t, mode: :complete) + @tag :iree_key_not_found_error test "qr" do input = Nx.iota({3, 2}) output = Nx.as_type(input, {:f, 32}) @@ -3851,6 +3873,7 @@ defmodule EXLA.Defn.ExprTest do defn svd(t), do: Nx.LinAlg.svd(t) + @tag :iree_key_not_found_error test "svd" do input = Nx.iota({3, 3}) output = Nx.as_type(input, {:f, 32}) @@ -3867,6 +3890,7 @@ defmodule EXLA.Defn.ExprTest do ) end + @tag :iree_key_not_found_error test "svd (tall matrix)" do input = Nx.tensor([[2, 0], [0, 1], [0, 0]]) output = Nx.as_type(input, {:f, 32}) @@ -3883,6 +3907,7 @@ defmodule EXLA.Defn.ExprTest do ) end + @tag :iree_key_not_found_error test "svd (wide matrix)" do input = Nx.tensor([[2, 0, 0], [0, 1, 0]]) output = Nx.as_type(input, {:f, 32}) @@ -4088,9 +4113,9 @@ defmodule EXLA.Defn.ExprTest do end describe "bfloat16" do - @describetag :iree_shape_mismatch_error defn add(t1, t2), do: t1 + t2 + @tag :iree_shape_mismatch_error test "accepts bfloat16 input" do lhs = Nx.tensor([1.0, 2.0, 3.0], type: {:bf, 16}) rhs = Nx.tensor([4.0, 5.0, 6.0], type: {:bf, 16}) diff --git a/exla/test/test_helper.exs b/exla/test/test_helper.exs index 8e60e6d395..5610669ff6 100644 --- a/exla/test/test_helper.exs +++ b/exla/test/test_helper.exs @@ -46,10 +46,12 @@ iree_excludes = :token, :iree_hangup_error, :iree_shape_mismatch_error, + :iree_resource_exhausted_error, :iree_key_not_found_error, :iree_wrong_result_error, :iree_unsupported_fft_error, :iree_segfault_error, + :iree_illegal_op_error, :multi_device ] else From 859aab073a3330b209e8fd7935854129354cf4ba Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 11 May 2024 18:03:01 -0300 Subject: [PATCH 14/40] fix: input and shape vector handling --- exla/c_src/exla/iree/runtime.cc | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/exla/c_src/exla/iree/runtime.cc b/exla/c_src/exla/iree/runtime.cc index 39ed8ddea4..5435a4b188 100644 --- a/exla/c_src/exla/iree/runtime.cc +++ b/exla/c_src/exla/iree/runtime.cc @@ -8,17 +8,15 @@ class IREEInput { public: void *data; size_t size; - iree_hal_dim_t *dims; - size_t rank; + std::vector dims; iree_hal_element_type_t type; // Default constructor IREEInput(void *data, size_t size, std::vector in_dims, iree_hal_element_type_t type) : size(size), type(type) { - rank = in_dims.size(); - dims = reinterpret_cast(iree_alloca(rank * sizeof(iree_hal_dim_t))); + dims.reserve(in_dims.size()); - for (size_t i = 0; i < rank; i++) { - dims[i] = in_dims[i]; + for (auto dim : in_dims) { + dims.push_back(static_cast(dim)); } this->data = std::malloc(size); // Allocate memory @@ -31,11 +29,6 @@ class IREEInput { std::free(data); data = nullptr; } - - if (dims) { - std::free(dims); - dims = nullptr; - } } // Disable copy and move semantics for simplicity @@ -160,10 +153,9 @@ int load_inputs(ErlNifEnv *env, std::vector terms, std::vector dims; xla::PrimitiveType primitive_type; iree_hal_element_type_t type; @@ -192,7 +184,7 @@ int load_inputs(ErlNifEnv *env, std::vector terms, std::vectorrank, - input->dims, + input->dims.size(), + input->dims.data(), input->type, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, buffer_params, From 0c007c82245d24546909d44a6efa0b097aa9e3a8 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sat, 11 May 2024 18:29:04 -0300 Subject: [PATCH 15/40] test: unskip shape error tests --- exla/test/exla/defn/expr_test.exs | 63 +++++++------------------- exla/test/exla/defn/vectorize_test.exs | 2 +- exla/test/test_helper.exs | 3 +- 3 files changed, 20 insertions(+), 48 deletions(-) diff --git a/exla/test/exla/defn/expr_test.exs b/exla/test/exla/defn/expr_test.exs index 514fd8e0c9..640f1db58d 100644 --- a/exla/test/exla/defn/expr_test.exs +++ b/exla/test/exla/defn/expr_test.exs @@ -9,9 +9,9 @@ defmodule EXLA.Defn.ExprTest do end describe "tuples" do - @describetag :iree_shape_mismatch_error defn add_subtract_tuple(a, b), do: {a + b, a - b} + @tag :iree_offset_error test "on results" do assert_equal(add_subtract_tuple(2, 3), {Nx.tensor(5), Nx.tensor(-1)}) @@ -238,8 +238,7 @@ defmodule EXLA.Defn.ExprTest do end describe "//2" do - @describetag :iree_shape_mismatch_error - defn divide_two(a, b), do: a / b + defn divide_two(a, b), do: a / b test "parameters" do tensors = [ @@ -279,7 +278,6 @@ defmodule EXLA.Defn.ExprTest do end describe "remainder" do - @describetag :iree_shape_mismatch_error defn remainder(a, b), do: Nx.remainder(a, b) test "integers" do @@ -289,7 +287,6 @@ defmodule EXLA.Defn.ExprTest do assert_all_close(remainder(left, right), Nx.remainder(left, right)) end - @tag :iree_shape_mismatch_error test "floats" do left = Nx.tensor([-8.3, -8.4, -8.5, 8.3, 8.4, 8.5]) right = Nx.tensor([[-4.2], [-4.1], [-4.0], [4.0], [4.1], [4.2]]) @@ -310,7 +307,6 @@ defmodule EXLA.Defn.ExprTest do defn subtract_two(a, b), do: a - b - @tag :iree_shape_mismatch_error test "-" do for {left, right} <- @tensors do assert_all_close(subtract_two(left, right), Nx.subtract(left, right)) @@ -320,7 +316,6 @@ defmodule EXLA.Defn.ExprTest do defn multiply_two(a, b), do: a * b - @tag :iree_shape_mismatch_error test "*" do for {left, right} <- @tensors do assert_all_close(multiply_two(left, right), Nx.multiply(left, right)) @@ -330,7 +325,6 @@ defmodule EXLA.Defn.ExprTest do defn unary_minus(a), do: -a - @tag :iree_shape_mismatch_error test "negate" do for t <- [ Nx.tensor([-1, 0, 1], type: {:u, 8}), @@ -343,7 +337,6 @@ defmodule EXLA.Defn.ExprTest do defn max_two(a, b), do: max(a, b) - @tag :iree_shape_mismatch_error test "max" do for {left, right} <- @tensors do assert_all_close(max_two(left, right), Nx.max(left, right)) @@ -353,7 +346,6 @@ defmodule EXLA.Defn.ExprTest do defn min_two(a, b), do: min(a, b) - @tag :iree_shape_mismatch_error test "min" do for {left, right} <- @tensors do assert_all_close(min_two(left, right), Nx.min(left, right)) @@ -363,7 +355,6 @@ defmodule EXLA.Defn.ExprTest do defn power_two(a, b), do: Nx.pow(a, b) - @tag :iree_shape_mismatch_error test "pow" do for {left, right} <- @tensors do assert_all_close(power_two(left, right), Nx.pow(left, right)) @@ -385,7 +376,6 @@ defmodule EXLA.Defn.ExprTest do defn quotient_two(a, b), do: Nx.quotient(a, b) - @tag :iree_shape_mismatch_error test "quotient" do int_tensors = [ {1, 2}, @@ -516,7 +506,6 @@ defmodule EXLA.Defn.ExprTest do end describe "equal" do - @describetag :iree_shape_mismatch_error defn equal(a, b), do: Nx.equal(a, b) test "computes equality of scalars" do @@ -548,12 +537,10 @@ defmodule EXLA.Defn.ExprTest do describe "not equal" do defn not_equal(a, b), do: Nx.not_equal(a, b) - @tag :iree_shape_mismatch_error test "computes equality of scalars" do assert_equal(not_equal(Nx.tensor(1), Nx.tensor(2)), Nx.tensor(1, type: {:u, 8})) end - @tag :iree_shape_mismatch_error test "computes equality with broadcasting" do assert_equal( not_equal(Nx.tensor(1), Nx.tensor([1, 2, 3])), @@ -561,7 +548,6 @@ defmodule EXLA.Defn.ExprTest do ) end - @tag :iree_shape_mismatch_error test "computes equality with mixed types" do assert_equal( not_equal(Nx.tensor([1, 2, 3]), Nx.tensor([1.0, 2.0, 3.0])), @@ -571,7 +557,6 @@ defmodule EXLA.Defn.ExprTest do end describe "less" do - @describetag :iree_shape_mismatch_error defn less(a, b), do: Nx.less(a, b) test "compares scalars" do @@ -591,7 +576,6 @@ defmodule EXLA.Defn.ExprTest do end describe "greater" do - @describetag :iree_shape_mismatch_error defn greater(a, b), do: Nx.greater(a, b) test "compares scalars" do @@ -614,7 +598,6 @@ defmodule EXLA.Defn.ExprTest do end describe "less equal" do - @describetag :iree_shape_mismatch_error defn less_equal(a, b), do: Nx.less_equal(a, b) test "compares scalars" do @@ -637,7 +620,6 @@ defmodule EXLA.Defn.ExprTest do end describe "greater equal" do - @describetag :iree_shape_mismatch_error defn greater_equal(a, b), do: Nx.greater_equal(a, b) test "compares scalars" do @@ -660,7 +642,6 @@ defmodule EXLA.Defn.ExprTest do end describe "logical" do - @describetag :iree_shape_mismatch_error defn logical_and(a, b), do: Nx.logical_and(a, b) test "and" do @@ -747,6 +728,7 @@ defmodule EXLA.Defn.ExprTest do defn logical_not(a), do: Nx.logical_not(a) + @tag :iree_key_not_found_error test "not" do assert_equal( logical_not(Nx.tensor([-2, -1, 0, 1, 2])), @@ -763,6 +745,7 @@ defmodule EXLA.Defn.ExprTest do defnp is_finite(x), do: Nx.all(Nx.logical_not(Nx.is_infinity(x))) + @tag :iree_key_not_found_error test "logical and/not with all predicate" do assert_equal(logical_and_all_finite(1, 0, 2.0), Nx.u8(1)) end @@ -798,7 +781,6 @@ defmodule EXLA.Defn.ExprTest do ) end - @tag :iree_shape_mismatch_error test "selects with broadcasting" do assert_equal( select(Nx.tensor([1, 0, 1, 0, 1]), Nx.tensor([10]), Nx.tensor([1, 2, 3, 4, 5])), @@ -813,7 +795,6 @@ defmodule EXLA.Defn.ExprTest do end describe "unary float ops" do - @describetag :iree_shape_mismatch_error @int_tensor Nx.tensor([1, 2, 3]) @float_tensor Nx.tensor([1.0, 2.0, 3.0]) @@ -824,6 +805,7 @@ defmodule EXLA.Defn.ExprTest do defn_var = Macro.var(defn_fun, __MODULE__) defn unquote(defn_fun)(t), do: Nx.unquote(fun)(t) + @tag :iree_type_mismatch_error test "#{fun}" do assert_all_close( unquote(defn_fun)(@float_tensor), @@ -1155,7 +1137,6 @@ defmodule EXLA.Defn.ExprTest do end describe "if" do - @describetag :iree_shape_mismatch_error defn if3(a, b, c), do: if(a, do: b, else: c) test "one param per branch" do @@ -1265,6 +1246,7 @@ defmodule EXLA.Defn.ExprTest do defn if_map(a, b, c), do: if(a, do: {%{a: a, b: b, c: 1}, c}, else: {%{a: c, b: b, c: 2}, a}) + @tag :iree_segfault_error test "with map" do assert_equal( if_map(Nx.tensor(0), Nx.tensor(10), Nx.tensor(20)), @@ -1329,7 +1311,6 @@ defmodule EXLA.Defn.ExprTest do end describe "cond" do - @describetag :iree_shape_mismatch_error defn cond3(a, b, c) do d = Nx.sum(a) @@ -1380,6 +1361,7 @@ defmodule EXLA.Defn.ExprTest do end end + @tag :iree_offset_error test "computes cond with cond as parameter" do assert_equal(nested_cond(Nx.tensor(10)), Nx.tensor(1)) assert_equal(nested_cond(Nx.tensor(-10)), Nx.tensor(0)) @@ -1534,17 +1516,18 @@ defmodule EXLA.Defn.ExprTest do end describe "map" do - @describetag :iree_shape_mismatch_error defn map_plus(t), do: Nx.map(t, fn x -> x + 1 end) defn map_equal(t), do: Nx.map(t, [type: {:f, 64}], fn x -> Nx.equal(x, 1) end) defn map_exp(t), do: Nx.map(t, [type: {:f, 64}], fn x -> Nx.exp(x) end) @tag :unsupported_64_bit_op + @tag :iree_wrong_result_error test "maps a function over the tensor" do assert_equal(map_plus(Nx.tensor([[1, 2, 3], [4, 5, 6]])), Nx.tensor([[2, 3, 4], [5, 6, 7]])) end @tag :unsupported_64_bit_op + @tag :iree_illegal_op_error test "maps a function with an output type" do assert_equal( map_equal(Nx.tensor([[1, 2, 3], [4, 5, 6]])), @@ -1567,6 +1550,7 @@ defmodule EXLA.Defn.ExprTest do @tag :conditional_inside_map_reduce @tag :unsupported_64_bit_op + @tag :iree_illegal_op_error test "maps a function with conditional" do assert_equal( map_conditional(Nx.tensor([-2, -1, 0, 1, 2])), @@ -1587,6 +1571,7 @@ defmodule EXLA.Defn.ExprTest do end end + @tag :iree_key_not_found_error test "while inside if" do assert %{a: a, b: b} = while_inside_if(1, %{a: 1, b: 2.0}) assert_all_close(a, 1) @@ -1833,7 +1818,6 @@ defmodule EXLA.Defn.ExprTest do end describe "indexed_add" do - @describetag :iree_shape_mismatch_error defn indexed_add(t, i, u) do Nx.indexed_add(t, i, u) end @@ -2039,7 +2023,6 @@ defmodule EXLA.Defn.ExprTest do end describe "sum" do - @describetag :iree_shape_mismatch_error defn sum(t), do: Nx.sum(t) test "computes the sum across types" do @@ -2088,7 +2071,6 @@ defmodule EXLA.Defn.ExprTest do end describe "product" do - @describetag :iree_shape_mismatch_error defn product(t), do: Nx.product(t) test "computes the product across types" do @@ -2137,7 +2119,6 @@ defmodule EXLA.Defn.ExprTest do end describe "mean" do - @describetag :iree_shape_mismatch_error defn mean(t), do: Nx.mean(t) test "computes mean without axis" do @@ -2201,7 +2182,6 @@ defmodule EXLA.Defn.ExprTest do end describe "reduce_max" do - @describetag :iree_shape_mismatch_error defn reduce_max(t), do: Nx.reduce_max(t) test "computes the maximum across types" do @@ -2239,7 +2219,6 @@ defmodule EXLA.Defn.ExprTest do defn reduce_max_keep(t), do: Nx.reduce_max(t, keep_axes: true) defn reduce_max_keep_2(t), do: Nx.reduce_max(t, axes: [0, 2], keep_axes: true) - @tag :iree_shape_mismatch_error test "keeps dimensions if keep_axes" do assert_equal(Nx.tensor([1, 2, 3]) |> reduce_max_keep(), Nx.tensor([3])) assert_equal(Nx.tensor([1.0, 2.0, 3.0]) |> reduce_max_keep(), Nx.tensor([3.0])) @@ -2285,7 +2264,7 @@ defmodule EXLA.Defn.ExprTest do defn reduce_min_neg_axis(t), do: Nx.reduce_min(t, axes: [-3]) defn reduce_min_pos_neg_axis(t), do: Nx.reduce_min(t, axes: [1, -3]) - @tag :iree_shape_mismatch_error + @tag :iree_wrong_result_error test "computes the min on a given axis" do t = Nx.tensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]) assert_equal(reduce_min_pos_axis(t), Nx.reduce_min(t, axes: [1])) @@ -2296,7 +2275,7 @@ defmodule EXLA.Defn.ExprTest do defn reduce_min_keep(t), do: Nx.reduce_min(t, keep_axes: true) defn reduce_min_keep_2(t), do: Nx.reduce_min(t, axes: [0, 2], keep_axes: true) - @tag :iree_shape_mismatch_error + @tag :iree_wrong_result_error test "keeps dimensions if keep_axes" do assert_equal(Nx.tensor([1, 2, 3]) |> reduce_min_keep(), Nx.tensor([1])) assert_equal(Nx.tensor([1.0, 2.0, 3.0]) |> reduce_min_keep(), Nx.tensor([1.0])) @@ -2679,7 +2658,6 @@ defmodule EXLA.Defn.ExprTest do end describe "dot product" do - @describetag :iree_shape_mismatch_error defn dot(a, b), do: Nx.dot(a, b) test "computes the dot product of scalars" do @@ -2774,7 +2752,6 @@ defmodule EXLA.Defn.ExprTest do end describe "convolution" do - @describetag :iree_shape_mismatch_error defn conv_valid_no_stride(inp, kernel), do: Nx.conv(inp, kernel) defn conv_valid_stride(inp, kernel), @@ -3439,7 +3416,6 @@ defmodule EXLA.Defn.ExprTest do end describe "put slice" do - @describetag :iree_shape_mismatch_error defn put_slice1(t1, t2), do: Nx.put_slice(t1, [2], t2) defn put_slice2(t1, t2), do: Nx.put_slice(t1, [1, 2], t2) defn put_slice3(t1, t2), do: Nx.put_slice(t1, [2, 2], t2) @@ -3478,7 +3454,6 @@ defmodule EXLA.Defn.ExprTest do end describe "take" do - @describetag :iree_shape_mismatch_error defn take_axis_0(t, idx), do: Nx.take(t, idx) defn take_axis_1(t, idx), do: Nx.take(t, idx, axis: 1) @@ -3541,7 +3516,6 @@ defmodule EXLA.Defn.ExprTest do end describe "gather" do - @describetag :iree_shape_mismatch_error defn gather(t, idx), do: Nx.gather(t, idx) test "1d result" do @@ -3664,7 +3638,6 @@ defmodule EXLA.Defn.ExprTest do end describe "concatenate" do - @describetag :iree_shape_mismatch_error defn concatenate0(t1, t2, t3), do: Nx.concatenate([t1, t2, t3], axis: 0) defn concatenate1(t1, t2, t3), do: Nx.concatenate([t1, t2, t3], axis: 1) defn concatenate2(t1, t2, t3), do: Nx.concatenate([t1, t2, t3], axis: 2) @@ -3971,9 +3944,9 @@ defmodule EXLA.Defn.ExprTest do end describe "top_k" do - @describetag :iree_shape_mismatch_error defn top_1(t), do: Nx.top_k(t, k: 1) + @tag :iree_offset_error test "returns top 1 values and indices" do a = Nx.iota({5}) assert_equal(top_1(a), {Nx.tensor([4]), Nx.tensor([4])}) @@ -3987,7 +3960,7 @@ defmodule EXLA.Defn.ExprTest do end describe "argsort" do - @describetag :iree_segfault_error + @describetag :iree_offset_error defn argsort0(t), do: Nx.argsort(t, axis: 0) defn argsort1(t), do: Nx.argsort(t, axis: 1) defn argsort1_asc(t), do: Nx.argsort(t, axis: 1, direction: :asc) @@ -4060,9 +4033,9 @@ defmodule EXLA.Defn.ExprTest do end describe "cholesky" do - @describetag :iree_shape_mismatch_error defn cholesky(t), do: Nx.LinAlg.cholesky(t) + @tag :iree_key_not_found_error test "works on 2x2 matrix" do lhs = cholesky(Nx.tensor([[20.0, 17.6], [17.6, 16.0]])) rhs = Nx.tensor([[4.47213595499958, 0.0], [3.93547964039963, 0.7155417527999305]]) @@ -4096,6 +4069,7 @@ defmodule EXLA.Defn.ExprTest do assert_all_close(lhs, rhs) end + @tag :iree_key_not_found_error test "works on a 50x50 matrix" do tensor = Nx.tensor( @@ -4115,7 +4089,6 @@ defmodule EXLA.Defn.ExprTest do describe "bfloat16" do defn add(t1, t2), do: t1 + t2 - @tag :iree_shape_mismatch_error test "accepts bfloat16 input" do lhs = Nx.tensor([1.0, 2.0, 3.0], type: {:bf, 16}) rhs = Nx.tensor([4.0, 5.0, 6.0], type: {:bf, 16}) @@ -4124,7 +4097,6 @@ defmodule EXLA.Defn.ExprTest do end describe "precision" do - @describetag :iree_shape_mismatch_error defn precision(t1, t2), do: Nx.dot(t1, t2) test "raises on bad precision" do @@ -4154,7 +4126,6 @@ defmodule EXLA.Defn.ExprTest do end describe "take_along_axis/3" do - @describetag :iree_shape_mismatch_error defn take_along_axis(t, idx, opts \\ [axis: 0]), do: Nx.take_along_axis(t, idx, opts) defn sort_with_take_along_axis(t, opts \\ []) do diff --git a/exla/test/exla/defn/vectorize_test.exs b/exla/test/exla/defn/vectorize_test.exs index 30a51fc548..a009dfe04d 100644 --- a/exla/test/exla/defn/vectorize_test.exs +++ b/exla/test/exla/defn/vectorize_test.exs @@ -4,7 +4,7 @@ defmodule EXLA.Defn.VectorizeTest do import Nx.Defn import Nx, only: :sigils - @moduletag :iree_shape_mismatch_error + @moduletag :iree_hangup_error setup do Nx.default_backend(EXLA.Backend) diff --git a/exla/test/test_helper.exs b/exla/test/test_helper.exs index 5610669ff6..5420bbbb24 100644 --- a/exla/test/test_helper.exs +++ b/exla/test/test_helper.exs @@ -45,13 +45,14 @@ iree_excludes = [ :token, :iree_hangup_error, - :iree_shape_mismatch_error, + :iree_type_mismatch_error, :iree_resource_exhausted_error, :iree_key_not_found_error, :iree_wrong_result_error, :iree_unsupported_fft_error, :iree_segfault_error, :iree_illegal_op_error, + :iree_offset_error, :multi_device ] else From 77b9096a22c8b28c4009caae09100cd74ae95a3a Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sun, 12 May 2024 23:41:37 -0300 Subject: [PATCH 16/40] refactor: use single iree instance --- exla/c_src/exla/iree/iree.cc | 6 +- exla/c_src/exla/iree/runtime.cc | 119 +++++++++++++++-------------- exla/c_src/exla/iree/runtime.h | 120 +++++++++++++++++++++++++++++- exla/lib/exla/application.ex | 2 +- exla/lib/exla/defn.ex | 1 + exla/lib/exla/executable.ex | 23 ++++-- exla/lib/exla/mlir/iree.ex | 19 ++++- exla/test/exla/defn/api_test.exs | 13 ++-- exla/test/exla/defn/expr_test.exs | 2 + 9 files changed, 234 insertions(+), 71 deletions(-) diff --git a/exla/c_src/exla/iree/iree.cc b/exla/c_src/exla/iree/iree.cc index 7f7bfe7a06..74f61c6377 100644 --- a/exla/c_src/exla/iree/iree.cc +++ b/exla/c_src/exla/iree/iree.cc @@ -16,7 +16,8 @@ static ErlNifFunc iree_funcs[] = { // MLIR Builder {"global_initialize", 0, global_initialize}, {"compile", 2, compile, ERL_NIF_DIRTY_JOB_CPU_BOUND}, - {"run_module", 2, run_module, ERL_NIF_DIRTY_JOB_IO_BOUND}}; + {"runtime_create_instance", 0, runtime_create_instance}, + {"run_module", 3, run_module, ERL_NIF_DIRTY_JOB_IO_BOUND}}; static int open_resources(ErlNifEnv *env) { const char *mod = "EXLA"; @@ -24,6 +25,9 @@ static int open_resources(ErlNifEnv *env) { if (!exla::nif::open_resource(env, mod, "ExlaMLIRModule")) { return -1; } + if (!exla::nif::open_resource(env, mod, "ExlaIREERuntimeInstance")) { + return -1; + } return 1; } diff --git a/exla/c_src/exla/iree/runtime.cc b/exla/c_src/exla/iree/runtime.cc index 5435a4b188..00cef2b3c6 100644 --- a/exla/c_src/exla/iree/runtime.cc +++ b/exla/c_src/exla/iree/runtime.cc @@ -1,8 +1,7 @@ #include "runtime.h" -#include -#include -#include +#include +#include class IREEInput { public: @@ -308,16 +307,22 @@ ERL_NIF_TERM return_results(ErlNifEnv *env, std::vector bytecode_vec = {}; std::vector input_terms = {}; std::vector inputs = {}; std::vector bytecode = {}; + exla::iree::runtime::Instance **instance; + iree_status_t status; - if (!exla::nif::get_list(env, argv[0], bytecode_vec)) { + if (!exla::nif::get(env, argv[0], instance)) { + return exla::nif::error(env, "Unable to get instance"); + } + + if (!exla::nif::get_list(env, argv[1], bytecode_vec)) { return exla::nif::error(env, "Unable to load bytecode binary"); } @@ -328,7 +333,7 @@ run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { bytecode[i] = static_cast(byte); } - if (!exla::nif::get_list(env, argv[1], input_terms)) { + if (!exla::nif::get_list(env, argv[2], input_terms)) { return exla::nif::error(env, "Unable to load input terms"); } @@ -336,67 +341,73 @@ run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { return exla::nif::error(env, "Unable to decode input terms"); } - iree_runtime_instance_options_t instance_options; - iree_runtime_instance_options_initialize(&instance_options); - iree_runtime_instance_options_use_all_available_drivers(&instance_options); - iree_runtime_instance_t *instance = NULL; - iree_status_t status = iree_runtime_instance_create(&instance_options, iree_allocator_system(), &instance); + exla::iree::runtime::Session *session = new exla::iree::runtime::Session(*instance); + status = session->initialize(bytecode); - iree_hal_device_t *device = NULL; - char device_uri[] = "metal://0000000100000971"; // TO-DO: change this to an argument - if (iree_status_is_ok(status)) { - status = iree_hal_create_device( - iree_runtime_instance_driver_registry(instance), - iree_make_cstring_view(device_uri), - iree_runtime_instance_host_allocator(instance), &device); + if (!iree_status_is_ok(status)) { + return exla::nif::error(env, "Failed to initialize IREE runtime session"); } - iree_runtime_session_t *session = NULL; - if (iree_status_is_ok(status)) { - iree_runtime_session_options_t session_options; - iree_runtime_session_options_initialize(&session_options); - status = iree_runtime_session_create_with_device( - instance, &session_options, device, - iree_runtime_instance_host_allocator(instance), &session); - } + std::vector> results; + status = call_module(session->get(), inputs, &results); - iree_const_byte_span_t span{.data = bytecode.data(), .data_length = bytecode.size()}; + if (!iree_status_is_ok(status)) { + // Dump nice status messages to stderr on failure. + // An application can route these through its own logging infrastructure as + // needed. Note that the status is a handle and must be freed! - if (iree_status_is_ok(status)) { - status = iree_runtime_session_append_bytecode_module_from_memory(session, span, iree_runtime_instance_host_allocator(instance)); - } + char *status_string = NULL; + size_t status_length = 0; - std::vector> results; - if (iree_status_is_ok(status)) { - // this is where we actually call code - // status = iree_runtime_demo_perform_mul(session); - status = call_module(session, inputs, &results); - } + auto system_allocator = iree_allocator_system(); + + iree_status_to_string(status, &system_allocator, &status_string, &status_length); + + std::stringstream ss; + ss << "Failed to execute IREE runtime due to error: "; + ss << status_string; + iree_status_ignore(status); - if (session) { - // Release the session and free all cached resources. - iree_runtime_session_release(session); + return exla::nif::error(env, ss.str().c_str()); } - if (device) { - // Release shared device once all sessions using it have been released. - iree_hal_device_release(device); + return return_results(env, results); +} + +ERL_NIF_TERM runtime_create_instance(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { + if (argc != 0) { + return exla::nif::error(env, "Bad argument count."); } - if (instance) { - // Release the shared instance - it will be deallocated when all sessions - // using it have been released (here it is deallocated immediately). - iree_runtime_instance_release(instance); + iree_runtime_instance_options_t instance_options; + iree_runtime_instance_options_initialize(&instance_options); + iree_runtime_instance_options_use_all_available_drivers(&instance_options); + iree_runtime_instance_t *instance_ptr = NULL; + iree_status_t status = iree_runtime_instance_create(&instance_options, iree_allocator_system(), &instance_ptr); + + if (!iree_status_is_ok(status)) { + iree_runtime_instance_release(instance_ptr); + return exla::nif::error(env, "Failed to create IREE runtime instance"); } + iree_hal_device_t *device_ptr = NULL; + char device_uri[] = "metal://0000000100000971"; // TO-DO: change this to an argument + status = iree_hal_create_device( + iree_runtime_instance_driver_registry(instance_ptr), + iree_make_cstring_view(device_uri), + iree_runtime_instance_host_allocator(instance_ptr), &device_ptr); + if (!iree_status_is_ok(status)) { - // Dump nice status messages to stderr on failure. - // An application can route these through its own logging infrastructure as - // needed. Note that the status is a handle and must be freed! - iree_status_fprint(stderr, status); - iree_status_ignore(status); - return exla::nif::error(env, "Failed to execute IREE runtime"); + if (device_ptr) { + iree_hal_device_release(device_ptr); + } + if (instance_ptr) { + iree_runtime_instance_release(instance_ptr); + } + return exla::nif::error(env, "Failed to create IREE device instance"); } - return return_results(env, results); + exla::iree::runtime::Instance *instance = new exla::iree::runtime::Instance(instance_ptr, device_ptr); + + return exla::nif::ok(env, exla::nif::make(env, instance)); } \ No newline at end of file diff --git a/exla/c_src/exla/iree/runtime.h b/exla/c_src/exla/iree/runtime.h index 240f541b5a..a083fdba4f 100644 --- a/exla/c_src/exla/iree/runtime.h +++ b/exla/c_src/exla/iree/runtime.h @@ -1,4 +1,122 @@ #pragma once +#include +#include +#include + +#include + #include "../exla_nif_util.h" -ERL_NIF_TERM run_module(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); \ No newline at end of file +ERL_NIF_TERM run_module(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); +ERL_NIF_TERM runtime_create_instance(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); +ERL_NIF_TERM runtime_create_session(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); + +namespace exla { +namespace iree { +namespace runtime { + +template +struct IREEDeleter { + void operator()(T* ptr) { + if (ptr) { + ReleaseFunc(ptr); // Call the specific release function + } + } +}; + +using IREEInstanceDeleter = IREEDeleter; +using IREEDeviceDeleter = IREEDeleter; +using IREESessionDeleter = IREEDeleter; + +class Instance { + public: + // Constructor + explicit Instance(iree_runtime_instance_t* instance, iree_hal_device_t* device) + : instance_(instance, IREEInstanceDeleter{}), device_(device, IREEDeviceDeleter{}) {} + + // Default destructor is fine, unique_ptr will handle the resource release + ~Instance() = default; + + // Copy and move operations are disabled to maintain unique ownership semantics + Instance(const Instance&) = delete; + Instance& operator=(const Instance&) = delete; + Instance(Instance&&) noexcept = default; + Instance& operator=(Instance&&) noexcept = default; + + iree_runtime_instance_t* get() const { + return instance_.get(); + } + + iree_runtime_instance_t* operator->() const { + return instance_.get(); + } + + iree_hal_device_t* device() const { + return device_.get(); + } + + private: + std::unique_ptr instance_; + std::unique_ptr device_; +}; + +class Session { + public: + // Constructor + explicit Session(Instance* instance) : instance_(instance) {} + + iree_status_t initialize(std::vector bytecode) { + iree_runtime_session_options_t session_options; + iree_runtime_session_options_initialize(&session_options); + + iree_runtime_session_t* session_ptr; + + iree_allocator_t host_allocator = iree_runtime_instance_host_allocator(instance_->get()); + iree_status_t status = iree_runtime_session_create_with_device( + instance_->get(), &session_options, instance_->device(), + host_allocator, &session_ptr); + + if (!iree_status_is_ok(status)) { + return status; + } + + session_.reset(session_ptr); + + iree_const_byte_span_t span{.data = bytecode.data(), .data_length = bytecode.size()}; + + status = iree_runtime_session_append_bytecode_module_from_memory(session_.get(), span, host_allocator); + + if (!iree_status_is_ok(status)) { + return status; + } + + return status; + } + + // Default destructor is fine, unique_ptr will handle the resource release + ~Session() = default; + + // Copy and move operations are disabled to maintain unique ownership semantics + Session(const Session&) = delete; + Session& operator=(const Session&) = delete; + Session(Session&&) noexcept = default; + Session& operator=(Session&&) noexcept = default; + + // Provide a way to access the underlying pointer like a raw pointer + iree_runtime_session_t* get() const { + return session_.get(); + } + + // Overload the arrow operator to enable direct member access to the iree_runtime_session_t + iree_runtime_session_t* operator->() const { + return session_.get(); + } + + private: + Instance* instance_; + std::unique_ptr session_; +}; + +} // namespace runtime +} // namespace iree +}; // namespace exla \ No newline at end of file diff --git a/exla/lib/exla/application.ex b/exla/lib/exla/application.ex index 82c31210fb..c1caa5ee54 100644 --- a/exla/lib/exla/application.ex +++ b/exla/lib/exla/application.ex @@ -10,7 +10,7 @@ defmodule EXLA.Application do _ -> :os.set_signal(:sigchld, :default) end - EXLA.MLIR.IREE.global_initialize() + EXLA.MLIR.IREE.init() children = [ EXLA.Logger, diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index b5b46cffdc..49539f4227 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -483,6 +483,7 @@ defmodule EXLA.Defn do if compiler_mode == :iree do {:ok, module_charlist} = EXLA.NIF.mlir_module_to_string(builder.module.ref) + {:ok, module_bytecode} = EXLA.MLIR.IREE.compile(module_charlist, "metal") %EXLA.Executable{ diff --git a/exla/lib/exla/executable.ex b/exla/lib/exla/executable.ex index 74d1d7572d..a200e58660 100644 --- a/exla/lib/exla/executable.ex +++ b/exla/lib/exla/executable.ex @@ -99,6 +99,8 @@ defmodule EXLA.Executable do end defp run(:iree, _client, ref, device_id, inputs, _options) do + dbg() + inputs = for subinputs <- inputs do Enum.map(subinputs, fn @@ -109,7 +111,12 @@ defmodule EXLA.Executable do target_type = {t, w2} data = - data |> Nx.from_binary(typespec.type) |> Nx.as_type(target_type) |> Nx.to_binary() + Nx.with_default_backend(Nx.BinaryBackend, fn -> + data + |> Nx.from_binary(typespec.type) + |> Nx.as_type(target_type) + |> Nx.to_binary() + end) data = <> @@ -121,7 +128,7 @@ defmodule EXLA.Executable do end ref - |> EXLA.MLIR.IREE.run_module(List.flatten(inputs)) + |> EXLA.MLIR.IREE.run(List.flatten(inputs)) |> unwrap!() |> then(&[{&1, device_id}]) end @@ -134,11 +141,13 @@ defmodule EXLA.Executable do if source_typespec == target_typespec do BinaryBuffer.from_binary(buf, target_typespec) else - buf - |> Nx.from_binary(source_typespec.type) - |> Nx.as_type(target_typespec.type) - |> Nx.to_binary() - |> BinaryBuffer.from_binary(target_typespec) + Nx.with_default_backend(Nx.BinaryBackend, fn -> + buf + |> Nx.from_binary(source_typespec.type) + |> Nx.as_type(target_typespec.type) + |> Nx.to_binary() + |> BinaryBuffer.from_binary(target_typespec) + end) end buf, typespec when is_reference(buf) -> diff --git a/exla/lib/exla/mlir/iree.ex b/exla/lib/exla/mlir/iree.ex index 0d6c932e0a..6416c7a110 100644 --- a/exla/lib/exla/mlir/iree.ex +++ b/exla/lib/exla/mlir/iree.ex @@ -7,9 +7,26 @@ defmodule EXLA.MLIR.IREE do :erlang.load_nif(path, 0) end + def init do + global_initialize() + {:ok, instance} = runtime_create_instance() + :persistent_term.put({__MODULE__, :instance}, instance) + dbg(instance) + + :ok + end + + def run(module, inputs) do + instance = :persistent_term.get({__MODULE__, :instance}) + dbg(instance) + run_module(instance, module, inputs) + end + def compile(_module, _target), do: :erlang.nif_error(:undef) def global_initialize, do: :erlang.nif_error(:undef) - def run_module(_module, _inputs), do: :erlang.nif_error(:undef) + def runtime_create_instance, do: :erlang.nif_error(:undef) + + def run_module(_instance, _module, _inputs), do: :erlang.nif_error(:undef) end diff --git a/exla/test/exla/defn/api_test.exs b/exla/test/exla/defn/api_test.exs index 2db80ca51c..314074e118 100644 --- a/exla/test/exla/defn/api_test.exs +++ b/exla/test/exla/defn/api_test.exs @@ -4,7 +4,7 @@ defmodule EXLA.Defn.APITest do import Nx.Defn import ExUnit.CaptureLog - defn(add_two(a, b), do: a + b) + defn add_two(a, b), do: a + b describe "multi-client" do @describetag :iree_key_not_found_error @@ -130,7 +130,7 @@ defmodule EXLA.Defn.APITest do describe "stream" do @describetag :token - defn(defn_sum(entry, acc), do: {acc, entry + acc}) + defn defn_sum(entry, acc), do: {acc, entry + acc} test "immediately done" do stream = EXLA.stream(&defn_sum/2, [0, 0]) @@ -182,7 +182,7 @@ defmodule EXLA.Defn.APITest do assert_equal(Nx.Stream.done(stream), {Nx.tensor(3), {Nx.tensor(2), Nx.tensor(4)}}) end - defn(stream_empty_outfeed(i, t), do: {{}, i + t}) + defn stream_empty_outfeed(i, t), do: {{}, i + t} test "send/recv with empty outfeed" do %_{} = stream = EXLA.stream(&stream_empty_outfeed/2, [0, 0.0]) @@ -195,7 +195,7 @@ defmodule EXLA.Defn.APITest do assert_equal(Nx.Stream.done(stream), Nx.tensor(3.0)) end - defn(stream_empty_acc(i, {}), do: {i * i, {}}) + defn stream_empty_acc(i, {}), do: {i * i, {}} test "send/recv with empty acc" do %_{} = stream = EXLA.stream(&stream_empty_acc/2, [0, {}]) @@ -414,7 +414,7 @@ defmodule EXLA.Defn.APITest do assert_equal(b, Nx.tensor(2)) end - defn(hook_stream(entry, acc), do: hook({acc, entry + acc}, :stream)) + defn hook_stream(entry, acc), do: hook({acc, entry + acc}, :stream) test "executes hook with stream" do %_{} = stream = EXLA.stream(&hook_stream/2, [0, 0], hooks: %{stream: send_to_self(:tag)}) @@ -438,7 +438,8 @@ defmodule EXLA.Defn.APITest do end describe "telemetry" do - defn(telemetry_add_two(a, b), do: a + b) + @describetag :iree_segfault_error + defn telemetry_add_two(a, b), do: a + b def telemetry_handler(_event_name, measurements, metadata, _config) do send(self(), {measurements, metadata}) diff --git a/exla/test/exla/defn/expr_test.exs b/exla/test/exla/defn/expr_test.exs index 640f1db58d..401d199f8a 100644 --- a/exla/test/exla/defn/expr_test.exs +++ b/exla/test/exla/defn/expr_test.exs @@ -2143,6 +2143,7 @@ defmodule EXLA.Defn.ExprTest do defn mean_over_multiple_axes(t), do: Nx.mean(t, axes: [0, 2]) + @tag :iree_segfault_error test "computes mean over multiple axes" do assert_equal( mean_over_multiple_axes(Nx.tensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])), @@ -3378,6 +3379,7 @@ defmodule EXLA.Defn.ExprTest do defn slice3_dynamic(t), do: Nx.slice(t, [Nx.tensor(0), Nx.tensor(4), Nx.tensor(11)], [2, 3, 9], strides: [2, 1, 3]) + @tag :iree_segfault_error test "works without stride" do t = Nx.iota({900}) t = Nx.reshape(t, {2, 15, 30}) From f52a6e6a62a528eebae4a6e623ec42c97c5b1276 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Mon, 13 May 2024 02:08:47 -0300 Subject: [PATCH 17/40] wip --- exla/c_src/exla/exla_nif_util.cc | 1 + exla/c_src/exla/iree/runtime.cc | 16 +++++--- exla/lib/exla/defn.ex | 48 ++++++++++++------------ exla/lib/exla/executable.ex | 2 - exla/lib/exla/mlir/function.ex | 2 +- exla/lib/exla/mlir/iree.ex | 2 - exla/test/exla/backend_test.exs | 64 +++++++++++++++++++++++++++++++- exla/test/support/exla_case.ex | 4 ++ exla/test/test_helper.exs | 4 +- 9 files changed, 105 insertions(+), 38 deletions(-) diff --git a/exla/c_src/exla/exla_nif_util.cc b/exla/c_src/exla/exla_nif_util.cc index 9a22dc106b..8e32621e69 100644 --- a/exla/c_src/exla/exla_nif_util.cc +++ b/exla/c_src/exla/exla_nif_util.cc @@ -315,6 +315,7 @@ int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var) { if (!enif_get_list_length(env, list, &length)) { return 0; } + var.clear(); var.reserve(length); ERL_NIF_TERM head, tail; diff --git a/exla/c_src/exla/iree/runtime.cc b/exla/c_src/exla/iree/runtime.cc index 00cef2b3c6..691358ff9f 100644 --- a/exla/c_src/exla/iree/runtime.cc +++ b/exla/c_src/exla/iree/runtime.cc @@ -214,16 +214,12 @@ iree_status_t call_module(iree_runtime_session_t *session, std::vectorget_function(function.module->self, IREE_VM_FUNCTION_LINKAGE_EXPORT, function.ordinal, &export_function, &export_function_name, &export_function_signature)); - iree_vm_function_signature_t signature = iree_vm_function_signature(&function); iree_string_view_t arguments; @@ -231,6 +227,9 @@ iree_status_t call_module(iree_runtime_session_t *session, std::vectorclear(); + result->reserve(size); + for (iree_vm_size_t i = 0; i < size; i++) { iree_hal_buffer_view_t *buffer_view = nullptr; iree_vm_ref_t ref = iree_vm_ref_null(); @@ -287,6 +289,7 @@ ERL_NIF_TERM return_results(ErlNifEnv *env, std::vector nif_terms; + nif_terms.clear(); nif_terms.reserve(n); for (auto [iree_type, binary] : results) { @@ -326,6 +329,7 @@ run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { return exla::nif::error(env, "Unable to load bytecode binary"); } + bytecode.clear(); bytecode.resize(bytecode_vec.size()); unsigned int byte; for (int i = 0; i < bytecode_vec.size(); i++) { diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 49539f4227..02a26d5af7 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -39,7 +39,7 @@ defmodule EXLA.Defn do client = EXLA.Client.fetch!(client_name) compile_options = Keyword.put(compile_options, :lazy_transfers, :never) - compile_options = Keyword.put_new(compile_options, :compiler_mode, :iree) + compile_options = Keyword.put_new(compile_options, :runtime, :iree) input_length = length(Nx.Defn.Composite.flatten_list([input])) acc_length = length(Nx.Defn.Composite.flatten_list([acc])) @@ -146,7 +146,7 @@ defmodule EXLA.Defn do outfeed, options ) do - if builder.compiler == :iree do + if builder.runtime == :iree do raise ArgumentError, "streaming not supported when compiling with IREE" end @@ -258,7 +258,7 @@ defmodule EXLA.Defn do {client_name, compile_options} = Keyword.pop_lazy(compile_options, :client, &EXLA.Client.default_name/0) - compile_options = Keyword.put_new(compile_options, :compiler_mode, :iree) + compile_options = Keyword.put_new(compile_options, :runtime, :iree) client = EXLA.Client.fetch!(client_name) @@ -302,14 +302,14 @@ defmodule EXLA.Defn do raise ArgumentError, "missing client" end - compiler_mode = Keyword.fetch!(options, :compiler_mode) + runtime = Keyword.fetch!(options, :runtime) - unless compiler_mode do - raise ArgumentError, "missing compiler_mode" + unless runtime do + raise ArgumentError, "missing runtime" end state_params = - if compiler_mode == :iree do + if runtime == :iree do Map.new(params) else Map.new(params ++ outfeed.infeeds) @@ -323,7 +323,7 @@ defmodule EXLA.Defn do scope_ids: Tree.scope_ids(expr) } - if compiler_mode == :iree do + if runtime == :iree do {res, _cache} = recur_flatten(expr, state, no_token_cache()) Value.return(function, res) {:ok, nil} @@ -397,7 +397,7 @@ defmodule EXLA.Defn do end {debug?, options} = Keyword.pop(options, :debug, false) - {compiler_mode, options} = Keyword.pop(options, :compiler_mode) + {runtime, options} = Keyword.pop(options, :runtime) {args_key, reverse_args_identifiers} = Enum.map_reduce(vars, [], fn var, acc -> @@ -462,10 +462,10 @@ defmodule EXLA.Defn do end) EXLA.MLIR.Module.new(comp_arg_typespecs, out_typespecs, fn builder -> - builder = %EXLA.MLIR.Function{builder | compiler: compiler_mode} + builder = %EXLA.MLIR.Function{builder | runtime: runtime} outfeed = - if compiler_mode != :iree do + if runtime != :iree do outfeed |> Outfeed.with_token(Value.create_token(builder)) |> Outfeed.add_infeeds(builder, reverse_infeeds) @@ -481,7 +481,7 @@ defmodule EXLA.Defn do typespecs = for {i, typespec} <- inputs_and_typespecs, i >= used_buffers, do: typespec - if compiler_mode == :iree do + if runtime == :iree do {:ok, module_charlist} = EXLA.NIF.mlir_module_to_string(builder.module.ref) {:ok, module_bytecode} = EXLA.MLIR.IREE.compile(module_charlist, "metal") @@ -577,7 +577,7 @@ defmodule EXLA.Defn do [initial_arg, _arg, pred, body] = args initial_with_token = - if state.builder.compiler == :iree do + if state.builder.runtime == :iree do initial_arg else [get_token(cache), initial_arg] @@ -590,7 +590,7 @@ defmodule EXLA.Defn do output = Value.while(function, pred_computation, body_computation, List.flatten(initial)) - case state.builder.compiler do + case state.builder.runtime do :iree -> result = wrap_tuple_result(output, initial_arg) {result, cache} @@ -644,10 +644,10 @@ defmodule EXLA.Defn do ] } }, - %{client: %EXLA.Client{platform: :host}, builder: %Function{compiler: compiler}} = state, + %{client: %EXLA.Client{platform: :host}, builder: %Function{runtime: runtime}} = state, cache ) - when type_kind != :c and compiler != :iree do + when type_kind != :c and runtime != :iree do # We match only on platform: :host for MLIR, as we want to support # QR-on-cpu as a custom call only in this case {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() @@ -720,7 +720,7 @@ defmodule EXLA.Defn do {computation, Map.put(cache, key, computation)} end - if state.builder.compiler == :iree do + if state.builder.runtime == :iree do typespecs = container_to_typespecs(expr) result = Value.call(state.builder, call_args, call_body, typespecs) @@ -744,7 +744,7 @@ defmodule EXLA.Defn do defp cached_recur_operator( :token, %T{data: %Expr{args: [_token]}}, - %{builder: %Function{compiler: :iree}}, + %{builder: %Function{runtime: :iree}}, cache ) do {[], cache} @@ -1693,7 +1693,7 @@ defmodule EXLA.Defn do name, args, expr, - %{builder: %Function{compiler: compiler}} = state, + %{builder: %Function{runtime: runtime}} = state, cache ) do %Function{module: module, name: name} = subbuilder(state.builder, name) @@ -1703,14 +1703,14 @@ defmodule EXLA.Defn do out_typespecs = container_to_typespecs(expr) in_types = - if compiler == :iree do + if runtime == :iree do arg_typespecs else [token_typespec | arg_typespecs] end out_types = - if compiler == :iree do + if runtime == :iree do out_typespecs else [token_typespec | out_typespecs] @@ -1719,12 +1719,12 @@ defmodule EXLA.Defn do function = EXLA.MLIR.Module.add_function(module, name, in_types, out_types) - function = %{function | compiler: compiler} + function = %{function | runtime: runtime} [arg_token | tail] = EXLA.MLIR.Function.get_arguments(function) params = - if compiler == :iree do + if runtime == :iree do Enum.with_index(tail, fn param, i -> {i, param} end) else Enum.with_index(tail, fn param, i -> {i, param} end) @@ -1737,7 +1737,7 @@ defmodule EXLA.Defn do scope_ids: Tree.scope_ids(expr) } - if compiler == :iree do + if runtime == :iree do {res, comp_cache} = recur_composite(expr, state, cache) Value.return(function, List.flatten(res)) {function, merge_outfeed(cache, comp_cache)} diff --git a/exla/lib/exla/executable.ex b/exla/lib/exla/executable.ex index a200e58660..7b83d07901 100644 --- a/exla/lib/exla/executable.ex +++ b/exla/lib/exla/executable.ex @@ -99,8 +99,6 @@ defmodule EXLA.Executable do end defp run(:iree, _client, ref, device_id, inputs, _options) do - dbg() - inputs = for subinputs <- inputs do Enum.map(subinputs, fn diff --git a/exla/lib/exla/mlir/function.ex b/exla/lib/exla/mlir/function.ex index 83e01f9bf3..6f27b9fd7e 100644 --- a/exla/lib/exla/mlir/function.ex +++ b/exla/lib/exla/mlir/function.ex @@ -2,7 +2,7 @@ defmodule EXLA.MLIR.Function do @moduledoc false # Representation of an MLIR Function or `func.func` type. - defstruct [:module, :ref, :name, :return_typespecs, :compiler] + defstruct [:module, :ref, :name, :return_typespecs, :runtime] alias __MODULE__, as: Function alias EXLA.MLIR.Value diff --git a/exla/lib/exla/mlir/iree.ex b/exla/lib/exla/mlir/iree.ex index 6416c7a110..ddeb0b9bbb 100644 --- a/exla/lib/exla/mlir/iree.ex +++ b/exla/lib/exla/mlir/iree.ex @@ -11,14 +11,12 @@ defmodule EXLA.MLIR.IREE do global_initialize() {:ok, instance} = runtime_create_instance() :persistent_term.put({__MODULE__, :instance}, instance) - dbg(instance) :ok end def run(module, inputs) do instance = :persistent_term.get({__MODULE__, :instance}) - dbg(instance) run_module(instance, module, inputs) end diff --git a/exla/test/exla/backend_test.exs b/exla/test/exla/backend_test.exs index 54ea6bfac9..6afa289fcb 100644 --- a/exla/test/exla/backend_test.exs +++ b/exla/test/exla/backend_test.exs @@ -27,10 +27,70 @@ defmodule EXLA.BackendTest do @skip_mac_arm [] end - @moduletag :iree_hangup_error + if iree_runtime?() do + @skip_iree [ + count_leading_zeros: 1, + window_min: 3, + window_max: 3, + window_mean: 3, + window_sum: 3, + window_product: 3, + population_count: 1, + window_scatter_max: 5, + window_scatter_min: 5, + fft: 2, + fft2: 2, + ifft: 2, + ifft2: 2, + all_close: 3, + take_diagonal: 2, + take_along_axis: 3, + gather: 3, + mean: 2, + sum: 2, + product: 2, + negate: 1, + reduce: 4, + reduce_min: 2, + reduce_max: 2, + equal: 2, + sigil_M: 2, + slice: 4, + atan2: 2, + weighted_mean: 3, + indexed_add: 4, + concatenate: 2, + stack: 2, + reshape_vectors: 2, + divide: 2, + mode: 2, + conv: 3, + put_slice: 3, + vectorize: 2, + argsort: 2, + sort: 2, + log2: 1, + select: 3, + pad: 3, + tile: 2, + variance: 2, + standard_deviation: 2, + cumulative_min: 2, + cumulative_max: 2, + dot: 6, + linspace: 3, + bitwise_xor: 2, + broadcast: 3, + axis_index: 2, + size: 1, + nex_axis: 3 + ] + else + @skip_iree [] + end doctest Nx, - except: [:moduledoc] ++ @excluded_doctests ++ @skip_mac_arm + except: [:moduledoc] ++ @excluded_doctests ++ @skip_mac_arm ++ @skip_iree test "Nx.to_binary/1" do t = Nx.tensor([1, 2, 3, 4], backend: EXLA.Backend) diff --git a/exla/test/support/exla_case.ex b/exla/test/support/exla_case.ex index fa12092575..20823c0fef 100644 --- a/exla/test/support/exla_case.ex +++ b/exla/test/support/exla_case.ex @@ -47,4 +47,8 @@ defmodule EXLA.Case do def is_mac_arm? do Application.fetch_env!(:exla, :is_mac_arm) end + + def iree_runtime? do + Nx.Defn.default_options()[:runtime] == :iree + end end diff --git a/exla/test/test_helper.exs b/exla/test/test_helper.exs index 5420bbbb24..4a6f741f73 100644 --- a/exla/test/test_helper.exs +++ b/exla/test/test_helper.exs @@ -5,7 +5,9 @@ if System.get_env("DEBUG") in ["1", "true"] do IO.gets("Press enter to continue... -- PID: #{System.pid()}") end -Nx.Defn.global_default_options(compiler: EXLA) +runtime = if System.get_env("EXLA_RUNTIME") == "iree", do: :iree, else: :xla + +Nx.Defn.global_default_options(compiler: EXLA, runtime: runtime) exclude_multi_device = if client.device_count > 1 and client.platform == :host, do: [], else: [:multi_device] From f9324dcec6c69a55b7ec1161992bc7d37c3758e0 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 14 May 2024 02:57:43 -0300 Subject: [PATCH 18/40] wip --- exla/c_src/exla/iree/runtime.cc | 38 +++++------------ exla/c_src/exla/iree/runtime.h | 54 ++++++++++-------------- exla/lib/exla/application.ex | 7 ++- exla/lib/exla/mlir/iree.ex | 15 +++---- exla/lib/exla/mlir/iree/instance_pool.ex | 37 ++++++++++++++++ 5 files changed, 82 insertions(+), 69 deletions(-) create mode 100644 exla/lib/exla/mlir/iree/instance_pool.ex diff --git a/exla/c_src/exla/iree/runtime.cc b/exla/c_src/exla/iree/runtime.cc index 691358ff9f..c4d3d23116 100644 --- a/exla/c_src/exla/iree/runtime.cc +++ b/exla/c_src/exla/iree/runtime.cc @@ -210,39 +210,19 @@ iree_status_t iree_input_to_hal_arg(iree_hal_buffer_view_t **arg, IREEInput *inp arg); } -iree_status_t call_module(iree_runtime_session_t *session, std::vector inputs, std::vector> *result) { +iree_status_t call_module(exla::iree::runtime::Session *session, std::vector inputs, std::vector> *result) { iree_runtime_call_t call; - iree_vm_function_t function; - std::cout << "inputs.size(): " << inputs.size() << "\n"; + IREE_RETURN_IF_ERROR(iree_runtime_call_initialize_by_name(session->get(), iree_make_cstring_view("module.main"), &call)); - IREE_RETURN_IF_ERROR(iree_runtime_session_lookup_function(session, iree_make_cstring_view("module.main"), &function)); - - IREE_RETURN_IF_ERROR(iree_runtime_call_initialize(session, function, &call)); - - iree_vm_function_signature_t signature = iree_vm_function_signature(&function); - - iree_string_view_t arguments; - iree_string_view_t results; - IREE_RETURN_IF_ERROR(iree_vm_function_call_get_cconv_fragments( - &signature, &arguments, &results)); - - std::cout << "arguments" << arguments.data << "\n"; - std::cout << "results" << results.data << "\n"; - - // Append the function inputs with the HAL device allocator in use by the - // session. The buffers will be usable within the session and _may_ be usable - // in other sessions depending on whether they share a compatible device. - iree_hal_device_t *device = iree_runtime_session_device(session); - iree_hal_allocator_t *device_allocator = - iree_runtime_session_device_allocator(session); + iree_hal_allocator_t *device_allocator = iree_runtime_session_device_allocator(session->get()); for (size_t i = 0; i < inputs.size(); i++) { IREEInput *input = inputs[i]; // iree_hal_buffer_view_t *buffer_view = nullptr; iree_hal_buffer_view_t *arg = nullptr; - IREE_RETURN_IF_ERROR(iree_input_to_hal_arg(&arg, input, device, device_allocator)); + IREE_RETURN_IF_ERROR(iree_input_to_hal_arg(&arg, input, session->instance()->device(), device_allocator)); IREE_RETURN_IF_ERROR(iree_runtime_call_inputs_push_back_buffer_view(&call, arg)); iree_hal_buffer_view_release(arg); } @@ -272,8 +252,8 @@ iree_status_t call_module(iree_runtime_session_t *session, std::vectorget()), + buffer, 0, binary.data, byte_size, IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout())); @@ -282,6 +262,8 @@ iree_status_t call_module(iree_runtime_session_t *session, std::vectorpush_back({element_type, binary}); } + iree_runtime_call_reset(&call); + return iree_make_status(IREE_STATUS_OK); } @@ -353,7 +335,7 @@ run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { } std::vector> results; - status = call_module(session->get(), inputs, &results); + status = call_module(session, inputs, &results); if (!iree_status_is_ok(status)) { // Dump nice status messages to stderr on failure. @@ -375,6 +357,8 @@ run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { return exla::nif::error(env, ss.str().c_str()); } + delete session; + return return_results(env, results); } diff --git a/exla/c_src/exla/iree/runtime.h b/exla/c_src/exla/iree/runtime.h index a083fdba4f..06442ca44d 100644 --- a/exla/c_src/exla/iree/runtime.h +++ b/exla/c_src/exla/iree/runtime.h @@ -15,27 +15,17 @@ namespace exla { namespace iree { namespace runtime { -template -struct IREEDeleter { - void operator()(T* ptr) { - if (ptr) { - ReleaseFunc(ptr); // Call the specific release function - } - } -}; - -using IREEInstanceDeleter = IREEDeleter; -using IREEDeviceDeleter = IREEDeleter; -using IREESessionDeleter = IREEDeleter; - class Instance { public: // Constructor explicit Instance(iree_runtime_instance_t* instance, iree_hal_device_t* device) - : instance_(instance, IREEInstanceDeleter{}), device_(device, IREEDeviceDeleter{}) {} + : instance_(instance), device_(device) {} // Default destructor is fine, unique_ptr will handle the resource release - ~Instance() = default; + ~Instance() { + iree_hal_device_release(device_); + iree_runtime_instance_release(instance_); + } // Copy and move operations are disabled to maintain unique ownership semantics Instance(const Instance&) = delete; @@ -44,20 +34,20 @@ class Instance { Instance& operator=(Instance&&) noexcept = default; iree_runtime_instance_t* get() const { - return instance_.get(); + return instance_; } iree_runtime_instance_t* operator->() const { - return instance_.get(); + return instance_; } iree_hal_device_t* device() const { - return device_.get(); + return device_; } private: - std::unique_ptr instance_; - std::unique_ptr device_; + iree_runtime_instance_t* instance_; + iree_hal_device_t* device_; }; class Session { @@ -69,22 +59,18 @@ class Session { iree_runtime_session_options_t session_options; iree_runtime_session_options_initialize(&session_options); - iree_runtime_session_t* session_ptr; - iree_allocator_t host_allocator = iree_runtime_instance_host_allocator(instance_->get()); iree_status_t status = iree_runtime_session_create_with_device( instance_->get(), &session_options, instance_->device(), - host_allocator, &session_ptr); + host_allocator, &session_); if (!iree_status_is_ok(status)) { return status; } - session_.reset(session_ptr); - iree_const_byte_span_t span{.data = bytecode.data(), .data_length = bytecode.size()}; - status = iree_runtime_session_append_bytecode_module_from_memory(session_.get(), span, host_allocator); + status = iree_runtime_session_append_bytecode_module_from_memory(session_, span, host_allocator); if (!iree_status_is_ok(status)) { return status; @@ -93,8 +79,10 @@ class Session { return status; } - // Default destructor is fine, unique_ptr will handle the resource release - ~Session() = default; + ~Session() { + instance_ = nullptr; + iree_runtime_session_release(session_); + } // Copy and move operations are disabled to maintain unique ownership semantics Session(const Session&) = delete; @@ -104,17 +92,21 @@ class Session { // Provide a way to access the underlying pointer like a raw pointer iree_runtime_session_t* get() const { - return session_.get(); + return session_; } // Overload the arrow operator to enable direct member access to the iree_runtime_session_t iree_runtime_session_t* operator->() const { - return session_.get(); + return session_; + } + + Instance* instance() const { + return instance_; } private: Instance* instance_; - std::unique_ptr session_; + iree_runtime_session_t* session_; }; } // namespace runtime diff --git a/exla/lib/exla/application.ex b/exla/lib/exla/application.ex index c1caa5ee54..2397ececf9 100644 --- a/exla/lib/exla/application.ex +++ b/exla/lib/exla/application.ex @@ -10,10 +10,15 @@ defmodule EXLA.Application do _ -> :os.set_signal(:sigchld, :default) end - EXLA.MLIR.IREE.init() + EXLA.MLIR.IREE.global_initialize() children = [ EXLA.Logger, + {NimblePool, + worker: {EXLA.MLIR.IREE.InstancePool, :pool_state}, + pool_size: 1, + name: EXLA.MLIR.IREE.InstancePool, + lazy: true}, {NimblePool, worker: {EXLA.MLIR.ContextPool, :pool_state}, pool_size: System.schedulers_online(), diff --git a/exla/lib/exla/mlir/iree.ex b/exla/lib/exla/mlir/iree.ex index ddeb0b9bbb..9d42fd155f 100644 --- a/exla/lib/exla/mlir/iree.ex +++ b/exla/lib/exla/mlir/iree.ex @@ -1,5 +1,7 @@ defmodule EXLA.MLIR.IREE do @moduledoc false + alias EXLA.MLIR.IREE.InstancePool + @on_load :__on_load__ def __on_load__ do @@ -7,17 +9,10 @@ defmodule EXLA.MLIR.IREE do :erlang.load_nif(path, 0) end - def init do - global_initialize() - {:ok, instance} = runtime_create_instance() - :persistent_term.put({__MODULE__, :instance}, instance) - - :ok - end - def run(module, inputs) do - instance = :persistent_term.get({__MODULE__, :instance}) - run_module(instance, module, inputs) + InstancePool.checkout(fn instance -> + run_module(instance, module, inputs) + end) end def compile(_module, _target), do: :erlang.nif_error(:undef) diff --git a/exla/lib/exla/mlir/iree/instance_pool.ex b/exla/lib/exla/mlir/iree/instance_pool.ex new file mode 100644 index 0000000000..9b9b8c6ec5 --- /dev/null +++ b/exla/lib/exla/mlir/iree/instance_pool.ex @@ -0,0 +1,37 @@ +defmodule EXLA.MLIR.IREE.InstancePool do + @moduledoc false + # Internal pool for MLIRContext reference management + @behaviour NimblePool + + def checkout(fun) when is_function(fun, 1) do + NimblePool.checkout!( + __MODULE__, + :checkout, + fn _pool, context -> {fun.(context), :ok} end, + :infinity + ) + end + + @impl NimblePool + def init_worker(pool_state) do + {:ok, instance} = EXLA.MLIR.IREE.runtime_create_instance() + {:ok, instance, pool_state} + end + + @impl NimblePool + def handle_checkout(:checkout, _from, instance, pool_state) do + {:ok, instance, instance, pool_state} + end + + @impl NimblePool + def handle_checkin(:ok, _from, instance, pool_state) do + # We just keep the references around and let them die out upon worker termination/GC + {:ok, instance, pool_state} + end + + @impl NimblePool + def terminate_worker(_reason, _instance, pool_state) do + # GC will clean it up + {:ok, pool_state} + end +end From f7e20c975081af56d1b44085c5e25c74aa34fb4f Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 14 May 2024 17:02:48 -0300 Subject: [PATCH 19/40] wip --- exla/c_src/exla/iree/runtime.cc | 167 ++++++++-------------------- exla/c_src/exla/iree/runtime.h | 187 ++++++++++++++++++++++++++++---- exla/lib/exla/defn.ex | 1 + exla/test/test_helper.exs | 1 + 4 files changed, 211 insertions(+), 145 deletions(-) diff --git a/exla/c_src/exla/iree/runtime.cc b/exla/c_src/exla/iree/runtime.cc index c4d3d23116..d31edd3860 100644 --- a/exla/c_src/exla/iree/runtime.cc +++ b/exla/c_src/exla/iree/runtime.cc @@ -3,40 +3,6 @@ #include #include -class IREEInput { - public: - void *data; - size_t size; - std::vector dims; - iree_hal_element_type_t type; - - // Default constructor - IREEInput(void *data, size_t size, std::vector in_dims, iree_hal_element_type_t type) : size(size), type(type) { - dims.reserve(in_dims.size()); - - for (auto dim : in_dims) { - dims.push_back(static_cast(dim)); - } - - this->data = std::malloc(size); // Allocate memory - std::memcpy(this->data, data, size); - } - - // Destructor - ~IREEInput() { - if (data) { - std::free(data); - data = nullptr; - } - } - - // Disable copy and move semantics for simplicity - IREEInput(const IREEInput &) = delete; - IREEInput &operator=(const IREEInput &) = delete; - IREEInput(IREEInput &&) = delete; - IREEInput &operator=(IREEInput &&) = delete; -}; - bool primitive_type_to_iree_element_type(xla::PrimitiveType t, iree_hal_element_type_t *type) { using xla::PrimitiveType; using type_enum = iree_hal_element_types_t; @@ -146,7 +112,7 @@ bool iree_element_type_to_nx_type(iree_hal_element_type_t type, std::string &nx_ } } -int load_inputs(ErlNifEnv *env, std::vector terms, std::vector &loaded) { +int load_inputs(ErlNifEnv *env, std::vector terms, std::vector &loaded) { const ERL_NIF_TERM *tuple, *typespec; int length; ErlNifBinary bin; @@ -183,88 +149,15 @@ int load_inputs(ErlNifEnv *env, std::vector terms, std::vectordata, input->size); - - iree_hal_buffer_params_t buffer_params = { - .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, - .access = IREE_HAL_MEMORY_ACCESS_ALL, - .usage = IREE_HAL_BUFFER_USAGE_DEFAULT, - }; - - return iree_hal_buffer_view_allocate_buffer_copy( - device, - device_allocator, - input->dims.size(), - input->dims.data(), - input->type, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, - buffer_params, - data_span, - arg); -} - -iree_status_t call_module(exla::iree::runtime::Session *session, std::vector inputs, std::vector> *result) { - iree_runtime_call_t call; - - IREE_RETURN_IF_ERROR(iree_runtime_call_initialize_by_name(session->get(), iree_make_cstring_view("module.main"), &call)); - - iree_hal_allocator_t *device_allocator = iree_runtime_session_device_allocator(session->get()); - - for (size_t i = 0; i < inputs.size(); i++) { - IREEInput *input = inputs[i]; - // iree_hal_buffer_view_t *buffer_view = nullptr; - iree_hal_buffer_view_t *arg = nullptr; - - IREE_RETURN_IF_ERROR(iree_input_to_hal_arg(&arg, input, session->instance()->device(), device_allocator)); - IREE_RETURN_IF_ERROR(iree_runtime_call_inputs_push_back_buffer_view(&call, arg)); - iree_hal_buffer_view_release(arg); - } - - IREE_RETURN_IF_ERROR(iree_runtime_call_invoke(&call, /*flags=*/0)); - - iree_vm_list_t *outputs = iree_runtime_call_outputs(&call); - - ErlNifBinary binary; - size_t size = iree_vm_list_size(outputs); - - result->clear(); - result->reserve(size); - - for (iree_vm_size_t i = 0; i < size; i++) { - iree_hal_buffer_view_t *buffer_view = nullptr; - iree_vm_ref_t ref = iree_vm_ref_null(); - IREE_RETURN_IF_ERROR(iree_vm_list_get_ref_assign(outputs, i, &ref)); - - // iree_runtime_call_outputs_pop_front_buffer_view(&call, &buffer_view); - IREE_RETURN_IF_ERROR(iree_hal_buffer_view_check_deref(ref, &buffer_view)); - iree_hal_element_type_t element_type = iree_hal_buffer_view_element_type(buffer_view); - - iree_hal_buffer_t *buffer = iree_hal_buffer_view_buffer(buffer_view); - // size_t byte_size = iree_hal_buffer_view_byte_length(buffer_view); - size_t byte_size = iree_hal_buffer_byte_length(buffer); - enif_alloc_binary(byte_size, &binary); - - IREE_RETURN_IF_ERROR(iree_hal_device_transfer_d2h( - iree_runtime_session_device(session->get()), - buffer, 0, binary.data, - byte_size, IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, - iree_infinite_timeout())); - - iree_hal_buffer_view_release(buffer_view); - - result->push_back({element_type, binary}); - } - - iree_runtime_call_reset(&call); - - return iree_make_status(IREE_STATUS_OK); +iree_status_t call_module(exla::iree::runtime::Session *session, std::vector inputs, std::vector> *result) { + IREE_RETURN_IF_ERROR(session->init_inputs_and_outputs(inputs)); + return session->call(result); } ERL_NIF_TERM return_results(ErlNifEnv *env, std::vector> results) { @@ -277,7 +170,7 @@ ERL_NIF_TERM return_results(ErlNifEnv *env, std::vector bytecode_vec = {}; std::vector input_terms = {}; - std::vector inputs = {}; + std::vector inputs = {}; std::vector bytecode = {}; - exla::iree::runtime::Instance **instance; - iree_status_t status; + // exla::iree::runtime::Instance **instance; + // iree_status_t status; + + // if (!exla::nif::get(env, argv[0], instance)) { + // return exla::nif::error(env, "Unable to get instance"); + // } + + iree_runtime_instance_options_t instance_options; + iree_runtime_instance_options_initialize(&instance_options); + iree_runtime_instance_options_use_all_available_drivers(&instance_options); + iree_runtime_instance_t *instance_ptr = NULL; + iree_status_t status = iree_runtime_instance_create(&instance_options, iree_allocator_system(), &instance_ptr); + + if (!iree_status_is_ok(status)) { + iree_runtime_instance_release(instance_ptr); + return exla::nif::error(env, "Failed to create IREE runtime instance"); + } + + iree_hal_device_t *device_ptr = NULL; + char device_uri[] = "metal://0000000100000971"; // TO-DO: change this to an argument + status = iree_hal_create_device( + iree_runtime_instance_driver_registry(instance_ptr), + iree_make_cstring_view(device_uri), + iree_runtime_instance_host_allocator(instance_ptr), &device_ptr); - if (!exla::nif::get(env, argv[0], instance)) { - return exla::nif::error(env, "Unable to get instance"); + if (!iree_status_is_ok(status)) { + if (device_ptr) { + iree_hal_device_release(device_ptr); + } + if (instance_ptr) { + iree_runtime_instance_release(instance_ptr); + } + return exla::nif::error(env, "Failed to create IREE device instance"); } + exla::iree::runtime::Instance *instance = new exla::iree::runtime::Instance(instance_ptr, device_ptr); + if (!exla::nif::get_list(env, argv[1], bytecode_vec)) { return exla::nif::error(env, "Unable to load bytecode binary"); } @@ -327,7 +250,7 @@ run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { return exla::nif::error(env, "Unable to decode input terms"); } - exla::iree::runtime::Session *session = new exla::iree::runtime::Session(*instance); + exla::iree::runtime::Session *session = new exla::iree::runtime::Session(instance); status = session->initialize(bytecode); if (!iree_status_is_ok(status)) { @@ -336,6 +259,7 @@ run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { std::vector> results; status = call_module(session, inputs, &results); + delete session; if (!iree_status_is_ok(status)) { // Dump nice status messages to stderr on failure. @@ -352,13 +276,12 @@ run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { std::stringstream ss; ss << "Failed to execute IREE runtime due to error: "; ss << status_string; - iree_status_ignore(status); + iree_status_free(status); return exla::nif::error(env, ss.str().c_str()); } - delete session; - + iree_status_free(status); return return_results(env, results); } diff --git a/exla/c_src/exla/iree/runtime.h b/exla/c_src/exla/iree/runtime.h index 06442ca44d..48976cc35c 100644 --- a/exla/c_src/exla/iree/runtime.h +++ b/exla/c_src/exla/iree/runtime.h @@ -1,7 +1,10 @@ #pragma once #include +#include #include #include +#include +#include #include @@ -15,6 +18,40 @@ namespace exla { namespace iree { namespace runtime { +class IREEInput { + public: + void* data; + size_t size; + std::vector dims; + iree_hal_element_type_t type; + + // Default constructor + IREEInput(void* data, size_t size, std::vector in_dims, iree_hal_element_type_t type) : size(size), type(type) { + dims.reserve(in_dims.size()); + + for (auto dim : in_dims) { + dims.push_back(static_cast(dim)); + } + + this->data = std::malloc(size); // Allocate memory + std::memcpy(this->data, data, size); + } + + // Destructor + ~IREEInput() { + if (data) { + std::free(data); + data = nullptr; + } + } + + // Disable copy and move semantics for simplicity + IREEInput(const IREEInput&) = delete; + IREEInput& operator=(const IREEInput&) = delete; + IREEInput(IREEInput&&) = delete; + IREEInput& operator=(IREEInput&&) = delete; +}; + class Instance { public: // Constructor @@ -59,29 +96,140 @@ class Session { iree_runtime_session_options_t session_options; iree_runtime_session_options_initialize(&session_options); - iree_allocator_t host_allocator = iree_runtime_instance_host_allocator(instance_->get()); - iree_status_t status = iree_runtime_session_create_with_device( - instance_->get(), &session_options, instance_->device(), - host_allocator, &session_); + iree_vm_instance_t* vm_instance = iree_runtime_instance_vm_instance(instance_->get()); + iree_hal_device_t* device = instance_->device(); + + iree_vm_module_t* hal_module = NULL; + IREE_RETURN_IF_ERROR(iree_hal_module_create( + vm_instance, /*device_count=*/1, &device, IREE_HAL_MODULE_FLAG_SYNCHRONOUS, + iree_allocator_system(), &hal_module)); + + iree_const_byte_span_t module_data{.data = bytecode.data(), .data_length = bytecode.size()}; + + iree_vm_module_t* bytecode_module = NULL; + IREE_RETURN_IF_ERROR(iree_vm_bytecode_module_create( + vm_instance, module_data, iree_allocator_null(), iree_allocator_system(), + &bytecode_module)); + + iree_vm_module_t* modules[] = {hal_module, bytecode_module}; + IREE_RETURN_IF_ERROR(iree_vm_context_create_with_modules( + vm_instance, IREE_VM_CONTEXT_FLAG_NONE, IREE_ARRAYSIZE(modules), &modules[0], + iree_allocator_system(), &context_)); + iree_vm_module_release(hal_module); + iree_vm_module_release(bytecode_module); + + // Lookup the entry point function. + // Note that we use the synchronous variant which operates on pure type/shape + // erased buffers. + const char kMainFunctionName[] = "module.main"; + IREE_RETURN_IF_ERROR(iree_vm_context_resolve_function( + context_, iree_make_cstring_view(kMainFunctionName), &main_function_)); - if (!iree_status_is_ok(status)) { - return status; + return iree_ok_status(); + } + + iree_status_t iree_input_to_hal_arg(iree_hal_buffer_view_t** arg, IREEInput* input, iree_hal_device_t* device, iree_hal_allocator_t* device_allocator) { + const iree_const_byte_span_t data_span = iree_make_const_byte_span(input->data, input->size); + + iree_hal_buffer_params_t buffer_params = { + .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, + .access = IREE_HAL_MEMORY_ACCESS_ALL, + .usage = IREE_HAL_BUFFER_USAGE_DEFAULT, + }; + + return iree_hal_buffer_view_allocate_buffer_copy( + device, + device_allocator, + input->dims.size(), + input->dims.data(), + input->type, + IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + buffer_params, + data_span, + arg); + } + + iree_status_t init_inputs_and_outputs(std::vector inputs) { + iree_hal_allocator_t* device_allocator = iree_hal_device_allocator(instance_->device()); + + iree_vm_function_signature_t signature = + iree_vm_function_signature(&main_function_); + iree_string_view_t arguments; + iree_string_view_t results; + + IREE_RETURN_IF_ERROR(iree_vm_function_call_get_cconv_fragments( + &signature, &arguments, &results)); + + inputs_ = NULL; + IREE_RETURN_IF_ERROR( + iree_vm_list_create(iree_vm_make_undefined_type_def(), + inputs.size(), iree_allocator_system(), &inputs_), + "can't allocate input vm list"); + + outputs_ = NULL; + IREE_RETURN_IF_ERROR( + iree_vm_list_create(iree_vm_make_undefined_type_def(), results.size, iree_allocator_system(), &outputs_), + "can't allocate output vm list"); + + for (size_t i = 0; i < inputs.size(); i++) { + IREEInput* input = inputs[i]; + // iree_hal_buffer_view_t *buffer_view = nullptr; + iree_hal_buffer_view_t* arg = nullptr; + IREE_RETURN_IF_ERROR(iree_input_to_hal_arg(&arg, input, instance()->device(), device_allocator)); + iree_vm_ref_t arg_ref = iree_hal_buffer_view_move_ref(arg); + IREE_RETURN_IF_ERROR(iree_vm_list_push_ref_move(inputs_, &arg_ref)); } + } + + iree_status_t call(std::vector>* result) { + // Synchronously invoke the function. + IREE_RETURN_IF_ERROR(iree_vm_invoke( + context_, main_function_, IREE_VM_INVOCATION_FLAG_NONE, + /*policy=*/NULL, inputs_, outputs_, iree_allocator_system())); + + ErlNifBinary binary; + size_t size = iree_vm_list_size(outputs_); + + result->resize(size); - iree_const_byte_span_t span{.data = bytecode.data(), .data_length = bytecode.size()}; + for (iree_vm_size_t i = 0; i < size; i++) { + iree_hal_buffer_view_t* buffer_view = nullptr; + iree_vm_ref_t ref = iree_vm_ref_null(); + IREE_RETURN_IF_ERROR(iree_vm_list_get_ref_assign(outputs_, i, &ref)); - status = iree_runtime_session_append_bytecode_module_from_memory(session_, span, host_allocator); + // iree_runtime_call_outputs_pop_front_buffer_view(&call, &buffer_view); + IREE_RETURN_IF_ERROR(iree_hal_buffer_view_check_deref(ref, &buffer_view)); + iree_hal_element_type_t element_type = iree_hal_buffer_view_element_type(buffer_view); - if (!iree_status_is_ok(status)) { - return status; + iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(buffer_view); + // size_t byte_size = iree_hal_buffer_view_byte_length(buffer_view); + size_t byte_size = iree_hal_buffer_byte_length(buffer); + enif_alloc_binary(byte_size, &binary); + + iree_status_t status = iree_hal_device_transfer_d2h( + instance_->device(), + buffer, 0, binary.data, + byte_size, IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, + iree_infinite_timeout()); + + if (!iree_status_is_ok(status)) { + enif_release_binary(&binary); + return status; + } + + iree_hal_buffer_view_release(buffer_view); + + (*result)[i] = {element_type, binary}; } - return status; + return iree_ok_status(); } ~Session() { instance_ = nullptr; - iree_runtime_session_release(session_); + iree_vm_list_release(inputs_); + iree_vm_list_release(outputs_); + iree_vm_context_release(context_); } // Copy and move operations are disabled to maintain unique ownership semantics @@ -90,23 +238,16 @@ class Session { Session(Session&&) noexcept = default; Session& operator=(Session&&) noexcept = default; - // Provide a way to access the underlying pointer like a raw pointer - iree_runtime_session_t* get() const { - return session_; - } - - // Overload the arrow operator to enable direct member access to the iree_runtime_session_t - iree_runtime_session_t* operator->() const { - return session_; - } - Instance* instance() const { return instance_; } private: Instance* instance_; - iree_runtime_session_t* session_; + iree_vm_context_t* context_; + iree_vm_list_t* inputs_; + iree_vm_list_t* outputs_; + iree_vm_function_t main_function_; }; } // namespace runtime diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 02a26d5af7..bd8dfd56f5 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -483,6 +483,7 @@ defmodule EXLA.Defn do if runtime == :iree do {:ok, module_charlist} = EXLA.NIF.mlir_module_to_string(builder.module.ref) + dbg(module_charlist) {:ok, module_bytecode} = EXLA.MLIR.IREE.compile(module_charlist, "metal") diff --git a/exla/test/test_helper.exs b/exla/test/test_helper.exs index 4a6f741f73..202a1841d5 100644 --- a/exla/test/test_helper.exs +++ b/exla/test/test_helper.exs @@ -2,6 +2,7 @@ target = System.get_env("EXLA_TARGET", "host") client = EXLAHelpers.client() if System.get_env("DEBUG") in ["1", "true"] do + dbg(System.schedulers_online()) IO.gets("Press enter to continue... -- PID: #{System.pid()}") end From 5f02ed7ef895e72c3a1a373327a9533dad2920a1 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 15 May 2024 20:42:53 -0300 Subject: [PATCH 20/40] feat: invoke not segfaulting (yet?) --- exla/c_src/exla/iree/iree.cc | 6 +- exla/c_src/exla/iree/runtime.cc | 171 ++++++++++++++------------- exla/c_src/exla/iree/runtime.h | 201 +------------------------------- exla/lib/exla/application.ex | 2 + exla/lib/exla/mlir/iree.ex | 10 +- 5 files changed, 106 insertions(+), 284 deletions(-) diff --git a/exla/c_src/exla/iree/iree.cc b/exla/c_src/exla/iree/iree.cc index 74f61c6377..d8005a827c 100644 --- a/exla/c_src/exla/iree/iree.cc +++ b/exla/c_src/exla/iree/iree.cc @@ -16,8 +16,8 @@ static ErlNifFunc iree_funcs[] = { // MLIR Builder {"global_initialize", 0, global_initialize}, {"compile", 2, compile, ERL_NIF_DIRTY_JOB_CPU_BOUND}, - {"runtime_create_instance", 0, runtime_create_instance}, - {"run_module", 3, run_module, ERL_NIF_DIRTY_JOB_IO_BOUND}}; + {"run_module", 3, run_module, ERL_NIF_DIRTY_JOB_IO_BOUND}, + {"setup_runtime", 0, setup_runtime, ERL_NIF_DIRTY_JOB_IO_BOUND}}; static int open_resources(ErlNifEnv *env) { const char *mod = "EXLA"; @@ -25,7 +25,7 @@ static int open_resources(ErlNifEnv *env) { if (!exla::nif::open_resource(env, mod, "ExlaMLIRModule")) { return -1; } - if (!exla::nif::open_resource(env, mod, "ExlaIREERuntimeInstance")) { + if (!exla::nif::open_resource(env, mod, "ExlaIreeHalDevice")) { return -1; } return 1; diff --git a/exla/c_src/exla/iree/runtime.cc b/exla/c_src/exla/iree/runtime.cc index d31edd3860..94cd516aa1 100644 --- a/exla/c_src/exla/iree/runtime.cc +++ b/exla/c_src/exla/iree/runtime.cc @@ -1,5 +1,8 @@ #include "runtime.h" +#include +#include + #include #include @@ -155,11 +158,6 @@ int load_inputs(ErlNifEnv *env, std::vector terms, std::vector inputs, std::vector> *result) { - IREE_RETURN_IF_ERROR(session->init_inputs_and_outputs(inputs)); - return session->call(result); -} - ERL_NIF_TERM return_results(ErlNifEnv *env, std::vector> results) { size_t n = results.size(); @@ -183,6 +181,82 @@ ERL_NIF_TERM return_results(ErlNifEnv *env, std::vector bytecode, std::vector exla_inputs, std::vector> &results) { + iree_vm_instance_t *instance = nullptr; + iree_vm_module_t *hal_module = nullptr; + iree_vm_module_t *bytecode_module = nullptr; + iree_vm_context_t *context = nullptr; + const char kMainFunctionName[] = "module.main"; + iree_vm_function_t main_function; + iree_vm_list_t *inputs = nullptr; + iree_vm_list_t *outputs = nullptr; + + IREE_RETURN_IF_ERROR(iree_vm_instance_create( + IREE_VM_TYPE_CAPACITY_DEFAULT, iree_allocator_system(), &instance)); + IREE_RETURN_IF_ERROR(iree_hal_module_register_all_types(instance)); + + IREE_RETURN_IF_ERROR(iree_hal_module_create( + instance, /*device_count=*/1, &device, IREE_HAL_MODULE_FLAG_SYNCHRONOUS, + iree_allocator_system(), &hal_module)); + + // (kFloat4, sizeof(kFloat4)) + const iree_const_byte_span_t module_data = iree_make_const_byte_span(bytecode.data(), bytecode.size()); + + IREE_RETURN_IF_ERROR(iree_vm_bytecode_module_create( + instance, module_data, iree_allocator_null(), iree_allocator_system(), + &bytecode_module)); + + iree_vm_module_t *modules[] = {hal_module, bytecode_module}; + IREE_RETURN_IF_ERROR(iree_vm_context_create_with_modules( + instance, IREE_VM_CONTEXT_FLAG_NONE, IREE_ARRAYSIZE(modules), &modules[0], + iree_allocator_system(), &context)); + iree_vm_module_release(hal_module); + iree_vm_module_release(bytecode_module); + + IREE_RETURN_IF_ERROR(iree_vm_context_resolve_function( + context, iree_make_cstring_view(kMainFunctionName), &main_function)); + + IREE_RETURN_IF_ERROR( + iree_vm_list_create(iree_vm_make_undefined_type_def(), exla_inputs.size(), iree_allocator_system(), &inputs), + "can't allocate input vm list"); + + for (auto input : exla_inputs) { + iree_hal_buffer_view_t *arg_buffer_view = nullptr; + IREE_RETURN_IF_ERROR(iree_hal_buffer_view_allocate_buffer_copy( + device, iree_hal_device_allocator(device), input->dims.size(), input->dims.data(), + input->type, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + (iree_hal_buffer_params_t){ + .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, + .usage = IREE_HAL_BUFFER_USAGE_DEFAULT, + }, + input->data_byte_span(), &arg_buffer_view)); + + iree_vm_ref_t arg_buffer_view_ref = iree_hal_buffer_view_move_ref(arg_buffer_view); + IREE_RETURN_IF_ERROR(iree_vm_list_push_ref_move(inputs, &arg_buffer_view_ref)); + } + + iree_vm_function_signature_t signature = + iree_vm_function_signature(&main_function); + iree_string_view_t input_signature; + iree_string_view_t output_signature; + + IREE_RETURN_IF_ERROR(iree_vm_function_call_get_cconv_fragments( + &signature, &input_signature, &output_signature)); + + IREE_RETURN_IF_ERROR(iree_vm_list_create(iree_vm_make_undefined_type_def(), output_signature.size, iree_allocator_system(), &outputs), "can't allocate output vm list"); + + // Synchronously invoke the function. + IREE_RETURN_IF_ERROR(iree_vm_invoke( + context, main_function, IREE_VM_INVOCATION_FLAG_NONE, + /*policy=*/NULL, inputs, outputs, iree_allocator_system())); + + iree_vm_list_release(inputs); + iree_vm_list_release(outputs); + iree_vm_context_release(context); + iree_vm_instance_release(instance); + return iree_ok_status(); +} + ERL_NIF_TERM run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { if (argc != 3) { @@ -193,43 +267,12 @@ run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { std::vector input_terms = {}; std::vector inputs = {}; std::vector bytecode = {}; - // exla::iree::runtime::Instance **instance; - // iree_status_t status; - - // if (!exla::nif::get(env, argv[0], instance)) { - // return exla::nif::error(env, "Unable to get instance"); - // } - - iree_runtime_instance_options_t instance_options; - iree_runtime_instance_options_initialize(&instance_options); - iree_runtime_instance_options_use_all_available_drivers(&instance_options); - iree_runtime_instance_t *instance_ptr = NULL; - iree_status_t status = iree_runtime_instance_create(&instance_options, iree_allocator_system(), &instance_ptr); - - if (!iree_status_is_ok(status)) { - iree_runtime_instance_release(instance_ptr); - return exla::nif::error(env, "Failed to create IREE runtime instance"); - } - - iree_hal_device_t *device_ptr = NULL; - char device_uri[] = "metal://0000000100000971"; // TO-DO: change this to an argument - status = iree_hal_create_device( - iree_runtime_instance_driver_registry(instance_ptr), - iree_make_cstring_view(device_uri), - iree_runtime_instance_host_allocator(instance_ptr), &device_ptr); + iree_hal_device_t **device; - if (!iree_status_is_ok(status)) { - if (device_ptr) { - iree_hal_device_release(device_ptr); - } - if (instance_ptr) { - iree_runtime_instance_release(instance_ptr); - } - return exla::nif::error(env, "Failed to create IREE device instance"); + if (!exla::nif::get(env, argv[0], device)) { + return exla::nif::error(env, "Unable to load device"); } - exla::iree::runtime::Instance *instance = new exla::iree::runtime::Instance(instance_ptr, device_ptr); - if (!exla::nif::get_list(env, argv[1], bytecode_vec)) { return exla::nif::error(env, "Unable to load bytecode binary"); } @@ -250,16 +293,9 @@ run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { return exla::nif::error(env, "Unable to decode input terms"); } - exla::iree::runtime::Session *session = new exla::iree::runtime::Session(instance); - status = session->initialize(bytecode); - - if (!iree_status_is_ok(status)) { - return exla::nif::error(env, "Failed to initialize IREE runtime session"); - } - std::vector> results; - status = call_module(session, inputs, &results); - delete session; + + iree_status_t status = call(*device, bytecode, inputs, results); if (!iree_status_is_ok(status)) { // Dump nice status messages to stderr on failure. @@ -285,40 +321,19 @@ run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { return return_results(env, results); } -ERL_NIF_TERM runtime_create_instance(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 0) { - return exla::nif::error(env, "Bad argument count."); - } - - iree_runtime_instance_options_t instance_options; - iree_runtime_instance_options_initialize(&instance_options); - iree_runtime_instance_options_use_all_available_drivers(&instance_options); - iree_runtime_instance_t *instance_ptr = NULL; - iree_status_t status = iree_runtime_instance_create(&instance_options, iree_allocator_system(), &instance_ptr); +ERL_NIF_TERM setup_runtime(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { + iree_hal_device_t *device = nullptr; - if (!iree_status_is_ok(status)) { - iree_runtime_instance_release(instance_ptr); - return exla::nif::error(env, "Failed to create IREE runtime instance"); - } + iree_status_t status = iree_hal_register_all_available_drivers(iree_hal_driver_registry_default()); - iree_hal_device_t *device_ptr = NULL; char device_uri[] = "metal://0000000100000971"; // TO-DO: change this to an argument - status = iree_hal_create_device( - iree_runtime_instance_driver_registry(instance_ptr), - iree_make_cstring_view(device_uri), - iree_runtime_instance_host_allocator(instance_ptr), &device_ptr); - if (!iree_status_is_ok(status)) { - if (device_ptr) { - iree_hal_device_release(device_ptr); - } - if (instance_ptr) { - iree_runtime_instance_release(instance_ptr); - } - return exla::nif::error(env, "Failed to create IREE device instance"); + if (iree_status_is_ok(status)) { + status = iree_hal_create_device( + iree_hal_driver_registry_default(), + iree_make_cstring_view(device_uri), + iree_allocator_system(), &device); } - exla::iree::runtime::Instance *instance = new exla::iree::runtime::Instance(instance_ptr, device_ptr); - - return exla::nif::ok(env, exla::nif::make(env, instance)); + return iree_status_is_ok(status) ? exla::nif::ok(env, exla::nif::make(env, device)) : exla::nif::error(env, "Failed to setup IREE runtime"); } \ No newline at end of file diff --git a/exla/c_src/exla/iree/runtime.h b/exla/c_src/exla/iree/runtime.h index 48976cc35c..7323987e1a 100644 --- a/exla/c_src/exla/iree/runtime.h +++ b/exla/c_src/exla/iree/runtime.h @@ -11,8 +11,7 @@ #include "../exla_nif_util.h" ERL_NIF_TERM run_module(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); -ERL_NIF_TERM runtime_create_instance(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); -ERL_NIF_TERM runtime_create_session(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); +ERL_NIF_TERM setup_runtime(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); namespace exla { namespace iree { @@ -50,204 +49,10 @@ class IREEInput { IREEInput& operator=(const IREEInput&) = delete; IREEInput(IREEInput&&) = delete; IREEInput& operator=(IREEInput&&) = delete; -}; - -class Instance { - public: - // Constructor - explicit Instance(iree_runtime_instance_t* instance, iree_hal_device_t* device) - : instance_(instance), device_(device) {} - - // Default destructor is fine, unique_ptr will handle the resource release - ~Instance() { - iree_hal_device_release(device_); - iree_runtime_instance_release(instance_); - } - - // Copy and move operations are disabled to maintain unique ownership semantics - Instance(const Instance&) = delete; - Instance& operator=(const Instance&) = delete; - Instance(Instance&&) noexcept = default; - Instance& operator=(Instance&&) noexcept = default; - - iree_runtime_instance_t* get() const { - return instance_; - } - - iree_runtime_instance_t* operator->() const { - return instance_; - } - - iree_hal_device_t* device() const { - return device_; - } - - private: - iree_runtime_instance_t* instance_; - iree_hal_device_t* device_; -}; - -class Session { - public: - // Constructor - explicit Session(Instance* instance) : instance_(instance) {} - - iree_status_t initialize(std::vector bytecode) { - iree_runtime_session_options_t session_options; - iree_runtime_session_options_initialize(&session_options); - - iree_vm_instance_t* vm_instance = iree_runtime_instance_vm_instance(instance_->get()); - iree_hal_device_t* device = instance_->device(); - - iree_vm_module_t* hal_module = NULL; - IREE_RETURN_IF_ERROR(iree_hal_module_create( - vm_instance, /*device_count=*/1, &device, IREE_HAL_MODULE_FLAG_SYNCHRONOUS, - iree_allocator_system(), &hal_module)); - - iree_const_byte_span_t module_data{.data = bytecode.data(), .data_length = bytecode.size()}; - - iree_vm_module_t* bytecode_module = NULL; - IREE_RETURN_IF_ERROR(iree_vm_bytecode_module_create( - vm_instance, module_data, iree_allocator_null(), iree_allocator_system(), - &bytecode_module)); - - iree_vm_module_t* modules[] = {hal_module, bytecode_module}; - IREE_RETURN_IF_ERROR(iree_vm_context_create_with_modules( - vm_instance, IREE_VM_CONTEXT_FLAG_NONE, IREE_ARRAYSIZE(modules), &modules[0], - iree_allocator_system(), &context_)); - iree_vm_module_release(hal_module); - iree_vm_module_release(bytecode_module); - - // Lookup the entry point function. - // Note that we use the synchronous variant which operates on pure type/shape - // erased buffers. - const char kMainFunctionName[] = "module.main"; - IREE_RETURN_IF_ERROR(iree_vm_context_resolve_function( - context_, iree_make_cstring_view(kMainFunctionName), &main_function_)); - - return iree_ok_status(); - } - - iree_status_t iree_input_to_hal_arg(iree_hal_buffer_view_t** arg, IREEInput* input, iree_hal_device_t* device, iree_hal_allocator_t* device_allocator) { - const iree_const_byte_span_t data_span = iree_make_const_byte_span(input->data, input->size); - - iree_hal_buffer_params_t buffer_params = { - .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, - .access = IREE_HAL_MEMORY_ACCESS_ALL, - .usage = IREE_HAL_BUFFER_USAGE_DEFAULT, - }; - return iree_hal_buffer_view_allocate_buffer_copy( - device, - device_allocator, - input->dims.size(), - input->dims.data(), - input->type, - IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, - buffer_params, - data_span, - arg); + iree_const_byte_span_t data_byte_span() const { + return iree_make_const_byte_span(static_cast(data), size); } - - iree_status_t init_inputs_and_outputs(std::vector inputs) { - iree_hal_allocator_t* device_allocator = iree_hal_device_allocator(instance_->device()); - - iree_vm_function_signature_t signature = - iree_vm_function_signature(&main_function_); - iree_string_view_t arguments; - iree_string_view_t results; - - IREE_RETURN_IF_ERROR(iree_vm_function_call_get_cconv_fragments( - &signature, &arguments, &results)); - - inputs_ = NULL; - IREE_RETURN_IF_ERROR( - iree_vm_list_create(iree_vm_make_undefined_type_def(), - inputs.size(), iree_allocator_system(), &inputs_), - "can't allocate input vm list"); - - outputs_ = NULL; - IREE_RETURN_IF_ERROR( - iree_vm_list_create(iree_vm_make_undefined_type_def(), results.size, iree_allocator_system(), &outputs_), - "can't allocate output vm list"); - - for (size_t i = 0; i < inputs.size(); i++) { - IREEInput* input = inputs[i]; - // iree_hal_buffer_view_t *buffer_view = nullptr; - iree_hal_buffer_view_t* arg = nullptr; - IREE_RETURN_IF_ERROR(iree_input_to_hal_arg(&arg, input, instance()->device(), device_allocator)); - iree_vm_ref_t arg_ref = iree_hal_buffer_view_move_ref(arg); - IREE_RETURN_IF_ERROR(iree_vm_list_push_ref_move(inputs_, &arg_ref)); - } - } - - iree_status_t call(std::vector>* result) { - // Synchronously invoke the function. - IREE_RETURN_IF_ERROR(iree_vm_invoke( - context_, main_function_, IREE_VM_INVOCATION_FLAG_NONE, - /*policy=*/NULL, inputs_, outputs_, iree_allocator_system())); - - ErlNifBinary binary; - size_t size = iree_vm_list_size(outputs_); - - result->resize(size); - - for (iree_vm_size_t i = 0; i < size; i++) { - iree_hal_buffer_view_t* buffer_view = nullptr; - iree_vm_ref_t ref = iree_vm_ref_null(); - IREE_RETURN_IF_ERROR(iree_vm_list_get_ref_assign(outputs_, i, &ref)); - - // iree_runtime_call_outputs_pop_front_buffer_view(&call, &buffer_view); - IREE_RETURN_IF_ERROR(iree_hal_buffer_view_check_deref(ref, &buffer_view)); - iree_hal_element_type_t element_type = iree_hal_buffer_view_element_type(buffer_view); - - iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(buffer_view); - // size_t byte_size = iree_hal_buffer_view_byte_length(buffer_view); - size_t byte_size = iree_hal_buffer_byte_length(buffer); - enif_alloc_binary(byte_size, &binary); - - iree_status_t status = iree_hal_device_transfer_d2h( - instance_->device(), - buffer, 0, binary.data, - byte_size, IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, - iree_infinite_timeout()); - - if (!iree_status_is_ok(status)) { - enif_release_binary(&binary); - return status; - } - - iree_hal_buffer_view_release(buffer_view); - - (*result)[i] = {element_type, binary}; - } - - return iree_ok_status(); - } - - ~Session() { - instance_ = nullptr; - iree_vm_list_release(inputs_); - iree_vm_list_release(outputs_); - iree_vm_context_release(context_); - } - - // Copy and move operations are disabled to maintain unique ownership semantics - Session(const Session&) = delete; - Session& operator=(const Session&) = delete; - Session(Session&&) noexcept = default; - Session& operator=(Session&&) noexcept = default; - - Instance* instance() const { - return instance_; - } - - private: - Instance* instance_; - iree_vm_context_t* context_; - iree_vm_list_t* inputs_; - iree_vm_list_t* outputs_; - iree_vm_function_t main_function_; }; } // namespace runtime diff --git a/exla/lib/exla/application.ex b/exla/lib/exla/application.ex index 2397ececf9..95d460573d 100644 --- a/exla/lib/exla/application.ex +++ b/exla/lib/exla/application.ex @@ -11,6 +11,8 @@ defmodule EXLA.Application do end EXLA.MLIR.IREE.global_initialize() + {:ok, device} = EXLA.MLIR.IREE.setup_runtime() + :persistent_term.put({EXLA.MLIR.IREE, :device}, device) children = [ EXLA.Logger, diff --git a/exla/lib/exla/mlir/iree.ex b/exla/lib/exla/mlir/iree.ex index 9d42fd155f..78544cd099 100644 --- a/exla/lib/exla/mlir/iree.ex +++ b/exla/lib/exla/mlir/iree.ex @@ -10,16 +10,16 @@ defmodule EXLA.MLIR.IREE do end def run(module, inputs) do - InstancePool.checkout(fn instance -> - run_module(instance, module, inputs) - end) + device = :persistent_term.get({EXLA.MLIR.IREE, :device}) + + run_module(device, module, inputs) end def compile(_module, _target), do: :erlang.nif_error(:undef) def global_initialize, do: :erlang.nif_error(:undef) - def runtime_create_instance, do: :erlang.nif_error(:undef) - def run_module(_instance, _module, _inputs), do: :erlang.nif_error(:undef) + + def setup_runtime, do: :erlang.nif_error(:undef) end From 20ac6275d088f59805dab517d857160045373726 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 15 May 2024 21:31:58 -0300 Subject: [PATCH 21/40] feat: output retrieval working --- exla/c_src/exla/iree/runtime.cc | 71 +++++++++++++++++++++++---------- 1 file changed, 51 insertions(+), 20 deletions(-) diff --git a/exla/c_src/exla/iree/runtime.cc b/exla/c_src/exla/iree/runtime.cc index 94cd516aa1..3a1b18a2df 100644 --- a/exla/c_src/exla/iree/runtime.cc +++ b/exla/c_src/exla/iree/runtime.cc @@ -181,7 +181,13 @@ ERL_NIF_TERM return_results(ErlNifEnv *env, std::vector bytecode, std::vector exla_inputs, std::vector> &results) { +#define RETURN_PAIR_IF_ERROR(status) \ + if (!iree_status_is_ok(status)) { \ + return {status, std::nullopt}; \ + } + +std::pair>>> +call(iree_hal_device_t *device, std::vector bytecode, std::vector exla_inputs) { iree_vm_instance_t *instance = nullptr; iree_vm_module_t *hal_module = nullptr; iree_vm_module_t *bytecode_module = nullptr; @@ -191,38 +197,36 @@ iree_status_t call(iree_hal_device_t *device, std::vector bytecode, std iree_vm_list_t *inputs = nullptr; iree_vm_list_t *outputs = nullptr; - IREE_RETURN_IF_ERROR(iree_vm_instance_create( + RETURN_PAIR_IF_ERROR(iree_vm_instance_create( IREE_VM_TYPE_CAPACITY_DEFAULT, iree_allocator_system(), &instance)); - IREE_RETURN_IF_ERROR(iree_hal_module_register_all_types(instance)); + RETURN_PAIR_IF_ERROR(iree_hal_module_register_all_types(instance)); - IREE_RETURN_IF_ERROR(iree_hal_module_create( + RETURN_PAIR_IF_ERROR(iree_hal_module_create( instance, /*device_count=*/1, &device, IREE_HAL_MODULE_FLAG_SYNCHRONOUS, iree_allocator_system(), &hal_module)); // (kFloat4, sizeof(kFloat4)) const iree_const_byte_span_t module_data = iree_make_const_byte_span(bytecode.data(), bytecode.size()); - IREE_RETURN_IF_ERROR(iree_vm_bytecode_module_create( + RETURN_PAIR_IF_ERROR(iree_vm_bytecode_module_create( instance, module_data, iree_allocator_null(), iree_allocator_system(), &bytecode_module)); iree_vm_module_t *modules[] = {hal_module, bytecode_module}; - IREE_RETURN_IF_ERROR(iree_vm_context_create_with_modules( + RETURN_PAIR_IF_ERROR(iree_vm_context_create_with_modules( instance, IREE_VM_CONTEXT_FLAG_NONE, IREE_ARRAYSIZE(modules), &modules[0], iree_allocator_system(), &context)); iree_vm_module_release(hal_module); iree_vm_module_release(bytecode_module); - IREE_RETURN_IF_ERROR(iree_vm_context_resolve_function( + RETURN_PAIR_IF_ERROR(iree_vm_context_resolve_function( context, iree_make_cstring_view(kMainFunctionName), &main_function)); - IREE_RETURN_IF_ERROR( - iree_vm_list_create(iree_vm_make_undefined_type_def(), exla_inputs.size(), iree_allocator_system(), &inputs), - "can't allocate input vm list"); + RETURN_PAIR_IF_ERROR(iree_vm_list_create(iree_vm_make_undefined_type_def(), exla_inputs.size(), iree_allocator_system(), &inputs)); for (auto input : exla_inputs) { iree_hal_buffer_view_t *arg_buffer_view = nullptr; - IREE_RETURN_IF_ERROR(iree_hal_buffer_view_allocate_buffer_copy( + RETURN_PAIR_IF_ERROR(iree_hal_buffer_view_allocate_buffer_copy( device, iree_hal_device_allocator(device), input->dims.size(), input->dims.data(), input->type, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, (iree_hal_buffer_params_t){ @@ -232,7 +236,7 @@ iree_status_t call(iree_hal_device_t *device, std::vector bytecode, std input->data_byte_span(), &arg_buffer_view)); iree_vm_ref_t arg_buffer_view_ref = iree_hal_buffer_view_move_ref(arg_buffer_view); - IREE_RETURN_IF_ERROR(iree_vm_list_push_ref_move(inputs, &arg_buffer_view_ref)); + RETURN_PAIR_IF_ERROR(iree_vm_list_push_ref_move(inputs, &arg_buffer_view_ref)); } iree_vm_function_signature_t signature = @@ -240,21 +244,50 @@ iree_status_t call(iree_hal_device_t *device, std::vector bytecode, std iree_string_view_t input_signature; iree_string_view_t output_signature; - IREE_RETURN_IF_ERROR(iree_vm_function_call_get_cconv_fragments( + RETURN_PAIR_IF_ERROR(iree_vm_function_call_get_cconv_fragments( &signature, &input_signature, &output_signature)); - IREE_RETURN_IF_ERROR(iree_vm_list_create(iree_vm_make_undefined_type_def(), output_signature.size, iree_allocator_system(), &outputs), "can't allocate output vm list"); + RETURN_PAIR_IF_ERROR(iree_vm_list_create(iree_vm_make_undefined_type_def(), output_signature.size, iree_allocator_system(), &outputs)); // Synchronously invoke the function. - IREE_RETURN_IF_ERROR(iree_vm_invoke( + RETURN_PAIR_IF_ERROR(iree_vm_invoke( context, main_function, IREE_VM_INVOCATION_FLAG_NONE, /*policy=*/NULL, inputs, outputs, iree_allocator_system())); + std::vector> results; + results.resize(output_signature.size); + for (int i = 0; i < output_signature.size; i++) { + iree_hal_buffer_view_t *output_buffer_view = iree_vm_list_get_buffer_view_assign(outputs, i); + if (!output_buffer_view) { + return {iree_make_status(IREE_STATUS_NOT_FOUND, "can't get output buffer view [index=%d]", i), std::nullopt}; + } + + size_t num_bytes = iree_hal_buffer_view_byte_length(output_buffer_view); + void *out_buffer = malloc(num_bytes); + if (!out_buffer) { + return {iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, "can't allocate output buffer [index=%d]", i), std::nullopt}; + } + + iree_hal_element_type_t element_type = iree_hal_buffer_view_element_type(output_buffer_view); + + RETURN_PAIR_IF_ERROR(iree_hal_device_transfer_d2h( + device, iree_hal_buffer_view_buffer(output_buffer_view), 0, out_buffer, + num_bytes, IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, + iree_infinite_timeout())); + + ErlNifBinary binary; + enif_alloc_binary(num_bytes, &binary); + std::memcpy(binary.data, out_buffer, num_bytes); + + // TO - DO : free out_buffer or maybe just use binary.data in its stead + results[i] = std::make_pair(element_type, binary); + } + iree_vm_list_release(inputs); iree_vm_list_release(outputs); iree_vm_context_release(context); iree_vm_instance_release(instance); - return iree_ok_status(); + return {iree_ok_status(), results}; } ERL_NIF_TERM @@ -293,9 +326,7 @@ run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { return exla::nif::error(env, "Unable to decode input terms"); } - std::vector> results; - - iree_status_t status = call(*device, bytecode, inputs, results); + auto [status, results] = call(*device, bytecode, inputs); if (!iree_status_is_ok(status)) { // Dump nice status messages to stderr on failure. @@ -318,7 +349,7 @@ run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { } iree_status_free(status); - return return_results(env, results); + return return_results(env, results.value()); } ERL_NIF_TERM setup_runtime(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { From 2c650b5619397f40e4cc3ff4d734a65007f1114a Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 15 May 2024 22:00:49 -0300 Subject: [PATCH 22/40] feat: instance resource pooling --- exla/c_src/exla/iree/iree.cc | 9 ++++-- exla/c_src/exla/iree/runtime.cc | 36 +++++++++++++++--------- exla/c_src/exla/iree/runtime.h | 1 + exla/lib/exla/application.ex | 2 +- exla/lib/exla/defn.ex | 2 -- exla/lib/exla/mlir/iree.ex | 11 +++++--- exla/lib/exla/mlir/iree/instance_pool.ex | 2 +- 7 files changed, 40 insertions(+), 23 deletions(-) diff --git a/exla/c_src/exla/iree/iree.cc b/exla/c_src/exla/iree/iree.cc index d8005a827c..e083c79e65 100644 --- a/exla/c_src/exla/iree/iree.cc +++ b/exla/c_src/exla/iree/iree.cc @@ -16,8 +16,10 @@ static ErlNifFunc iree_funcs[] = { // MLIR Builder {"global_initialize", 0, global_initialize}, {"compile", 2, compile, ERL_NIF_DIRTY_JOB_CPU_BOUND}, - {"run_module", 3, run_module, ERL_NIF_DIRTY_JOB_IO_BOUND}, - {"setup_runtime", 0, setup_runtime, ERL_NIF_DIRTY_JOB_IO_BOUND}}; + {"run_module", 4, run_module, ERL_NIF_DIRTY_JOB_IO_BOUND}, + {"setup_runtime", 0, setup_runtime, ERL_NIF_DIRTY_JOB_IO_BOUND}, + {"create_instance", 0, create_instance, ERL_NIF_DIRTY_JOB_IO_BOUND}, +}; static int open_resources(ErlNifEnv *env) { const char *mod = "EXLA"; @@ -28,6 +30,9 @@ static int open_resources(ErlNifEnv *env) { if (!exla::nif::open_resource(env, mod, "ExlaIreeHalDevice")) { return -1; } + if (!exla::nif::open_resource(env, mod, "ExlaIreeVmInstance")) { + return -1; + } return 1; } diff --git a/exla/c_src/exla/iree/runtime.cc b/exla/c_src/exla/iree/runtime.cc index 3a1b18a2df..3cb7c46780 100644 --- a/exla/c_src/exla/iree/runtime.cc +++ b/exla/c_src/exla/iree/runtime.cc @@ -187,8 +187,7 @@ ERL_NIF_TERM return_results(ErlNifEnv *env, std::vector>>> -call(iree_hal_device_t *device, std::vector bytecode, std::vector exla_inputs) { - iree_vm_instance_t *instance = nullptr; +call(iree_vm_instance_t *instance, iree_hal_device_t *device, std::vector bytecode, std::vector exla_inputs) { iree_vm_module_t *hal_module = nullptr; iree_vm_module_t *bytecode_module = nullptr; iree_vm_context_t *context = nullptr; @@ -197,10 +196,6 @@ call(iree_hal_device_t *device, std::vector bytecode, std::vector bytecode, std::vector inputs = {}; std::vector bytecode = {}; iree_hal_device_t **device; + iree_vm_instance_t **instance; - if (!exla::nif::get(env, argv[0], device)) { + if (!exla::nif::get(env, argv[0], instance)) { return exla::nif::error(env, "Unable to load device"); } - if (!exla::nif::get_list(env, argv[1], bytecode_vec)) { + if (!exla::nif::get(env, argv[1], device)) { + return exla::nif::error(env, "Unable to load device"); + } + + if (!exla::nif::get_list(env, argv[2], bytecode_vec)) { return exla::nif::error(env, "Unable to load bytecode binary"); } @@ -318,7 +317,7 @@ run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { bytecode[i] = static_cast(byte); } - if (!exla::nif::get_list(env, argv[2], input_terms)) { + if (!exla::nif::get_list(env, argv[3], input_terms)) { return exla::nif::error(env, "Unable to load input terms"); } @@ -326,7 +325,7 @@ run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { return exla::nif::error(env, "Unable to decode input terms"); } - auto [status, results] = call(*device, bytecode, inputs); + auto [status, results] = call(*instance, *device, bytecode, inputs); if (!iree_status_is_ok(status)) { // Dump nice status messages to stderr on failure. @@ -367,4 +366,15 @@ ERL_NIF_TERM setup_runtime(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) } return iree_status_is_ok(status) ? exla::nif::ok(env, exla::nif::make(env, device)) : exla::nif::error(env, "Failed to setup IREE runtime"); -} \ No newline at end of file +} + +ERL_NIF_TERM create_instance(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { + iree_vm_instance_t *instance = nullptr; + iree_status_t status = iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, iree_allocator_system(), &instance); + + if (iree_status_is_ok(status)) { + status = iree_hal_module_register_all_types(instance); + } + + return iree_status_is_ok(status) ? exla::nif::ok(env, exla::nif::make(env, instance)) : exla::nif::error(env, "Failed to create IREE VM instance"); +} diff --git a/exla/c_src/exla/iree/runtime.h b/exla/c_src/exla/iree/runtime.h index 7323987e1a..90d3f84663 100644 --- a/exla/c_src/exla/iree/runtime.h +++ b/exla/c_src/exla/iree/runtime.h @@ -12,6 +12,7 @@ ERL_NIF_TERM run_module(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); ERL_NIF_TERM setup_runtime(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); +ERL_NIF_TERM create_instance(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); namespace exla { namespace iree { diff --git a/exla/lib/exla/application.ex b/exla/lib/exla/application.ex index 95d460573d..a3eaad7259 100644 --- a/exla/lib/exla/application.ex +++ b/exla/lib/exla/application.ex @@ -18,7 +18,7 @@ defmodule EXLA.Application do EXLA.Logger, {NimblePool, worker: {EXLA.MLIR.IREE.InstancePool, :pool_state}, - pool_size: 1, + pool_size: System.schedulers_online(), name: EXLA.MLIR.IREE.InstancePool, lazy: true}, {NimblePool, diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index bd8dfd56f5..d21a6d0a45 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -483,8 +483,6 @@ defmodule EXLA.Defn do if runtime == :iree do {:ok, module_charlist} = EXLA.NIF.mlir_module_to_string(builder.module.ref) - dbg(module_charlist) - {:ok, module_bytecode} = EXLA.MLIR.IREE.compile(module_charlist, "metal") %EXLA.Executable{ diff --git a/exla/lib/exla/mlir/iree.ex b/exla/lib/exla/mlir/iree.ex index 78544cd099..b570c33590 100644 --- a/exla/lib/exla/mlir/iree.ex +++ b/exla/lib/exla/mlir/iree.ex @@ -10,16 +10,19 @@ defmodule EXLA.MLIR.IREE do end def run(module, inputs) do - device = :persistent_term.get({EXLA.MLIR.IREE, :device}) - - run_module(device, module, inputs) + InstancePool.checkout(fn instance -> + device = :persistent_term.get({EXLA.MLIR.IREE, :device}) + run_module(instance, device, module, inputs) + end) end def compile(_module, _target), do: :erlang.nif_error(:undef) def global_initialize, do: :erlang.nif_error(:undef) - def run_module(_instance, _module, _inputs), do: :erlang.nif_error(:undef) + def run_module(_instance, _device, _module, _inputs), do: :erlang.nif_error(:undef) def setup_runtime, do: :erlang.nif_error(:undef) + + def create_instance, do: :erlang.nif_error(:undef) end diff --git a/exla/lib/exla/mlir/iree/instance_pool.ex b/exla/lib/exla/mlir/iree/instance_pool.ex index 9b9b8c6ec5..7aa2693d51 100644 --- a/exla/lib/exla/mlir/iree/instance_pool.ex +++ b/exla/lib/exla/mlir/iree/instance_pool.ex @@ -14,7 +14,7 @@ defmodule EXLA.MLIR.IREE.InstancePool do @impl NimblePool def init_worker(pool_state) do - {:ok, instance} = EXLA.MLIR.IREE.runtime_create_instance() + {:ok, instance} = EXLA.MLIR.IREE.create_instance() {:ok, instance, pool_state} end From d619e30b1fe3d224457f08eb365db35e58c6e61c Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 15 May 2024 22:04:51 -0300 Subject: [PATCH 23/40] test: skip 64 bit tests --- exla/test/exla/defn/expr_test.exs | 8 +++++++- exla/test/test_helper.exs | 3 ++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/exla/test/exla/defn/expr_test.exs b/exla/test/exla/defn/expr_test.exs index 401d199f8a..e09234fb7b 100644 --- a/exla/test/exla/defn/expr_test.exs +++ b/exla/test/exla/defn/expr_test.exs @@ -190,6 +190,13 @@ defmodule EXLA.Defn.ExprTest do defn add_two_int(t), do: t + 2 defn add_two_float(t), do: t + 2.0 + @tag :unsupported_64_bit_op + test "constants f64" do + t = Nx.tensor([1, 2], type: {:f, 64}) + assert_equal(add_two_int(t), Nx.add(t, 2)) + assert_equal(add_two_float(t), Nx.add(t, 2.0)) + end + test "constants" do tensors = [ Nx.tensor([1, 2], type: {:u, 8}), @@ -198,7 +205,6 @@ defmodule EXLA.Defn.ExprTest do Nx.tensor([1, 2], type: {:s, 8}), Nx.tensor([1, 2], type: {:s, 32}), Nx.tensor([1, 2], type: {:f, 32}), - Nx.tensor([1, 2], type: {:f, 64}) ] for t <- tensors do diff --git a/exla/test/test_helper.exs b/exla/test/test_helper.exs index 202a1841d5..4492d52b02 100644 --- a/exla/test/test_helper.exs +++ b/exla/test/test_helper.exs @@ -56,7 +56,8 @@ iree_excludes = :iree_segfault_error, :iree_illegal_op_error, :iree_offset_error, - :multi_device + :multi_device, + :unsupported_64_bit_op ] else [] From 45647560ac43cde9e03a4a9a98ad41a486f23bf0 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 16 May 2024 01:50:29 -0300 Subject: [PATCH 24/40] feat: lazy read from device --- exla/c_src/exla/iree/iree.cc | 5 +- exla/c_src/exla/iree/runtime.cc | 126 ++++++++++++++++++-------------- exla/c_src/exla/iree/runtime.h | 4 + exla/lib/exla/client.ex | 10 +++ exla/lib/exla/defn.ex | 4 + exla/lib/exla/device_buffer.ex | 39 +++++++++- exla/lib/exla/executable.ex | 62 +++++++--------- exla/lib/exla/mlir/iree.ex | 7 ++ exla/test/test_helper.exs | 1 - 9 files changed, 164 insertions(+), 94 deletions(-) diff --git a/exla/c_src/exla/iree/iree.cc b/exla/c_src/exla/iree/iree.cc index e083c79e65..3d2940f482 100644 --- a/exla/c_src/exla/iree/iree.cc +++ b/exla/c_src/exla/iree/iree.cc @@ -19,7 +19,7 @@ static ErlNifFunc iree_funcs[] = { {"run_module", 4, run_module, ERL_NIF_DIRTY_JOB_IO_BOUND}, {"setup_runtime", 0, setup_runtime, ERL_NIF_DIRTY_JOB_IO_BOUND}, {"create_instance", 0, create_instance, ERL_NIF_DIRTY_JOB_IO_BOUND}, -}; + {"read_buffer", 3, read_buffer, ERL_NIF_DIRTY_JOB_IO_BOUND}}; static int open_resources(ErlNifEnv *env) { const char *mod = "EXLA"; @@ -33,6 +33,9 @@ static int open_resources(ErlNifEnv *env) { if (!exla::nif::open_resource(env, mod, "ExlaIreeVmInstance")) { return -1; } + if (!exla::nif::open_resource(env, mod, "ExlaIreeHallBuffer")) { + return -1; + } return 1; } diff --git a/exla/c_src/exla/iree/runtime.cc b/exla/c_src/exla/iree/runtime.cc index 3cb7c46780..a831207612 100644 --- a/exla/c_src/exla/iree/runtime.cc +++ b/exla/c_src/exla/iree/runtime.cc @@ -119,6 +119,7 @@ int load_inputs(ErlNifEnv *env, std::vector terms, std::vector terms, std::vector(env, term, buffer)) { + loaded.push_back(std::move(new exla::iree::runtime::IREEInput(*buffer))); + continue; + } return 0; } @@ -158,27 +163,8 @@ int load_inputs(ErlNifEnv *env, std::vector terms, std::vector> results) { - size_t n = results.size(); - - std::vector nif_terms; - nif_terms.clear(); - nif_terms.reserve(n); - - for (auto [iree_type, binary] : results) { - std::string nx_type; - if (!iree_element_type_to_nx_type(iree_type, nx_type)) { - return exla::nif::error(env, "Unable to convert IREE type to Nx type"); - } - ERL_NIF_TERM type = exla::nif::make(env, nx_type); - ERL_NIF_TERM bin_term = enif_make_binary(env, &binary); - - nif_terms.push_back(enif_make_tuple2(env, type, bin_term)); - } - - auto data = nif_terms.data(); - auto list = enif_make_list_from_array(env, &data[0], n); - return exla::nif::ok(env, list); +ERL_NIF_TERM return_results(ErlNifEnv *env, std::vector results) { + return exla::nif::ok(env, exla::nif::make_list(env, results)); } #define RETURN_PAIR_IF_ERROR(status) \ @@ -186,7 +172,7 @@ ERL_NIF_TERM return_results(ErlNifEnv *env, std::vector>>> +std::pair>> call(iree_vm_instance_t *instance, iree_hal_device_t *device, std::vector bytecode, std::vector exla_inputs) { iree_vm_module_t *hal_module = nullptr; iree_vm_module_t *bytecode_module = nullptr; @@ -220,17 +206,23 @@ call(iree_vm_instance_t *instance, iree_hal_device_t *device, std::vectordims.size(), input->dims.data(), - input->type, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, - (iree_hal_buffer_params_t){ - .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, - .usage = IREE_HAL_BUFFER_USAGE_DEFAULT, - }, - input->data_byte_span(), &arg_buffer_view)); - - iree_vm_ref_t arg_buffer_view_ref = iree_hal_buffer_view_move_ref(arg_buffer_view); + iree_vm_ref_t arg_buffer_view_ref; + + if (input->buffer_view) { + arg_buffer_view_ref = iree_hal_buffer_view_move_ref(input->buffer_view); + } else { + iree_hal_buffer_view_t *arg_buffer_view = nullptr; + RETURN_PAIR_IF_ERROR(iree_hal_buffer_view_allocate_buffer_copy( + device, iree_hal_device_allocator(device), input->dims.size(), input->dims.data(), + input->type, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + (iree_hal_buffer_params_t){ + .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, + .usage = IREE_HAL_BUFFER_USAGE_DEFAULT, + }, + input->data_byte_span(), &arg_buffer_view)); + + arg_buffer_view_ref = iree_hal_buffer_view_move_ref(arg_buffer_view); + } RETURN_PAIR_IF_ERROR(iree_vm_list_push_ref_move(inputs, &arg_buffer_view_ref)); } @@ -249,33 +241,15 @@ call(iree_vm_instance_t *instance, iree_hal_device_t *device, std::vector> results; + std::vector results; results.resize(output_signature.size); for (int i = 0; i < output_signature.size; i++) { - iree_hal_buffer_view_t *output_buffer_view = iree_vm_list_get_buffer_view_assign(outputs, i); + iree_hal_buffer_view_t *output_buffer_view = iree_vm_list_get_buffer_view_retain(outputs, i); if (!output_buffer_view) { return {iree_make_status(IREE_STATUS_NOT_FOUND, "can't get output buffer view [index=%d]", i), std::nullopt}; } - size_t num_bytes = iree_hal_buffer_view_byte_length(output_buffer_view); - void *out_buffer = malloc(num_bytes); - if (!out_buffer) { - return {iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, "can't allocate output buffer [index=%d]", i), std::nullopt}; - } - - iree_hal_element_type_t element_type = iree_hal_buffer_view_element_type(output_buffer_view); - - RETURN_PAIR_IF_ERROR(iree_hal_device_transfer_d2h( - device, iree_hal_buffer_view_buffer(output_buffer_view), 0, out_buffer, - num_bytes, IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, - iree_infinite_timeout())); - - ErlNifBinary binary; - enif_alloc_binary(num_bytes, &binary); - std::memcpy(binary.data, out_buffer, num_bytes); - - // TO - DO : free out_buffer or maybe just use binary.data in its stead - results[i] = std::make_pair(element_type, binary); + results[i] = output_buffer_view; } iree_vm_list_release(inputs); @@ -378,3 +352,47 @@ ERL_NIF_TERM create_instance(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[] return iree_status_is_ok(status) ? exla::nif::ok(env, exla::nif::make(env, instance)) : exla::nif::error(env, "Failed to create IREE VM instance"); } + +ERL_NIF_TERM release_instance(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { + iree_vm_instance_t **instance = nullptr; + + if (!exla::nif::get(env, argv[0], instance)) { + return exla::nif::error(env, "Unable to load device"); + } + + iree_vm_instance_release(*instance); + + return exla::nif::ok(env); +} + +ERL_NIF_TERM read_buffer(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { + iree_hal_buffer_view_t **buffer_view = nullptr; + iree_hal_device_t **device = nullptr; + int64_t num_bytes; + ErlNifBinary binary; + + if (!exla::nif::get(env, argv[0], buffer_view)) { + return exla::nif::error(env, "Unable to load buffer"); + } + + if (!exla::nif::get(env, argv[1], device)) { + return exla::nif::error(env, "Unable to load device"); + } + + if (!exla::nif::get(env, argv[2], &num_bytes)) { + return exla::nif::error(env, "Unable to get buffer size"); + } + + iree_hal_buffer_t *buffer = iree_hal_buffer_view_buffer(*buffer_view); + + iree_device_size_t num_bytes_actual = num_bytes == -1 ? iree_hal_buffer_byte_length(buffer) : (iree_device_size_t)num_bytes; + + enif_alloc_binary(num_bytes_actual, &binary); + + iree_status_t status = iree_hal_device_transfer_d2h( + *device, buffer, 0, binary.data, + num_bytes_actual, IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, + iree_infinite_timeout()); + + return iree_status_is_ok(status) ? exla::nif::ok(env, exla::nif::make(env, binary)) : exla::nif::error(env, "Failed to read buffer"); +} \ No newline at end of file diff --git a/exla/c_src/exla/iree/runtime.h b/exla/c_src/exla/iree/runtime.h index 90d3f84663..dbb3dbda2b 100644 --- a/exla/c_src/exla/iree/runtime.h +++ b/exla/c_src/exla/iree/runtime.h @@ -13,6 +13,7 @@ ERL_NIF_TERM run_module(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); ERL_NIF_TERM setup_runtime(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); ERL_NIF_TERM create_instance(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); +ERL_NIF_TERM read_buffer(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); namespace exla { namespace iree { @@ -24,6 +25,9 @@ class IREEInput { size_t size; std::vector dims; iree_hal_element_type_t type; + iree_hal_buffer_view_t* buffer_view; + + IREEInput(iree_hal_buffer_view_t* buffer_view) : buffer_view(buffer_view) {} // Default constructor IREEInput(void* data, size_t size, std::vector in_dims, iree_hal_element_type_t type) : size(size), type(type) { diff --git a/exla/lib/exla/client.ex b/exla/lib/exla/client.ex index 3ad4ae51b6..40af71aefa 100644 --- a/exla/lib/exla/client.ex +++ b/exla/lib/exla/client.ex @@ -56,6 +56,16 @@ defmodule EXLA.Client do @doc """ Fetches a client with the given `name` from configuration. """ + def fetch!(:iree), + do: %__MODULE__{ + ref: nil, + platform: :iree, + name: :iree, + device_count: -1, + default_device_id: -1, + automatic_transfers: true + } + def fetch!(name) when is_atom(name) do # We could use the LockedCache but that is ETS based and the clients # are static enough that we can keep them on `persistent_term`. diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index d21a6d0a45..99f1600de8 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -632,6 +632,10 @@ defmodule EXLA.Defn do {fun_computation(args, expr, type, state), cache} end + defp cached_recur_operator(:optional, _, %{builder: %{runtime: :iree}}, _cache) do + raise ArgumentError, "optional not supported yet when compiling with IREE" + end + defp cached_recur_operator( :optional, %T{ diff --git a/exla/lib/exla/device_buffer.ex b/exla/lib/exla/device_buffer.ex index dc2944a927..2f32773962 100644 --- a/exla/lib/exla/device_buffer.ex +++ b/exla/lib/exla/device_buffer.ex @@ -10,6 +10,10 @@ defmodule EXLA.DeviceBuffer do defstruct [:ref, :client_name, :device_id, :typespec] @doc false + def from_ref(ref, :iree, device_id, typespec) when is_reference(ref) do + %DeviceBuffer{ref: ref, client_name: :iree, device_id: device_id, typespec: typespec} + end + def from_ref(ref, %Client{name: name}, device_id, typespec) when is_reference(ref) do %DeviceBuffer{ref: ref, client_name: name, device_id: device_id, typespec: typespec} end @@ -47,7 +51,40 @@ defmodule EXLA.DeviceBuffer do without destroying it. If `size` is negative, then it reads the whole buffer. """ - def read(%DeviceBuffer{ref: ref}, size \\ -1) do + def read(buffer, size \\ -1) + + def read(%DeviceBuffer{typespec: typespec, ref: ref, client_name: :iree}, size) do + target_type = {s, w} = typespec.type + + size = + if size == -1 do + div(w, 8) * Tuple.product(typespec.shape) + else + size + end + + {read_size, source_type} = + if target_type in [f: 64, c: 128, s: 64, u: 64] do + {div(size, 2), {s, div(w, 2)}} + else + {size, target_type} + end + + data = EXLA.MLIR.IREE.read(ref, read_size) |> unwrap!() + + if source_type == target_type do + data + else + Nx.with_default_backend(Nx.BinaryBackend, fn -> + data + |> Nx.from_binary(source_type) + |> Nx.as_type(target_type) + |> Nx.to_binary() + end) + end + end + + def read(%DeviceBuffer{ref: ref}, size) do EXLA.NIF.read_device_mem(ref, size) |> unwrap!() end diff --git a/exla/lib/exla/executable.ex b/exla/lib/exla/executable.ex index 7b83d07901..ffca77b19d 100644 --- a/exla/lib/exla/executable.ex +++ b/exla/lib/exla/executable.ex @@ -31,7 +31,9 @@ defmodule EXLA.Executable do } = executable - for data_and_device_id <- run(runtime, client, ref, device_id, inputs, options) do + client = if runtime == :iree, do: :iree, else: client + + for data_and_device_id <- run(client, ref, device_id, inputs, options) do decompose_output(data_and_device_id, output_typespecs, client) end end @@ -77,31 +79,13 @@ defmodule EXLA.Executable do end end - defp run(:xla, client, ref, device_id, inputs, _options) do + defp run(:iree, ref, device_id, inputs, _options) do inputs = for subinputs <- inputs do Enum.map(subinputs, fn %DeviceBuffer{ref: ref} -> ref - %BinaryBuffer{data: data, typespec: typespec} -> - {data, EXLA.Typespec.nif_encode(typespec)} - end) - end - - data = - case client.platform do - :host -> EXLA.NIF.run_cpu(client.ref, ref, inputs, device_id) - _ -> EXLA.NIF.run_io(client.ref, ref, inputs, device_id) - end - - unwrap!(data) - end - - defp run(:iree, _client, ref, device_id, inputs, _options) do - inputs = - for subinputs <- inputs do - Enum.map(subinputs, fn %BinaryBuffer{data: data, typespec: typespec} -> if typespec.type in [f: 64, c: 128, s: 64, u: 64] do {t, w} = typespec.type @@ -116,8 +100,6 @@ defmodule EXLA.Executable do |> Nx.to_binary() end) - data = <> - {data, EXLA.Typespec.nif_encode(typespec)} else {data, EXLA.Typespec.nif_encode(typespec)} @@ -131,23 +113,29 @@ defmodule EXLA.Executable do |> then(&[{&1, device_id}]) end + defp run(client, ref, device_id, inputs, _options) do + inputs = + for subinputs <- inputs do + Enum.map(subinputs, fn + %DeviceBuffer{ref: ref} -> + ref + + %BinaryBuffer{data: data, typespec: typespec} -> + {data, EXLA.Typespec.nif_encode(typespec)} + end) + end + + data = + case client.platform do + :host -> EXLA.NIF.run_cpu(client.ref, ref, inputs, device_id) + _ -> EXLA.NIF.run_io(client.ref, ref, inputs, device_id) + end + + unwrap!(data) + end + defp decompose_output({data, device_id}, output_typespecs, client) do Enum.zip_with(data, output_typespecs, fn - {type, buf}, target_typespec when is_binary(buf) and is_list(type) -> - source_typespec = EXLA.Typespec.nif_decode({type, target_typespec.shape}) - - if source_typespec == target_typespec do - BinaryBuffer.from_binary(buf, target_typespec) - else - Nx.with_default_backend(Nx.BinaryBackend, fn -> - buf - |> Nx.from_binary(source_typespec.type) - |> Nx.as_type(target_typespec.type) - |> Nx.to_binary() - |> BinaryBuffer.from_binary(target_typespec) - end) - end - buf, typespec when is_reference(buf) -> DeviceBuffer.from_ref(buf, client, device_id, typespec) diff --git a/exla/lib/exla/mlir/iree.ex b/exla/lib/exla/mlir/iree.ex index b570c33590..d677e09530 100644 --- a/exla/lib/exla/mlir/iree.ex +++ b/exla/lib/exla/mlir/iree.ex @@ -16,6 +16,11 @@ defmodule EXLA.MLIR.IREE do end) end + def read(buffer, size) do + device = :persistent_term.get({EXLA.MLIR.IREE, :device}) + read_buffer(buffer, device, size) + end + def compile(_module, _target), do: :erlang.nif_error(:undef) def global_initialize, do: :erlang.nif_error(:undef) @@ -25,4 +30,6 @@ defmodule EXLA.MLIR.IREE do def setup_runtime, do: :erlang.nif_error(:undef) def create_instance, do: :erlang.nif_error(:undef) + + def read_buffer(_buffer, _device, _size), do: :erlang.nif_error(:undef) end diff --git a/exla/test/test_helper.exs b/exla/test/test_helper.exs index 4492d52b02..535133fdc8 100644 --- a/exla/test/test_helper.exs +++ b/exla/test/test_helper.exs @@ -2,7 +2,6 @@ target = System.get_env("EXLA_TARGET", "host") client = EXLAHelpers.client() if System.get_env("DEBUG") in ["1", "true"] do - dbg(System.schedulers_online()) IO.gets("Press enter to continue... -- PID: #{System.pid()}") end From 6f18f764d5b35ed3ec8497dc6d517bac853b0558 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 16 May 2024 12:46:27 -0300 Subject: [PATCH 25/40] fix: use func.return only on the higher level --- exla/lib/exla/defn.ex | 16 ++++------ exla/lib/exla/mlir/value.ex | 5 +++ exla/test/exla/backend_test.exs | 54 +++------------------------------ 3 files changed, 16 insertions(+), 59 deletions(-) diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 99f1600de8..16be3f9aa8 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -241,7 +241,7 @@ defmodule EXLA.Defn do output = wrap_tuple_result(acc, acc_typespec) outfeed = outfeed |> Outfeed.with_token(out_token) |> Outfeed.close(builder) - Value.return(builder, output) + Value.func_return(builder, output) {{input_typespecs, input_indexes}, outfeed} end @@ -325,7 +325,7 @@ defmodule EXLA.Defn do if runtime == :iree do {res, _cache} = recur_flatten(expr, state, no_token_cache()) - Value.return(function, res) + Value.func_return(function, res) {:ok, nil} else {res, cache} = recur_flatten(expr, state, new_cache(outfeed)) @@ -333,7 +333,7 @@ defmodule EXLA.Defn do outfeed = cache |> get_outfeed() |> Outfeed.close(function) - Value.return(function, res) + Value.func_return(function, res) {:ok, outfeed} end @@ -632,10 +632,6 @@ defmodule EXLA.Defn do {fun_computation(args, expr, type, state), cache} end - defp cached_recur_operator(:optional, _, %{builder: %{runtime: :iree}}, _cache) do - raise ArgumentError, "optional not supported yet when compiling with IREE" - end - defp cached_recur_operator( :optional, %T{ @@ -1728,7 +1724,7 @@ defmodule EXLA.Defn do params = if runtime == :iree do - Enum.with_index(tail, fn param, i -> {i, param} end) + Enum.with_index([arg_token | tail], fn param, i -> {i, param} end) else Enum.with_index(tail, fn param, i -> {i, param} end) end @@ -1742,11 +1738,11 @@ defmodule EXLA.Defn do if runtime == :iree do {res, comp_cache} = recur_composite(expr, state, cache) - Value.return(function, List.flatten(res)) + Value.func_return(function, List.flatten(res)) {function, merge_outfeed(cache, comp_cache)} else {res, comp_cache} = recur_composite(expr, state, reset_token(cache, arg_token)) - Value.return(function, [get_token(comp_cache) | List.flatten(res)]) + Value.func_return(function, [get_token(comp_cache) | List.flatten(res)]) {function, merge_outfeed(cache, comp_cache)} end end diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index b469048369..5dc0c515f1 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -724,6 +724,11 @@ defmodule EXLA.MLIR.Value do op(func, "stablehlo.while", initial, result_types, regions: regions) end + + def func_return(func, values) when is_list(values) do + op(func, "func.return", values, []) + end + def return(func, values) when is_list(values) do op(func, "stablehlo.return", values, []) end diff --git a/exla/test/exla/backend_test.exs b/exla/test/exla/backend_test.exs index 6afa289fcb..52507560a1 100644 --- a/exla/test/exla/backend_test.exs +++ b/exla/test/exla/backend_test.exs @@ -29,61 +29,17 @@ defmodule EXLA.BackendTest do if iree_runtime?() do @skip_iree [ - count_leading_zeros: 1, - window_min: 3, - window_max: 3, - window_mean: 3, window_sum: 3, window_product: 3, - population_count: 1, + window_mean: 3, + window_min: 3, + window_max: 3, window_scatter_max: 5, window_scatter_min: 5, - fft: 2, - fft2: 2, - ifft: 2, - ifft2: 2, - all_close: 3, - take_diagonal: 2, - take_along_axis: 3, - gather: 3, - mean: 2, - sum: 2, - product: 2, - negate: 1, - reduce: 4, - reduce_min: 2, - reduce_max: 2, - equal: 2, - sigil_M: 2, - slice: 4, - atan2: 2, - weighted_mean: 3, - indexed_add: 4, - concatenate: 2, - stack: 2, - reshape_vectors: 2, - divide: 2, - mode: 2, - conv: 3, - put_slice: 3, - vectorize: 2, + median: 2, argsort: 2, sort: 2, - log2: 1, - select: 3, - pad: 3, - tile: 2, - variance: 2, - standard_deviation: 2, - cumulative_min: 2, - cumulative_max: 2, - dot: 6, - linspace: 3, - bitwise_xor: 2, - broadcast: 3, - axis_index: 2, - size: 1, - nex_axis: 3 + conv: 3, ] else @skip_iree [] From 88591788db36f2e812693c0932b9942e40793a07 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 16 May 2024 12:46:40 -0300 Subject: [PATCH 26/40] chore: format --- exla/lib/exla/mlir/value.ex | 1 - exla/test/exla/backend_test.exs | 2 +- exla/test/exla/defn/expr_test.exs | 6 +++--- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 5dc0c515f1..d7fc25138d 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -724,7 +724,6 @@ defmodule EXLA.MLIR.Value do op(func, "stablehlo.while", initial, result_types, regions: regions) end - def func_return(func, values) when is_list(values) do op(func, "func.return", values, []) end diff --git a/exla/test/exla/backend_test.exs b/exla/test/exla/backend_test.exs index 52507560a1..4222f31bad 100644 --- a/exla/test/exla/backend_test.exs +++ b/exla/test/exla/backend_test.exs @@ -39,7 +39,7 @@ defmodule EXLA.BackendTest do median: 2, argsort: 2, sort: 2, - conv: 3, + conv: 3 ] else @skip_iree [] diff --git a/exla/test/exla/defn/expr_test.exs b/exla/test/exla/defn/expr_test.exs index e09234fb7b..40fd9b3fe0 100644 --- a/exla/test/exla/defn/expr_test.exs +++ b/exla/test/exla/defn/expr_test.exs @@ -196,7 +196,7 @@ defmodule EXLA.Defn.ExprTest do assert_equal(add_two_int(t), Nx.add(t, 2)) assert_equal(add_two_float(t), Nx.add(t, 2.0)) end - + test "constants" do tensors = [ Nx.tensor([1, 2], type: {:u, 8}), @@ -204,7 +204,7 @@ defmodule EXLA.Defn.ExprTest do Nx.tensor([1, 2], type: {:u, 32}), Nx.tensor([1, 2], type: {:s, 8}), Nx.tensor([1, 2], type: {:s, 32}), - Nx.tensor([1, 2], type: {:f, 32}), + Nx.tensor([1, 2], type: {:f, 32}) ] for t <- tensors do @@ -244,7 +244,7 @@ defmodule EXLA.Defn.ExprTest do end describe "//2" do - defn divide_two(a, b), do: a / b + defn divide_two(a, b), do: a / b test "parameters" do tensors = [ From b76b717bd6e2abb20b05b7757b8ef84b344bc6eb Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sun, 19 May 2024 17:36:03 -0300 Subject: [PATCH 27/40] feat: more coverage --- exla/c_src/exla/iree/compiler.cc | 19 +++++++++++-------- exla/c_src/exla/iree/iree.cc | 8 +++++--- exla/c_src/exla/iree/runtime.cc | 16 ++++++++++++++-- exla/c_src/exla/iree/runtime.h | 1 + exla/c_src/iree_runtime/CMakeLists.txt | 2 +- exla/lib/exla/defn.ex | 13 ++++++++++++- exla/lib/exla/device_buffer.ex | 5 ++++- exla/lib/exla/executable.ex | 2 +- exla/lib/exla/mlir/iree.ex | 4 +++- exla/test/exla/defn/api_test.exs | 1 - exla/test/exla/defn/expr_test.exs | 25 +++++++------------------ exla/test/exla/defn/vectorize_test.exs | 1 + exla/test/test_helper.exs | 11 ++--------- 13 files changed, 62 insertions(+), 46 deletions(-) diff --git a/exla/c_src/exla/iree/compiler.cc b/exla/c_src/exla/iree/compiler.cc index 4f275785a9..533f07baca 100644 --- a/exla/c_src/exla/iree/compiler.cc +++ b/exla/c_src/exla/iree/compiler.cc @@ -56,11 +56,21 @@ ERL_NIF_TERM compile(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { } std::string module_str; + std::vector flags_str; + std::vector flags; if (!exla::nif::get(env, argv[0], module_str)) { return exla::nif::error(env, "Unable to get module."); } + if (!exla::nif::get_list(env, argv[1], flags_str)) { + return exla::nif::error(env, "Unable to get list."); + } + + for (auto &flag : flags_str) { + flags.push_back(reinterpret_cast(flag.data)); + } + compiler_state_t state; state.session = NULL; state.source = NULL; @@ -80,14 +90,7 @@ ERL_NIF_TERM compile(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { // Set flags. iree_compiler_error_t *err; - const char *flags[] = { - "--iree-hal-target-backends=metal-spirv", - "--iree-input-type=stablehlo_xla", - "--iree-execution-model=async-internal", - "--output-format=vm-bytecode", - "--iree-opt-demote-f64-to-f32=true", - "--iree-opt-demote-i64-to-i32=true"}; - err = ireeCompilerSessionSetFlags(state.session, 1, flags); + err = ireeCompilerSessionSetFlags(state.session, 1, flags.data()); if (err) { cleanup_compiler_state(state); return exla::nif::error(env, "Unable to set flags."); diff --git a/exla/c_src/exla/iree/iree.cc b/exla/c_src/exla/iree/iree.cc index 3d2940f482..7a81c40921 100644 --- a/exla/c_src/exla/iree/iree.cc +++ b/exla/c_src/exla/iree/iree.cc @@ -17,9 +17,11 @@ static ErlNifFunc iree_funcs[] = { {"global_initialize", 0, global_initialize}, {"compile", 2, compile, ERL_NIF_DIRTY_JOB_CPU_BOUND}, {"run_module", 4, run_module, ERL_NIF_DIRTY_JOB_IO_BOUND}, - {"setup_runtime", 0, setup_runtime, ERL_NIF_DIRTY_JOB_IO_BOUND}, - {"create_instance", 0, create_instance, ERL_NIF_DIRTY_JOB_IO_BOUND}, - {"read_buffer", 3, read_buffer, ERL_NIF_DIRTY_JOB_IO_BOUND}}; + {"setup_runtime", 0, setup_runtime}, + {"create_instance", 0, create_instance}, + {"read_buffer", 3, read_buffer, ERL_NIF_DIRTY_JOB_IO_BOUND}, + {"deallocate_buffer", 1, deallocate_buffer, ERL_NIF_DIRTY_JOB_IO_BOUND}, +}; static int open_resources(ErlNifEnv *env) { const char *mod = "EXLA"; diff --git a/exla/c_src/exla/iree/runtime.cc b/exla/c_src/exla/iree/runtime.cc index a831207612..8c988a4dc7 100644 --- a/exla/c_src/exla/iree/runtime.cc +++ b/exla/c_src/exla/iree/runtime.cc @@ -24,7 +24,7 @@ bool primitive_type_to_iree_element_type(xla::PrimitiveType t, iree_hal_element_ *type = type_enum::IREE_HAL_ELEMENT_TYPE_INT_32; return true; case PrimitiveType::S64: - *type = type_enum::IREE_HAL_ELEMENT_TYPE_INT_32; + *type = type_enum::IREE_HAL_ELEMENT_TYPE_INT_64; return true; case PrimitiveType::U8: *type = type_enum::IREE_HAL_ELEMENT_TYPE_UINT_8; @@ -36,7 +36,7 @@ bool primitive_type_to_iree_element_type(xla::PrimitiveType t, iree_hal_element_ *type = type_enum::IREE_HAL_ELEMENT_TYPE_UINT_32; return true; case PrimitiveType::U64: - *type = type_enum::IREE_HAL_ELEMENT_TYPE_UINT_32; + *type = type_enum::IREE_HAL_ELEMENT_TYPE_UINT_64; return true; case PrimitiveType::BF16: *type = type_enum::IREE_HAL_ELEMENT_TYPE_BFLOAT_16; @@ -395,4 +395,16 @@ ERL_NIF_TERM read_buffer(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { iree_infinite_timeout()); return iree_status_is_ok(status) ? exla::nif::ok(env, exla::nif::make(env, binary)) : exla::nif::error(env, "Failed to read buffer"); +} + +ERL_NIF_TERM deallocate_buffer(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { + iree_hal_buffer_view_t **buffer_view = nullptr; + + if (!exla::nif::get(env, argv[0], buffer_view)) { + return exla::nif::error(env, "Unable to load buffer"); + } + + iree_hal_buffer_view_release(*buffer_view); + + return exla::nif::ok(env); } \ No newline at end of file diff --git a/exla/c_src/exla/iree/runtime.h b/exla/c_src/exla/iree/runtime.h index dbb3dbda2b..c3277bf039 100644 --- a/exla/c_src/exla/iree/runtime.h +++ b/exla/c_src/exla/iree/runtime.h @@ -14,6 +14,7 @@ ERL_NIF_TERM run_module(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); ERL_NIF_TERM setup_runtime(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); ERL_NIF_TERM create_instance(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); ERL_NIF_TERM read_buffer(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); +ERL_NIF_TERM deallocate_buffer(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); namespace exla { namespace iree { diff --git a/exla/c_src/iree_runtime/CMakeLists.txt b/exla/c_src/iree_runtime/CMakeLists.txt index 61d86f20f2..60b7a3ecf6 100644 --- a/exla/c_src/iree_runtime/CMakeLists.txt +++ b/exla/c_src/iree_runtime/CMakeLists.txt @@ -80,7 +80,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -Wall -Wextra -Wno-unused-function set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter -Wno-missing-field-initializers -DLLVM_VERSION_STRING= -std=c++17") if($ENV{DEBUG}) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O0") else() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") endif() diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 16be3f9aa8..ac246ce3e7 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -483,7 +483,18 @@ defmodule EXLA.Defn do if runtime == :iree do {:ok, module_charlist} = EXLA.NIF.mlir_module_to_string(builder.module.ref) - {:ok, module_bytecode} = EXLA.MLIR.IREE.compile(module_charlist, "metal") + + flags = [ + "--iree-hal-target-backends=metal-spirv", + "--iree-input-type=stablehlo_xla", + "--iree-execution-model=async-internal", + "--output-format=vm-bytecode", + "--iree-opt-demote-f64-to-f32=true", + "--iree-opt-demote-i64-to-i32=false", + "--iree-input-demote-i64-to-i32=false" + ] + + {:ok, module_bytecode} = EXLA.MLIR.IREE.compile(module_charlist, flags) %EXLA.Executable{ client: client, diff --git a/exla/lib/exla/device_buffer.ex b/exla/lib/exla/device_buffer.ex index 2f32773962..8af13201de 100644 --- a/exla/lib/exla/device_buffer.ex +++ b/exla/lib/exla/device_buffer.ex @@ -64,7 +64,7 @@ defmodule EXLA.DeviceBuffer do end {read_size, source_type} = - if target_type in [f: 64, c: 128, s: 64, u: 64] do + if target_type in [f: 64, c: 128] do {div(size, 2), {s, div(w, 2)}} else {size, target_type} @@ -93,6 +93,9 @@ defmodule EXLA.DeviceBuffer do Returns `:ok` | `:already_deallocated`. """ + def deallocate(%DeviceBuffer{ref: ref, client_name: :iree}), + do: EXLA.MLIR.IREE.deallocate_buffer(ref) |> unwrap!() + def deallocate(%DeviceBuffer{ref: ref}), do: EXLA.NIF.deallocate_device_mem(ref) |> unwrap!() diff --git a/exla/lib/exla/executable.ex b/exla/lib/exla/executable.ex index ffca77b19d..5a79d26757 100644 --- a/exla/lib/exla/executable.ex +++ b/exla/lib/exla/executable.ex @@ -87,7 +87,7 @@ defmodule EXLA.Executable do ref %BinaryBuffer{data: data, typespec: typespec} -> - if typespec.type in [f: 64, c: 128, s: 64, u: 64] do + if typespec.type in [f: 64, c: 128] do {t, w} = typespec.type w2 = div(w, 2) target_type = {t, w2} diff --git a/exla/lib/exla/mlir/iree.ex b/exla/lib/exla/mlir/iree.ex index d677e09530..bb4e08d9b7 100644 --- a/exla/lib/exla/mlir/iree.ex +++ b/exla/lib/exla/mlir/iree.ex @@ -21,7 +21,7 @@ defmodule EXLA.MLIR.IREE do read_buffer(buffer, device, size) end - def compile(_module, _target), do: :erlang.nif_error(:undef) + def compile(_module, _flags), do: :erlang.nif_error(:undef) def global_initialize, do: :erlang.nif_error(:undef) @@ -31,5 +31,7 @@ defmodule EXLA.MLIR.IREE do def create_instance, do: :erlang.nif_error(:undef) + def deallocate_buffer(_buffer), do: :erlang.nif_error(:undef) + def read_buffer(_buffer, _device, _size), do: :erlang.nif_error(:undef) end diff --git a/exla/test/exla/defn/api_test.exs b/exla/test/exla/defn/api_test.exs index 314074e118..cac07dc01a 100644 --- a/exla/test/exla/defn/api_test.exs +++ b/exla/test/exla/defn/api_test.exs @@ -99,7 +99,6 @@ defmodule EXLA.Defn.APITest do end describe "batch" do - @tag :iree_resource_exhausted_error test "when padded" do input = Nx.tensor([[1, 2, 3]], backend: EXLA.Backend) batch = [input] |> Nx.Batch.concatenate() |> Nx.Batch.pad(1) diff --git a/exla/test/exla/defn/expr_test.exs b/exla/test/exla/defn/expr_test.exs index 40fd9b3fe0..405dd81cc3 100644 --- a/exla/test/exla/defn/expr_test.exs +++ b/exla/test/exla/defn/expr_test.exs @@ -11,7 +11,6 @@ defmodule EXLA.Defn.ExprTest do describe "tuples" do defn add_subtract_tuple(a, b), do: {a + b, a - b} - @tag :iree_offset_error test "on results" do assert_equal(add_subtract_tuple(2, 3), {Nx.tensor(5), Nx.tensor(-1)}) @@ -370,7 +369,6 @@ defmodule EXLA.Defn.ExprTest do defn atan2_two(a, b), do: Nx.atan2(a, b) - @tag :iree_resource_exhausted_error test "atan2" do <> = <<0x8000000000000000::64>> left = Nx.tensor([-1.0, neg_zero, 0.0, 1.0]) @@ -405,7 +403,6 @@ defmodule EXLA.Defn.ExprTest do defn bitwise_and(a, b), do: a &&& b - @tag :iree_resource_exhausted_error test "bitwise_and" do assert Nx.shape(bitwise_and(@left, @right)) == {5, 5} assert_equal(bitwise_and(@left, @right), Nx.bitwise_and(@left, @right)) @@ -413,7 +410,6 @@ defmodule EXLA.Defn.ExprTest do defn bitwise_or(a, b), do: a ||| b - @tag :iree_resource_exhausted_error test "bitwise_or" do assert Nx.shape(bitwise_or(@left, @right)) == {5, 5} assert_equal(bitwise_or(@left, @right), Nx.bitwise_or(@left, @right)) @@ -421,7 +417,6 @@ defmodule EXLA.Defn.ExprTest do defn bitwise_not(a), do: ~~~a - @tag :iree_resource_exhausted_error test "bitwise_not" do assert Nx.shape(bitwise_not(@left)) == {5} assert_equal(bitwise_not(@left), Nx.bitwise_not(@left)) @@ -448,7 +443,6 @@ defmodule EXLA.Defn.ExprTest do defn left_shift(a, b), do: a <<< b - @tag :iree_resource_exhausted_error test "left_shift" do assert Nx.shape(left_shift(@left, @right)) == {5, 5} assert_equal(left_shift(@left, @right), Nx.left_shift(@left, @right)) @@ -462,7 +456,6 @@ defmodule EXLA.Defn.ExprTest do defn right_shift(a, b), do: a >>> b - @tag :iree_resource_exhausted_error test "right_shift" do assert Nx.shape(right_shift(@left_signed, @right_signed)) == {9, 9} @@ -811,7 +804,9 @@ defmodule EXLA.Defn.ExprTest do defn_var = Macro.var(defn_fun, __MODULE__) defn unquote(defn_fun)(t), do: Nx.unquote(fun)(t) - @tag :iree_type_mismatch_error + if fun in [:is_nan, :is_infinity, :rsqrt] do + @tag :iree_type_mismatch_error + end test "#{fun}" do assert_all_close( unquote(defn_fun)(@float_tensor), @@ -827,7 +822,7 @@ defmodule EXLA.Defn.ExprTest do end describe "complex ops" do - @describetag :iree_unsupported_fft_error + @describetag :iree_illegal_op_error defn fft(t, opts \\ []), do: Nx.fft(t, opts) defn ifft(t, opts \\ []), do: Nx.ifft(t, opts) @@ -1870,7 +1865,6 @@ defmodule EXLA.Defn.ExprTest do ) end - @tag :iree_resource_exhausted_error test "indexed_add handles different input types" do target = Nx.tensor([0]) indices = Nx.tensor([[0]]) @@ -1903,7 +1897,6 @@ defmodule EXLA.Defn.ExprTest do Nx.indexed_put(t, i, u) end - @tag :iree_resource_exhausted_error test "indexed_add works for multi-dim tensor" do target = Nx.broadcast(0, {2, 3, 4}) @@ -1944,7 +1937,6 @@ defmodule EXLA.Defn.ExprTest do ) end - @tag :iree_resource_exhausted_error test "indexed_put handles different input types" do target = Nx.tensor([0]) indices = Nx.tensor([[0]]) @@ -2240,7 +2232,6 @@ defmodule EXLA.Defn.ExprTest do describe "reduce_min" do defn reduce_min(t), do: Nx.reduce_min(t) - @tag :iree_resource_exhausted_error test "computes the minimum across types" do assert_equal(Nx.tensor([1, 2, 3]) |> reduce_min(), Nx.tensor(1)) @@ -2262,7 +2253,6 @@ defmodule EXLA.Defn.ExprTest do ) end - @tag :iree_resource_exhausted_error test "computes the minimum across nan" do assert_equal(Nx.tensor([:nan, :nan, :nan]) |> reduce_min(), Nx.tensor(:nan)) end @@ -2539,7 +2529,7 @@ defmodule EXLA.Defn.ExprTest do end describe "window min" do - @describetag :iree_wrong_result_error + @describetag :iree_segfault_error defn window_min0(t), do: Nx.window_min(t, {2}) defn window_min1(t), do: Nx.window_min(t, {1, 2, 1}) @@ -2939,7 +2929,6 @@ defmodule EXLA.Defn.ExprTest do end @tag :unsupported_64_bit_op - @tag :iree_resource_exhausted_error test "computes the convolution with general padding, stride" do img = Nx.iota({2, 1, 12, 24}, type: {:f, 64}) kernel = Nx.iota({2, 1, 6, 6}, type: {:f, 64}) @@ -3952,9 +3941,9 @@ defmodule EXLA.Defn.ExprTest do end describe "top_k" do + @describetag :iree_segfault_error defn top_1(t), do: Nx.top_k(t, k: 1) - @tag :iree_offset_error test "returns top 1 values and indices" do a = Nx.iota({5}) assert_equal(top_1(a), {Nx.tensor([4]), Nx.tensor([4])}) @@ -3968,7 +3957,7 @@ defmodule EXLA.Defn.ExprTest do end describe "argsort" do - @describetag :iree_offset_error + @describetag :iree_segfault_error defn argsort0(t), do: Nx.argsort(t, axis: 0) defn argsort1(t), do: Nx.argsort(t, axis: 1) defn argsort1_asc(t), do: Nx.argsort(t, axis: 1, direction: :asc) diff --git a/exla/test/exla/defn/vectorize_test.exs b/exla/test/exla/defn/vectorize_test.exs index a009dfe04d..a21fbaf2c8 100644 --- a/exla/test/exla/defn/vectorize_test.exs +++ b/exla/test/exla/defn/vectorize_test.exs @@ -160,6 +160,7 @@ defmodule EXLA.Defn.VectorizeTest do end describe "cond" do + @describetag :token deftransformp send_value(val, opts \\ []) do Nx.Defn.Kernel.hook( val, diff --git a/exla/test/test_helper.exs b/exla/test/test_helper.exs index 535133fdc8..49b97547b5 100644 --- a/exla/test/test_helper.exs +++ b/exla/test/test_helper.exs @@ -46,17 +46,10 @@ iree_excludes = if compiler_mode == :iree do [ :token, - :iree_hangup_error, - :iree_type_mismatch_error, - :iree_resource_exhausted_error, - :iree_key_not_found_error, - :iree_wrong_result_error, - :iree_unsupported_fft_error, :iree_segfault_error, :iree_illegal_op_error, - :iree_offset_error, - :multi_device, - :unsupported_64_bit_op + :iree_key_not_found_error, + :iree_type_mismatch_error ] else [] From c02d839ebfdbda99d695b899b64021f26e2572e4 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 21 May 2024 18:34:36 -0300 Subject: [PATCH 28/40] chore: update iree --- exla/Makefile | 22 ++++++++++++++-------- exla/test/test_helper.exs | 4 +--- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/exla/Makefile b/exla/Makefile index c8029070f9..92da14a064 100644 --- a/exla/Makefile +++ b/exla/Makefile @@ -108,7 +108,7 @@ $(EXLA_CACHE_OBJ_DIR)/%.o: $(EXLA_DIR)/%.cc $(HEADERS) @ mkdir -p $(EXLA_CACHE_OBJ_DIR)/iree $(CXX) $(CFLAGS) -c $< -o $@ -$(EXLA_CACHE_SO): $(XLA_EXTENSION_DIR) $(IREE_COMPILER_LIB) $(OBJECTS) +$(EXLA_CACHE_SO): $(XLA_EXTENSION_DIR) cache/iree $(OBJECTS) $(CXX) $(OBJECTS) -o $(EXLA_CACHE_SO) $(LDFLAGS) @@ -120,7 +120,19 @@ else IREE_CMAKE_CONFIG = Release endif -$(EXLA_CACHE_IREE_COMPILER_SO): +# This is gonna be extracted out to a library +# For now, we're doing it here to get things working +IREE_COMMIT := d4aa8491a755e31d590f00a507e6c3859dfa662d +cache/iree: + @mkdir -p cache + @git clone https://github.com/iree-org/iree cache/iree + @cd cache/iree && git checkout $(IREE_COMMIT) + @cd cache/iree && git submodule update --init --recursive + @mkdir -p cache/iree/build + cmake -G Ninja -B cache/iree/build -DIREE_BUILD_TESTS=OFF -DIREE_BUILD_SAMPLES=OFF -DIREE_ENABLE_ASSERTIONS=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_CXX_FLAGS="-fvisibility=hidden" cache/iree + cmake --build cache/iree/build + +$(EXLA_CACHE_IREE_COMPILER_SO): cache/iree @mkdir -p $(IREE_CMAKE_BUILD_DIR) @mkdir -p cache/objs/iree_cmake_out @mkdir -p cache/objs/mlir_cmake_out @@ -136,11 +148,5 @@ $(EXLA_CACHE_IREE_COMPILER_SO): cmake --build $(IREE_CMAKE_BUILD_DIR) --config $(IREE_CMAKE_CONFIG) --verbose cmake --install $(IREE_CMAKE_BUILD_DIR) --config $(IREE_CMAKE_CONFIG) --prefix cache -$(IREE_COMPILER_LIB): - # TO-DO: setup proper download and caching of the iree compiler - @ln -s $(HOME)/coding/iree cache/iree - cmake -G Ninja -B cache/iree/build -DIREE_BUILD_TESTS=OFF -DIREE_BUILD_SAMPLES=OFF -DIREE_ENABLE_ASSERTIONS=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_CXX_FLAGS="-fvisibility=hidden" cache/iree - cmake --build cache/iree/build - clean: rm -rf cache diff --git a/exla/test/test_helper.exs b/exla/test/test_helper.exs index 49b97547b5..4700a8655a 100644 --- a/exla/test/test_helper.exs +++ b/exla/test/test_helper.exs @@ -40,10 +40,8 @@ cuda_required = [:cuda_required] end -compiler_mode = :iree - iree_excludes = - if compiler_mode == :iree do + if runtime == :iree do [ :token, :iree_segfault_error, From 0fef901fd3d52a2ae5d440ac091e9ab987d064f0 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 21 May 2024 19:13:16 -0300 Subject: [PATCH 29/40] feat: update reduce --- exla/lib/exla/mlir/module.ex | 2 ++ exla/lib/exla/mlir/value.ex | 10 +++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/exla/lib/exla/mlir/module.ex b/exla/lib/exla/mlir/module.ex index 3c91024c32..1005d15ed9 100644 --- a/exla/lib/exla/mlir/module.ex +++ b/exla/lib/exla/mlir/module.ex @@ -98,6 +98,8 @@ defmodule EXLA.MLIR.Module do do: -1, else: Keyword.get(options, :device_id, client.default_device_id) + # module.ref |> EXLA.NIF.mlir_module_to_string() |> elem(1) |> IO.puts() + ref = EXLA.NIF.mlir_compile( client.ref, diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index b469048369..07ba4a6cfa 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -625,7 +625,7 @@ defmodule EXLA.MLIR.Value do ) do operands = inputs ++ init_values result_types = typespecs_to_mlir_types(typespecs) - attributes = [dimensions: attr_dense_i64_elements(dimensions)] + attributes = [dimensions: attr_array_i64_elements(dimensions)] regions = [reducer] op(func, "stablehlo.reduce", operands, result_types, attributes: attributes, regions: regions) end @@ -904,6 +904,14 @@ defmodule EXLA.MLIR.Value do <> end + defp attr_array_i64_elements([]) do + "array" + end + + defp attr_array_i64_elements(list) do + "array" + end + defp attr_dense_i64_elements(list) do attr_dense_elements(list, {:s, 64}, {length(list)}) end From 7e14d030eb7f0b42f498c86b869193c778a5f6cc Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 21 May 2024 19:18:25 -0300 Subject: [PATCH 30/40] feat: update all dense i64 to array i64 --- exla/lib/exla/mlir/value.ex | 48 +++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 07ba4a6cfa..b5f853b2da 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -186,13 +186,13 @@ defmodule EXLA.MLIR.Value do def reverse(%Value{function: func} = operand, dims, typespec) do result_types = typespecs_to_mlir_types([typespec]) - attributes = [dimensions: attr_dense_i64_elements(dims)] + attributes = [dimensions: attr_array_i64_elements(dims)] op(func, "stablehlo.reverse", [operand], result_types, attributes: attributes) |> one!() end def transpose(%Value{function: func} = operand, axes, typespec) do result_types = typespecs_to_mlir_types([typespec]) - attributes = [permutation: attr_dense_i64_elements(axes)] + attributes = [permutation: attr_array_i64_elements(axes)] op(func, "stablehlo.transpose", [operand], result_types, attributes: attributes) |> one!() end @@ -200,9 +200,9 @@ defmodule EXLA.MLIR.Value do result_types = typespecs_to_mlir_types([typespec]) attributes = [ - start_indices: attr_dense_i64_elements(starts), - limit_indices: attr_dense_i64_elements(limits), - strides: attr_dense_i64_elements(strides) + start_indices: attr_array_i64_elements(starts), + limit_indices: attr_array_i64_elements(limits), + strides: attr_array_i64_elements(strides) ] op(func, "stablehlo.slice", [operand], result_types, attributes: attributes) |> one!() @@ -211,7 +211,7 @@ defmodule EXLA.MLIR.Value do def dynamic_slice(%Value{function: func} = operand, starts, lengths, typespec) do result_types = typespecs_to_mlir_types([typespec]) operands = [operand] ++ starts - attributes = [slice_sizes: attr_dense_i64_elements(lengths)] + attributes = [slice_sizes: attr_array_i64_elements(lengths)] op(func, "stablehlo.dynamic_slice", operands, result_types, attributes: attributes) |> one!() end @@ -303,7 +303,7 @@ defmodule EXLA.MLIR.Value do result_types = typespecs_to_mlir_types([typespec]) attributes = [ - broadcast_dimensions: attr_dense_i64_elements(axes) + broadcast_dimensions: attr_array_i64_elements(axes) ] op(func, "stablehlo.broadcast_in_dim", [operand], result_types, attributes: attributes) @@ -347,9 +347,9 @@ defmodule EXLA.MLIR.Value do {padding_low, padding_high, padding_mid} = unzip_padding_config(padding_config) attributes = [ - edge_padding_low: attr_dense_i64_elements(padding_low), - edge_padding_high: attr_dense_i64_elements(padding_high), - interior_padding: attr_dense_i64_elements(padding_mid) + edge_padding_low: attr_array_i64_elements(padding_low), + edge_padding_high: attr_array_i64_elements(padding_high), + interior_padding: attr_array_i64_elements(padding_mid) ] op(func, "stablehlo.pad", [operand, pad], result_types, attributes: attributes) |> one!() @@ -375,7 +375,7 @@ defmodule EXLA.MLIR.Value do attributes = [ fft_type: fft_type, - fft_length: attr_dense_i64_elements(List.wrap(fft_length)) + fft_length: attr_array_i64_elements(List.wrap(fft_length)) ] op(func, "stablehlo.fft", [value], result_types, attributes: attributes) |> one!() @@ -451,8 +451,8 @@ defmodule EXLA.MLIR.Value do result_types = typespecs_to_mlir_types([typespec]) attributes = [ - window_dimensions: attr_dense_i64_elements(window_dimensions), - window_strides: attr_dense_i64_elements(window_strides), + window_dimensions: attr_array_i64_elements(window_dimensions), + window_strides: attr_array_i64_elements(window_strides), padding: attr_padding(padding) ] @@ -501,7 +501,7 @@ defmodule EXLA.MLIR.Value do attributes = [ dimension_numbers: dimension_numbers, - slice_sizes: attr_dense_i64_elements(slice_sizes), + slice_sizes: attr_array_i64_elements(slice_sizes), indices_are_sorted: attr_boolean(false) ] @@ -546,10 +546,10 @@ defmodule EXLA.MLIR.Value do attr_precision_config = attr_precision_config(precision_config) attributes = [ - window_strides: attr_dense_i64_elements(strides), + window_strides: attr_array_i64_elements(strides), padding: attr_padding(padding), - lhs_dilation: attr_dense_i64_elements(input_dilation), - rhs_dilation: attr_dense_i64_elements(kernel_dilation), + lhs_dilation: attr_array_i64_elements(input_dilation), + rhs_dilation: attr_array_i64_elements(kernel_dilation), dimension_numbers: attr_conv_dimension_numbers(dimension_numbers), feature_group_count: attr_i64(feature_group_count), batch_group_count: attr_i64(batch_group_count), @@ -645,10 +645,10 @@ defmodule EXLA.MLIR.Value do result_types = typespecs_to_mlir_types(typespecs) attributes = [ - window_dimensions: attr_dense_i64_elements(window_dimensions), - window_strides: attr_dense_i64_elements(window_strides), - base_dilations: attr_dense_i64_elements(input_dilations), - window_dilations: attr_dense_i64_elements(window_dilations), + window_dimensions: attr_array_i64_elements(window_dimensions), + window_strides: attr_array_i64_elements(window_strides), + base_dilations: attr_array_i64_elements(input_dilations), + window_dilations: attr_array_i64_elements(window_dilations), padding: attr_padding(padding) ] @@ -669,7 +669,7 @@ defmodule EXLA.MLIR.Value do result_types = typespecs_to_mlir_types([typespec]) attributes = [ - dimensions: attr_dense_i64_elements(dimensions) + dimensions: attr_array_i64_elements(dimensions) ] regions = [mapper] @@ -912,10 +912,6 @@ defmodule EXLA.MLIR.Value do "array" end - defp attr_dense_i64_elements(list) do - attr_dense_elements(list, {:s, 64}, {length(list)}) - end - defp attr_dense_elements([], type, {0} = shape) do "dense<[]> : #{type_tensor(type, shape)}" end From 7d0f6f40f2969b6b496e6b8d17deaacd27581ef4 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 21 May 2024 19:46:43 -0300 Subject: [PATCH 31/40] feat: update stablehlo --- exla/c_src/exla/iree/runtime.cc | 4 ++-- exla/lib/exla/defn.ex | 3 +-- exla/lib/exla/device_buffer.ex | 22 ++-------------------- exla/lib/exla/executable.ex | 18 +----------------- 4 files changed, 6 insertions(+), 41 deletions(-) diff --git a/exla/c_src/exla/iree/runtime.cc b/exla/c_src/exla/iree/runtime.cc index 8c988a4dc7..1673435ce9 100644 --- a/exla/c_src/exla/iree/runtime.cc +++ b/exla/c_src/exla/iree/runtime.cc @@ -102,13 +102,13 @@ bool iree_element_type_to_nx_type(iree_hal_element_type_t type, std::string &nx_ nx_type = "f32"; return true; case type_enum::IREE_HAL_ELEMENT_TYPE_FLOAT_64: - nx_type = "f32"; + nx_type = "f64"; return true; case type_enum::IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_64: nx_type = "c64"; return true; case type_enum::IREE_HAL_ELEMENT_TYPE_COMPLEX_FLOAT_128: - nx_type = "c64"; + nx_type = "c128"; return true; default: return false; diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 7dfd8fecaa..8466ba2edf 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -489,8 +489,7 @@ defmodule EXLA.Defn do "--iree-input-type=stablehlo_xla", "--iree-execution-model=async-internal", "--output-format=vm-bytecode", - "--iree-opt-demote-f64-to-f32=true", - "--iree-opt-demote-i64-to-i32=false", + "--iree-input-demote-f64-to-f32=false", "--iree-input-demote-i64-to-i32=false" ] diff --git a/exla/lib/exla/device_buffer.ex b/exla/lib/exla/device_buffer.ex index 8af13201de..72c4225509 100644 --- a/exla/lib/exla/device_buffer.ex +++ b/exla/lib/exla/device_buffer.ex @@ -54,7 +54,7 @@ defmodule EXLA.DeviceBuffer do def read(buffer, size \\ -1) def read(%DeviceBuffer{typespec: typespec, ref: ref, client_name: :iree}, size) do - target_type = {s, w} = typespec.type + {_s, w} = typespec.type size = if size == -1 do @@ -63,25 +63,7 @@ defmodule EXLA.DeviceBuffer do size end - {read_size, source_type} = - if target_type in [f: 64, c: 128] do - {div(size, 2), {s, div(w, 2)}} - else - {size, target_type} - end - - data = EXLA.MLIR.IREE.read(ref, read_size) |> unwrap!() - - if source_type == target_type do - data - else - Nx.with_default_backend(Nx.BinaryBackend, fn -> - data - |> Nx.from_binary(source_type) - |> Nx.as_type(target_type) - |> Nx.to_binary() - end) - end + EXLA.MLIR.IREE.read(ref, size) |> unwrap!() end def read(%DeviceBuffer{ref: ref}, size) do diff --git a/exla/lib/exla/executable.ex b/exla/lib/exla/executable.ex index 5a79d26757..0df40177fa 100644 --- a/exla/lib/exla/executable.ex +++ b/exla/lib/exla/executable.ex @@ -87,23 +87,7 @@ defmodule EXLA.Executable do ref %BinaryBuffer{data: data, typespec: typespec} -> - if typespec.type in [f: 64, c: 128] do - {t, w} = typespec.type - w2 = div(w, 2) - target_type = {t, w2} - - data = - Nx.with_default_backend(Nx.BinaryBackend, fn -> - data - |> Nx.from_binary(typespec.type) - |> Nx.as_type(target_type) - |> Nx.to_binary() - end) - - {data, EXLA.Typespec.nif_encode(typespec)} - else - {data, EXLA.Typespec.nif_encode(typespec)} - end + {data, EXLA.Typespec.nif_encode(typespec)} end) end From f198bb93abe9557f2938a7e56bd01f9c91cdecbd Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 21 May 2024 20:00:37 -0300 Subject: [PATCH 32/40] feat: parameterize device on application setup --- exla/c_src/exla/iree/iree.cc | 2 +- exla/c_src/exla/iree/runtime.cc | 13 +++++++++++-- exla/lib/exla/application.ex | 3 ++- exla/lib/exla/defn.ex | 2 +- exla/lib/exla/mlir/iree.ex | 2 +- 5 files changed, 16 insertions(+), 6 deletions(-) diff --git a/exla/c_src/exla/iree/iree.cc b/exla/c_src/exla/iree/iree.cc index 7a81c40921..a95d4375fe 100644 --- a/exla/c_src/exla/iree/iree.cc +++ b/exla/c_src/exla/iree/iree.cc @@ -17,7 +17,7 @@ static ErlNifFunc iree_funcs[] = { {"global_initialize", 0, global_initialize}, {"compile", 2, compile, ERL_NIF_DIRTY_JOB_CPU_BOUND}, {"run_module", 4, run_module, ERL_NIF_DIRTY_JOB_IO_BOUND}, - {"setup_runtime", 0, setup_runtime}, + {"setup_runtime", 1, setup_runtime}, {"create_instance", 0, create_instance}, {"read_buffer", 3, read_buffer, ERL_NIF_DIRTY_JOB_IO_BOUND}, {"deallocate_buffer", 1, deallocate_buffer, ERL_NIF_DIRTY_JOB_IO_BOUND}, diff --git a/exla/c_src/exla/iree/runtime.cc b/exla/c_src/exla/iree/runtime.cc index 1673435ce9..fb7d980a3e 100644 --- a/exla/c_src/exla/iree/runtime.cc +++ b/exla/c_src/exla/iree/runtime.cc @@ -326,16 +326,25 @@ run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { } ERL_NIF_TERM setup_runtime(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { + if (argc != 1) { + return exla::nif::error(env, "Bad argument count."); + } + iree_hal_device_t *device = nullptr; + std::string device_uri; + + if (!exla::nif::get(env, argv[0], device_uri)) { + return exla::nif::error(env, "Unable to get buffer size"); + } iree_status_t status = iree_hal_register_all_available_drivers(iree_hal_driver_registry_default()); - char device_uri[] = "metal://0000000100000971"; // TO-DO: change this to an argument + // char device_uri[] = "metal://0000000100000971"; // TO-DO: change this to an argument if (iree_status_is_ok(status)) { status = iree_hal_create_device( iree_hal_driver_registry_default(), - iree_make_cstring_view(device_uri), + iree_make_cstring_view(device_uri.c_str()), iree_allocator_system(), &device); } diff --git a/exla/lib/exla/application.ex b/exla/lib/exla/application.ex index a3eaad7259..29e8c25b0a 100644 --- a/exla/lib/exla/application.ex +++ b/exla/lib/exla/application.ex @@ -11,7 +11,8 @@ defmodule EXLA.Application do end EXLA.MLIR.IREE.global_initialize() - {:ok, device} = EXLA.MLIR.IREE.setup_runtime() + # {:ok, device} = EXLA.MLIR.IREE.setup_runtime(~c"metal://0000000100000971") + {:ok, device} = EXLA.MLIR.IREE.setup_runtime(~c"local-sync://") :persistent_term.put({EXLA.MLIR.IREE, :device}, device) children = [ diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 8466ba2edf..803984935f 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -485,7 +485,7 @@ defmodule EXLA.Defn do {:ok, module_charlist} = EXLA.NIF.mlir_module_to_string(builder.module.ref) flags = [ - "--iree-hal-target-backends=metal-spirv", + "--iree-hal-target-backends=llvm-cpu", "--iree-input-type=stablehlo_xla", "--iree-execution-model=async-internal", "--output-format=vm-bytecode", diff --git a/exla/lib/exla/mlir/iree.ex b/exla/lib/exla/mlir/iree.ex index bb4e08d9b7..58173e297c 100644 --- a/exla/lib/exla/mlir/iree.ex +++ b/exla/lib/exla/mlir/iree.ex @@ -27,7 +27,7 @@ defmodule EXLA.MLIR.IREE do def run_module(_instance, _device, _module, _inputs), do: :erlang.nif_error(:undef) - def setup_runtime, do: :erlang.nif_error(:undef) + def setup_runtime(_device_uri), do: :erlang.nif_error(:undef) def create_instance, do: :erlang.nif_error(:undef) From 9ca3acb0d3cadce21c1d7841c8e54b2224189b75 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 21 May 2024 20:21:02 -0300 Subject: [PATCH 33/40] fix: return ui8 from is_nan and is_infinity --- exla/lib/exla/mlir/value.ex | 95 +++++++++++++++++++++---------------- 1 file changed, 55 insertions(+), 40 deletions(-) diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 07466729d1..991ee59e85 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -125,57 +125,72 @@ defmodule EXLA.MLIR.Value do end end - def is_infinity(%Value{function: func} = operand, typespec) do + def is_infinity(%Value{function: func} = operand, output_typespec) do %{type: type} = get_typespec(operand) - typespec = Typespec.to_type(typespec, {:pred, 8}) + typespec = Typespec.to_type(output_typespec, {:pred, 8}) - cond do - Nx.Type.complex?(type) -> - float_typespec = Typespec.to_type(typespec, complex_part_type(type)) - real = real(operand, float_typespec) - imag = imag(operand, float_typespec) - is_inf_real = is_infinity(real, typespec) - is_inf_imag = is_infinity(imag, typespec) - bitwise_or(is_inf_real, is_inf_imag, typespec) - - Nx.Type.integer?(type) -> - # Integers are never infinity. We use inequality to make sure - # the operand is still a part of the computation - not_equal(operand, operand, typespec) + result = + cond do + Nx.Type.complex?(type) -> + float_typespec = Typespec.to_type(typespec, complex_part_type(type)) + real = real(operand, float_typespec) + imag = imag(operand, float_typespec) + is_inf_real = is_infinity(real, typespec) + is_inf_imag = is_infinity(imag, typespec) + bitwise_or(is_inf_real, is_inf_imag, typespec) + + Nx.Type.integer?(type) -> + # Integers are never infinity. We use inequality to make sure + # the operand is still a part of the computation + not_equal(operand, operand, typespec) + + true -> + result_types = typespecs_to_mlir_types([typespec]) + op(func, "chlo.is_inf", [operand], result_types) |> one!() + end - true -> - result_types = typespecs_to_mlir_types([typespec]) - op(func, "chlo.is_inf", [operand], result_types) |> one!() + if output_typespec.type == {:u, 8} do + convert(result, output_typespec) + else + result end end - def is_nan(%Value{function: func} = operand, typespec) do + def is_nan(%Value{function: func} = operand, output_typespec) do %{type: type} = get_typespec(operand) + %{type: output_type} = output_typespec - typespec = Typespec.to_type(typespec, {:pred, 8}) + typespec = Typespec.to_type(output_typespec, {:pred, 8}) - cond do - Nx.Type.complex?(type) -> - float_typespec = Typespec.to_type(typespec, complex_part_type(type)) - real = real(operand, float_typespec) - imag = imag(operand, float_typespec) - is_nan_real = is_nan(real, typespec) - is_nan_imag = is_nan(imag, typespec) - bitwise_or(is_nan_real, is_nan_imag, typespec) - - Nx.Type.integer?(type) -> - # Integers are never nan. We use inequality to make sure - # the operand is still a part of the computation - not_equal(operand, operand, typespec) + result = + cond do + Nx.Type.complex?(type) -> + float_typespec = Typespec.to_type(typespec, complex_part_type(type)) + real = real(operand, float_typespec) + imag = imag(operand, float_typespec) + is_nan_real = is_nan(real, typespec) + is_nan_imag = is_nan(imag, typespec) + bitwise_or(is_nan_real, is_nan_imag, typespec) + + Nx.Type.integer?(type) -> + # Integers are never nan. We use inequality to make sure + # the operand is still a part of the computation + not_equal(operand, operand, typespec) + + true -> + result_types = typespecs_to_mlir_types([typespec]) + is_inf = op(func, "chlo.is_inf", [operand], result_types) |> one!() + is_finite = op(func, "stablehlo.is_finite", [operand], result_types) |> one!() + is_not_inf = bitwise_not(is_inf, typespec) + is_not_finite = bitwise_not(is_finite, typespec) + bitwise_and(is_not_inf, is_not_finite, typespec) + end - true -> - result_types = typespecs_to_mlir_types([typespec]) - is_inf = op(func, "chlo.is_inf", [operand], result_types) |> one!() - is_finite = op(func, "stablehlo.is_finite", [operand], result_types) |> one!() - is_not_inf = bitwise_not(is_inf, typespec) - is_not_finite = bitwise_not(is_finite, typespec) - bitwise_and(is_not_inf, is_not_finite, typespec) + if output_type == {:u, 8} do + convert(result, output_typespec) + else + result end end From 7d288ed2a7044c2314c3a336e10d09dd50d3af39 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 21 May 2024 20:22:01 -0300 Subject: [PATCH 34/40] chore: remove type mismatch tag --- exla/test/exla/defn/expr_test.exs | 3 --- exla/test/test_helper.exs | 3 +-- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/exla/test/exla/defn/expr_test.exs b/exla/test/exla/defn/expr_test.exs index 5170441343..c2793dcfa6 100644 --- a/exla/test/exla/defn/expr_test.exs +++ b/exla/test/exla/defn/expr_test.exs @@ -804,9 +804,6 @@ defmodule EXLA.Defn.ExprTest do defn_var = Macro.var(defn_fun, __MODULE__) defn unquote(defn_fun)(t), do: Nx.unquote(fun)(t) - if fun in [:is_nan, :is_infinity, :rsqrt] do - @tag :iree_type_mismatch_error - end test "#{fun}" do assert_all_close( unquote(defn_fun)(@float_tensor), diff --git a/exla/test/test_helper.exs b/exla/test/test_helper.exs index 4700a8655a..6998edb2b3 100644 --- a/exla/test/test_helper.exs +++ b/exla/test/test_helper.exs @@ -46,8 +46,7 @@ iree_excludes = :token, :iree_segfault_error, :iree_illegal_op_error, - :iree_key_not_found_error, - :iree_type_mismatch_error + :iree_key_not_found_error ] else [] From bdd34c4403d190e4d93d9d7a661e3ecc50d5971f Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 21 May 2024 20:24:22 -0300 Subject: [PATCH 35/40] chore: use metal again --- exla/lib/exla/application.ex | 4 ++-- exla/lib/exla/defn.ex | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/exla/lib/exla/application.ex b/exla/lib/exla/application.ex index 29e8c25b0a..00476cb025 100644 --- a/exla/lib/exla/application.ex +++ b/exla/lib/exla/application.ex @@ -11,8 +11,8 @@ defmodule EXLA.Application do end EXLA.MLIR.IREE.global_initialize() - # {:ok, device} = EXLA.MLIR.IREE.setup_runtime(~c"metal://0000000100000971") - {:ok, device} = EXLA.MLIR.IREE.setup_runtime(~c"local-sync://") + {:ok, device} = EXLA.MLIR.IREE.setup_runtime(~c"metal://0000000100000971") + # {:ok, device} = EXLA.MLIR.IREE.setup_runtime(~c"local-sync://") :persistent_term.put({EXLA.MLIR.IREE, :device}, device) children = [ diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 803984935f..98e7ee31bd 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -485,7 +485,8 @@ defmodule EXLA.Defn do {:ok, module_charlist} = EXLA.NIF.mlir_module_to_string(builder.module.ref) flags = [ - "--iree-hal-target-backends=llvm-cpu", + # "--iree-hal-target-backends=llvm-cpu", + "--iree-hal-target-backends=metal-spirv", "--iree-input-type=stablehlo_xla", "--iree-execution-model=async-internal", "--output-format=vm-bytecode", From 18836360c7b112c45246acd8832e03b3a3373cb7 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 22 May 2024 03:59:36 -0300 Subject: [PATCH 36/40] feat: all green tests (some skipped) --- exla/c_src/exla/iree/runtime.cc | 2 -- exla/lib/exla/defn.ex | 2 +- exla/lib/exla/device_buffer.ex | 21 +++++++++++++++-- exla/lib/exla/executable.ex | 16 ++++++++++++- exla/test/exla/backend_test.exs | 39 +++++++++++++++++++++++++------ exla/test/exla/defn/expr_test.exs | 10 ++++++++ exla/test/exla/random_test.exs | 1 + exla/test/test_helper.exs | 4 +++- 8 files changed, 81 insertions(+), 14 deletions(-) diff --git a/exla/c_src/exla/iree/runtime.cc b/exla/c_src/exla/iree/runtime.cc index fb7d980a3e..e128d58112 100644 --- a/exla/c_src/exla/iree/runtime.cc +++ b/exla/c_src/exla/iree/runtime.cc @@ -339,8 +339,6 @@ ERL_NIF_TERM setup_runtime(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) iree_status_t status = iree_hal_register_all_available_drivers(iree_hal_driver_registry_default()); - // char device_uri[] = "metal://0000000100000971"; // TO-DO: change this to an argument - if (iree_status_is_ok(status)) { status = iree_hal_create_device( iree_hal_driver_registry_default(), diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 98e7ee31bd..9d739fdacd 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -490,7 +490,7 @@ defmodule EXLA.Defn do "--iree-input-type=stablehlo_xla", "--iree-execution-model=async-internal", "--output-format=vm-bytecode", - "--iree-input-demote-f64-to-f32=false", + "--iree-input-demote-f64-to-f32=true", "--iree-input-demote-i64-to-i32=false" ] diff --git a/exla/lib/exla/device_buffer.ex b/exla/lib/exla/device_buffer.ex index 72c4225509..aea8fe9a4f 100644 --- a/exla/lib/exla/device_buffer.ex +++ b/exla/lib/exla/device_buffer.ex @@ -53,8 +53,10 @@ defmodule EXLA.DeviceBuffer do """ def read(buffer, size \\ -1) + @downcast_types [f: 64, c: 128] + def read(%DeviceBuffer{typespec: typespec, ref: ref, client_name: :iree}, size) do - {_s, w} = typespec.type + {s, w} = typespec.type size = if size == -1 do @@ -63,7 +65,22 @@ defmodule EXLA.DeviceBuffer do size end - EXLA.MLIR.IREE.read(ref, size) |> unwrap!() + read_size = + if {s, w} in @downcast_types do + div(size, 2) + else + size + end + + data = EXLA.MLIR.IREE.read(ref, read_size) |> unwrap!() + + if read_size != size do + Nx.with_default_backend(Nx.BinaryBackend, fn -> + data |> Nx.from_binary({s, div(w, 2)}) |> Nx.as_type({s, w}) |> Nx.to_binary() + end) + else + data + end end def read(%DeviceBuffer{ref: ref}, size) do diff --git a/exla/lib/exla/executable.ex b/exla/lib/exla/executable.ex index 0df40177fa..ebc77c657b 100644 --- a/exla/lib/exla/executable.ex +++ b/exla/lib/exla/executable.ex @@ -87,7 +87,21 @@ defmodule EXLA.Executable do ref %BinaryBuffer{data: data, typespec: typespec} -> - {data, EXLA.Typespec.nif_encode(typespec)} + case typespec do + %{type: {:f, 64}} -> + data = + Nx.with_default_backend(Nx.BinaryBackend, fn -> + data + |> Nx.from_binary(:f64) + |> Nx.as_type(:f32) + |> Nx.to_binary() + end) + + {data, EXLA.Typespec.nif_encode(%{typespec | type: {:f, 32}})} + + _ -> + {data, EXLA.Typespec.nif_encode(typespec)} + end end) end diff --git a/exla/test/exla/backend_test.exs b/exla/test/exla/backend_test.exs index 860e1c0a99..8c6b627a68 100644 --- a/exla/test/exla/backend_test.exs +++ b/exla/test/exla/backend_test.exs @@ -29,17 +29,41 @@ defmodule EXLA.BackendTest do if iree_runtime?() do @skip_iree [ - window_sum: 3, - window_product: 3, - window_mean: 3, - window_min: 3, - window_max: 3, + # illegal op errors + conv: 3, + fft2: 2, + ifft2: 2, + fft: 2, + ifft: 2, + population_count: 1, window_scatter_max: 5, window_scatter_min: 5, - median: 2, + # clz is not fully supported + count_leading_zeros: 1, + # precision errors + atan2: 2, + acosh: 1, + phase: 1, + standard_deviation: 2, + weighted_mean: 3, + covariance: 3, + variance: 2, + atan: 1, + acos: 1, + cbrt: 1, + # wrong result (segfault in argsort and sort) argsort: 2, sort: 2, - conv: 3 + is_nan: 1, + top_k: 2, + # cryptic crashes + median: 2, + window_min: 3, + window_max: 3, + window_sum: 3, + window_product: 3, + window_mean: 3, + all_close: 3, ] else @skip_iree [] @@ -134,6 +158,7 @@ defmodule EXLA.BackendTest do assert_equal(result, Nx.tensor([0, 1, 1, 0])) end + @tag :iree_key_not_found_error test "Nx.LinAlg.svd/2" do t = Nx.iota({4, 4}) assert {u, s, vt} = Nx.LinAlg.svd(t, max_iter: 10_000) diff --git a/exla/test/exla/defn/expr_test.exs b/exla/test/exla/defn/expr_test.exs index c2793dcfa6..045d7d5761 100644 --- a/exla/test/exla/defn/expr_test.exs +++ b/exla/test/exla/defn/expr_test.exs @@ -369,6 +369,7 @@ defmodule EXLA.Defn.ExprTest do defn atan2_two(a, b), do: Nx.atan2(a, b) + @tag :iree_wrong_result_error test "atan2" do <> = <<0x8000000000000000::64>> left = Nx.tensor([-1.0, neg_zero, 0.0, 1.0]) @@ -456,6 +457,7 @@ defmodule EXLA.Defn.ExprTest do defn right_shift(a, b), do: a >>> b + @tag :iree_wrong_result_error test "right_shift" do assert Nx.shape(right_shift(@left_signed, @right_signed)) == {9, 9} @@ -2822,6 +2824,7 @@ defmodule EXLA.Defn.ExprTest do ) end + @tag :iree_illegal_op_error test "computes a convolution with channels last" do img = Nx.iota({8, 12, 12, 3}, type: {:f, 32}, names: [:batch, :height, :width, :channels]) kernel = Nx.iota({6, 3, 2, 2}, type: {:f, 32}) @@ -2839,6 +2842,7 @@ defmodule EXLA.Defn.ExprTest do assert %{names: [:batch, :height, :width, :channels], shape: {8, 11, 11, 6}} = lhs end + @tag :iree_illegal_op_error test "computes a convolution with a permutation" do img = Nx.iota({12, 12, 3, 4}, type: {:f, 32}) kernel = Nx.iota({3, 2, 32, 2}, type: {:f, 32}) @@ -2875,6 +2879,7 @@ defmodule EXLA.Defn.ExprTest do end @tag :unsupported_64_bit_op + @tag :iree_illegal_op_error test "computes the convolution with valid padding, no stride" do img = Nx.iota({5, 1, 12, 12}, type: {:f, 64}) kernel = Nx.iota({32, 1, 3, 3}, type: {:f, 64}) @@ -2885,6 +2890,7 @@ defmodule EXLA.Defn.ExprTest do end @tag :unsupported_64_bit_op + @tag :iree_illegal_op_error test "computes the convolution with valid padding, {2, 2} stride" do img = Nx.iota({25, 1, 11, 8}, type: {:f, 64}) kernel = Nx.iota({32, 1, 3, 3}, type: {:f, 64}) @@ -2895,6 +2901,7 @@ defmodule EXLA.Defn.ExprTest do end @tag :unsupported_64_bit_op + @tag :iree_illegal_op_error test "computes the convolution with same padding, no stride" do img = Nx.iota({13, 3, 10, 6}, type: {:f, 64}) kernel = Nx.iota({32, 3, 3, 3}, type: {:f, 64}) @@ -2905,6 +2912,7 @@ defmodule EXLA.Defn.ExprTest do end @tag :unsupported_64_bit_op + @tag :iree_illegal_op_error test "computes the convolution with same padding, stride" do img = Nx.iota({32, 1, 9, 9}, type: {:f, 64}) kernel = Nx.iota({32, 1, 7, 7}, type: {:f, 64}) @@ -2958,6 +2966,7 @@ defmodule EXLA.Defn.ExprTest do end @tag :unsupported_64_bit_op + @tag :iree_illegal_op_error test "computes a dilated convolution" do img = Nx.iota({4, 3, 10, 10}, type: {:f, 64}) kernel = Nx.iota({6, 3, 2, 2}, type: {:f, 64}) @@ -2978,6 +2987,7 @@ defmodule EXLA.Defn.ExprTest do assert_all_close(lhs, rhs) end + @tag :iree_illegal_op_error test "computes a conv with both dilations" do img = Nx.iota({4, 3, 15, 15}, type: {:f, 32}) kernel = Nx.iota({6, 3, 3, 2}, type: {:f, 32}) diff --git a/exla/test/exla/random_test.exs b/exla/test/exla/random_test.exs index db1bea2ec4..7f9dcb547a 100644 --- a/exla/test/exla/random_test.exs +++ b/exla/test/exla/random_test.exs @@ -9,6 +9,7 @@ defmodule EXLA.NxRandomTest do end describe "range" do + @tag :iree_operand_does_not_dominate_error test "randint" do key = Nx.Random.key(127) diff --git a/exla/test/test_helper.exs b/exla/test/test_helper.exs index 6998edb2b3..91d692a973 100644 --- a/exla/test/test_helper.exs +++ b/exla/test/test_helper.exs @@ -46,7 +46,9 @@ iree_excludes = :token, :iree_segfault_error, :iree_illegal_op_error, - :iree_key_not_found_error + :iree_key_not_found_error, + :iree_wrong_result_error, + :iree_operand_does_not_dominate_error ] else [] From 4e05bd9f537bf54477c6abb76a148f6243f944fd Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 22 May 2024 23:24:43 -0300 Subject: [PATCH 37/40] wip: trace debugging --- exla/Makefile | 15 ++- exla/c_src/exla/iree/compiler.cc | 136 ++++++++++++++--------- exla/c_src/exla/iree/runtime.cc | 19 ++-- exla/c_src/iree_runtime/CMakeLists.txt | 8 +- exla/lib/exla/application.ex | 57 +++++++++- exla/lib/exla/defn.ex | 33 ++++-- exla/lib/exla/mlir/iree/instance_pool.ex | 62 ++++++----- exla/mix.exs | 7 +- 8 files changed, 225 insertions(+), 112 deletions(-) diff --git a/exla/Makefile b/exla/Makefile index 6c92607766..8d73f8f86f 100644 --- a/exla/Makefile +++ b/exla/Makefile @@ -127,9 +127,9 @@ cache/iree: @mkdir -p cache @git clone https://github.com/iree-org/iree cache/iree @cd cache/iree && git checkout $(IREE_COMMIT) - @cd cache/iree && git submodule update --init --recursive + @cd cache/iree && git submodule update --init --recursive --depth 1 @mkdir -p cache/iree/build - cmake -G Ninja -B cache/iree/build -DIREE_BUILD_TESTS=OFF -DIREE_BUILD_SAMPLES=OFF -DIREE_ENABLE_ASSERTIONS=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_CXX_FLAGS="-fvisibility=hidden" cache/iree + cmake -G Ninja -B cache/iree/build -DTRACY_ENABLE=ON -DIREE_ENABLE_RUNTIME_TRACING=ON -DIREE_ENABLE_COMPILER_TRACING=ON -DIREE_BUILD_TESTS=OFF -DIREE_BUILD_SAMPLES=OFF -DIREE_ENABLE_ASSERTIONS=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_CXX_FLAGS="-fvisibility=hidden" cache/iree cmake --build cache/iree/build $(EXLA_CACHE_IREE_COMPILER_SO): cache/iree @@ -140,12 +140,19 @@ $(EXLA_CACHE_IREE_COMPILER_SO): cache/iree cmake -S c_src/iree_runtime -B $(IREE_CMAKE_BUILD_DIR) \ -DIREE_COMPILER_INCLUDE_PATH=$(IREE_COMPILER_INCLUDE_PATH) \ -DIREE_COMPILER_DIR=$(IREE_COMPILER_DIR) \ + -DIREE_ENABLE_RUNTIME_TRACING=ON \ + -DIREE_TRACING_MODE=ON \ + -DTRACY_ENABLE=ON \ -DXLA_INCLUDE_PATH=$(abspath $(XLA_INCLUDE_PATH)) \ -DIREE_INSTALL_PREFIX=$(IREE_INSTALL_PREFIX) \ -DCACHE_DIR=cache\ -DXLA_EXTENSION_LIB=$(abspath $(XLA_EXTENSION_LIB))\ - -DCMAKE_BUILD_TYPE=$(IREE_CMAKE_CONFIG) - cmake --build $(IREE_CMAKE_BUILD_DIR) --config $(IREE_CMAKE_CONFIG) --verbose + -DCMAKE_BUILD_TYPE=$(IREE_CMAKE_CONFIG)\ + -DTRACY_ENABLE=1\ + -DTRACY_CALLSTACK=ON\ + -DTRACY_NO_SAMPLING=OFF\ + -DTRACY_NO_VERIFY=ON + cmake --build $(IREE_CMAKE_BUILD_DIR) --config $(IREE_CMAKE_CONFIG) cmake --install $(IREE_CMAKE_BUILD_DIR) --config $(IREE_CMAKE_CONFIG) --prefix cache clean: diff --git a/exla/c_src/exla/iree/compiler.cc b/exla/c_src/exla/iree/compiler.cc index 533f07baca..e842e04ed5 100644 --- a/exla/c_src/exla/iree/compiler.cc +++ b/exla/c_src/exla/iree/compiler.cc @@ -2,6 +2,7 @@ #include // For O_WRONLY, O_CREAT, O_TRUNC #include +#include #include #include #include @@ -12,6 +13,7 @@ #include #include +#include typedef struct compiler_state_t { iree_compiler_session_t *session; @@ -36,39 +38,38 @@ void cleanup_compiler_state(compiler_state_t s) { ireeCompilerSourceDestroy(s.source); if (s.session) ireeCompilerSessionDestroy(s.session); - // ireeCompilerGlobalShutdown(); + ireeCompilerGlobalShutdown(); } static void initializeCompiler(struct compiler_state_t *state) { - // ireeCompilerGlobalInitialize(); + ireeCompilerGlobalInitialize(); state->session = ireeCompilerSessionCreate(); state->context = ireeCompilerSessionBorrowContext(state->session); } -static void shutdownCompiler(struct compiler_state_t *state) { - ireeCompilerSessionDestroy(state->session); - // ireeCompilerGlobalShutdown(); -} - ERL_NIF_TERM compile(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { + ZoneScopedN("compile main"); if (argc != 2) { return exla::nif::error(env, "Bad argument count."); } std::string module_str; - std::vector flags_str; + std::vector flags_str; std::vector flags; - if (!exla::nif::get(env, argv[0], module_str)) { - return exla::nif::error(env, "Unable to get module."); - } + { + ZoneScopedN("compiler get arguments"); + if (!exla::nif::get(env, argv[0], module_str)) { + return exla::nif::error(env, "Unable to get module."); + } - if (!exla::nif::get_list(env, argv[1], flags_str)) { - return exla::nif::error(env, "Unable to get list."); - } + if (!exla::nif::get_list(env, argv[1], flags_str)) { + return exla::nif::error(env, "Unable to get list."); + } - for (auto &flag : flags_str) { - flags.push_back(reinterpret_cast(flag.data)); + for (auto &flag : flags_str) { + flags.push_back(flag.c_str()); + } } compiler_state_t state; @@ -80,61 +81,89 @@ ERL_NIF_TERM compile(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { initializeCompiler(&state); - MlirOperation module_op = mlirOperationCreateParse( - state.context, - mlirStringRefCreate(module_str.c_str(), module_str.size()), - mlirStringRefCreateFromCString("source.stablehlo")); - if (mlirOperationIsNull(module_op)) { - return exla::nif::error(env, "Unable to create MlirOperation module."); + MlirOperation module_op; + + { + ZoneScopedN("Parse module"); + module_op = mlirOperationCreateParse( + state.context, + mlirStringRefCreate(module_str.c_str(), module_str.size()), + mlirStringRefCreateFromCString("source.stablehlo")); + if (mlirOperationIsNull(module_op)) { + return exla::nif::error(env, "Unable to create MlirOperation module."); + } } // Set flags. - iree_compiler_error_t *err; - err = ireeCompilerSessionSetFlags(state.session, 1, flags.data()); - if (err) { - cleanup_compiler_state(state); - return exla::nif::error(env, "Unable to set flags."); + { + ZoneScopedN("Set flags"); + error = ireeCompilerSessionSetFlags(state.session, flags.size(), flags.data()); + if (error) { + const char *msg = ireeCompilerErrorGetMessage(error); + + cleanup_compiler_state(state); + + std::stringstream ss; + ss << "Unable to set flags due to error: "; + ss << msg; + + return exla::nif::error(env, ss.str().c_str()); + } } - state.invocation = ireeCompilerInvocationCreate(state.session); - ireeCompilerInvocationEnableConsoleDiagnostics(state.invocation); + { + ZoneScopedN("Create invocation"); + state.invocation = ireeCompilerInvocationCreate(state.session); + ireeCompilerInvocationEnableConsoleDiagnostics(state.invocation); + } - if (!ireeCompilerInvocationImportStealModule(state.invocation, module_op)) { - cleanup_compiler_state(state); - return exla::nif::error(env, "Unable to import module."); + { + ZoneScopedN("Import module"); + if (!ireeCompilerInvocationImportStealModule(state.invocation, module_op)) { + cleanup_compiler_state(state); + return exla::nif::error(env, "Unable to import module."); + } } // Compile. - if (!ireeCompilerInvocationPipeline(state.invocation, iree_compiler_pipeline_t::IREE_COMPILER_PIPELINE_STD)) { - cleanup_compiler_state(state); - return exla::nif::error(env, "Unable to compile module."); + { + ZoneScopedN("Invocation Pipeline"); + if (!ireeCompilerInvocationPipeline(state.invocation, iree_compiler_pipeline_t::IREE_COMPILER_PIPELINE_STD)) { + cleanup_compiler_state(state); + return exla::nif::error(env, "Unable to compile module."); + } } - fflush(stdout); - - error = ireeCompilerOutputOpenMembuffer(&state.output); - if (error) { - handle_compiler_error(error); - cleanup_compiler_state(state); - return exla::nif::error(env, "Error opening output membuffer"); + { + ZoneScopedN("Open output membuffer"); + error = ireeCompilerOutputOpenMembuffer(&state.output); + if (error) { + handle_compiler_error(error); + cleanup_compiler_state(state); + return exla::nif::error(env, "Error opening output membuffer"); + } } - error = ireeCompilerInvocationOutputVMBytecode(state.invocation, state.output); - if (error) { - handle_compiler_error(error); - cleanup_compiler_state(state); - return exla::nif::error(env, "Failed to output VM Bytecode"); + { + ZoneScopedN("Output VM Bytecode"); + error = ireeCompilerInvocationOutputVMBytecode(state.invocation, state.output); + if (error) { + handle_compiler_error(error); + cleanup_compiler_state(state); + return exla::nif::error(env, "Failed to output VM Bytecode"); + } } uint8_t *contents; uint64_t size; + ErlNifBinary output_binary; - error = ireeCompilerOutputMapMemory(state.output, (void **)&contents, &size); + { + ZoneScopedN("Map and copy output to binary"); + error = ireeCompilerOutputMapMemory(state.output, (void **)&contents, &size); - std::vector bytes_term; - bytes_term.resize(size); - for (size_t i = 0; i < size; i++) { - bytes_term[i] = enif_make_uint(env, static_cast(contents[i])); + enif_alloc_binary(size, &output_binary); + memcpy(output_binary.data, contents, size); } if (error) { @@ -145,5 +174,6 @@ ERL_NIF_TERM compile(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { cleanup_compiler_state(state); - return exla::nif::ok(env, enif_make_list_from_array(env, bytes_term.data(), bytes_term.size())); + IREE_TRACE_ZONE_END(compile); + return exla::nif::ok(env, enif_make_binary(env, &output_binary)); } \ No newline at end of file diff --git a/exla/c_src/exla/iree/runtime.cc b/exla/c_src/exla/iree/runtime.cc index e128d58112..5f4741e0c9 100644 --- a/exla/c_src/exla/iree/runtime.cc +++ b/exla/c_src/exla/iree/runtime.cc @@ -5,6 +5,7 @@ #include #include +#include bool primitive_type_to_iree_element_type(xla::PrimitiveType t, iree_hal_element_type_t *type) { using xla::PrimitiveType; @@ -173,7 +174,8 @@ ERL_NIF_TERM return_results(ErlNifEnv *env, std::vector>> -call(iree_vm_instance_t *instance, iree_hal_device_t *device, std::vector bytecode, std::vector exla_inputs) { +call(iree_vm_instance_t *instance, iree_hal_device_t *device, ErlNifBinary bytecode, std::vector exla_inputs) { + ZoneScoped; iree_vm_module_t *hal_module = nullptr; iree_vm_module_t *bytecode_module = nullptr; iree_vm_context_t *context = nullptr; @@ -187,7 +189,7 @@ call(iree_vm_instance_t *instance, iree_hal_device_t *device, std::vector bytecode_vec = {}; std::vector input_terms = {}; std::vector inputs = {}; - std::vector bytecode = {}; + ErlNifBinary bytecode; iree_hal_device_t **device; iree_vm_instance_t **instance; @@ -279,18 +282,10 @@ run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { return exla::nif::error(env, "Unable to load device"); } - if (!exla::nif::get_list(env, argv[2], bytecode_vec)) { + if (!enif_inspect_binary(env, argv[2], &bytecode)) { return exla::nif::error(env, "Unable to load bytecode binary"); } - bytecode.clear(); - bytecode.resize(bytecode_vec.size()); - unsigned int byte; - for (int i = 0; i < bytecode_vec.size(); i++) { - enif_get_uint(env, bytecode_vec[i], &byte); - bytecode[i] = static_cast(byte); - } - if (!exla::nif::get_list(env, argv[3], input_terms)) { return exla::nif::error(env, "Unable to load input terms"); } diff --git a/exla/c_src/iree_runtime/CMakeLists.txt b/exla/c_src/iree_runtime/CMakeLists.txt index 60b7a3ecf6..1c48452080 100644 --- a/exla/c_src/iree_runtime/CMakeLists.txt +++ b/exla/c_src/iree_runtime/CMakeLists.txt @@ -46,7 +46,7 @@ target_include_directories(${_NAME} SYSTEM "${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree/third_party/llvm-project/mlir/include" ) -add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree" "${CMAKE_CURRENT_SOURCE_DIR}/../../cache/objs/iree_cmake_out" EXCLUDE_FROM_ALL) +add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree" "${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree/build" EXCLUDE_FROM_ALL) install( TARGETS ${_NAME} @@ -99,4 +99,8 @@ if(NOT APPLE) target_link_libraries(${_NAME} "${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree/build/lib/libIREECompiler.so") else() target_link_libraries(${_NAME} "${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree/build/lib/libIREECompiler.dylib") -endif() \ No newline at end of file +endif() + +add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree/third_party/tracy" "${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree/third_party/tracy/build") +include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree/third_party/tracy/public") +target_link_libraries(${_NAME} TracyClient) diff --git a/exla/lib/exla/application.ex b/exla/lib/exla/application.ex index 00476cb025..430a5d8223 100644 --- a/exla/lib/exla/application.ex +++ b/exla/lib/exla/application.ex @@ -15,13 +15,60 @@ defmodule EXLA.Application do # {:ok, device} = EXLA.MLIR.IREE.setup_runtime(~c"local-sync://") :persistent_term.put({EXLA.MLIR.IREE, :device}, device) + {:ok, instance} = EXLA.MLIR.IREE.create_instance() + :persistent_term.put({EXLA.MLIR.IREE, :instance}, instance) + + :persistent_term.put({EXLA.Telemetry, :checkout}, {0, 0, 0, nil}) + + :telemetry.attach( + :exla_telemetry_checkout, + [:exla, :mlir, :iree, :instance_pool, :checkout], + fn _name, %{duration: duration}, _meta, _config -> + {total, count, max, min} = :persistent_term.get({EXLA.Telemetry, :checkout}) + total = total + duration + count = count + 1 + max = max(max, duration) + min = min(min, duration) + + File.write( + "/tmp/checkout.txt", + "#{total / count / 1_000} ms | #{max / 1_000} ms | #{min / 1_000} ms\n" + ) + + :persistent_term.put({EXLA.Telemetry, :checkout}, {total, count, max, min}) + end, + nil + ) + + :persistent_term.put({EXLA.Telemetry, :compile}, {0, 0, 0, nil}) + + :telemetry.attach( + :exla_telemetry_compile, + [:exla, :mlir, :iree, :compile], + fn _name, %{duration: duration}, _meta, _config -> + {total, count, max, min} = :persistent_term.get({EXLA.Telemetry, :compile}) + total = total + duration + count = count + 1 + max = max(max, duration) + min = min(min, duration) + + File.write( + "/tmp/compile.txt", + "#{total / count / 1_000} ms | #{max / 1_000} ms | #{min / 1_000} ms\n" + ) + + :persistent_term.put({EXLA.Telemetry, :compile}, {total, count, max, min}) + end, + nil + ) + children = [ EXLA.Logger, - {NimblePool, - worker: {EXLA.MLIR.IREE.InstancePool, :pool_state}, - pool_size: System.schedulers_online(), - name: EXLA.MLIR.IREE.InstancePool, - lazy: true}, + # {NimblePool, + # worker: {EXLA.MLIR.IREE.InstancePool, :pool_state}, + # pool_size: System.schedulers_online(), + # name: EXLA.MLIR.IREE.InstancePool, + # lazy: true}, {NimblePool, worker: {EXLA.MLIR.ContextPool, :pool_state}, pool_size: System.schedulers_online(), diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 9d739fdacd..0367a9a39f 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -482,19 +482,34 @@ defmodule EXLA.Defn do for {i, typespec} <- inputs_and_typespecs, i >= used_buffers, do: typespec if runtime == :iree do - {:ok, module_charlist} = EXLA.NIF.mlir_module_to_string(builder.module.ref) + {t1, {:ok, module_charlist}} = + :timer.tc(fn -> + {:ok, module_charlist} = EXLA.NIF.mlir_module_to_string(builder.module.ref) + end) + + # :telemetry.execute( + # [:exla, :mlir, :iree, :compile], + # %{duration: t1} + # ) flags = [ - # "--iree-hal-target-backends=llvm-cpu", - "--iree-hal-target-backends=metal-spirv", - "--iree-input-type=stablehlo_xla", - "--iree-execution-model=async-internal", - "--output-format=vm-bytecode", - "--iree-input-demote-f64-to-f32=true", - "--iree-input-demote-i64-to-i32=false" + # ~c"--iree-hal-target-backends=llvm-cpu", + ~c"--iree-hal-target-backends=metal-spirv", + ~c"--iree-input-type=stablehlo_xla", + ~c"--iree-execution-model=async-internal", + ~c"--iree-input-demote-f64-to-f32=true", + ~c"--iree-input-demote-i64-to-i32=false" ] - {:ok, module_bytecode} = EXLA.MLIR.IREE.compile(module_charlist, flags) + {t2, {:ok, module_bytecode}} = + :timer.tc(fn -> + {:ok, module_bytecode} = EXLA.MLIR.IREE.compile(module_charlist, flags) + end) + + :telemetry.execute( + [:exla, :mlir, :iree, :compile], + %{duration: t2} + ) %EXLA.Executable{ client: client, diff --git a/exla/lib/exla/mlir/iree/instance_pool.ex b/exla/lib/exla/mlir/iree/instance_pool.ex index 7aa2693d51..07cef8bbb3 100644 --- a/exla/lib/exla/mlir/iree/instance_pool.ex +++ b/exla/lib/exla/mlir/iree/instance_pool.ex @@ -1,37 +1,49 @@ defmodule EXLA.MLIR.IREE.InstancePool do @moduledoc false # Internal pool for MLIRContext reference management - @behaviour NimblePool + # @behaviour NimblePool def checkout(fun) when is_function(fun, 1) do - NimblePool.checkout!( - __MODULE__, - :checkout, - fn _pool, context -> {fun.(context), :ok} end, - :infinity + {t, r} = + :timer.tc(fn -> + fun.(:persistent_term.get({EXLA.MLIR.IREE, :instance})) + end) + + :telemetry.execute( + [:exla, :mlir, :iree, :instance_pool, :checkout], + %{duration: t} ) - end - @impl NimblePool - def init_worker(pool_state) do - {:ok, instance} = EXLA.MLIR.IREE.create_instance() - {:ok, instance, pool_state} - end + r - @impl NimblePool - def handle_checkout(:checkout, _from, instance, pool_state) do - {:ok, instance, instance, pool_state} + # NimblePool.checkout!( + # __MODULE__, + # :checkout, + # fn _pool, instance -> {fun.(instance), :ok} end, + # :infinity + # ) end - @impl NimblePool - def handle_checkin(:ok, _from, instance, pool_state) do - # We just keep the references around and let them die out upon worker termination/GC - {:ok, instance, pool_state} - end + # @impl NimblePool + # def init_worker(pool_state) do + # {:ok, instance} = EXLA.MLIR.IREE.create_instance() + # {:ok, instance, pool_state} + # end - @impl NimblePool - def terminate_worker(_reason, _instance, pool_state) do - # GC will clean it up - {:ok, pool_state} - end + # @impl NimblePool + # def handle_checkout(:checkout, _from, instance, pool_state) do + # {:ok, instance, instance, pool_state} + # end + + # @impl NimblePool + # def handle_checkin(:ok, _from, instance, pool_state) do + # # We just keep the references around and let them die out upon worker termination/GC + # {:ok, instance, pool_state} + # end + + # @impl NimblePool + # def terminate_worker(_reason, _instance, pool_state) do + # # GC will clean it up + # {:ok, pool_state} + # end end diff --git a/exla/mix.exs b/exla/mix.exs index 9c6e32e713..2d35b4637b 100644 --- a/exla/mix.exs +++ b/exla/mix.exs @@ -5,8 +5,10 @@ defmodule EXLA.MixProject do @version "0.7.1" def project do + n_jobs = to_string(max(System.schedulers_online() - 2, 1)) + make_args = - Application.get_env(:exla, :make_args) || ["-j#{max(System.schedulers_online() - 2, 1)}"] + Application.get_env(:exla, :make_args) || ["-j#{n_jobs}"] [ app: :exla, @@ -35,7 +37,8 @@ defmodule EXLA.MixProject do %{ "MIX_BUILD_EMBEDDED" => "#{Mix.Project.config()[:build_embedded]}", - "CWD_RELATIVE_TO_PRIV_PATH" => cwd_relative_to_priv + "CWD_RELATIVE_TO_PRIV_PATH" => cwd_relative_to_priv, + "MAKE_NUM_JOBS" => n_jobs } end, make_args: make_args From 89307247141b68f97c466e0c66a6a8e3021e547f Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 23 May 2024 16:15:42 -0300 Subject: [PATCH 38/40] wip: enable ios runtime compilation --- exla/Makefile | 43 +++++--- exla/c_src/exla/iree/compiler.cc | 119 ++++++++--------------- exla/c_src/exla/iree/runtime.cc | 9 +- exla/c_src/exla/iree/runtime.h | 5 +- exla/c_src/iree_runtime/CMakeLists.txt | 21 ++-- exla/lib/exla/application.ex | 59 ++--------- exla/lib/exla/defn.ex | 73 +++++++------- exla/lib/exla/mlir/iree/instance_pool.ex | 62 +++++------- 8 files changed, 159 insertions(+), 232 deletions(-) diff --git a/exla/Makefile b/exla/Makefile index 8d73f8f86f..1d98b512d0 100644 --- a/exla/Makefile +++ b/exla/Makefile @@ -108,7 +108,7 @@ $(EXLA_CACHE_OBJ_DIR)/%.o: $(EXLA_DIR)/%.cc $(HEADERS) @ mkdir -p $(EXLA_CACHE_OBJ_DIR)/iree $(CXX) $(CFLAGS) -c $< -o $@ -$(EXLA_CACHE_SO): $(XLA_EXTENSION_DIR) cache/iree $(OBJECTS) +$(EXLA_CACHE_SO): $(XLA_EXTENSION_DIR) $(EXLA_CACHE_IREE_COMPILER_SO) $(OBJECTS) $(CXX) $(OBJECTS) -o $(EXLA_CACHE_SO) $(LDFLAGS) @@ -128,11 +128,14 @@ cache/iree: @git clone https://github.com/iree-org/iree cache/iree @cd cache/iree && git checkout $(IREE_COMMIT) @cd cache/iree && git submodule update --init --recursive --depth 1 + +cache/iree/build: cache/iree @mkdir -p cache/iree/build - cmake -G Ninja -B cache/iree/build -DTRACY_ENABLE=ON -DIREE_ENABLE_RUNTIME_TRACING=ON -DIREE_ENABLE_COMPILER_TRACING=ON -DIREE_BUILD_TESTS=OFF -DIREE_BUILD_SAMPLES=OFF -DIREE_ENABLE_ASSERTIONS=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_CXX_FLAGS="-fvisibility=hidden" cache/iree + cmake -G Ninja -B cache/iree/build -DCMAKE_INSTALL_PREFIX=cache/iree/build/install -DIREE_BUILD_TESTS=OFF -DIREE_BUILD_SAMPLES=OFF -DIREE_ENABLE_ASSERTIONS=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_CXX_FLAGS="-fvisibility=hidden" cache/iree cmake --build cache/iree/build + cmake --build cache/iree/build --target install -$(EXLA_CACHE_IREE_COMPILER_SO): cache/iree +$(EXLA_CACHE_IREE_COMPILER_SO): cache/iree/build @mkdir -p $(IREE_CMAKE_BUILD_DIR) @mkdir -p cache/objs/iree_cmake_out @mkdir -p cache/objs/mlir_cmake_out @@ -140,20 +143,38 @@ $(EXLA_CACHE_IREE_COMPILER_SO): cache/iree cmake -S c_src/iree_runtime -B $(IREE_CMAKE_BUILD_DIR) \ -DIREE_COMPILER_INCLUDE_PATH=$(IREE_COMPILER_INCLUDE_PATH) \ -DIREE_COMPILER_DIR=$(IREE_COMPILER_DIR) \ - -DIREE_ENABLE_RUNTIME_TRACING=ON \ - -DIREE_TRACING_MODE=ON \ - -DTRACY_ENABLE=ON \ -DXLA_INCLUDE_PATH=$(abspath $(XLA_INCLUDE_PATH)) \ -DIREE_INSTALL_PREFIX=$(IREE_INSTALL_PREFIX) \ - -DCACHE_DIR=cache\ -DXLA_EXTENSION_LIB=$(abspath $(XLA_EXTENSION_LIB))\ -DCMAKE_BUILD_TYPE=$(IREE_CMAKE_CONFIG)\ - -DTRACY_ENABLE=1\ - -DTRACY_CALLSTACK=ON\ - -DTRACY_NO_SAMPLING=OFF\ - -DTRACY_NO_VERIFY=ON + -DIREE_BUILD_COMPILER=ON cmake --build $(IREE_CMAKE_BUILD_DIR) --config $(IREE_CMAKE_CONFIG) cmake --install $(IREE_CMAKE_BUILD_DIR) --config $(IREE_CMAKE_CONFIG) --prefix cache +ios_runtime: cache/iree/build + @mkdir -p $(IREE_CMAKE_BUILD_DIR) + @mkdir -p cache/ios_runtime/install + @mkdir -p cache/objs/iree_cmake_out + @mkdir -p cache/objs/mlir_cmake_out + @mkdir -p cache/objs/llvm_cmake_out + cmake -G Ninja -S c_src/iree_runtime -B cache/ios_runtime \ + -DIREE_COMPILER_INCLUDE_PATH=$(IREE_COMPILER_INCLUDE_PATH) \ + -DIREE_COMPILER_DIR=$(IREE_COMPILER_DIR) \ + -DIREE_INSTALL_PREFIX=$(IREE_INSTALL_PREFIX) \ + -DCMAKE_SYSTEM_NAME=iOS \ + -DCMAKE_OSX_SYSROOT=$(shell xcodebuild -version -sdk iphonesimulator Path) \ + -DCMAKE_OSX_ARCHITECTURES=arm64 \ + -DCMAKE_SYSTEM_PROCESSOR=arm64 \ + -DCMAKE_OSX_DEPLOYMENT_TARGET=16.0 \ + -DCMAKE_IOS_INSTALL_COMBINED=YES \ + -DCMAKE_BUILD_TYPE=$(IREE_CMAKE_CONFIG)\ + -DIREE_HOST_BIN_DIR=build/install/bin \ + -DCMAKE_INSTALL_PREFIX=cache/ios_runtime/install \ + -DIREE_BUILD_COMPILER=OFF \ + -DXLA_INCLUDE_PATH=$(abspath $(XLA_INCLUDE_PATH))\ + -DXLA_EXTENSION_LIB=$(abspath $(XLA_EXTENSION_LIB)) + cmake --build $(IREE_CMAKE_BUILD_DIR) --config $(IREE_CMAKE_CONFIG) + cmake --install $(IREE_CMAKE_BUILD_DIR) --config $(IREE_CMAKE_CONFIG) --prefix cache/ios_runtime + clean: rm -rf cache diff --git a/exla/c_src/exla/iree/compiler.cc b/exla/c_src/exla/iree/compiler.cc index e842e04ed5..827beb5ef8 100644 --- a/exla/c_src/exla/iree/compiler.cc +++ b/exla/c_src/exla/iree/compiler.cc @@ -2,7 +2,6 @@ #include // For O_WRONLY, O_CREAT, O_TRUNC #include -#include #include #include #include @@ -13,7 +12,6 @@ #include #include -#include typedef struct compiler_state_t { iree_compiler_session_t *session; @@ -48,7 +46,6 @@ static void initializeCompiler(struct compiler_state_t *state) { } ERL_NIF_TERM compile(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { - ZoneScopedN("compile main"); if (argc != 2) { return exla::nif::error(env, "Bad argument count."); } @@ -57,19 +54,16 @@ ERL_NIF_TERM compile(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { std::vector flags_str; std::vector flags; - { - ZoneScopedN("compiler get arguments"); - if (!exla::nif::get(env, argv[0], module_str)) { - return exla::nif::error(env, "Unable to get module."); - } + if (!exla::nif::get(env, argv[0], module_str)) { + return exla::nif::error(env, "Unable to get module."); + } - if (!exla::nif::get_list(env, argv[1], flags_str)) { - return exla::nif::error(env, "Unable to get list."); - } + if (!exla::nif::get_list(env, argv[1], flags_str)) { + return exla::nif::error(env, "Unable to get list."); + } - for (auto &flag : flags_str) { - flags.push_back(flag.c_str()); - } + for (auto &flag : flags_str) { + flags.push_back(flag.c_str()); } compiler_state_t state; @@ -83,88 +77,62 @@ ERL_NIF_TERM compile(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { MlirOperation module_op; - { - ZoneScopedN("Parse module"); - module_op = mlirOperationCreateParse( - state.context, - mlirStringRefCreate(module_str.c_str(), module_str.size()), - mlirStringRefCreateFromCString("source.stablehlo")); - if (mlirOperationIsNull(module_op)) { - return exla::nif::error(env, "Unable to create MlirOperation module."); - } + module_op = mlirOperationCreateParse( + state.context, + mlirStringRefCreate(module_str.c_str(), module_str.size()), + mlirStringRefCreateFromCString("source.stablehlo")); + if (mlirOperationIsNull(module_op)) { + return exla::nif::error(env, "Unable to create MlirOperation module."); } - // Set flags. - { - ZoneScopedN("Set flags"); - error = ireeCompilerSessionSetFlags(state.session, flags.size(), flags.data()); - if (error) { - const char *msg = ireeCompilerErrorGetMessage(error); - - cleanup_compiler_state(state); + error = ireeCompilerSessionSetFlags(state.session, flags.size(), flags.data()); + if (error) { + const char *msg = ireeCompilerErrorGetMessage(error); - std::stringstream ss; - ss << "Unable to set flags due to error: "; - ss << msg; + cleanup_compiler_state(state); - return exla::nif::error(env, ss.str().c_str()); - } - } + std::stringstream ss; + ss << "Unable to set flags due to error: "; + ss << msg; - { - ZoneScopedN("Create invocation"); - state.invocation = ireeCompilerInvocationCreate(state.session); - ireeCompilerInvocationEnableConsoleDiagnostics(state.invocation); + return exla::nif::error(env, ss.str().c_str()); } + state.invocation = ireeCompilerInvocationCreate(state.session); + ireeCompilerInvocationEnableConsoleDiagnostics(state.invocation); - { - ZoneScopedN("Import module"); - if (!ireeCompilerInvocationImportStealModule(state.invocation, module_op)) { - cleanup_compiler_state(state); - return exla::nif::error(env, "Unable to import module."); - } + if (!ireeCompilerInvocationImportStealModule(state.invocation, module_op)) { + cleanup_compiler_state(state); + return exla::nif::error(env, "Unable to import module."); } // Compile. - { - ZoneScopedN("Invocation Pipeline"); - if (!ireeCompilerInvocationPipeline(state.invocation, iree_compiler_pipeline_t::IREE_COMPILER_PIPELINE_STD)) { - cleanup_compiler_state(state); - return exla::nif::error(env, "Unable to compile module."); - } + if (!ireeCompilerInvocationPipeline(state.invocation, iree_compiler_pipeline_t::IREE_COMPILER_PIPELINE_STD)) { + cleanup_compiler_state(state); + return exla::nif::error(env, "Unable to compile module."); } - { - ZoneScopedN("Open output membuffer"); - error = ireeCompilerOutputOpenMembuffer(&state.output); - if (error) { - handle_compiler_error(error); - cleanup_compiler_state(state); - return exla::nif::error(env, "Error opening output membuffer"); - } + error = ireeCompilerOutputOpenMembuffer(&state.output); + if (error) { + handle_compiler_error(error); + cleanup_compiler_state(state); + return exla::nif::error(env, "Error opening output membuffer"); } - { - ZoneScopedN("Output VM Bytecode"); - error = ireeCompilerInvocationOutputVMBytecode(state.invocation, state.output); - if (error) { - handle_compiler_error(error); - cleanup_compiler_state(state); - return exla::nif::error(env, "Failed to output VM Bytecode"); - } + error = ireeCompilerInvocationOutputVMBytecode(state.invocation, state.output); + if (error) { + handle_compiler_error(error); + cleanup_compiler_state(state); + return exla::nif::error(env, "Failed to output VM Bytecode"); } uint8_t *contents; uint64_t size; ErlNifBinary output_binary; - { - ZoneScopedN("Map and copy output to binary"); - error = ireeCompilerOutputMapMemory(state.output, (void **)&contents, &size); + error = ireeCompilerOutputMapMemory(state.output, (void **)&contents, &size); - enif_alloc_binary(size, &output_binary); - memcpy(output_binary.data, contents, size); - } + enif_alloc_binary(size, &output_binary); + memcpy(output_binary.data, contents, size); if (error) { handle_compiler_error(error); @@ -174,6 +142,5 @@ ERL_NIF_TERM compile(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { cleanup_compiler_state(state); - IREE_TRACE_ZONE_END(compile); return exla::nif::ok(env, enif_make_binary(env, &output_binary)); } \ No newline at end of file diff --git a/exla/c_src/exla/iree/runtime.cc b/exla/c_src/exla/iree/runtime.cc index 5f4741e0c9..289a25f811 100644 --- a/exla/c_src/exla/iree/runtime.cc +++ b/exla/c_src/exla/iree/runtime.cc @@ -5,7 +5,6 @@ #include #include -#include bool primitive_type_to_iree_element_type(xla::PrimitiveType t, iree_hal_element_type_t *type) { using xla::PrimitiveType; @@ -174,8 +173,7 @@ ERL_NIF_TERM return_results(ErlNifEnv *env, std::vector>> -call(iree_vm_instance_t *instance, iree_hal_device_t *device, ErlNifBinary bytecode, std::vector exla_inputs) { - ZoneScoped; +call(iree_vm_instance_t *instance, iree_hal_device_t *device, unsigned char *bytecode, size_t bytecode_size, std::vector exla_inputs) { iree_vm_module_t *hal_module = nullptr; iree_vm_module_t *bytecode_module = nullptr; iree_vm_context_t *context = nullptr; @@ -189,7 +187,7 @@ call(iree_vm_instance_t *instance, iree_hal_device_t *device, ErlNifBinary bytec iree_allocator_system(), &hal_module)); // (kFloat4, sizeof(kFloat4)) - const iree_const_byte_span_t module_data = iree_make_const_byte_span(bytecode.data, bytecode.size); + const iree_const_byte_span_t module_data = iree_make_const_byte_span(bytecode, bytecode_size); RETURN_PAIR_IF_ERROR(iree_vm_bytecode_module_create( instance, module_data, iree_allocator_null(), iree_allocator_system(), @@ -262,7 +260,6 @@ call(iree_vm_instance_t *instance, iree_hal_device_t *device, ErlNifBinary bytec ERL_NIF_TERM run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { - ZoneScoped; if (argc != 4) { return exla::nif::error(env, "Bad argument count."); } @@ -294,7 +291,7 @@ run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { return exla::nif::error(env, "Unable to decode input terms"); } - auto [status, results] = call(*instance, *device, bytecode, inputs); + auto [status, results] = call(*instance, *device, reinterpret_cast(bytecode.data), reinterpret_cast(bytecode.size), inputs); if (!iree_status_is_ok(status)) { // Dump nice status messages to stderr on failure. diff --git a/exla/c_src/exla/iree/runtime.h b/exla/c_src/exla/iree/runtime.h index c3277bf039..9cfd123b7b 100644 --- a/exla/c_src/exla/iree/runtime.h +++ b/exla/c_src/exla/iree/runtime.h @@ -63,4 +63,7 @@ class IREEInput { } // namespace runtime } // namespace iree -}; // namespace exla \ No newline at end of file +}; // namespace exla + +std::pair>> +call(iree_vm_instance_t* i, iree_hal_device_t*, unsigned char*, size_t, std::vector); \ No newline at end of file diff --git a/exla/c_src/iree_runtime/CMakeLists.txt b/exla/c_src/iree_runtime/CMakeLists.txt index 1c48452080..fa3b53b6fa 100644 --- a/exla/c_src/iree_runtime/CMakeLists.txt +++ b/exla/c_src/iree_runtime/CMakeLists.txt @@ -6,16 +6,15 @@ project(${_NAME} VERSION 1.0 LANGUAGES CXX C) set_property(GLOBAL PROPERTY USE_FOLDERS ON) include(CheckCCompilerFlag) -set(LLVM_DIR "${IREE_INSTALL_PREFIX}/llvm-project/lib/cmake/llvm") -set(MLIR_DIR "${IREE_INSTALL_PREFIX}/lib/cmake/mlir") -set(LLD_DIR "${IREE_INSTALL_PREFIX}/lib/cmake/lld") -set(Clang_DIR "${IREE_INSTALL_PREFIX}/lib/cmake/clang") +# set(LLVM_DIR "${IREE_INSTALL_PREFIX}/llvm-project/lib/cmake/llvm") +# set(MLIR_DIR "${IREE_INSTALL_PREFIX}/lib/cmake/mlir") +# set(LLD_DIR "${IREE_INSTALL_PREFIX}/lib/cmake/lld") +# set(Clang_DIR "${IREE_INSTALL_PREFIX}/lib/cmake/clang") -set(LLVM_ABI_BREAKING_CHECKS FORCE_OFFrm ) +set(LLVM_ABI_BREAKING_CHECKS FORCE_OFF) -set(IREE_BUILD_COMPILER ON) set(IREE_INPUT_STABLEHLO ON) -set(IREE_BUILD_BUNDLED_LLVM OFF) +# set(IREE_BUILD_BUNDLED_LLVM ON) set(IREE_BUILD_TESTS OFF) set(IREE_BUILD_SAMPLES OFF) @@ -101,6 +100,8 @@ else() target_link_libraries(${_NAME} "${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree/build/lib/libIREECompiler.dylib") endif() -add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree/third_party/tracy" "${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree/third_party/tracy/build") -include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree/third_party/tracy/public") -target_link_libraries(${_NAME} TracyClient) +if($ENV{DEBUG}) + add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree/third_party/tracy" "${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree/third_party/tracy/build") + include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree/third_party/tracy/public") + target_link_libraries(${_NAME} TracyClient) +endif() diff --git a/exla/lib/exla/application.ex b/exla/lib/exla/application.ex index 430a5d8223..2c94ffa821 100644 --- a/exla/lib/exla/application.ex +++ b/exla/lib/exla/application.ex @@ -14,61 +14,14 @@ defmodule EXLA.Application do {:ok, device} = EXLA.MLIR.IREE.setup_runtime(~c"metal://0000000100000971") # {:ok, device} = EXLA.MLIR.IREE.setup_runtime(~c"local-sync://") :persistent_term.put({EXLA.MLIR.IREE, :device}, device) - - {:ok, instance} = EXLA.MLIR.IREE.create_instance() - :persistent_term.put({EXLA.MLIR.IREE, :instance}, instance) - - :persistent_term.put({EXLA.Telemetry, :checkout}, {0, 0, 0, nil}) - - :telemetry.attach( - :exla_telemetry_checkout, - [:exla, :mlir, :iree, :instance_pool, :checkout], - fn _name, %{duration: duration}, _meta, _config -> - {total, count, max, min} = :persistent_term.get({EXLA.Telemetry, :checkout}) - total = total + duration - count = count + 1 - max = max(max, duration) - min = min(min, duration) - - File.write( - "/tmp/checkout.txt", - "#{total / count / 1_000} ms | #{max / 1_000} ms | #{min / 1_000} ms\n" - ) - - :persistent_term.put({EXLA.Telemetry, :checkout}, {total, count, max, min}) - end, - nil - ) - - :persistent_term.put({EXLA.Telemetry, :compile}, {0, 0, 0, nil}) - - :telemetry.attach( - :exla_telemetry_compile, - [:exla, :mlir, :iree, :compile], - fn _name, %{duration: duration}, _meta, _config -> - {total, count, max, min} = :persistent_term.get({EXLA.Telemetry, :compile}) - total = total + duration - count = count + 1 - max = max(max, duration) - min = min(min, duration) - - File.write( - "/tmp/compile.txt", - "#{total / count / 1_000} ms | #{max / 1_000} ms | #{min / 1_000} ms\n" - ) - - :persistent_term.put({EXLA.Telemetry, :compile}, {total, count, max, min}) - end, - nil - ) - + children = [ EXLA.Logger, - # {NimblePool, - # worker: {EXLA.MLIR.IREE.InstancePool, :pool_state}, - # pool_size: System.schedulers_online(), - # name: EXLA.MLIR.IREE.InstancePool, - # lazy: true}, + {NimblePool, + worker: {EXLA.MLIR.IREE.InstancePool, :pool_state}, + pool_size: System.schedulers_online(), + name: EXLA.MLIR.IREE.InstancePool, + lazy: true}, {NimblePool, worker: {EXLA.MLIR.ContextPool, :pool_state}, pool_size: System.schedulers_online(), diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 0367a9a39f..e78d185159 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -255,17 +255,8 @@ defmodule EXLA.Defn do def __compile__(key, vars, fun, options) do {run_options, compile_options} = Keyword.pop(options, :run_options, []) - {client_name, compile_options} = - Keyword.pop_lazy(compile_options, :client, &EXLA.Client.default_name/0) - - compile_options = Keyword.put_new(compile_options, :runtime, :iree) - - client = EXLA.Client.fetch!(client_name) - - callback = &to_root_computation(&1, &2, &3, &4, Keyword.put(compile_options, :client, client)) - {executable, used_inputs, outputs, outfeed, :ok, debug?} = - compile(client, key, vars, fun, compile_options, 0, [], callback) + compile_executable(key, vars, fun, compile_options) fn [args] -> {time, lock} = @@ -290,6 +281,26 @@ defmodule EXLA.Defn do end end + def export_executable(fun, vars, options) do + fun + |> compile_executable(vars, &Function.identity/1, Keyword.delete(options, :run_options)) + |> elem(0) + end + + defp compile_executable(key, vars, fun, compile_options) do + {client_name, compile_options} = + Keyword.pop_lazy(compile_options, :client, &EXLA.Client.default_name/0) + + compile_options = Keyword.put_new(compile_options, :runtime, :iree) + + client = EXLA.Client.fetch!(client_name) + + callback = &to_root_computation(&1, &2, &3, &4, Keyword.put(compile_options, :client, client)) + + {executable, used_inputs, outputs, outfeed, :ok, debug?} = + compile(client, key, vars, fun, compile_options, 0, [], callback) + end + defp to_root_computation(%Function{} = function, expr, used_typespecs, outfeed, options) do params = Enum.zip_with(used_typespecs, Function.get_arguments(function), fn {pos, _typespec}, arg -> @@ -482,34 +493,20 @@ defmodule EXLA.Defn do for {i, typespec} <- inputs_and_typespecs, i >= used_buffers, do: typespec if runtime == :iree do - {t1, {:ok, module_charlist}} = - :timer.tc(fn -> - {:ok, module_charlist} = EXLA.NIF.mlir_module_to_string(builder.module.ref) - end) - - # :telemetry.execute( - # [:exla, :mlir, :iree, :compile], - # %{duration: t1} - # ) - - flags = [ - # ~c"--iree-hal-target-backends=llvm-cpu", - ~c"--iree-hal-target-backends=metal-spirv", - ~c"--iree-input-type=stablehlo_xla", - ~c"--iree-execution-model=async-internal", - ~c"--iree-input-demote-f64-to-f32=true", - ~c"--iree-input-demote-i64-to-i32=false" - ] - - {t2, {:ok, module_bytecode}} = - :timer.tc(fn -> - {:ok, module_bytecode} = EXLA.MLIR.IREE.compile(module_charlist, flags) - end) - - :telemetry.execute( - [:exla, :mlir, :iree, :compile], - %{duration: t2} - ) + {:ok, module_charlist} = EXLA.NIF.mlir_module_to_string(builder.module.ref) + + flags = + options[:iree_flags] || + [ + # ~c"--iree-hal-target-backends=llvm-cpu", + ~c"--iree-hal-target-backends=metal-spirv", + ~c"--iree-input-type=stablehlo_xla", + ~c"--iree-execution-model=async-internal", + ~c"--iree-input-demote-f64-to-f32=true", + ~c"--iree-input-demote-i64-to-i32=false" + ] + + {:ok, module_bytecode} = EXLA.MLIR.IREE.compile(module_charlist, flags) %EXLA.Executable{ client: client, diff --git a/exla/lib/exla/mlir/iree/instance_pool.ex b/exla/lib/exla/mlir/iree/instance_pool.ex index 07cef8bbb3..093050ae03 100644 --- a/exla/lib/exla/mlir/iree/instance_pool.ex +++ b/exla/lib/exla/mlir/iree/instance_pool.ex @@ -1,49 +1,37 @@ defmodule EXLA.MLIR.IREE.InstancePool do @moduledoc false # Internal pool for MLIRContext reference management - # @behaviour NimblePool + @behaviour NimblePool def checkout(fun) when is_function(fun, 1) do - {t, r} = - :timer.tc(fn -> - fun.(:persistent_term.get({EXLA.MLIR.IREE, :instance})) - end) - - :telemetry.execute( - [:exla, :mlir, :iree, :instance_pool, :checkout], - %{duration: t} + NimblePool.checkout!( + __MODULE__, + :checkout, + fn _pool, instance -> {fun.(instance), :ok} end, + :infinity ) - - r - - # NimblePool.checkout!( - # __MODULE__, - # :checkout, - # fn _pool, instance -> {fun.(instance), :ok} end, - # :infinity - # ) end - # @impl NimblePool - # def init_worker(pool_state) do - # {:ok, instance} = EXLA.MLIR.IREE.create_instance() - # {:ok, instance, pool_state} - # end + @impl NimblePool + def init_worker(pool_state) do + {:ok, instance} = EXLA.MLIR.IREE.create_instance() + {:ok, instance, pool_state} + end - # @impl NimblePool - # def handle_checkout(:checkout, _from, instance, pool_state) do - # {:ok, instance, instance, pool_state} - # end + @impl NimblePool + def handle_checkout(:checkout, _from, instance, pool_state) do + {:ok, instance, instance, pool_state} + end - # @impl NimblePool - # def handle_checkin(:ok, _from, instance, pool_state) do - # # We just keep the references around and let them die out upon worker termination/GC - # {:ok, instance, pool_state} - # end + @impl NimblePool + def handle_checkin(:ok, _from, instance, pool_state) do + # We just keep the references around and let them die out upon worker termination/GC + {:ok, instance, pool_state} + end - # @impl NimblePool - # def terminate_worker(_reason, _instance, pool_state) do - # # GC will clean it up - # {:ok, pool_state} - # end + @impl NimblePool + def terminate_worker(_reason, _instance, pool_state) do + # GC will clean it up + {:ok, pool_state} + end end From 0144cf723624a96f992600a799a1edc50f18d844 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 23 May 2024 16:46:18 -0300 Subject: [PATCH 39/40] chore: xla token bug --- exla/lib/exla/defn.ex | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 56e6ac5c37..cc0a72ab77 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -39,7 +39,7 @@ defmodule EXLA.Defn do client = EXLA.Client.fetch!(client_name) compile_options = Keyword.put(compile_options, :lazy_transfers, :never) - compile_options = Keyword.put_new(compile_options, :runtime, :iree) + compile_options = Keyword.put_new(compile_options, :runtime, :xla) input_length = length(Nx.Defn.Composite.flatten_list([input])) acc_length = length(Nx.Defn.Composite.flatten_list([acc])) @@ -291,7 +291,7 @@ defmodule EXLA.Defn do {client_name, compile_options} = Keyword.pop_lazy(compile_options, :client, &EXLA.Client.default_name/0) - compile_options = Keyword.put_new(compile_options, :runtime, :iree) + compile_options = Keyword.put_new(compile_options, :runtime, :xla) client = EXLA.Client.fetch!(client_name) @@ -612,8 +612,11 @@ defmodule EXLA.Defn do ) do [initial_arg, _arg, pred, body] = args + token = get_token(cache) + has_token = not is_nil(token) + initial = - if token = get_token(cache) do + if has_token do {token, initial_arg} else initial_arg @@ -627,7 +630,7 @@ defmodule EXLA.Defn do results = Value.while(function, pred_computation, body_computation, List.flatten(initial)) - if get_token(cache) do + if has_token do [token | results] = results result = wrap_tuple_result(results, initial_arg) {result, update_token(cache, token)} From bee7309c6d94ae1fbac213263b0a0ccfc017ec24 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Sun, 26 May 2024 07:53:49 -0300 Subject: [PATCH 40/40] wip: split runtime --- exla/Makefile | 22 +- exla/c_src/exla/iree/iree.cc | 2 +- exla/c_src/exla/iree/runtime.cc | 282 ++++++++++++++----------- exla/c_src/exla/iree/runtime.h | 13 +- exla/c_src/iree_runtime/CMakeLists.txt | 70 ++++-- exla/lib/exla/defn.ex | 4 + 6 files changed, 234 insertions(+), 159 deletions(-) diff --git a/exla/Makefile b/exla/Makefile index 1d98b512d0..5fa9fdc0ba 100644 --- a/exla/Makefile +++ b/exla/Makefile @@ -151,30 +151,24 @@ $(EXLA_CACHE_IREE_COMPILER_SO): cache/iree/build cmake --build $(IREE_CMAKE_BUILD_DIR) --config $(IREE_CMAKE_CONFIG) cmake --install $(IREE_CMAKE_BUILD_DIR) --config $(IREE_CMAKE_CONFIG) --prefix cache -ios_runtime: cache/iree/build - @mkdir -p $(IREE_CMAKE_BUILD_DIR) +ios_runtime: + @mkdir -p $(IREE_CMAKE_BUILD_DIR)/ios @mkdir -p cache/ios_runtime/install - @mkdir -p cache/objs/iree_cmake_out - @mkdir -p cache/objs/mlir_cmake_out - @mkdir -p cache/objs/llvm_cmake_out cmake -G Ninja -S c_src/iree_runtime -B cache/ios_runtime \ - -DIREE_COMPILER_INCLUDE_PATH=$(IREE_COMPILER_INCLUDE_PATH) \ - -DIREE_COMPILER_DIR=$(IREE_COMPILER_DIR) \ - -DIREE_INSTALL_PREFIX=$(IREE_INSTALL_PREFIX) \ -DCMAKE_SYSTEM_NAME=iOS \ - -DCMAKE_OSX_SYSROOT=$(shell xcodebuild -version -sdk iphonesimulator Path) \ + -DCMAKE_OSX_SYSROOT=$(shell xcodebuild -version -sdk iphonesimulator Path) \ -DCMAKE_OSX_ARCHITECTURES=arm64 \ -DCMAKE_SYSTEM_PROCESSOR=arm64 \ - -DCMAKE_OSX_DEPLOYMENT_TARGET=16.0 \ + -DCMAKE_OSX_DEPLOYMENT_TARGET=17.5 \ -DCMAKE_IOS_INSTALL_COMBINED=YES \ -DCMAKE_BUILD_TYPE=$(IREE_CMAKE_CONFIG)\ -DIREE_HOST_BIN_DIR=build/install/bin \ -DCMAKE_INSTALL_PREFIX=cache/ios_runtime/install \ -DIREE_BUILD_COMPILER=OFF \ - -DXLA_INCLUDE_PATH=$(abspath $(XLA_INCLUDE_PATH))\ - -DXLA_EXTENSION_LIB=$(abspath $(XLA_EXTENSION_LIB)) - cmake --build $(IREE_CMAKE_BUILD_DIR) --config $(IREE_CMAKE_CONFIG) - cmake --install $(IREE_CMAKE_BUILD_DIR) --config $(IREE_CMAKE_CONFIG) --prefix cache/ios_runtime + -DIREE_BUILD_FOR_IOS=ON +# -DCMAKE_TOOLCHAIN_FILE=$(abspath cache/iree/third_party/llvm-project/llvm/cmake/platforms/iOS.cmake) + cmake --build cache/ios_runtime --config $(IREE_CMAKE_CONFIG) + cmake --install cache/ios_runtime --config $(IREE_CMAKE_CONFIG) --prefix cache/ios_runtime clean: rm -rf cache diff --git a/exla/c_src/exla/iree/iree.cc b/exla/c_src/exla/iree/iree.cc index a95d4375fe..3d47dd2ea5 100644 --- a/exla/c_src/exla/iree/iree.cc +++ b/exla/c_src/exla/iree/iree.cc @@ -19,7 +19,7 @@ static ErlNifFunc iree_funcs[] = { {"run_module", 4, run_module, ERL_NIF_DIRTY_JOB_IO_BOUND}, {"setup_runtime", 1, setup_runtime}, {"create_instance", 0, create_instance}, - {"read_buffer", 3, read_buffer, ERL_NIF_DIRTY_JOB_IO_BOUND}, + {"read_buffer", 3, read_buffer_to_term, ERL_NIF_DIRTY_JOB_IO_BOUND}, {"deallocate_buffer", 1, deallocate_buffer, ERL_NIF_DIRTY_JOB_IO_BOUND}, }; diff --git a/exla/c_src/exla/iree/runtime.cc b/exla/c_src/exla/iree/runtime.cc index 289a25f811..8d7ff50057 100644 --- a/exla/c_src/exla/iree/runtime.cc +++ b/exla/c_src/exla/iree/runtime.cc @@ -6,6 +6,161 @@ #include #include +#define RETURN_PAIR_IF_ERROR(status) \ + if (!iree_status_is_ok(status)) { \ + return {status, std::nullopt}; \ + } + +iree_vm_instance_t *create_instance() { + iree_vm_instance_t *instance = nullptr; + iree_status_t status = iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, iree_allocator_system(), &instance); + if (!iree_status_is_ok(status)) { + return nullptr; + } + + status = iree_hal_module_register_all_types(instance); + if (!iree_status_is_ok(status)) { + iree_vm_instance_release(instance); + return nullptr; + } + + return instance; +} + +iree_hal_device_t *create_device(const std::string &device_uri) { + iree_hal_device_t *device = nullptr; + iree_status_t status = iree_hal_register_all_available_drivers(iree_hal_driver_registry_default()); + + if (!iree_status_is_ok(status)) { + return nullptr; + } + + status = iree_hal_create_device( + iree_hal_driver_registry_default(), + iree_make_cstring_view(device_uri.c_str()), + iree_allocator_system(), &device); + + if (!iree_status_is_ok(status)) { + return nullptr; + } + + return device; +} + +std::pair>> +call(iree_vm_instance_t *instance, iree_hal_device_t *device, unsigned char *bytecode, size_t bytecode_size, std::vector exla_inputs) { + iree_vm_module_t *hal_module = nullptr; + iree_vm_module_t *bytecode_module = nullptr; + iree_vm_context_t *context = nullptr; + const char kMainFunctionName[] = "module.main"; + iree_vm_function_t main_function; + iree_vm_list_t *inputs = nullptr; + iree_vm_list_t *outputs = nullptr; + + RETURN_PAIR_IF_ERROR(iree_hal_module_create( + instance, /*device_count=*/1, &device, IREE_HAL_MODULE_FLAG_SYNCHRONOUS, + iree_allocator_system(), &hal_module)); + + // (kFloat4, sizeof(kFloat4)) + const iree_const_byte_span_t module_data = iree_make_const_byte_span(bytecode, bytecode_size); + + RETURN_PAIR_IF_ERROR(iree_vm_bytecode_module_create( + instance, module_data, iree_allocator_null(), iree_allocator_system(), + &bytecode_module)); + + iree_vm_module_t *modules[] = {hal_module, bytecode_module}; + RETURN_PAIR_IF_ERROR(iree_vm_context_create_with_modules( + instance, IREE_VM_CONTEXT_FLAG_NONE, IREE_ARRAYSIZE(modules), &modules[0], + iree_allocator_system(), &context)); + iree_vm_module_release(hal_module); + iree_vm_module_release(bytecode_module); + + RETURN_PAIR_IF_ERROR(iree_vm_context_resolve_function( + context, iree_make_cstring_view(kMainFunctionName), &main_function)); + + RETURN_PAIR_IF_ERROR(iree_vm_list_create(iree_vm_make_undefined_type_def(), exla_inputs.size(), iree_allocator_system(), &inputs)); + + for (auto input : exla_inputs) { + iree_vm_ref_t arg_buffer_view_ref; + + if (input->buffer_view) { + arg_buffer_view_ref = iree_hal_buffer_view_move_ref(input->buffer_view); + } else { + iree_hal_buffer_view_t *arg_buffer_view = nullptr; + RETURN_PAIR_IF_ERROR(iree_hal_buffer_view_allocate_buffer_copy( + device, iree_hal_device_allocator(device), input->dims.size(), input->dims.data(), + input->type, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + (iree_hal_buffer_params_t){ + .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, + .usage = IREE_HAL_BUFFER_USAGE_DEFAULT, + }, + input->data_byte_span(), &arg_buffer_view)); + + arg_buffer_view_ref = iree_hal_buffer_view_move_ref(arg_buffer_view); + } + RETURN_PAIR_IF_ERROR(iree_vm_list_push_ref_move(inputs, &arg_buffer_view_ref)); + } + + iree_vm_function_signature_t signature = + iree_vm_function_signature(&main_function); + iree_string_view_t input_signature; + iree_string_view_t output_signature; + + RETURN_PAIR_IF_ERROR(iree_vm_function_call_get_cconv_fragments( + &signature, &input_signature, &output_signature)); + + RETURN_PAIR_IF_ERROR(iree_vm_list_create(iree_vm_make_undefined_type_def(), output_signature.size, iree_allocator_system(), &outputs)); + + // Synchronously invoke the function. + RETURN_PAIR_IF_ERROR(iree_vm_invoke( + context, main_function, IREE_VM_INVOCATION_FLAG_NONE, + /*policy=*/NULL, inputs, outputs, iree_allocator_system())); + + std::vector results; + results.resize(output_signature.size); + for (int i = 0; i < output_signature.size; i++) { + iree_hal_buffer_view_t *output_buffer_view = iree_vm_list_get_buffer_view_retain(outputs, i); + if (!output_buffer_view) { + return {iree_make_status(IREE_STATUS_NOT_FOUND, "can't get output buffer view [index=%d]", i), std::nullopt}; + } + + results[i] = output_buffer_view; + } + + iree_vm_list_release(inputs); + iree_vm_list_release(outputs); + iree_vm_context_release(context); + return {iree_ok_status(), results}; +} + +iree_status_t read_buffer(iree_hal_device_t *device, iree_hal_buffer_view_t *buffer_view, void *output_buffer, size_t num_bytes) { + iree_hal_buffer_t *buffer = iree_hal_buffer_view_buffer(buffer_view); + + iree_device_size_t num_bytes_actual = num_bytes == -1 ? iree_hal_buffer_byte_length(buffer) : (iree_device_size_t)num_bytes; + + return iree_hal_device_transfer_d2h( + device, buffer, 0, output_buffer, + num_bytes_actual, IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, + iree_infinite_timeout()); +} + +std::string get_status_message(iree_status_t status) { + char *status_string = NULL; + size_t status_length = 0; + + auto system_allocator = iree_allocator_system(); + + iree_status_to_string(status, &system_allocator, &status_string, &status_length); + + std::stringstream ss; + ss << "Failed to execute IREE runtime due to error: "; + ss << status_string; + iree_status_free(status); + return ss.str(); +} + +#ifndef BUILD_FOR_IOS + bool primitive_type_to_iree_element_type(xla::PrimitiveType t, iree_hal_element_type_t *type) { using xla::PrimitiveType; using type_enum = iree_hal_element_types_t; @@ -166,98 +321,6 @@ int load_inputs(ErlNifEnv *env, std::vector terms, std::vector results) { return exla::nif::ok(env, exla::nif::make_list(env, results)); } - -#define RETURN_PAIR_IF_ERROR(status) \ - if (!iree_status_is_ok(status)) { \ - return {status, std::nullopt}; \ - } - -std::pair>> -call(iree_vm_instance_t *instance, iree_hal_device_t *device, unsigned char *bytecode, size_t bytecode_size, std::vector exla_inputs) { - iree_vm_module_t *hal_module = nullptr; - iree_vm_module_t *bytecode_module = nullptr; - iree_vm_context_t *context = nullptr; - const char kMainFunctionName[] = "module.main"; - iree_vm_function_t main_function; - iree_vm_list_t *inputs = nullptr; - iree_vm_list_t *outputs = nullptr; - - RETURN_PAIR_IF_ERROR(iree_hal_module_create( - instance, /*device_count=*/1, &device, IREE_HAL_MODULE_FLAG_SYNCHRONOUS, - iree_allocator_system(), &hal_module)); - - // (kFloat4, sizeof(kFloat4)) - const iree_const_byte_span_t module_data = iree_make_const_byte_span(bytecode, bytecode_size); - - RETURN_PAIR_IF_ERROR(iree_vm_bytecode_module_create( - instance, module_data, iree_allocator_null(), iree_allocator_system(), - &bytecode_module)); - - iree_vm_module_t *modules[] = {hal_module, bytecode_module}; - RETURN_PAIR_IF_ERROR(iree_vm_context_create_with_modules( - instance, IREE_VM_CONTEXT_FLAG_NONE, IREE_ARRAYSIZE(modules), &modules[0], - iree_allocator_system(), &context)); - iree_vm_module_release(hal_module); - iree_vm_module_release(bytecode_module); - - RETURN_PAIR_IF_ERROR(iree_vm_context_resolve_function( - context, iree_make_cstring_view(kMainFunctionName), &main_function)); - - RETURN_PAIR_IF_ERROR(iree_vm_list_create(iree_vm_make_undefined_type_def(), exla_inputs.size(), iree_allocator_system(), &inputs)); - - for (auto input : exla_inputs) { - iree_vm_ref_t arg_buffer_view_ref; - - if (input->buffer_view) { - arg_buffer_view_ref = iree_hal_buffer_view_move_ref(input->buffer_view); - } else { - iree_hal_buffer_view_t *arg_buffer_view = nullptr; - RETURN_PAIR_IF_ERROR(iree_hal_buffer_view_allocate_buffer_copy( - device, iree_hal_device_allocator(device), input->dims.size(), input->dims.data(), - input->type, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, - (iree_hal_buffer_params_t){ - .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, - .usage = IREE_HAL_BUFFER_USAGE_DEFAULT, - }, - input->data_byte_span(), &arg_buffer_view)); - - arg_buffer_view_ref = iree_hal_buffer_view_move_ref(arg_buffer_view); - } - RETURN_PAIR_IF_ERROR(iree_vm_list_push_ref_move(inputs, &arg_buffer_view_ref)); - } - - iree_vm_function_signature_t signature = - iree_vm_function_signature(&main_function); - iree_string_view_t input_signature; - iree_string_view_t output_signature; - - RETURN_PAIR_IF_ERROR(iree_vm_function_call_get_cconv_fragments( - &signature, &input_signature, &output_signature)); - - RETURN_PAIR_IF_ERROR(iree_vm_list_create(iree_vm_make_undefined_type_def(), output_signature.size, iree_allocator_system(), &outputs)); - - // Synchronously invoke the function. - RETURN_PAIR_IF_ERROR(iree_vm_invoke( - context, main_function, IREE_VM_INVOCATION_FLAG_NONE, - /*policy=*/NULL, inputs, outputs, iree_allocator_system())); - - std::vector results; - results.resize(output_signature.size); - for (int i = 0; i < output_signature.size; i++) { - iree_hal_buffer_view_t *output_buffer_view = iree_vm_list_get_buffer_view_retain(outputs, i); - if (!output_buffer_view) { - return {iree_make_status(IREE_STATUS_NOT_FOUND, "can't get output buffer view [index=%d]", i), std::nullopt}; - } - - results[i] = output_buffer_view; - } - - iree_vm_list_release(inputs); - iree_vm_list_release(outputs); - iree_vm_context_release(context); - return {iree_ok_status(), results}; -} - ERL_NIF_TERM run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { if (argc != 4) { @@ -298,19 +361,9 @@ run_module(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { // An application can route these through its own logging infrastructure as // needed. Note that the status is a handle and must be freed! - char *status_string = NULL; - size_t status_length = 0; - - auto system_allocator = iree_allocator_system(); - - iree_status_to_string(status, &system_allocator, &status_string, &status_length); - - std::stringstream ss; - ss << "Failed to execute IREE runtime due to error: "; - ss << status_string; - iree_status_free(status); + std::string status_msg = get_status_message(status); - return exla::nif::error(env, ss.str().c_str()); + return exla::nif::error(env, status_msg.c_str()); } iree_status_free(status); @@ -342,14 +395,9 @@ ERL_NIF_TERM setup_runtime(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) } ERL_NIF_TERM create_instance(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { - iree_vm_instance_t *instance = nullptr; - iree_status_t status = iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, iree_allocator_system(), &instance); - - if (iree_status_is_ok(status)) { - status = iree_hal_module_register_all_types(instance); - } + iree_vm_instance_t *instance = create_instance(); - return iree_status_is_ok(status) ? exla::nif::ok(env, exla::nif::make(env, instance)) : exla::nif::error(env, "Failed to create IREE VM instance"); + return instance != nullptr ? exla::nif::ok(env, exla::nif::make(env, instance)) : exla::nif::error(env, "Failed to create IREE VM instance"); } ERL_NIF_TERM release_instance(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { @@ -364,7 +412,7 @@ ERL_NIF_TERM release_instance(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[ return exla::nif::ok(env); } -ERL_NIF_TERM read_buffer(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { +ERL_NIF_TERM read_buffer_to_term(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { iree_hal_buffer_view_t **buffer_view = nullptr; iree_hal_device_t **device = nullptr; int64_t num_bytes; @@ -383,15 +431,10 @@ ERL_NIF_TERM read_buffer(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) { } iree_hal_buffer_t *buffer = iree_hal_buffer_view_buffer(*buffer_view); - iree_device_size_t num_bytes_actual = num_bytes == -1 ? iree_hal_buffer_byte_length(buffer) : (iree_device_size_t)num_bytes; enif_alloc_binary(num_bytes_actual, &binary); - - iree_status_t status = iree_hal_device_transfer_d2h( - *device, buffer, 0, binary.data, - num_bytes_actual, IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, - iree_infinite_timeout()); + iree_status_t status = read_buffer(*device, *buffer_view, binary.data, num_bytes_actual); return iree_status_is_ok(status) ? exla::nif::ok(env, exla::nif::make(env, binary)) : exla::nif::error(env, "Failed to read buffer"); } @@ -406,4 +449,5 @@ ERL_NIF_TERM deallocate_buffer(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv iree_hal_buffer_view_release(*buffer_view); return exla::nif::ok(env); -} \ No newline at end of file +} +#endif // BUILD_FOR_IOS diff --git a/exla/c_src/exla/iree/runtime.h b/exla/c_src/exla/iree/runtime.h index 9cfd123b7b..1877d3a623 100644 --- a/exla/c_src/exla/iree/runtime.h +++ b/exla/c_src/exla/iree/runtime.h @@ -7,14 +7,18 @@ #include #include +#include +#include +#ifndef BUILD_FOR_IOS #include "../exla_nif_util.h" ERL_NIF_TERM run_module(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); ERL_NIF_TERM setup_runtime(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); ERL_NIF_TERM create_instance(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); -ERL_NIF_TERM read_buffer(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); +ERL_NIF_TERM read_buffer_to_term(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); ERL_NIF_TERM deallocate_buffer(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); +#endif namespace exla { namespace iree { @@ -65,5 +69,10 @@ class IREEInput { } // namespace iree }; // namespace exla + +iree_vm_instance_t* create_instance(); +iree_hal_device_t* create_device(const std::string& device_uri); std::pair>> -call(iree_vm_instance_t* i, iree_hal_device_t*, unsigned char*, size_t, std::vector); \ No newline at end of file +call(iree_vm_instance_t* i, iree_hal_device_t*, unsigned char*, size_t, std::vector); +iree_status_t read_buffer(iree_hal_device_t* device, iree_hal_buffer_view_t* buffer_view, void* output_buffer, size_t num_bytes); +std::string get_status_message(iree_status_t status); \ No newline at end of file diff --git a/exla/c_src/iree_runtime/CMakeLists.txt b/exla/c_src/iree_runtime/CMakeLists.txt index fa3b53b6fa..e98f412d3b 100644 --- a/exla/c_src/iree_runtime/CMakeLists.txt +++ b/exla/c_src/iree_runtime/CMakeLists.txt @@ -6,6 +6,10 @@ project(${_NAME} VERSION 1.0 LANGUAGES CXX C) set_property(GLOBAL PROPERTY USE_FOLDERS ON) include(CheckCCompilerFlag) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) + # set(LLVM_DIR "${IREE_INSTALL_PREFIX}/llvm-project/lib/cmake/llvm") # set(MLIR_DIR "${IREE_INSTALL_PREFIX}/lib/cmake/mlir") # set(LLD_DIR "${IREE_INSTALL_PREFIX}/lib/cmake/lld") @@ -27,25 +31,26 @@ if(CMAKE_BUILD_TYPE MATCHES MinSizeRel) set(IREE_SIZE_OPTIMIZED ON) endif() - set(C_SRC "${CMAKE_CURRENT_SOURCE_DIR}/../exla/iree") file(GLOB iree_runtime_sources CONFIGURE_DEPENDS "${C_SRC}/*.cc" "${C_SRC}/*.h" "${C_SRC}/../exla_nif_util.cc" "${C_SRC}/../exla_nif_util.h") + +if(IREE_BUILD_FOR_IOS) + set(BUILD_IREE_COMPILER OFF) + file(GLOB iree_runtime_sources CONFIGURE_DEPENDS "${C_SRC}/runtime.cc" "${C_SRC}/runtime.h") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_FOR_IOS") +endif() + add_library(${_NAME} SHARED ${iree_runtime_sources}) set_property(TARGET ${_NAME} PROPERTY CXX_STANDARD 17) -target_include_directories(${_NAME} PUBLIC $ENV{ERTS_INCLUDE_DIR}) -target_include_directories(${_NAME} PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/../../${XLA_INCLUDE_PATH}") -target_include_directories(${_NAME} SYSTEM - PRIVATE - "${CMAKE_CURRENT_SOURCE_DIR}/../../${IREE_COMPILER_INCLUDE_PATH}/iree/compiler" -) -target_include_directories(${_NAME} SYSTEM - PRIVATE - "${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree/third_party/llvm-project/mlir/include" -) +if (IREE_BUILD_FOR_IOS) + set(__BUILD_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree/ios_runtime/build") +else() + set(__BUILD_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree/build") +endif() -add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree" "${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree/build" EXCLUDE_FROM_ALL) +add_subdirectory("${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree" ${__BUILD_DIR} EXCLUDE_FROM_ALL) install( TARGETS ${_NAME} @@ -54,10 +59,12 @@ install( set_target_properties(${_NAME} PROPERTIES SUFFIX ".so") +if (NOT IREE_BUILD_FOR_IOS) set_target_properties(${_NAME} PROPERTIES INSTALL_RPATH_USE_LINK_PATH TRUE BUILD_WITH_INSTALL_RPATH TRUE ) +endif() if(NOT APPLE) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -shared") @@ -70,8 +77,11 @@ else() if(ARM64_SUPPORTED) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMAC_ARM64") endif() - # set(CMAKE_SHARED_LINKER_FLAGS "-bundle -flat_namespace -undefined suppress") - set_target_properties(${_NAME} PROPERTIES INSTALL_RPATH "@loader_path/${IREE_COMPILER_DIR}") + + if(IREE_BUILD_FOR_IOS) + else() + set_target_properties(${_NAME} PROPERTIES INSTALL_RPATH "@loader_path/${IREE_COMPILER_DIR}") + endif() endif() target_compile_options(${_NAME} PRIVATE ${IREE_DEFAULT_COPTS}) @@ -86,18 +96,32 @@ endif() add_definitions(-DLLVM_DISABLE_ABI_BREAKING_CHECKS_ENFORCING=1) -set(XLA_EXTENSION_LIB_PATH ${XLA_EXTENSION_LIB}) -set(XLA_EXTENSION_INCLUDE_PATH ${XLA_INCLUDE_PATH}) -include_directories(${XLA_EXTENSION_INCLUDE_PATH}) -target_link_libraries(${_NAME} "${XLA_EXTENSION_LIB}/libxla_extension.so") - target_link_libraries(${_NAME} iree_runtime_runtime) -target_link_libraries(${_NAME} iree_compiler_bindings_c_loader) -if(NOT APPLE) - target_link_libraries(${_NAME} "${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree/build/lib/libIREECompiler.so") +if(IREE_BUILD_FOR_IOS) else() - target_link_libraries(${_NAME} "${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree/build/lib/libIREECompiler.dylib") + if(NOT APPLE) + target_link_libraries(${_NAME} "${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree/build/lib/libIREECompiler.so") + else() + target_link_libraries(${_NAME} "${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree/build/lib/libIREECompiler.dylib") + endif() + + target_include_directories(${_NAME} PUBLIC $ENV{ERTS_INCLUDE_DIR}) + target_include_directories(${_NAME} PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/../../${XLA_INCLUDE_PATH}") + target_include_directories(${_NAME} SYSTEM + PRIVATE + "${CMAKE_CURRENT_SOURCE_DIR}/../../${IREE_COMPILER_INCLUDE_PATH}/iree/compiler" + ) + target_include_directories(${_NAME} SYSTEM + PRIVATE + "${CMAKE_CURRENT_SOURCE_DIR}/../../cache/iree/third_party/llvm-project/mlir/include" + ) + + set(XLA_EXTENSION_LIB_PATH ${XLA_EXTENSION_LIB}) + set(XLA_EXTENSION_INCLUDE_PATH ${XLA_INCLUDE_PATH}) + include_directories(${XLA_EXTENSION_INCLUDE_PATH}) + target_link_libraries(${_NAME} "${XLA_EXTENSION_LIB}/libxla_extension.so") + target_link_libraries(${_NAME} iree_compiler_bindings_c_loader) endif() if($ENV{DEBUG}) diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index cc0a72ab77..5ba91bbe76 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -522,6 +522,10 @@ defmodule EXLA.Defn do {:ok, module_bytecode} = EXLA.MLIR.IREE.compile(module_charlist, flags) + if filename = options[:iree_filename] do + File.write!(filename, module_bytecode) + end + %EXLA.Executable{ client: client, ref: module_bytecode,