diff --git a/exla/Makefile b/exla/Makefile index 91b0cf0d21..8efb4c1fc0 100644 --- a/exla/Makefile +++ b/exla/Makefile @@ -61,8 +61,8 @@ $(EXLA_SO): $(EXLA_CACHE_SO) ln -sf $(EXLA_CACHE_SO_LINK_PATH) $(EXLA_SO) ; \ fi -SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/mlir/ops.cc $(EXLA_DIR)/mlir/builder.cc $(EXLA_DIR)/mlir/custom_calls.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_nif_util.cc -HEADERS = $(EXLA_DIR)/mlir/ops.h $(EXLA_DIR)/mlir/builder.h $(EXLA_DIR)/mlir/custom_calls.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h +SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/custom_calls.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_nif_util.cc +HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(SOURCES)) $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o diff --git a/exla/c_src/exla/mlir/custom_calls.cc b/exla/c_src/exla/custom_calls.cc similarity index 86% rename from exla/c_src/exla/mlir/custom_calls.cc rename to exla/c_src/exla/custom_calls.cc index 4040f238f0..eb2ce90d27 100644 --- a/exla/c_src/exla/mlir/custom_calls.cc +++ b/exla/c_src/exla/custom_calls.cc @@ -1,9 +1,10 @@ #include "custom_calls.h" +#include "exla_nif_util.h" -#include -#include +#include "xla/service/custom_call_target_registry.h" -#include "builder.h" +#include "Eigen/Dense" +#include "Eigen/QR" template void single_matrix_qr_cpu_custom_call(DataType *q_out, DataType *r_out, DataType *in, int64_t m, int64_t k, int64_t n, bool complete) { @@ -102,4 +103,9 @@ void qr_cpu_custom_call_f32(void *out[], const void *in[]) { void qr_cpu_custom_call_f64(void *out[], const void *in[]) { qr_cpu_custom_call(out, in); -} \ No newline at end of file +} + +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f32", qr_cpu_custom_call_f32); +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f64", qr_cpu_custom_call_f64); +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f16", qr_cpu_custom_call_f16); +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_bf16", qr_cpu_custom_call_bf16); diff --git a/exla/c_src/exla/mlir/custom_calls.h b/exla/c_src/exla/custom_calls.h similarity index 53% rename from exla/c_src/exla/mlir/custom_calls.h rename to exla/c_src/exla/custom_calls.h index 38a59ee3f1..f00834d411 100644 --- a/exla/c_src/exla/mlir/custom_calls.h +++ b/exla/c_src/exla/custom_calls.h @@ -1,9 +1,9 @@ -#pragma once - -template -void qr_cpu_custom_call(void *out[], const void *in[]); +#ifndef EXLA_MLIR_CUSTOM_CALLS_H_ +#define EXLA_MLIR_CUSTOM_CALLS_H_ void qr_cpu_custom_call_bf16(void *out[], const void *in[]); void qr_cpu_custom_call_f16(void *out[], const void *in[]); void qr_cpu_custom_call_f32(void *out[], const void *in[]); -void qr_cpu_custom_call_f64(void *out[], const void *in[]); \ No newline at end of file +void qr_cpu_custom_call_f64(void *out[], const void *in[]); + +#endif diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index c43a12a5ba..3fc0d10fdb 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -1,20 +1,18 @@ -#include -#include -#include #include +#include "exla_mlir.h" #include "exla_client.h" #include "exla_cuda.h" #include "exla_log_sink.h" #include "exla_nif_util.h" -#include "mlir/ops.h" -#include "xla/client/client.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" + #include "xla/pjrt/pjrt_api.h" -#include "xla/primitive_util.h" #include "xla/service/platform_util.h" -#include "xla/shape_util.h" + +#include "mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "stablehlo/dialect/ChloOps.h" +#include "stablehlo/dialect/StablehloOps.h" // All of these are created with calls to `new` and subsequently // passed to the VM as pointers-to-pointers so we balance it out @@ -28,14 +26,6 @@ void free_exla_executable(ErlNifEnv* env, void* obj) { } } -void free_xla_builder(ErlNifEnv* env, void* obj) { - xla::XlaBuilder** builder = reinterpret_cast(obj); - if (*builder != nullptr) { - delete *builder; - *builder = nullptr; - } -} - void free_exla_client(ErlNifEnv* env, void* obj) { exla::ExlaClient** client = reinterpret_cast(obj); if (*client != nullptr) { @@ -55,15 +45,9 @@ void free_exla_buffer(ErlNifEnv* env, void* obj) { static int open_resources(ErlNifEnv* env) { const char* mod = "EXLA"; - if (!exla::nif::open_resource(env, mod, "Shape")) { - return -1; - } if (!exla::nif::open_resource(env, mod, "Executable", free_exla_executable)) { return -1; } - if (!exla::nif::open_resource(env, mod, "Builder", free_xla_builder)) { - return -1; - } if (!exla::nif::open_resource(env, mod, "ExlaClient", free_exla_client)) { return -1; } @@ -71,7 +55,7 @@ static int open_resources(ErlNifEnv* env) { return -1; } // MLIR - if (!exla::nif::open_resource(env, mod, "MLIRBlock")) { + if (!exla::nif::open_resource(env, mod, "MLIRFunction")) { return -1; } if (!exla::nif::open_resource(env, mod, "MLIRValue")) { @@ -96,41 +80,308 @@ static int load(ErlNifEnv* env, void** priv, ERL_NIF_TERM load_info) { return 0; } -// XlaBuilder Functions +// MLIR Functions + +ERL_NIF_TERM type_parsing_error(ErlNifEnv* env, std::string type_string) { + return exla::nif::make(env, "Unable to parse MLIR type: " + type_string); +} + +ERL_NIF_TERM attribute_parsing_error(ErlNifEnv* env, std::string attribute_string) { + return exla::nif::make(env, "Unable to parse MLIR attribute: " + attribute_string); +} + +ERL_NIF_TERM mlir_compile(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { + if (argc != 7) { + return exla::nif::error(env, "Bad argument count."); + } + + exla::ExlaClient** client; + exla::MLIRModule** module; + std::vector argument_layouts; + xla::ExecutableBuildOptions build_options; + int num_replicas; + int num_partitions; + bool use_spmd; + int device_id; + + if (!exla::nif::get(env, argv[0], client)) { + return exla::nif::error(env, "Unable to get client."); + } + if (!exla::nif::get(env, argv[1], module)) { + return exla::nif::error(env, "Unable to get module."); + } + if (!exla::nif::get_list(env, argv[2], argument_layouts)) { + return exla::nif::error(env, "Unable to get argument layouts."); + } + if (!exla::nif::get(env, argv[3], &num_replicas)) { + return exla::nif::error(env, "Unable to get Number of Replicas."); + } + if (!exla::nif::get(env, argv[4], &num_partitions)) { + return exla::nif::error(env, "Unable to get Number of Partitions."); + } + if (!exla::nif::get(env, argv[5], &use_spmd)) { + return exla::nif::error(env, "Unable to get SPMD Partitioning Flag."); + } + if (!exla::nif::get(env, argv[6], &device_id)) { + return exla::nif::error(env, "Unable to get device ID."); + } + + build_options.set_num_replicas(num_replicas); + build_options.set_num_partitions(num_partitions); + build_options.set_use_spmd_partitioning(use_spmd); + + bool compile_portable_executable = false; + if (device_id >= 0) { + compile_portable_executable = true; + build_options.set_device_ordinal(device_id); + } + + EXLA_ASSIGN_OR_RETURN_NIF(exla::ExlaExecutable * executable, + (*client)->Compile((*module)->module(), argument_layouts, build_options, compile_portable_executable), env); + + return exla::nif::ok(env, exla::nif::make(env, executable)); +} + +ERL_NIF_TERM mlir_new_context(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { + if (argc != 0) { + return exla::nif::error(env, "Bad argument count."); + } + + mlir::MLIRContext* context = new mlir::MLIRContext(); + context->getOrLoadDialect(); + context->getOrLoadDialect(); + context->getOrLoadDialect(); + context->getOrLoadDialect(); + + auto ret = exla::nif::make(env, context); + return exla::nif::ok(env, ret); +} + +ERL_NIF_TERM mlir_new_module(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { + if (argc != 1) { + return exla::nif::error(env, "Bad argument count."); + } + + mlir::MLIRContext** ctx; + + if (!exla::nif::get(env, argv[0], ctx)) { + return exla::nif::error(env, "Unable to get context."); + } + + exla::MLIRModule* module = new exla::MLIRModule(*ctx); + + return exla::nif::ok(env, exla::nif::make(env, module)); +} + +ERL_NIF_TERM mlir_create_function(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { + if (argc != 5) { + return exla::nif::error(env, "Bad argument count."); + } + + exla::MLIRModule** module; + std::string func_name; + std::vector arg_type_strings; + std::vector ret_type_strings; + bool is_public; + + if (!exla::nif::get(env, argv[0], module)) { + return exla::nif::error(env, "Unable to get module."); + } + if (!exla::nif::get(env, argv[1], func_name)) { + return exla::nif::error(env, "Unable to get function name."); + } + if (!exla::nif::get_list(env, argv[2], arg_type_strings)) { + return exla::nif::error(env, "Unable to get args."); + } + if (!exla::nif::get_list(env, argv[3], ret_type_strings)) { + return exla::nif::error(env, "Unable to get return."); + } + if (!exla::nif::get(env, argv[4], &is_public)) { + return exla::nif::error(env, "Unable to get is_public."); + } + + auto arg_types = std::vector{}; + + for (auto const & type_string : arg_type_strings) { + auto type = (*module)->ParseType(type_string); + if(type == nullptr) { + return type_parsing_error(env, type_string); + } + arg_types.push_back(type); + } + + auto ret_types = std::vector{}; + + for (auto const & type_string : ret_type_strings) { + auto type = (*module)->ParseType(type_string); + if(type == nullptr) { + return type_parsing_error(env, type_string); + } + ret_types.push_back(type); + } -ERL_NIF_TERM new_builder(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { + exla::MLIRFunction* func = (*module)->CreateFunction(func_name, arg_types, ret_types, is_public); + + return exla::nif::ok(env, exla::nif::make(env, func)); +} + +ERL_NIF_TERM mlir_get_function_arguments(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { if (argc != 1) { return exla::nif::error(env, "Bad argument count."); } - std::string name; - if (!exla::nif::get(env, argv[0], name)) { - return exla::nif::error(env, "Unable to get builder name."); + exla::MLIRFunction** function; + + if (!exla::nif::get(env, argv[0], function)) { + return exla::nif::error(env, "Unable to get function."); + } + + llvm::MutableArrayRef args = (*function)->GetArguments(); + std::vector terms; + terms.reserve(args.size()); + + for (auto arg : args) { + ERL_NIF_TERM term = exla::nif::make(env, arg); + terms.push_back(term); + } + + return exla::nif::ok(env, enif_make_list_from_array(env, terms.data(), terms.size())); +} + +ERL_NIF_TERM mlir_op(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { + if (argc != 6) { + return exla::nif::error(env, "Bad argument count."); + } + + exla::MLIRFunction** function; + std::string op_name; + std::vector operands; + std::vector result_type_strings; + std::vector> attributes_kwlist; + std::vector regions; + + if (!exla::nif::get(env, argv[0], function)) { + return exla::nif::error(env, "Unable to get function."); + } + if (!exla::nif::get(env, argv[1], op_name)) { + return exla::nif::error(env, "Unable to get op name."); + } + if (!exla::nif::get_list(env, argv[2], operands)) { + return exla::nif::error(env, "Unable to get operands."); + } + if (!exla::nif::get_list(env, argv[3], result_type_strings)) { + return exla::nif::error(env, "Unable to get result types."); + } + if (!exla::nif::get_keyword_list(env, argv[4], attributes_kwlist)) { + return exla::nif::error(env, "Unable to get attributes."); + } + if (!exla::nif::get_list(env, argv[5], regions)) { + return exla::nif::error(env, "Unable to get regions."); } - xla::XlaBuilder* builder = new xla::XlaBuilder(name); + auto result_types = std::vector{}; - return exla::nif::ok(env, exla::nif::make(env, builder)); + for (auto const & type_string : result_type_strings) { + auto type = (*function)->module()->ParseType(type_string); + if(type == nullptr) { + return type_parsing_error(env, type_string); + } + result_types.push_back(type); + } + + auto attributes = std::vector>{}; + + for (auto const & pair : attributes_kwlist) { + auto attribute_value = (*function)->module()->ParseAttribute(pair.second); + if(attribute_value == nullptr) { + return attribute_parsing_error(env, pair.second); + } + attributes.push_back(std::pair{pair.first, attribute_value}); + } + + auto results = (*function)->Op(op_name, operands, result_types, attributes, regions); + + return exla::nif::ok(env, exla::nif::make_list(env, results)); } -ERL_NIF_TERM create_sub_builder(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { + +ERL_NIF_TERM mlir_push_region(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { if (argc != 2) { return exla::nif::error(env, "Bad argument count."); } - xla::XlaBuilder** builder; - std::string name; + exla::MLIRFunction** function; + std::vector arg_types; - if (!exla::nif::get(env, argv[0], builder)) { - return exla::nif::error(env, "Unable to get builder."); + if (!exla::nif::get(env, argv[0], function)) { + return exla::nif::error(env, "Unable to get function."); + } + if (!exla::nif::get_list(env, argv[1], arg_types)) { + return exla::nif::error(env, "Unable to get arg types."); + } + + auto types = std::vector{}; + + for (auto const & type_string : arg_types) { + auto type = (*function)->module()->ParseType(type_string); + if(type == nullptr) { + return type_parsing_error(env, type_string); + } + types.push_back(type); + } + + mlir::Region* region; + std::vector args; + std::tie(region, args) = (*function)->PushRegion(types); + + return exla::nif::ok(env, enif_make_tuple2(env, exla::nif::make(env, region), exla::nif::make_list(env, args))); +} + +ERL_NIF_TERM mlir_pop_region(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { + if (argc != 1) { + return exla::nif::error(env, "Bad argument count."); + } + + exla::MLIRFunction** function; + + if (!exla::nif::get(env, argv[0], function)) { + return exla::nif::error(env, "Unable to get function."); + } + + (*function)->PopRegion(); + return exla::nif::ok(env); +} + +ERL_NIF_TERM mlir_get_typespec(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { + if (argc != 1) { + return exla::nif::error(env, "Bad argument count."); } - if (!exla::nif::get(env, argv[1], name)) { - return exla::nif::error(env, "Unable to get name."); + + mlir::Value* t; + + if (!exla::nif::get(env, argv[0], t)) { + return exla::nif::error(env, "Unable to get tensor."); } - auto uniq_sub_builder = (*builder)->CreateSubBuilder(name); - xla::XlaBuilder* sub_builder = uniq_sub_builder.release(); - return exla::nif::ok(env, exla::nif::make(env, sub_builder)); + mlir::Type type = t->getType(); + + return exla::nif::ok(env, exla::nif::make_typespec(env, type)); +} + +ERL_NIF_TERM mlir_module_to_string(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { + if (argc != 1) { + return exla::nif::error(env, "Bad argument count."); + } + + exla::MLIRModule** module; + + if (!exla::nif::get(env, argv[0], module)) { + return exla::nif::error(env, "Unable to get builder."); + } + + auto string = (*module)->ToString(); + + return exla::nif::ok(env, exla::nif::make(env, string)); } // ExlaBuffer Functions @@ -191,7 +442,7 @@ ERL_NIF_TERM create_buffer_from_device_pointer(ErlNifEnv* env, int argc, const E exla::ExlaClient** client; std::vector pointer_vec; - xla::Shape* shape; + xla::Shape shape; int device_id; std::string pointer_kind; @@ -204,7 +455,7 @@ ERL_NIF_TERM create_buffer_from_device_pointer(ErlNifEnv* env, int argc, const E if (!exla::nif::get_atom(env, argv[2], pointer_kind)) { return exla::nif::error(env, "Unable to get device pointer kind."); } - if (!exla::nif::get(env, argv[3], shape)) { + if (!exla::nif::get_typespec_as_xla_shape(env, argv[3], &shape)) { return exla::nif::error(env, "Unable to get shape."); } if (!exla::nif::get(env, argv[4], &device_id)) { @@ -233,7 +484,7 @@ ERL_NIF_TERM create_buffer_from_device_pointer(ErlNifEnv* env, int argc, const E EXLA_ASSIGN_OR_RETURN_NIF(xla::PjRtDevice * device, (*client)->client()->LookupDevice(device_id), env); std::function on_delete_callback = []() {}; - EXLA_ASSIGN_OR_RETURN_NIF(std::unique_ptr buffer, (*client)->client()->CreateViewOfDeviceBuffer(ptr, *shape, device, on_delete_callback), env); + EXLA_ASSIGN_OR_RETURN_NIF(std::unique_ptr buffer, (*client)->client()->CreateViewOfDeviceBuffer(ptr, shape, device, on_delete_callback), env); exla::ExlaBuffer* exla_buffer = new exla::ExlaBuffer(std::move(buffer)); return exla::nif::ok(env, exla::nif::make(env, exla_buffer)); } @@ -244,7 +495,7 @@ ERL_NIF_TERM binary_to_device_mem(ErlNifEnv* env, int argc, const ERL_NIF_TERM a } ErlNifBinary bin; - xla::Shape* shape; + xla::Shape shape; exla::ExlaClient** client; int device_id; @@ -254,7 +505,7 @@ ERL_NIF_TERM binary_to_device_mem(ErlNifEnv* env, int argc, const ERL_NIF_TERM a if (!exla::nif::get_binary(env, argv[1], &bin)) { return exla::nif::error(env, "Unable to get data."); } - if (!exla::nif::get(env, argv[2], shape)) { + if (!exla::nif::get_typespec_as_xla_shape(env, argv[2], &shape)) { return exla::nif::error(env, "Unable to get shape."); } if (!exla::nif::get(env, argv[3], &device_id)) { @@ -262,7 +513,7 @@ ERL_NIF_TERM binary_to_device_mem(ErlNifEnv* env, int argc, const ERL_NIF_TERM a } EXLA_ASSIGN_OR_RETURN_NIF(exla::ExlaBuffer * buffer, - (*client)->BufferFromBinary(env, argv[1], *shape, device_id), env); + (*client)->BufferFromBinary(env, argv[1], shape, device_id), env); return exla::nif::ok(env, exla::nif::make(env, buffer)); } @@ -306,51 +557,6 @@ ERL_NIF_TERM deallocate_device_mem(ErlNifEnv* env, int argc, const ERL_NIF_TERM } } -// Shape Functions - -ERL_NIF_TERM make_shape(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 2) { - return exla::nif::error(env, "Bad argument count."); - } - - xla::PrimitiveType element_type; - std::vector dims; - - if (!exla::nif::get_primitive_type(env, argv[0], &element_type)) { - return exla::nif::error(env, "Unable to get type."); - } - if (!exla::nif::get_tuple(env, argv[1], dims)) { - return exla::nif::error(env, "Unable to get dimensions."); - } - - xla::Shape shape = xla::ShapeUtil::MakeShape(element_type, dims); - - return exla::nif::ok(env, exla::nif::make(env, shape)); -} - -ERL_NIF_TERM make_token_shape(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 0) { - return exla::nif::error(env, "Bad argument count."); - } - - xla::Shape shape = xla::ShapeUtil::MakeTokenShape(); - return exla::nif::ok(env, exla::nif::make(env, shape)); -} - -ERL_NIF_TERM get_shape_info(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 1) { - return exla::nif::error(env, "Bad argument count."); - } - - xla::Shape* shape; - - if (!exla::nif::get(env, argv[0], shape)) { - return exla::nif::error(env, "Unable to get shape."); - } - - return exla::nif::ok(env, exla::nif::make_shape_info(env, *shape)); -} - ERL_NIF_TERM transfer_to_infeed(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { if (argc != 3) { return exla::nif::error(env, "Bad argument count."); @@ -371,17 +577,17 @@ ERL_NIF_TERM transfer_to_infeed(ErlNifEnv* env, int argc, const ERL_NIF_TERM arg while (enif_get_list_cell(env, data, &head, &tail)) { const ERL_NIF_TERM* terms; int count; - xla::Shape* shape; + xla::Shape shape; if (!enif_get_tuple(env, head, &count, &terms) && count != 2) { return exla::nif::error(env, "Unable to binary-shape tuple."); } - if (!exla::nif::get(env, terms[1], shape)) { + if (!exla::nif::get_typespec_as_xla_shape(env, terms[1], &shape)) { return exla::nif::error(env, "Unable to get shape."); } - xla::Status transfer_status = (*client)->TransferToInfeed(env, terms[0], *shape, device_id); + xla::Status transfer_status = (*client)->TransferToInfeed(env, terms[0], shape, device_id); if (!transfer_status.ok()) { return exla::nif::error(env, transfer_status.message().data()); @@ -415,15 +621,15 @@ ERL_NIF_TERM transfer_from_outfeed(ErlNifEnv* env, int argc, const ERL_NIF_TERM ERL_NIF_TERM data = argv[2]; ERL_NIF_TERM head, tail; while (enif_get_list_cell(env, data, &head, &tail)) { - xla::Shape* shape; + xla::Shape shape; - if (!exla::nif::get(env, head, shape)) { + if (!exla::nif::get_typespec_as_xla_shape(env, head, &shape)) { return exla::nif::error(env, "Unable to get shape."); } ErlNifEnv* penv = enif_alloc_env(); ERL_NIF_TERM ref = enif_make_copy(penv, argv[4]); - auto statusor = (*client)->TransferFromOutfeed(penv, device_id, *shape); + auto statusor = (*client)->TransferFromOutfeed(penv, device_id, shape); if (!statusor.ok()) { enif_clear_env(penv); @@ -685,109 +891,15 @@ ERL_NIF_TERM start_log_sink(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) static ErlNifFunc exla_funcs[] = { // MLIR Builder - {"new_mlir_context", 0, new_mlir_context}, - {"new_mlir_module", 1, new_mlir_module}, - {"create_mlir_function", 5, create_mlir_function}, - {"get_mlir_function_arguments", 1, get_mlir_function_arguments}, - {"mlir_add", 3, mlir_add}, - {"mlir_subtract", 3, mlir_subtract}, - {"mlir_multiply", 3, mlir_multiply}, - {"mlir_min", 3, mlir_min}, - {"mlir_max", 3, mlir_max}, - {"mlir_remainder", 3, mlir_remainder}, - {"mlir_pow", 3, mlir_pow}, - {"mlir_divide", 3, mlir_divide}, - {"mlir_atan2", 3, mlir_atan2}, - {"mlir_equal", 3, mlir_equal}, - {"mlir_not_equal", 3, mlir_not_equal}, - {"mlir_less", 3, mlir_less}, - {"mlir_less_equal", 3, mlir_less_equal}, - {"mlir_greater", 3, mlir_greater}, - {"mlir_greater_equal", 3, mlir_greater_equal}, - {"dump_mlir_module", 1, dump_mlir_module}, - {"mlir_get_shape", 1, mlir_get_shape}, - {"mlir_convert", 3, mlir_convert}, - {"mlir_bitcast_convert", 4, mlir_bitcast_convert}, - {"mlir_abs", 2, mlir_abs}, - {"mlir_exp", 2, mlir_exp}, - {"mlir_expm1", 2, mlir_expm1}, - {"mlir_floor", 2, mlir_floor}, - {"mlir_ceil", 2, mlir_ceil}, - {"mlir_round", 2, mlir_round}, - {"mlir_log", 2, mlir_log}, - {"mlir_sigmoid", 2, mlir_sigmoid}, - {"mlir_log1p", 2, mlir_log1p}, - {"mlir_sign", 2, mlir_sign}, - {"mlir_cos", 2, mlir_cos}, - {"mlir_sin", 2, mlir_sin}, - {"mlir_tan", 2, mlir_tan}, - {"mlir_acos", 2, mlir_acos}, - {"mlir_asin", 2, mlir_asin}, - {"mlir_atan", 2, mlir_atan}, - {"mlir_cosh", 2, mlir_cosh}, - {"mlir_sinh", 2, mlir_sinh}, - {"mlir_tanh", 2, mlir_tanh}, - {"mlir_acosh", 2, mlir_acosh}, - {"mlir_asinh", 2, mlir_asinh}, - {"mlir_atanh", 2, mlir_atanh}, - {"mlir_sqrt", 2, mlir_sqrt}, - {"mlir_cbrt", 2, mlir_cbrt}, - {"mlir_iota", 3, mlir_iota}, - {"mlir_top_k", 3, mlir_top_k}, - {"mlir_sort", 5, mlir_sort}, - {"mlir_scatter", 9, mlir_scatter}, - {"mlir_select_and_scatter", 8, mlir_select_and_scatter}, - {"mlir_gather", 8, mlir_gather}, - {"mlir_reshape", 3, mlir_reshape}, - {"mlir_reverse", 3, mlir_reverse}, - {"mlir_transpose", 3, mlir_transpose}, - {"mlir_slice", 5, mlir_slice}, - {"mlir_dynamic_slice", 4, mlir_dynamic_slice}, - {"mlir_constant_r0", 3, mlir_constant_r0}, - {"mlir_constant_from_binary", 4, mlir_constant_from_binary}, - {"mlir_bitwise_and", 3, mlir_bitwise_and}, - {"mlir_bitwise_or", 3, mlir_bitwise_or}, - {"mlir_bitwise_xor", 3, mlir_bitwise_xor}, - {"mlir_bitwise_not", 2, mlir_bitwise_not}, - {"mlir_left_shift", 3, mlir_shift_left}, - {"mlir_right_shift_logical", 3, mlir_shift_right_logical}, - {"mlir_right_shift_arithmetic", 3, mlir_shift_right_arithmetic}, - {"mlir_negate", 2, mlir_negate}, - {"mlir_erf", 2, mlir_erf}, - {"mlir_erfc", 2, mlir_erfc}, - {"mlir_erf_inv", 2, mlir_erf_inv}, - {"mlir_is_infinity", 2, mlir_is_infinity}, - {"mlir_is_nan", 2, mlir_is_nan}, - {"mlir_rsqrt", 2, mlir_rsqrt}, - {"mlir_count_leading_zeros", 2, mlir_clz}, - {"mlir_real", 2, mlir_real}, - {"mlir_imag", 2, mlir_imag}, - {"mlir_conjugate", 2, mlir_conjugate}, - {"mlir_dot_general", 6, mlir_dot_general}, - {"mlir_clamp", 4, mlir_clamp}, - {"mlir_population_count", 2, mlir_population_count}, - {"mlir_broadcast_in_dim", 4, mlir_broadcast_in_dim}, - {"mlir_concatenate", 3, mlir_concatenate}, - {"mlir_optimization_barrier", 2, mlir_optimization_barrier}, - {"mlir_select", 4, mlir_select}, - {"mlir_pad", 6, mlir_pad}, - {"mlir_fft", 4, mlir_fft}, - {"mlir_convolution", 12, mlir_convolution}, - {"mlir_create_token", 1, mlir_create_token}, - {"mlir_triangular_solve", 6, mlir_triangular_solve}, - {"mlir_dynamic_update_slice", 4, mlir_dynamic_update_slice}, - {"mlir_infeed", 3, mlir_infeed}, - {"mlir_outfeed", 3, mlir_outfeed}, - {"mlir_call", 3, mlir_call}, - {"mlir_reduce", 5, mlir_reduce}, - {"mlir_window_reduce", 9, mlir_window_reduce}, - {"mlir_map", 4, mlir_map}, - {"mlir_if", 3, mlir_if}, + {"mlir_new_context", 0, mlir_new_context}, + {"mlir_new_module", 1, mlir_new_module}, + {"mlir_create_function", 5, mlir_create_function}, + {"mlir_get_function_arguments", 1, mlir_get_function_arguments}, + {"mlir_op", 6, mlir_op}, {"mlir_push_region", 2, mlir_push_region}, + {"mlir_get_typespec", 1, mlir_get_typespec}, {"mlir_pop_region", 1, mlir_pop_region}, - {"mlir_while", 2, mlir_while}, - {"mlir_return", 2, mlir_return}, - {"mlir_qr", 4, mlir_qr}, + {"mlir_module_to_string", 1, mlir_module_to_string}, // ExlaClient {"get_host_client", 0, get_host_client}, {"get_gpu_client", 2, get_gpu_client}, @@ -809,10 +921,6 @@ static ErlNifFunc exla_funcs[] = { // ExlaExecutable {"run_io", 4, run, ERL_NIF_DIRTY_JOB_IO_BOUND}, {"run_cpu", 4, run, ERL_NIF_DIRTY_JOB_CPU_BOUND}, - // Shape - {"make_shape", 2, make_shape}, - {"make_token_shape", 0, make_token_shape}, - {"get_shape_info", 1, get_shape_info}, // Log Sink {"start_log_sink", 1, start_log_sink}, // Serialization diff --git a/exla/c_src/exla/exla_client.cc b/exla/c_src/exla/exla_client.cc index 0d473c66a8..4ed775ff06 100644 --- a/exla/c_src/exla/exla_client.cc +++ b/exla/c_src/exla/exla_client.cc @@ -98,17 +98,17 @@ UnpackReplicaArguments(ErlNifEnv* env, if (enif_get_tuple(env, head, &arity, &tuple)) { // if the term is a tuple, that means it represents a {shape, binary} // tuple which we must convert into an exla buffer for use in the computation - xla::Shape* shape; + xla::Shape shape; - if (!nif::get(env, tuple[1], shape)) { - return xla::InvalidArgument("Expected argument to be shape reference."); + if (!nif::get_typespec_as_xla_shape(env, tuple[1], &shape)) { + return xla::InvalidArgument("Expected argument to be a typespec."); } // we convert the binary into a buffer and transfer it to the correct device, // this buffer is not managed by the erlang vm so it must be deallocated explicitly // after use by the execution EXLA_ASSIGN_OR_RETURN(std::unique_ptr buf, - PjRtBufferFromBinary(client->client(), env, tuple[0], *shape, device)); + PjRtBufferFromBinary(client->client(), env, tuple[0], shape, device)); replica_buffers.push_back(buf.release()); } else if (nif::get(env, head, buffer)) { // if the buffer is not a tuple it must be a reference to an exla buffer @@ -364,13 +364,13 @@ xla::StatusOr ExlaClient::DeserializeExecutable(std::string des } xla::StatusOr ExlaClient::Compile(const mlir::OwningOpRef& module, - std::vector argument_layouts, + std::vector argument_layouts, xla::ExecutableBuildOptions& options, bool compile_portable_executable) { std::vector layouts; layouts.reserve(argument_layouts.size()); for (auto shape : argument_layouts) { - xla::Shape cpy_shape = xla::ShapeUtil::MakeShape(shape->element_type(), shape->dimensions()); + xla::Shape cpy_shape = xla::ShapeUtil::MakeShape(shape.element_type(), shape.dimensions()); xla::LayoutUtil::ClearLayout(&cpy_shape); layouts.push_back(cpy_shape); } diff --git a/exla/c_src/exla/exla_client.h b/exla/c_src/exla/exla_client.h index 6d7907acc5..ef531b3351 100644 --- a/exla/c_src/exla/exla_client.h +++ b/exla/c_src/exla/exla_client.h @@ -79,7 +79,7 @@ class ExlaClient { // Compiles the given computation with the given compile options xla::StatusOr Compile(const mlir::OwningOpRef& computation, - std::vector argument_layouts, + std::vector argument_layouts, xla::ExecutableBuildOptions& options, bool compile_portable_executable); diff --git a/exla/c_src/exla/exla_mlir.cc b/exla/c_src/exla/exla_mlir.cc new file mode 100644 index 0000000000..8b83cacae5 --- /dev/null +++ b/exla/c_src/exla/exla_mlir.cc @@ -0,0 +1,134 @@ +#include "exla_mlir.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/AsmParser/AsmParser.h" + +namespace exla { +MLIRFunction::MLIRFunction(MLIRModule *module, std::unique_ptr func) + : func_(std::move(func)), + module_(module) {} + +std::vector MLIRFunction::Op( + std::string op_name, std::vector operands, + std::vector result_types, + std::vector> attributes, + std::vector regions) { + auto builder = module_->builder(); + auto context = builder->getContext(); + + auto types_range = mlir::TypeRange{llvm::ArrayRef{result_types}}; + + auto named_attributes = std::vector{}; + for (auto const &pair : attributes) { + auto attribute = builder->getNamedAttr(pair.first, pair.second); + named_attributes.push_back(attribute); + } + + auto operands_range = mlir::ValueRange{ + llvm::ArrayRef(operands.data(), operands.size())}; + auto attributes_array = llvm::ArrayRef{named_attributes}; + + setInsertionPoint(); + + auto op_state = mlir::OperationState{mlir::UnknownLoc::get(context), + builder->getStringAttr(op_name), + operands_range, + types_range, + attributes_array, + {}, + {}}; + + for (auto region : regions) { + auto new_region = op_state.addRegion(); + new_region->getBlocks().splice(new_region->end(), region->getBlocks()); + } + + auto op = builder->create(op_state); + + auto results = op->getResults(); + return std::vector(results.begin(), results.end()); +} + +std::pair> MLIRFunction::PushRegion(std::vector types) { + auto context = module_->builder()->getContext(); + + auto region = new mlir::Region(); + auto & block = region->emplaceBlock(); + + for (mlir::Type type : types) { + block.addArgument(type, mlir::UnknownLoc::get(context)); + } + + auto args = std::vector{}; + for (auto &arg : block.getArguments()) { + args.push_back(arg); + } + + region_stack.push(std::move(region)); + setInsertionPoint(); + + return {region, args}; +} + +void MLIRFunction::PopRegion() { + region_stack.pop(); + setInsertionPoint(); +} + +void MLIRFunction::setInsertionPoint() { + if (region_stack.size() == 0) { + module_->builder()->setInsertionPointToEnd(&func_->getBody().back()); + } else { + module_->builder()->setInsertionPointToEnd(®ion_stack.top()->back()); + } +} + +MLIRModule::MLIRModule(mlir::MLIRContext *context) { + context_ = context; + module_ = mlir::OwningOpRef(mlir::ModuleOp::create(mlir::UnknownLoc::get(context_))); + builder_ = std::make_unique(context_); + builder_->setInsertionPointToStart(module_->getBody()); +} + +MLIRFunction *MLIRModule::CreateFunction( + std::string name, + std::vector arg_types, + std::vector ret_types, + bool is_public) { + auto visibility = is_public ? "public" : "nested"; + + auto funcType = builder_->getFunctionType(arg_types, ret_types); + auto loc = builder_->getUnknownLoc(); + auto funcOp = std::make_unique(mlir::func::FuncOp::create(loc, name, funcType)); + funcOp->setSymVisibility(visibility); + module_->push_back(*funcOp); + funcOp->addEntryBlock(); + builder_->setInsertionPointToStart(&funcOp->getBody().front()); + + return new MLIRFunction(this, std::move(funcOp)); +} + +std::string MLIRModule::ToString() { + auto output_string = std::string{}; + auto output_stream = llvm::raw_string_ostream{output_string}; + module_->print(output_stream); + return output_string; +} + +mlir::Type MLIRModule::ParseType(std::string string) { + return mlir::parseType(string, context_); +} + +mlir::Attribute MLIRModule::ParseAttribute(std::string string) { + auto attribute = mlir::parseAttribute(string, context_); + + if (attribute == nullptr) { + std::cerr << "Unable to parse MLIR attribute: " << string << std::endl; + exit(1); + } + + return attribute; +} + +} // namespace exla diff --git a/exla/c_src/exla/exla_mlir.h b/exla/c_src/exla/exla_mlir.h new file mode 100644 index 0000000000..49073b87b1 --- /dev/null +++ b/exla/c_src/exla/exla_mlir.h @@ -0,0 +1,76 @@ +#ifndef EXLA_MLIR_BUILDER_H_ +#define EXLA_MLIR_BUILDER_H_ + +#include + +#include "exla_nif_util.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/reference/Types.h" +#include "xla/shape.h" +#include "xla/types.h" + +namespace exla { + +class MLIRModule; + +class MLIRFunction { + public: + MLIRFunction(MLIRModule *module, std::unique_ptr func); + + std::vector Op( + std::string op_name, + std::vector operands, + std::vector result_types, + std::vector> attributes, + std::vector regions); + + std::pair> PushRegion(std::vector types); + void PopRegion(); + + llvm::MutableArrayRef GetArguments() { return func_->getBody().front().getArguments(); } + + std::shared_ptr module() { return module_; } + + private: + std::shared_ptr module_; + std::unique_ptr func_; + + std::stack region_stack; + + void setInsertionPoint(); +}; + +class MLIRModule { + public: + MLIRModule(mlir::MLIRContext *context); + + MLIRFunction *CreateFunction( + std::string name, + std::vector arg_types, + std::vector ret_types, + bool is_public); + + std::string ToString(); + + // Note: returns nullptr if the parsing fails + mlir::Type ParseType(std::string); + mlir::Attribute ParseAttribute(std::string); + + mlir::ModuleOp module() { return module_.get(); } + mlir::OpBuilder *builder() { return builder_.get(); } + + private: + mlir::MLIRContext *context_; + mlir::OwningOpRef module_; + std::unique_ptr builder_; +}; + +} // namespace exla + +#endif diff --git a/exla/c_src/exla/exla_nif_util.cc b/exla/c_src/exla/exla_nif_util.cc index a65ed8cf75..563a4cd8eb 100644 --- a/exla/c_src/exla/exla_nif_util.cc +++ b/exla/c_src/exla/exla_nif_util.cc @@ -1,7 +1,10 @@ #include "exla_nif_util.h" +#include "mlir/IR/BuiltinTypes.h" #include "xla/primitive_util.h" #include "xla/shape_util.h" +#include "mlir/IR/Builders.h" +#include "stablehlo/dialect/StablehloOps.h" namespace exla { namespace nif { @@ -260,256 +263,110 @@ ERL_NIF_TERM make_map(ErlNifEnv* env, std::map& map) { return term; } -// Protobuf types - -int get_padding_config(ErlNifEnv* env, - ERL_NIF_TERM list, - xla::PaddingConfig* padding_config) { - ERL_NIF_TERM head, tail; - while (enif_get_list_cell(env, list, &head, &tail)) { - const ERL_NIF_TERM* terms; - int length; - if (!enif_get_tuple(env, head, &length, &terms)) return 0; - if (length != 3) return 0; - - int64 pad_lo, pad_hi, interior; - if (!get(env, terms[0], &pad_lo)) return 0; - if (!get(env, terms[1], &pad_hi)) return 0; - if (!get(env, terms[2], &interior)) return 0; +int get_primitive_type(ErlNifEnv* env, ERL_NIF_TERM term, xla::PrimitiveType* type) { + std::string type_str; + if (!get(env, term, type_str)) return 0; - auto dim = padding_config->add_dimensions(); - dim->set_edge_padding_low(pad_lo); - dim->set_edge_padding_high(pad_hi); - dim->set_interior_padding(interior); + xla::StatusOr type_status = + xla::primitive_util::StringToPrimitiveType(type_str); - list = tail; + if (!type_status.ok()) { + return 0; } + *type = type_status.value(); return 1; } -int get_dot_dimension_numbers(ErlNifEnv* env, - ERL_NIF_TERM tuple, - xla::DotDimensionNumbers* dims) { - const ERL_NIF_TERM* terms; - int count; - if (!enif_get_tuple(env, tuple, &count, &terms)) return 0; - if (count != 4) return 0; - - ERL_NIF_TERM lhs_contract, lhs_contract_tail; - ERL_NIF_TERM list = terms[0]; - while (enif_get_list_cell(env, list, &lhs_contract, &lhs_contract_tail)) { - int64 dim; - if (!get(env, lhs_contract, &dim)) return 0; - dims->add_lhs_contracting_dimensions(dim); - - list = lhs_contract_tail; - } +int get_typespec_as_xla_shape(ErlNifEnv* env, ERL_NIF_TERM term, xla::Shape* shape) { + int arity; + const ERL_NIF_TERM* tuple; - ERL_NIF_TERM lhs_batch, lhs_batch_tail; - list = terms[1]; - while (enif_get_list_cell(env, list, &lhs_batch, &lhs_batch_tail)) { - int64 dim; - if (!get(env, lhs_batch, &dim)) return 0; - dims->add_lhs_batch_dimensions(dim); + if (!enif_get_tuple(env, term, &arity, &tuple)) return 0; - list = lhs_batch_tail; - } + xla::PrimitiveType element_type; + std::vector dims; - ERL_NIF_TERM rhs_contract, rhs_contract_tail; - list = terms[2]; - while (enif_get_list_cell(env, list, &rhs_contract, &rhs_contract_tail)) { - int64 dim; - if (!get(env, rhs_contract, &dim)) return 0; - dims->add_rhs_contracting_dimensions(dim); + if (!get_primitive_type(env, tuple[0], &element_type)) return 0; + if (!get_tuple(env, tuple[1], dims)) return 0; - list = rhs_contract_tail; - } - - ERL_NIF_TERM rhs_batch, rhs_batch_tail; - list = terms[3]; - while (enif_get_list_cell(env, list, &rhs_batch, &rhs_batch_tail)) { - int64 dim; - if (!get(env, rhs_batch, &dim)) return 0; - dims->add_rhs_batch_dimensions(dim); - - list = rhs_batch_tail; - } + *shape = xla::ShapeUtil::MakeShape(element_type, dims); return 1; } -int get_precision_config(ErlNifEnv* env, - ERL_NIF_TERM config_term, - int num_operands, - xla::PrecisionConfig* config) { - int config_int; - if (!get(env, config_term, &config_int)) return 0; - - switch (config_int) { - case 0: - for (int i = 0; i < num_operands; i++) { - config->add_operand_precision(xla::PrecisionConfig::DEFAULT); - } - break; - case 1: - for (int i = 0; i < num_operands; i++) { - config->add_operand_precision(xla::PrecisionConfig::HIGH); - } - break; - case 2: - for (int i = 0; i < num_operands; i++) { - config->add_operand_precision(xla::PrecisionConfig::HIGHEST); - } - break; - default: - return 0; +int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var) { + unsigned int length; + if (!enif_get_list_length(env, list, &length)) { + return 0; } + var.reserve(length); + ERL_NIF_TERM head, tail; + while (enif_get_list_cell(env, list, &head, &tail)) { + xla::Shape elem; + if (!get_typespec_as_xla_shape(env, head, &elem)) { + return 0; + } + var.push_back(elem); + list = tail; + } return 1; } -int get_conv_dimension_numbers(ErlNifEnv* env, - ERL_NIF_TERM tuple, - xla::ConvolutionDimensionNumbers* dimension_numbers) { - const ERL_NIF_TERM* terms; - int count; - - if (!enif_get_tuple(env, tuple, &count, &terms)) return 0; - if (count != 3) return 0; - - const ERL_NIF_TERM* input_dims; - int input_count; - if (!enif_get_tuple(env, terms[0], &input_count, &input_dims)) return 0; - if (count < 3) return 0; - - int64 input_batch_dimension; - int64 input_feature_dimension; - if (!get(env, input_dims[0], &input_batch_dimension)) return 0; - if (!get(env, input_dims[1], &input_feature_dimension)) return 0; - - dimension_numbers->set_input_batch_dimension(input_batch_dimension); - dimension_numbers->set_input_feature_dimension(input_feature_dimension); - for (int i = 2; i < input_count; i++) { - int64 value; - if (!get(env, input_dims[i], &value)) return 0; - dimension_numbers->add_input_spatial_dimensions(value); +std::string mlir_numeric_type_to_string(mlir::Type type) { + if (type.isSignlessInteger(1)) { + return "pred"; } - - const ERL_NIF_TERM* kernel_dims; - int kernel_count; - if (!enif_get_tuple(env, terms[1], &kernel_count, &kernel_dims)) return 0; - if (kernel_count < 3) return 0; - - int64 kernel_input_feature_dimension; - int64 kernel_output_feature_dimension; - if (!get(env, kernel_dims[0], &kernel_input_feature_dimension)) return 0; - if (!get(env, kernel_dims[1], &kernel_output_feature_dimension)) return 0; - - dimension_numbers->set_kernel_output_feature_dimension(kernel_output_feature_dimension); - dimension_numbers->set_kernel_input_feature_dimension(kernel_input_feature_dimension); - for (int i = 2; i < kernel_count; i++) { - int64 value; - if (!get(env, kernel_dims[i], &value)) return 0; - dimension_numbers->add_kernel_spatial_dimensions(value); + if (auto integer_type = type.dyn_cast()) { + if (integer_type.isUnsigned()) { + return "u" + std::to_string(integer_type.getWidth()); + } else { + return "s" + std::to_string(integer_type.getWidth()); + } } - - const ERL_NIF_TERM* output_dims; - int output_count; - if (!enif_get_tuple(env, terms[2], &output_count, &output_dims)) return 0; - if (output_count < 3) return 0; - - int64 output_batch_dimension; - int64 output_feature_dimension; - if (!get(env, output_dims[0], &output_batch_dimension)) return 0; - if (!get(env, output_dims[1], &output_feature_dimension)) return 0; - - dimension_numbers->set_output_batch_dimension(output_batch_dimension); - dimension_numbers->set_output_feature_dimension(output_feature_dimension); - for (int i = 2; i < output_count; i++) { - int64 value; - if (!get(env, output_dims[i], &value)) return 0; - dimension_numbers->add_output_spatial_dimensions(value); + if (type.isBF16()) { + return "bf16"; } - - return 1; -} - -int get_general_padding(ErlNifEnv* env, - ERL_NIF_TERM padding_term, - std::vector>& padding) { - unsigned int length; - if (!enif_get_list_length(env, padding_term, &length)) return 0; - - padding.reserve(length); - ERL_NIF_TERM head, tail; - - while (enif_get_list_cell(env, padding_term, &head, &tail)) { - const ERL_NIF_TERM* terms; - int count; - - if (!enif_get_tuple(env, head, &count, &terms)) return 0; - if (count != 2) return 0; - - int64 lo, hi; - if (!get(env, terms[0], &lo)) return 0; - if (!get(env, terms[1], &hi)) return 0; - - padding.push_back(std::pair(lo, hi)); - - padding_term = tail; + if (auto float_type = type.dyn_cast()) { + return "f" + std::to_string(float_type.getWidth()); + } + if (auto complex_type = type.dyn_cast()) { + auto element_type = complex_type.getElementType(); + return "c" + std::to_string(element_type.cast().getWidth() * 2); } - return 1; + std::cerr << "Unexpected mlir type" << std::endl; + exit(1); } -int get_primitive_type(ErlNifEnv* env, ERL_NIF_TERM term, xla::PrimitiveType* type) { - std::string type_str; - if (!get(env, term, type_str)) return 0; - - xla::StatusOr type_status = - xla::primitive_util::StringToPrimitiveType(type_str); +ERL_NIF_TERM make_typespec(ErlNifEnv* env, mlir::Type type) { + if (type.isa()) { + auto type_term = make(env, "token"); + auto shape_term = enif_make_tuple(env, 0); - if (!type_status.ok()) { - return 0; + return enif_make_tuple(env, 2, type_term, shape_term); } - *type = type_status.value(); - return 1; -} + if (type.isa()) { + auto tensor_type = type.cast(); + auto dims = tensor_type.getShape(); + auto element_type = tensor_type.getElementType(); -ERL_NIF_TERM make_shape_info(ErlNifEnv* env, xla::Shape shape) { - if (shape.IsTuple()) { - std::cerr << "Unexpected tuple shape" << std::endl; - exit(1); - } else if (shape.IsArray()) { - xla::PrimitiveType type = shape.element_type(); - absl::Span dims = shape.dimensions(); - std::string name = xla::primitive_util::LowercasePrimitiveTypeName(type); + auto dims_array = std::vector{}; + dims_array.reserve(dims.size()); - std::vector dim_arr; - dim_arr.reserve(dims.size()); - for (int i = 0; i < dims.size(); i++) { - int copy; - copy = dims.at(i); - dim_arr.push_back(make(env, copy)); + for (auto dim : dims) { + dims_array.push_back(enif_make_int(env, dim)); } - ERL_NIF_TERM dims_term = enif_make_tuple_from_array(env, &dim_arr[0], dims.size()); - ERL_NIF_TERM type_term = make(env, name); - - return enif_make_tuple(env, 2, dims_term, type_term); - } else { - // Shape is probably a token or opaque type, with no dims - // calling `rank` fails a check in TF - xla::PrimitiveType type = shape.element_type(); - - std::string name = xla::primitive_util::LowercasePrimitiveTypeName(type); + auto type_term = make(env, mlir_numeric_type_to_string(element_type)); + auto shape_term = enif_make_tuple_from_array(env, dims_array.data(), dims_array.size()); - ERL_NIF_TERM empty_tuple = enif_make_tuple(env, 0); - ERL_NIF_TERM type_term = make(env, name); - - return enif_make_tuple(env, 2, empty_tuple, type_term); + return enif_make_tuple(env, 2, type_term, shape_term); } + + std::cerr << "Unexpected mlir type" << std::endl; + exit(1); } } // namespace nif diff --git a/exla/c_src/exla/exla_nif_util.h b/exla/c_src/exla/exla_nif_util.h index eb8878ecf8..5abf7e3cda 100644 --- a/exla/c_src/exla/exla_nif_util.h +++ b/exla/c_src/exla/exla_nif_util.h @@ -12,6 +12,7 @@ #include "xla/shape.h" #include "xla/types.h" #include "xla/xla_data.pb.h" +#include "mlir/IR/Builders.h" #if !defined(__GNUC__) && (defined(__WIN32__) || defined(_WIN32) || defined(_WIN32_)) typedef unsigned __int64 nif_uint64_t; @@ -80,7 +81,7 @@ ERL_NIF_TERM ok(ErlNifEnv* env); // Numeric types // // Floating/Complex types will never get used, except -// when defining scalar-constants with `constant_r0`. +// when defining scalar-constants with `constant`. int get(ErlNifEnv* env, ERL_NIF_TERM term, int8* var); int get(ErlNifEnv* env, ERL_NIF_TERM term, int16* var); @@ -244,6 +245,8 @@ int get_list(ErlNifEnv* env, int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var); +int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var); + template int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var) { unsigned int length; @@ -276,6 +279,32 @@ int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector& var) { return 1; } +template +int get_keyword_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector>& var) { + unsigned int length; + if (!enif_get_list_length(env, list, &length)) return 0; + var.reserve(length); + ERL_NIF_TERM head, tail; + + while (enif_get_list_cell(env, list, &head, &tail)) { + const ERL_NIF_TERM* terms; + int count; + + if (!enif_get_tuple(env, head, &count, &terms)) return 0; + if (count != 2) return 0; + + std::string lo; + T hi; + if (!get_atom(env, terms[0], lo)) return 0; + if (!get(env, terms[1], hi)) return 0; + + var.push_back(std::pair(lo, hi)); + + list = tail; + } + return 1; +} + int get_binary(ErlNifEnv* env, ERL_NIF_TERM term, ErlNifBinary* var); ERL_NIF_TERM make_map(ErlNifEnv* env, std::map& map); @@ -285,52 +314,13 @@ ERL_NIF_TERM make_map(ErlNifEnv* env, std::map& map); // See: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto // for more details on each type and additional types not listed here. -// Gets a padding configuration from `list`. A padding configuration -// is a list of 3-tuples representing edge high, edge low, and interior -// padding. -int get_padding_config(ErlNifEnv* env, - ERL_NIF_TERM list, - xla::PaddingConfig* padding_config); - -// Gets dimension numbers for usage in the XLA DotGeneral operation. -// Dot dimension numbers are a 2-tuple of lists. The first list -// represents the lhs contraction dimensions. The second list -// represents the rhs contraction dimensions. We do not match -// on the batch dimensions for now. -int get_dot_dimension_numbers(ErlNifEnv* env, - ERL_NIF_TERM tuple, - xla::DotDimensionNumbers* dims); - -// Gets a precision configuration from the configuration term. -// The term should be an integer `0`, `1`, or `2` corresponding -// to default, high, or highest precision respectively. Precision -// configuration is set for each term in an operation. -int get_precision_config(ErlNifEnv* env, - ERL_NIF_TERM config_term, - int num_operands, - xla::PrecisionConfig* precision_config); - -// Gets the convolution dimension numbers. Convolutions are determined -// based on input, kernel, and output batch and feature dimensions. -// We receive the dimension numbers as a 3-tuple of tuples. Each tuple -// corresponds to input batch/feature dimensions, kernel input/output -// feature dimensions, and output batch/feature dimensions respectively. -int get_conv_dimension_numbers(ErlNifEnv* env, - ERL_NIF_TERM tuple, - xla::ConvolutionDimensionNumbers* dimension_numbers); - -// Gets a general padding configuration. This is slightly different from -// get_padding_config for usage in a convolution. The convolution only -// supports passing padding as a vector of pairs of edge high, edge low padding -// values. We receive the padding configuration as a list of 2-tuples. -int get_general_padding(ErlNifEnv* env, - ERL_NIF_TERM padding_term, - std::vector>& padding); - // Gets the primitive type from the given term. The term is a string // encoding one of the XLA primitive types. int get_primitive_type(ErlNifEnv* env, ERL_NIF_TERM term, xla::PrimitiveType* type); +// Gets encoded EXLA.Typespec as xla::Shape. +int get_typespec_as_xla_shape(ErlNifEnv* env, ERL_NIF_TERM term, xla::Shape* shape); + // Template for retrieving a value from a scalar. This is // necessary to avoid having to use templates in the NIF. template < @@ -343,7 +333,7 @@ T get_value(ErlNifEnv* env, ERL_NIF_TERM term) { } // Extracts information from `GetShape` into a usable term. -ERL_NIF_TERM make_shape_info(ErlNifEnv* env, xla::Shape shape); +ERL_NIF_TERM make_typespec(ErlNifEnv* env, mlir::Type type); } // namespace nif } // namespace exla diff --git a/exla/c_src/exla/mlir/builder.cc b/exla/c_src/exla/mlir/builder.cc deleted file mode 100644 index 7567162da6..0000000000 --- a/exla/c_src/exla/mlir/builder.cc +++ /dev/null @@ -1,1466 +0,0 @@ -#include "builder.h" - -#include - -#include "../exla_nif_util.h" -#include "custom_calls.h" -#include "mhlo/IR/hlo_ops.h" -#include "mhlo/transforms/rewriters.h" -#include "mhlo/utils/type_conversion.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/OwningOpRef.h" -#include "mlir/IR/PatternMatch.h" -#include "stablehlo/dialect/Base.h" -#include "stablehlo/dialect/ChloOps.h" -#include "stablehlo/dialect/StablehloOps.h" -#include "xla/primitive_util.h" -#include "xla/service/custom_call_target_registry.h" -#include "xla/types.h" - -namespace exla { -mlir::Type TypeIntToMLIRType(mlir::OpBuilder *builder, xla::PrimitiveType type_int) { - // type_int comes from the xla::PrimitiveType enum - using xla::PrimitiveType; - switch (type_int) { - case PrimitiveType::S8: - return builder->getIntegerType(8); - case PrimitiveType::S16: - return builder->getIntegerType(16); - case PrimitiveType::S32: - return builder->getIntegerType(32); - case PrimitiveType::S64: - return builder->getIntegerType(64); - case PrimitiveType::PRED: - return builder->getIntegerType(1); - case PrimitiveType::U8: - return builder->getIntegerType(8, false); - case PrimitiveType::U16: - return builder->getIntegerType(16, false); - case PrimitiveType::U32: - return builder->getIntegerType(32, false); - case PrimitiveType::U64: - return builder->getIntegerType(64, false); - case PrimitiveType::F16: - return builder->getF16Type(); - case PrimitiveType::F32: - return builder->getF32Type(); - case PrimitiveType::F64: - return builder->getF64Type(); - case PrimitiveType::BF16: - return builder->getBF16Type(); - case PrimitiveType::C64: - return mlir::ComplexType::get(builder->getF32Type()); - case PrimitiveType::C128: - return mlir::ComplexType::get(builder->getF64Type()); - default: - std::cerr << "Unknown type: " << type_int << std::endl; - exit(1); - } -} - -mlir::TensorType -GetMLIRType(mlir::OpBuilder *builder, std::vector dims, xla::PrimitiveType type_int) { - if (type_int == xla::PrimitiveType::TUPLE) { - std::cerr << "Tuples are not supported yet" << std::endl; - exit(1); - } - auto type = TypeIntToMLIRType(builder, type_int); - return mlir::RankedTensorType::get(dims, type); -} - -mlir::Type GetMLIRFunctionType(mlir::OpBuilder *builder, xla::Shape *shape) { - if (shape->IsToken()) { - return mlir::stablehlo::TokenType::get(builder->getContext()); - } - if (shape->IsTuple()) { - // iterate through tuple types - std::vector element_types; - for (xla::Shape element : shape->tuple_shapes()) { - mlir::Type element_type; - if (element.IsTuple() or element.IsToken()) { - element_type = GetMLIRFunctionType(builder, &element); - } else { - auto span = element.dimensions(); - std::vector dims(span.begin(), span.end()); - element_type = GetMLIRType(builder, dims, element.element_type()); - } - element_types.push_back(element_type); - } - - mlir::TupleType tuple = mlir::TupleType::get(builder->getContext(), mlir::TypeRange(element_types)); - return tuple; - } - - auto span = shape->dimensions(); - std::vector dims(span.begin(), span.end()); - return GetMLIRType(builder, dims, shape->element_type()); -} - -mlir::stablehlo::DotDimensionNumbersAttr ConvertDotDimensionNumbersToAttr(mlir::OpBuilder *builder, const xla::DotDimensionNumbers &dotDimNumbers) { - std::vector lhsContractingVec(dotDimNumbers.lhs_contracting_dimensions().begin(), - dotDimNumbers.lhs_contracting_dimensions().end()); - std::vector rhsContractingVec(dotDimNumbers.rhs_contracting_dimensions().begin(), - dotDimNumbers.rhs_contracting_dimensions().end()); - std::vector lhsBatchVec(dotDimNumbers.lhs_batch_dimensions().begin(), - dotDimNumbers.lhs_batch_dimensions().end()); - std::vector rhsBatchVec(dotDimNumbers.rhs_batch_dimensions().begin(), - dotDimNumbers.rhs_batch_dimensions().end()); - - return mlir::stablehlo::DotDimensionNumbersAttr::get( - builder->getContext(), - lhsBatchVec, - rhsBatchVec, - lhsContractingVec, - rhsContractingVec); -} - -void MLIRFunction::dump_mlir_module() { - module_->module().dump(); -} - -int MLIRFunction::get_mlir_type(ErlNifEnv *env, ERL_NIF_TERM term, mlir::Type *type) { - auto builder = module_->builder(); - std::string type_str; - if (!exla::nif::get(env, term, type_str)) return 1; - - if (type_str == "pred") { - *type = builder->getIntegerType(1); - return 0; - } - if (type_str == "u8") { - *type = builder->getIntegerType(8, false); - return 0; - } - if (type_str == "u16") { - *type = builder->getIntegerType(16, false); - return 0; - } - if (type_str == "u32") { - *type = builder->getIntegerType(32, false); - return 0; - } - if (type_str == "u64") { - *type = builder->getIntegerType(64, false); - return 0; - } - if (type_str == "s8") { - *type = builder->getIntegerType(8); - return 0; - } - if (type_str == "s16") { - *type = builder->getIntegerType(16); - return 0; - } - if (type_str == "s32") { - *type = builder->getIntegerType(32); - return 0; - } - if (type_str == "s64") { - *type = builder->getIntegerType(64); - return 0; - } - if (type_str == "f16") { - *type = builder->getF16Type(); - return 0; - } - if (type_str == "f32") { - *type = builder->getF32Type(); - return 0; - } - if (type_str == "f64") { - *type = builder->getF64Type(); - return 0; - } - if (type_str == "bf16") { - *type = builder->getBF16Type(); - return 0; - } - if (type_str == "c64") { - *type = mlir::ComplexType::get(builder->getF32Type()); - return 0; - } - if (type_str == "c128") { - *type = mlir::ComplexType::get(builder->getF64Type()); - return 0; - } - - return 1; -} - -mlir::DenseIntElementsAttr Int64ToDenseIntElementsAttr(mlir::OpBuilder *builder, std::vector vec) { - int64_t num_entries[] = {static_cast(vec.size())}; - auto type = mlir::RankedTensorType::get(llvm::ArrayRef(num_entries, 1), builder->getIntegerType(64)); - auto dense_attr = mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec.data(), vec.size())); - return llvm::cast(dense_attr); -} - -mlir::DenseIntElementsAttr Int64ToDenseIntElementsAttr(mlir::OpBuilder *builder, std::vector> vec_in) { - std::vector vec; - int64_t num_pairs = vec_in.size(); - vec.reserve(num_pairs * 2); - for (auto pair : vec_in) { - vec.push_back(pair.first); - vec.push_back(pair.second); - } - - int64_t num_entries[] = {num_pairs, 2}; - auto type = mlir::RankedTensorType::get(llvm::ArrayRef(num_entries, 2), builder->getIntegerType(64)); - auto dense_attr = mlir::DenseElementsAttr::get(type, llvm::ArrayRef(vec.data(), vec.size())); - return llvm::cast(dense_attr); -} - -MLIRFunction::MLIRFunction(MLIRModule *module, std::unique_ptr func) - : func_(std::move(func)), - module_(module) {} - -mlir::Value MLIRFunction::SubtractOp(mlir::Value lhs, mlir::Value rhs) { - setInsertionPoint(); - auto op = module_->builder()->create(module_->builder()->getUnknownLoc(), lhs, rhs); - return op; -} - -mlir::Value MLIRFunction::ConvertOp(mlir::Value operand, mlir::Type type) { - auto builder = module_->builder(); - setInsertionPoint(); - - if (operand.getType().isa() && !type.isa()) { - // get the real part of the operand in case we're downcasting from complex to something else - operand = builder->create(builder->getUnknownLoc(), operand); - } - - auto op = builder->create(builder->getUnknownLoc(), operand, type); - return op; -} - -mlir::Value MLIRFunction::BitcastConvertOp(mlir::Value operand, xla::Shape shape) { - auto builder = module_->builder(); - setInsertionPoint(); - - absl::Span - dimensions_span = shape.dimensions(); - std::vector dimensions(dimensions_span.begin(), dimensions_span.end()); - - mlir::Type type = GetMLIRFunctionType(module_->builder(), &shape); - - auto op = builder->create(builder->getUnknownLoc(), type, operand); - return op; -} - -mlir::Value MLIRFunction::AddOp(mlir::Value lhs, mlir::Value rhs) { - setInsertionPoint(); - auto op = module_->builder()->create(module_->builder()->getUnknownLoc(), lhs, rhs); - return op; -} - -mlir::Value MLIRFunction::MulOp(mlir::Value lhs, mlir::Value rhs) { - setInsertionPoint(); - auto op = module_->builder()->create(module_->builder()->getUnknownLoc(), lhs, rhs); - return op; -} - -mlir::Value MLIRFunction::MinOp(mlir::Value lhs, mlir::Value rhs) { - setInsertionPoint(); - auto op = module_->builder()->create(module_->builder()->getUnknownLoc(), lhs, rhs); - return op; -} - -mlir::Value MLIRFunction::MaxOp(mlir::Value lhs, mlir::Value rhs) { - setInsertionPoint(); - auto op = module_->builder()->create(module_->builder()->getUnknownLoc(), lhs, rhs); - return op; -} - -mlir::Value MLIRFunction::RemOp(mlir::Value lhs, mlir::Value rhs) { - setInsertionPoint(); - auto op = module_->builder()->create(module_->builder()->getUnknownLoc(), lhs, rhs); - return op; -} - -mlir::Value MLIRFunction::PowOp(mlir::Value lhs, mlir::Value rhs) { - setInsertionPoint(); - auto op = module_->builder()->create(module_->builder()->getUnknownLoc(), lhs, rhs); - return op; -} - -mlir::Value MLIRFunction::DivOp(mlir::Value lhs, mlir::Value rhs) { - setInsertionPoint(); - auto op = module_->builder()->create(module_->builder()->getUnknownLoc(), lhs, rhs); - return op; -} - -mlir::Value MLIRFunction::Atan2Op(mlir::Value lhs, mlir::Value rhs) { - setInsertionPoint(); - auto op = module_->builder()->create(module_->builder()->getUnknownLoc(), lhs, rhs); - return op; -} - -mlir::Value MLIRFunction::PadOp(mlir::Value op, mlir::Value pad, std::vector padding_low, std::vector padding_high, std::vector padding_mid) { - setInsertionPoint(); - - auto padding_low_attr = Int64ToDenseIntElementsAttr(module_->builder(), padding_low); - auto padding_high_attr = Int64ToDenseIntElementsAttr(module_->builder(), padding_high); - auto padding_mid_attr = Int64ToDenseIntElementsAttr(module_->builder(), padding_mid); - - return module_->builder()->create(module_->builder()->getUnknownLoc(), op, pad, padding_low_attr, padding_high_attr, padding_mid_attr); -} - -mlir::Value compare_and_return_bool(mlir::OpBuilder *builder, mlir::Value lhs, mlir::Value rhs, mlir::stablehlo::ComparisonDirection direction) { - mlir::stablehlo::ComparisonType comparison_type; - mlir::RankedTensorType ranked_type = llvm::cast(lhs.getType()); - mlir::Type left_type = mlir::RankedTensorType::get({}, ranked_type.getElementType()); - - ranked_type = llvm::cast(rhs.getType()); - mlir::Type right_type = mlir::RankedTensorType::get({}, ranked_type.getElementType()); - if (left_type.isa() || right_type.isa()) { - comparison_type = mlir::stablehlo::symbolizeComparisonType("TOTALORDER").value(); - } else { - comparison_type = mlir::stablehlo::ComparisonType::NOTYPE; - } - - auto direction_attr = mlir::stablehlo::ComparisonDirectionAttr::get(builder->getContext(), direction); - auto comparison_type_attr = mlir::stablehlo::ComparisonTypeAttr::get(builder->getContext(), comparison_type); - mlir::Type mlir_bool = builder->getIntegerType(1); - auto shape = llvm::cast(lhs.getType()).getShape(); - - mlir::Type out_type = mlir::RankedTensorType::get(shape, builder->getIntegerType(1)); - auto op = builder->create(builder->getUnknownLoc(), out_type, lhs, rhs, direction_attr, comparison_type_attr); - return op; -} - -mlir::Value MLIRFunction::EqualOp(mlir::Value lhs, mlir::Value rhs) { - setInsertionPoint(); - return compare_and_return_bool(module_->builder(), lhs, rhs, mlir::stablehlo::ComparisonDirection::EQ); -} - -mlir::Value MLIRFunction::NotEqualOp(mlir::Value lhs, mlir::Value rhs) { - setInsertionPoint(); - return compare_and_return_bool(module_->builder(), lhs, rhs, mlir::stablehlo::ComparisonDirection::NE); -} - -mlir::Value MLIRFunction::LessOp(mlir::Value lhs, mlir::Value rhs) { - setInsertionPoint(); - return compare_and_return_bool(module_->builder(), lhs, rhs, mlir::stablehlo::ComparisonDirection::LT); -} - -mlir::Value MLIRFunction::LessEqualOp(mlir::Value lhs, mlir::Value rhs) { - setInsertionPoint(); - return compare_and_return_bool(module_->builder(), lhs, rhs, mlir::stablehlo::ComparisonDirection::LE); -} - -mlir::Value MLIRFunction::GreaterOp(mlir::Value lhs, mlir::Value rhs) { - setInsertionPoint(); - return compare_and_return_bool(module_->builder(), lhs, rhs, mlir::stablehlo::ComparisonDirection::GT); -} - -mlir::Value MLIRFunction::GreaterEqualOp(mlir::Value lhs, mlir::Value rhs) { - setInsertionPoint(); - return compare_and_return_bool(module_->builder(), lhs, rhs, mlir::stablehlo::ComparisonDirection::GE); -} - -mlir::Value MLIRFunction::ShiftLeftOp(mlir::Value lhs, mlir::Value rhs) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), lhs, rhs); -} - -mlir::Value MLIRFunction::ShiftRightLogicalOp(mlir::Value lhs, mlir::Value rhs) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), lhs, rhs); -} - -mlir::Value MLIRFunction::ShiftRightArithmeticOp(mlir::Value lhs, mlir::Value rhs) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), lhs, rhs); -} - -mlir::Value MLIRFunction::BitwiseAndOp(mlir::Value lhs, mlir::Value rhs) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), lhs, rhs); -} - -mlir::Value MLIRFunction::BitwiseOrOp(mlir::Value lhs, mlir::Value rhs) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), lhs, rhs); -} - -mlir::Value MLIRFunction::BitwiseNotOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::BitwiseXorOp(mlir::Value lhs, mlir::Value rhs) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), lhs, rhs); -} - -mlir::Value MLIRFunction::AbsOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::ExpOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::Expm1Op(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::FloorOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::CeilOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::RoundOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::LogOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::LogisticOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::Log1pOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::SignOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::CosOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::SinOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::TanOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::AcosOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::AsinOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::AtanOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::CoshOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::SinhOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::TanhOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::AcoshOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::AsinhOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::AtanhOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::SqrtOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::CbrtOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} -mlir::Value MLIRFunction::NegateOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} -mlir::Value MLIRFunction::ErfOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} -mlir::Value MLIRFunction::ErfInvOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} -mlir::Value MLIRFunction::ErfcOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::IsInfOp(mlir::Value operand) { - setInsertionPoint(); - mlir::Value result; - - mlir::RankedTensorType type = llvm::cast(operand.getType()); - mlir::Type element_type = type.getElementType(); - - if (element_type.isa()) { - auto real_op = module_->builder()->create(module_->builder()->getUnknownLoc(), operand); - auto imag_op = module_->builder()->create(module_->builder()->getUnknownLoc(), operand); - - auto is_inf_real_op = this->ConvertOp(this->IsInfOp(real_op), element_type); - auto is_inf_imag_op = this->ConvertOp(this->IsInfOp(imag_op), element_type); - result = this->AddOp(is_inf_real_op, is_inf_imag_op); - } else if (element_type.isa()) { - // integers are never infinity - return this->NotEqualOp(operand, operand); - } else { - result = module_->builder()->create(module_->builder()->getUnknownLoc(), operand); - } - mlir::Type mlir_bool = module_->builder()->getIntegerType(1); - return module_->builder()->create(module_->builder()->getUnknownLoc(), result, mlir_bool); -} - -mlir::Value MLIRFunction::IsNanOp(mlir::Value operand) { - setInsertionPoint(); - mlir::Type mlir_bool = module_->builder()->getI1Type(); - - mlir::RankedTensorType type = llvm::cast(operand.getType()); - mlir::Type element_type = type.getElementType(); - - mlir::Value result; - - if (element_type.isa()) { - auto real_op = module_->builder()->create(module_->builder()->getUnknownLoc(), operand); - auto imag_op = module_->builder()->create(module_->builder()->getUnknownLoc(), operand); - - auto is_inf_real_op = this->ConvertOp(this->IsNanOp(real_op), element_type); - auto is_inf_imag_op = this->ConvertOp(this->IsNanOp(imag_op), element_type); - result = this->AddOp(is_inf_real_op, is_inf_imag_op); - return module_->builder()->create(module_->builder()->getUnknownLoc(), result, mlir_bool); - } else if (element_type.isa()) { - // integers are never nan - return this->NotEqualOp(operand, operand); - } else { - mlir::Value is_finite_op = module_->builder()->create(module_->builder()->getUnknownLoc(), operand); - is_finite_op = module_->builder()->create(module_->builder()->getUnknownLoc(), is_finite_op, mlir_bool); - - mlir::Value is_inf_op = this->IsInfOp(operand); - is_inf_op = module_->builder()->create(module_->builder()->getUnknownLoc(), is_inf_op, mlir_bool); - - return this->BitwiseAndOp(this->BitwiseNotOp(is_inf_op), this->BitwiseNotOp(is_finite_op)); - } -} -mlir::Value MLIRFunction::RsqrtOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::RealOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::ImagOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::ConjOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::TransposeOp(mlir::Value operand, std::vector axes) { - setInsertionPoint(); - auto axes_attr = Int64ToDenseIntElementsAttr(module_->builder(), axes); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand, axes_attr); -} - -mlir::Value MLIRFunction::ReshapeOp(mlir::Value operand, std::vector target_shape) { - setInsertionPoint(); - mlir::RankedTensorType t_in = llvm::cast(operand.getType()); - mlir::RankedTensorType type = mlir::RankedTensorType::get(target_shape, t_in.getElementType()); - return module_->builder()->create(module_->builder()->getUnknownLoc(), type, operand); -} - -mlir::Value MLIRFunction::ReverseOp(mlir::Value operand, std::vector dims) { - setInsertionPoint(); - auto dims_attr = Int64ToDenseIntElementsAttr(module_->builder(), dims); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand, dims_attr); -} - -class PublicPatternRewriter : public mlir::PatternRewriter { - public: - PublicPatternRewriter(mlir::MLIRContext *context) : mlir::PatternRewriter(context) {} -}; - -static void buildSortComparisonBody(llvm::ArrayRef elementTypes, - mlir::stablehlo::ComparisonDirection direction, - std::optional compare_type, - mlir::Region *body, mlir::OpBuilder *builder) { - mlir::OpBuilder::InsertionGuard insertionPointGuard(*builder); - mlir::Location loc = body->getLoc(); - mlir::Block *block = builder->createBlock(body); - // Add two arguments for each element type. - for (mlir::Type elementType : elementTypes) { - // mlir::ShapedType shapedType = mlir::RankedTensorType::get({}, elementType); - block->addArguments({elementType, elementType}, {loc, loc}); - } - mlir::stablehlo::ComparisonType type_attr; - if (compare_type) { - type_attr = mlir::stablehlo::symbolizeComparisonType(*compare_type).value(); - } else { - type_attr = mlir::stablehlo::ComparisonType::NOTYPE; - } - mlir::BlockArgument arg0 = block->getArgument(0); - mlir::BlockArgument arg1 = block->getArgument(1); - mlir::Value compare = builder->create(loc, arg0, arg1, direction); - builder->create(loc, compare); -} - -std::vector MLIRFunction::TopKOp(mlir::Value operand, int64_t k) { - auto builder = module_->builder(); - setInsertionPoint(); - - mlir::chlo::TopKOp top_k_op = builder->create(builder->getUnknownLoc(), operand, k); - mlir::Operation::result_range results = top_k_op.getResults(); - - auto results_vec = std::vector(results.begin(), results.end()); - - mlir::Value idx = builder->create(builder->getUnknownLoc(), results_vec[1], builder->getI64Type()); - results_vec[1] = idx; - return results_vec; -} - -std::vector MLIRFunction::SortOp(MLIRFunction *comparator, std::vector operands, int64_t dim, bool stable) { - auto builder = module_->builder(); - setInsertionPoint(); - mlir::ValueRange value_range(operands); - mlir::stablehlo::SortOp sort_op = builder->create( - builder->getUnknownLoc(), - value_range, - dim, - stable); - - mlir::Region &compareBody = sort_op.getComparator(); - mlir::Region &comparatorBody = comparator->function()->getBody(); - compareBody.getBlocks().splice(compareBody.end(), comparatorBody.getBlocks()); - comparator->function()->erase(); - - mlir::Operation::result_range results = sort_op.getResults(); - return std::vector(results.begin(), results.end()); -} - -mlir::Value MLIRFunction::SliceOp(mlir::Value operand, std::vector starts, std::vector limits, std::vector strides) { - setInsertionPoint(); - auto idx_attr = Int64ToDenseIntElementsAttr(module_->builder(), starts); - auto lim_attr = Int64ToDenseIntElementsAttr(module_->builder(), limits); - auto strides_attr = Int64ToDenseIntElementsAttr(module_->builder(), strides); - - return module_->builder()->create( - module_->builder()->getUnknownLoc(), - operand, - idx_attr, - lim_attr, - strides_attr); -} - -mlir::Value MLIRFunction::DynamicSliceOp(mlir::Value operand, std::vector starts, std::vector lengths) { - setInsertionPoint(); - auto len_attr = Int64ToDenseIntElementsAttr(module_->builder(), lengths); - mlir::ValueRange starts_range(llvm::ArrayRef(starts.data(), starts.size())); - - return module_->builder() - ->create( - module_->builder()->getUnknownLoc(), - operand, - starts_range, - len_attr); -} - -mlir::Value MLIRFunction::ClzOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::PopulationCountOp(mlir::Value operand) { - setInsertionPoint(); - return module_->builder()->create(module_->builder()->getUnknownLoc(), operand); -} - -mlir::Value MLIRFunction::TupleOp(std::vector vals) { - setInsertionPoint(); - auto op = module_->builder()->create(module_->builder()->getUnknownLoc(), vals); - return op; -} - -mlir::Value MLIRFunction::GetTupleElementOp(mlir::Value tuple, tsl::int64 index) { - setInsertionPoint(); - auto op = module_->builder()->create(module_->builder()->getUnknownLoc(), tuple, index); - return op; -} - -mlir::Value MLIRFunction::IotaOp(xla::Shape shape, int64_t dimension) { - setInsertionPoint(); - - absl::Span dimensions_span = shape.dimensions(); - std::vector dimensions(dimensions_span.begin(), dimensions_span.end()); - - mlir::Type type = GetMLIRFunctionType(module_->builder(), &shape); - - return module_->builder()->create(module_->builder()->getUnknownLoc(), type, dimension); -} - -mlir::Value MLIRFunction::DotGeneralOp( - xla::Shape output_shape, - mlir::Value lhs, - mlir::Value rhs, - xla::DotDimensionNumbers dnums, - xla::PrecisionConfig config) { - setInsertionPoint(); - - absl::Span dimensions_span = output_shape.dimensions(); - std::vector dimensions(dimensions_span.begin(), dimensions_span.end()); - - mlir::Type output_type = GetMLIRFunctionType(module_->builder(), &output_shape); - auto mlir_dnums = ConvertDotDimensionNumbersToAttr(module_->builder(), dnums); - - auto op = module_->builder()->create( - module_->builder()->getUnknownLoc(), - output_type, - lhs, - rhs, - mlir_dnums, - nullptr); - - return op; -} - -mlir::Value MLIRFunction::BroadcastInDimOp(mlir::Value operand, xla::Shape shape, std::vector axes) { - setInsertionPoint(); - - absl::Span dimensions_span = shape.dimensions(); - std::vector dimensions(dimensions_span.begin(), dimensions_span.end()); - mlir::Type result_type = GetMLIRFunctionType(module_->builder(), &shape); - - auto axes_attr = Int64ToDenseIntElementsAttr(module_->builder(), axes); - - auto op = module_->builder()->create(module_->builder()->getUnknownLoc(), result_type, operand, axes_attr); - return op; -} - -mlir::Value MLIRFunction::ConcatenateOp(std::vector operands, int64_t dimension) { - setInsertionPoint(); - mlir::ValueRange operands_range(llvm::ArrayRef(operands.data(), operands.size())); - auto op = module_->builder()->create(module_->builder()->getUnknownLoc(), operands_range, dimension); - return op; -} - -mlir::Value MLIRFunction::OptimizationBarrierOp(mlir::Value operand) { - setInsertionPoint(); - auto op = module_->builder()->create(module_->builder()->getUnknownLoc(), operand); - return op.getResult()[0]; -} - -mlir::Value MLIRFunction::ClampOp(mlir::Value min, mlir::Value operand, mlir::Value max) { - setInsertionPoint(); - auto op = module_->builder()->create(module_->builder()->getUnknownLoc(), min, operand, max); - return op; -} - -mlir::Value MLIRFunction::SelectOp(mlir::Value pred, mlir::Value on_true, mlir::Value on_false) { - setInsertionPoint(); - auto op = module_->builder()->create(module_->builder()->getUnknownLoc(), pred, on_true, on_false); - return op; -} - -static void buildScatterComputation(mlir::Type element_type, bool add_or_put, mlir::Region *body, mlir::OpBuilder *builder) { - mlir::OpBuilder::InsertionGuard insertionPointGuard(*builder); - mlir::Location loc = body->getLoc(); - mlir::Block *block = builder->createBlock(body); - // Add two arguments for each element type. - block->addArguments({element_type, element_type}, {loc, loc}); - - if (add_or_put) { - mlir::BlockArgument arg0 = block->getArgument(0); - mlir::BlockArgument arg1 = block->getArgument(1); - mlir::Value add = builder->create(loc, arg0, arg1); - builder->create(loc, add); - } else { - mlir::BlockArgument arg1 = block->getArgument(1); - builder->create(loc, arg1); - } -} - -mlir::Value MLIRFunction::ScatterOp(mlir::Value target, mlir::Value indices, mlir::Value updates, bool add_or_put, int64_t indices_rank, std::vector update_window_dims, std::vector inserted_window_dims, std::vector index_dims_to_window_dims) { - auto builder = module_->builder(); - setInsertionPoint(); - mlir::RankedTensorType type = llvm::cast(target.getType()); - auto scatter_dimension_numbers = mlir::stablehlo::ScatterDimensionNumbersAttr::get(builder->getContext(), update_window_dims, inserted_window_dims, index_dims_to_window_dims, indices_rank); - - mlir::stablehlo::ScatterOp scatter_op = builder->create(builder->getUnknownLoc(), target, indices, updates, scatter_dimension_numbers); - mlir::Type computation_operand_type = mlir::RankedTensorType::get({}, type.getElementType()); - buildScatterComputation(computation_operand_type, add_or_put, &scatter_op.getUpdateComputation(), builder); - return scatter_op.getResult(0); -} - -std::vector MLIRFunction::WindowReduceOp( - MLIRFunction *reducer, - std::vector init_values, - std::vector inputs, - std::vector window_dimensions, - std::vector window_strides, - std::vector input_dilations, - std::vector window_dilations, - std::vector> padding) { - auto builder = module_->builder(); - setInsertionPoint(); - - mlir::ValueRange init_values_range(init_values); - mlir::ValueRange inputs_range(inputs); - mlir::DenseIntElementsAttr window_dimensions_attr = Int64ToDenseIntElementsAttr(builder, window_dimensions); - mlir::DenseIntElementsAttr window_strides_attr = Int64ToDenseIntElementsAttr(builder, window_strides); - mlir::DenseIntElementsAttr input_dilations_attr = Int64ToDenseIntElementsAttr(builder, input_dilations); - mlir::DenseIntElementsAttr window_dilations_attr = Int64ToDenseIntElementsAttr(builder, window_dilations); - mlir::DenseIntElementsAttr padding_attr = Int64ToDenseIntElementsAttr(builder, padding); - - mlir::stablehlo::ReduceWindowOp reduce_window_op = builder->create( - builder->getUnknownLoc(), - inputs_range, - init_values_range, - window_dimensions_attr, - window_strides_attr, - input_dilations_attr, - window_dilations_attr, - padding_attr); - - mlir::Region &reduceBody = reduce_window_op.getRegion(); - mlir::Region &funcBody = reducer->function()->getBody(); - reduceBody.getBlocks().splice(reduceBody.end(), funcBody.getBlocks()); - reducer->function()->erase(); - - mlir::Operation::result_range results = reduce_window_op.getResults(); - return std::vector(results.begin(), results.end()); -} - -std::vector MLIRFunction::ReduceOp( - MLIRFunction *reducer, - std::vector init_values, - std::vector inputs, - std::vector dimensions) { - auto builder = module_->builder(); - setInsertionPoint(); - - mlir::ValueRange init_values_range(init_values); - mlir::ValueRange inputs_range(inputs); - mlir::DenseIntElementsAttr dimensions_attr = Int64ToDenseIntElementsAttr(builder, dimensions); - - mlir::stablehlo::ReduceOp reduce_op = builder->create(builder->getUnknownLoc(), inputs_range, init_values_range, dimensions_attr); - mlir::Region &reduceBody = reduce_op.getRegion(); - mlir::Region &funcBody = reducer->function()->getBody(); - reduceBody.getBlocks().splice(reduceBody.end(), funcBody.getBlocks()); - reducer->function()->erase(); - - mlir::Operation::result_range results = reduce_op.getResults(); - return std::vector(results.begin(), results.end()); -} - -mlir::Value MLIRFunction::MapOp( - MLIRFunction *mapper, - std::vector inputs, - std::vector dimensions) { - auto builder = module_->builder(); - setInsertionPoint(); - - mlir::ValueRange inputs_range(inputs); - mlir::DenseIntElementsAttr dimensions_attr = Int64ToDenseIntElementsAttr(builder, dimensions); - - mlir::stablehlo::MapOp map_op = builder->create(builder->getUnknownLoc(), inputs[0].getType(), inputs_range, dimensions_attr); - - mlir::Region &mapBody = map_op.getComputation(); - mlir::Region &funcBody = mapper->function()->getBody(); - mapBody.getBlocks().splice(mapBody.end(), funcBody.getBlocks()); - mapper->function()->erase(); - - return map_op; -} - -std::pair, std::pair> MLIRFunction::IfOp(mlir::Value pred, std::vector output_shapes) { - auto builder = module_->builder(); - setInsertionPoint(); - - std::vector - output_types; - output_types.reserve(output_shapes.size()); - - for (auto shape : output_shapes) { - auto type = GetMLIRFunctionType(builder, &shape); - output_types.push_back(type); - } - - mlir::Type pred_type = llvm::cast(pred.getType()).getElementType(); - if (!pred_type.isInteger(1)) { - pred = builder->create(builder->getUnknownLoc(), pred, builder->getIntegerType(1)); - } - - mlir::stablehlo::IfOp if_op = builder->create(builder->getUnknownLoc(), mlir::TypeRange(output_types), pred); - - mlir::Operation::result_range result_range = if_op.getResults(); - std::vector results(result_range.begin(), result_range.end()); - - mlir::Region *true_region = &if_op.getTrueBranch(); - true_region->emplaceBlock(); - mlir::Region *false_region = &if_op.getFalseBranch(); - false_region->emplaceBlock(); - - return std::make_pair(results, std::make_pair(true_region, false_region)); -} - -std::vector MLIRFunction::PushRegion(mlir::Region *region) { - std::vector args; - mlir::Block &block = region->front(); - for (auto &arg : block.getArguments()) { - args.push_back(arg); - } - - regions.push(std::move(region)); - setInsertionPoint(); - - return args; -} - -void MLIRFunction::PopRegion() { - regions.pop(); - setInsertionPoint(); -} - -mlir::Value MLIRFunction::SelectAndScatterOp( - mlir::Value target, - mlir::Value source, - mlir::Value init_value, - bool gt_or_lt, - std::vector window_dimensions, - std::vector window_strides, - std::vector padding) { - auto builder = module_->builder(); - setInsertionPoint(); - mlir::RankedTensorType type = llvm::cast(target.getType()); - int64_t rank = type.getShape().size(); - std::vector axes(rank); - for (int64_t i = 0; i < rank; i++) { - axes[i] = i; - } - auto scatter_dimension_numbers = mlir::stablehlo::ScatterDimensionNumbersAttr::get(builder->getContext(), {}, axes, axes, rank); - - mlir::DenseIntElementsAttr window_dimensions_attr = Int64ToDenseIntElementsAttr(module_->builder(), window_dimensions); - mlir::DenseIntElementsAttr window_strides_attr = Int64ToDenseIntElementsAttr(module_->builder(), window_strides); - - auto dense_attr_type = mlir::RankedTensorType::get({static_cast(padding.size() / 2), 2}, builder->getIntegerType(64)); - auto dense_attr = mlir::DenseElementsAttr::get(dense_attr_type, llvm::ArrayRef(padding.data(), padding.size())); - auto padding_attr = llvm::cast(dense_attr); - - mlir::stablehlo::SelectAndScatterOp op = builder->create( - builder->getUnknownLoc(), - target, - source, - init_value, - window_dimensions_attr, - window_strides_attr, - padding_attr); - - mlir::Type computation_operand_type = mlir::RankedTensorType::get({}, type.getElementType()); - buildScatterComputation(computation_operand_type, true, &op.getScatter(), builder); - - mlir::stablehlo::ComparisonDirection direction = gt_or_lt ? mlir::stablehlo::ComparisonDirection::GT : mlir::stablehlo::ComparisonDirection::LT; - std::optional compare_type = std::nullopt; - if (type.isa()) { - compare_type.emplace("TOTALORDER"); - } - - buildSortComparisonBody({computation_operand_type}, direction, compare_type, &op.getSelect(), builder); - return op.getResult(); -} - -mlir::Value MLIRFunction::GatherOp(mlir::Value source, mlir::Value indices, std::vector offset_dims, std::vector collapsed_slice_dims, std::vector start_index_map, std::vector slice_sizes, int64_t index_vector_dim) { - auto builder = module_->builder(); - setInsertionPoint(); - auto gather_dimension_numbers = mlir::stablehlo::GatherDimensionNumbersAttr::get(builder->getContext(), offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim); - auto slice_sizes_attr = Int64ToDenseIntElementsAttr(module_->builder(), slice_sizes); - return builder->create(builder->getUnknownLoc(), source, indices, gather_dimension_numbers, slice_sizes_attr, false); -} - -mlir::Value MLIRFunction::FFTOp(mlir::Value tensor, bool forward_fft, std::vector fft_length) { - auto builder = module_->builder(); - setInsertionPoint(); - - auto fft_type = mlir::stablehlo::FftTypeAttr::get(builder->getContext(), forward_fft ? mlir::stablehlo::FftType::FFT : mlir::stablehlo::FftType::IFFT); - return builder->create(builder->getUnknownLoc(), tensor, fft_type, Int64ToDenseIntElementsAttr(builder, fft_length)); -} - -template -ERL_NIF_TERM ConstantOpImpl(mlir::OpBuilder *builder, mlir::Type type, ErlNifEnv *env, ERL_NIF_TERM term, std::optional> dims_opt) { - bool scalar = !dims_opt; - std::vector dims = scalar ? std::vector(0) : dims_opt.value(); - - mlir::RankedTensorType ty = mlir::RankedTensorType::get(dims, type); - mlir::DenseElementsAttr attr; - - if (scalar) { - T value; - if (!exla::nif::get(env, term, &value)) { - return exla::nif::error(env, "Unable to cast scalar to type."); - } - attr = mlir::DenseElementsAttr::get(ty, value); - } else { - // non-scalar case. we'll assume our data - // is in the form of a raw buffer - ErlNifBinary binary; - if (!exla::nif::get_binary(env, term, &binary)) { - return exla::nif::error(env, "Unable to get binary data."); - } - char *data = const_cast(reinterpret_cast(binary.data)); - llvm::ArrayRef values(data, binary.size); - - attr = mlir::DenseElementsAttr::getFromRawBuffer(ty, values); - } - - // We set a fixed scalar shape because we're using single values here. - mlir::Value op = builder->create(builder->getUnknownLoc(), attr); - return exla::nif::ok(env, exla::nif::make(env, op)); -} - -ERL_NIF_TERM MLIRFunction::ConstantOp(mlir::Type type, ErlNifEnv *env, ERL_NIF_TERM term, std::optional> dims) { - auto builder = module_->builder(); - setInsertionPoint(); - - if (type.isSignlessInteger(1)) { - return ConstantOpImpl(module_->builder(), type, env, term, dims); - } - if (type.isUnsignedInteger(8)) { - return ConstantOpImpl(module_->builder(), type, env, term, dims); - } - - if (type.isUnsignedInteger(16)) { - return ConstantOpImpl(module_->builder(), type, env, term, dims); - } - - if (type.isUnsignedInteger(32)) { - return ConstantOpImpl(module_->builder(), type, env, term, dims); - } - - if (type.isUnsignedInteger(64)) { - return ConstantOpImpl(module_->builder(), type, env, term, dims); - } - - if (type.isSignlessInteger(8)) { - return ConstantOpImpl(module_->builder(), type, env, term, dims); - } - - if (type.isSignlessInteger(16)) { - return ConstantOpImpl(module_->builder(), type, env, term, dims); - } - - if (type.isSignlessInteger(32)) { - return ConstantOpImpl(module_->builder(), type, env, term, dims); - } - - if (type.isSignlessInteger(64)) { - return ConstantOpImpl(module_->builder(), type, env, term, dims); - } - - if (type.isBF16()) { - return ConstantOpImpl(module_->builder(), type, env, term, dims); - } - - if (type.isF16()) { - return ConstantOpImpl(module_->builder(), type, env, term, dims); - } - - if (type.isa()) { - mlir::ComplexType complex_type = llvm::cast(type); - if (complex_type.getElementType().isF32()) { - return ConstantOpImpl(module_->builder(), complex_type, env, term, dims); - } else { - return ConstantOpImpl(module_->builder(), complex_type, env, term, dims); - } - } - - if (type.isF32()) { - return ConstantOpImpl(module_->builder(), type, env, term, dims); - } - - if (type.isF64()) { - return ConstantOpImpl(module_->builder(), type, env, term, dims); - } - - return exla::nif::error(env, "invalid type received"); -} - -MLIRModule::MLIRModule(mlir::MLIRContext *context) { - context_ = context; - module_ = mlir::OwningOpRef(mlir::ModuleOp::create(mlir::UnknownLoc::get(context_))); - builder_ = std::make_unique(context_); - builder_->setInsertionPointToStart(module_->getBody()); -} - -xla::PrimitiveType MLIRTypeToPrimitiveType(mlir::Type type) { - if (!type.getAsOpaquePointer()) { - std::cerr << "Type with no implementation received" << std::endl; - exit(1); - } - if (type.isa()) { - return xla::PrimitiveType::TOKEN; - } - if (type.isa()) { - return xla::PrimitiveType::TUPLE; - } - if (type.isSignlessInteger(1)) { - return xla::PrimitiveType::PRED; - } - if (type.isUnsignedInteger(8)) { - return xla::primitive_util::NativeToPrimitiveType(); - } - if (type.isUnsignedInteger(16)) { - return xla::primitive_util::NativeToPrimitiveType(); - } - if (type.isUnsignedInteger(32)) { - return xla::primitive_util::NativeToPrimitiveType(); - } - if (type.isUnsignedInteger(64)) { - return xla::primitive_util::NativeToPrimitiveType(); - } - if (type.isSignlessInteger(8)) { - return xla::primitive_util::NativeToPrimitiveType(); - } - if (type.isSignlessInteger(16)) { - return xla::primitive_util::NativeToPrimitiveType(); - } - if (type.isSignlessInteger(32)) { - return xla::primitive_util::NativeToPrimitiveType(); - } - if (type.isSignlessInteger(64)) { - return xla::primitive_util::NativeToPrimitiveType(); - } - if (type.isBF16()) { - return xla::primitive_util::NativeToPrimitiveType(); - } - if (type.isF16()) { - return xla::primitive_util::NativeToPrimitiveType(); - } - if (type.isF32()) { - return xla::primitive_util::NativeToPrimitiveType(); - } - if (type.isF64()) { - return xla::primitive_util::NativeToPrimitiveType(); - } - if (type.isa()) { - mlir::ComplexType complex_type = llvm::cast(type); - if (complex_type.getElementType().isF32()) { - return xla::primitive_util::NativeToPrimitiveType(); - } else { - return xla::primitive_util::NativeToPrimitiveType(); - } - } - - std::cerr << "Invalid type received" << std::endl; - exit(1); -} - -MLIRFunction *MLIRModule::CreateFunction( - std::string name, - std::vector arg_shapes, - std::vector ret_shapes, - bool is_public) { - std::vector types; - types.reserve(arg_shapes.size()); - for (auto arg_shape : arg_shapes) { - mlir::Type type = GetMLIRFunctionType(builder_.get(), arg_shape); - types.push_back(type); - } - - std::vector return_types; - return_types.reserve(ret_shapes.size()); - for (auto ret_shape : ret_shapes) { - mlir::Type type = GetMLIRFunctionType(builder_.get(), ret_shape); - return_types.push_back(type); - } - - auto visibility = is_public ? "public" : "nested"; - - auto funcType = builder_->getFunctionType(types, return_types); - auto loc = builder_->getUnknownLoc(); - auto funcOp = std::make_unique(mlir::func::FuncOp::create(loc, name, funcType)); - funcOp->setSymVisibility(visibility); - module_->push_back(*funcOp); - funcOp->addEntryBlock(); - builder_->setInsertionPointToStart(&funcOp->getBody().front()); - - return new MLIRFunction(this, std::move(funcOp)); -} - -mlir::Value MLIRFunction::ConvOp( - mlir::Value tensor, - mlir::Value kernel, - std::vector window_strides, - std::vector padding, - std::vector tensor_dilation, - std::vector kernel_dilation, - xla::ConvolutionDimensionNumbers dimension_numbers, - uint64_t feature_group_count, - uint64_t batch_group_count, - uint64_t precision_config, - std::vector output_dims) { - auto builder = module_->builder(); - setInsertionPoint(); - - mlir::RankedTensorType t_in = llvm::cast(tensor.getType()); - mlir::RankedTensorType result_type = mlir::RankedTensorType::get(output_dims, t_in.getElementType()); - - auto window_strides_attr = Int64ToDenseIntElementsAttr(module_->builder(), window_strides); - auto tensor_dilation_attr = Int64ToDenseIntElementsAttr(module_->builder(), tensor_dilation); - auto kernel_dilation_attr = Int64ToDenseIntElementsAttr(module_->builder(), kernel_dilation); - auto dimension_numbers_attr = mlir::stablehlo::ConvDimensionNumbersAttr::get( - builder->getContext(), - dimension_numbers.input_batch_dimension(), - dimension_numbers.input_feature_dimension(), - llvm::ArrayRef(dimension_numbers.input_spatial_dimensions().data(), dimension_numbers.input_spatial_dimensions_size()), - dimension_numbers.kernel_input_feature_dimension(), - dimension_numbers.kernel_output_feature_dimension(), - llvm::ArrayRef(dimension_numbers.kernel_spatial_dimensions().data(), dimension_numbers.kernel_spatial_dimensions_size()), - dimension_numbers.output_batch_dimension(), - dimension_numbers.output_feature_dimension(), - llvm::ArrayRef(dimension_numbers.output_spatial_dimensions().data(), dimension_numbers.output_spatial_dimensions_size())); - - auto dense_attr_type = mlir::RankedTensorType::get({static_cast(padding.size() / 2), 2}, builder->getIntegerType(64)); - auto dense_attr = mlir::DenseElementsAttr::get(dense_attr_type, llvm::ArrayRef(padding.data(), padding.size())); - auto padding_attr = llvm::cast(dense_attr); - - return builder->create( - builder->getUnknownLoc(), - result_type, - tensor, - kernel, - window_strides_attr, - padding_attr, - tensor_dilation_attr, - kernel_dilation_attr, - nullptr, - dimension_numbers_attr, - feature_group_count, - batch_group_count, - nullptr); -} - -mlir::Value MLIRFunction::CreateTokenOp() { - auto builder = module_->builder(); - setInsertionPoint(); - return builder->create(builder->getUnknownLoc()); -} - -mlir::Value MLIRFunction::TriangularSolveOp(mlir::Value a, mlir::Value b, bool left_side, bool lower, bool transpose_a) { - auto builder = module_->builder(); - setInsertionPoint(); - mlir::stablehlo::Transpose transpose = mlir::stablehlo::Transpose::NO_TRANSPOSE; - - if (a.getType().isa() and transpose_a) { - transpose = mlir::stablehlo::Transpose::ADJOINT; - } else if (transpose_a) { - transpose = mlir::stablehlo::Transpose::TRANSPOSE; - } - - return builder->create(builder->getUnknownLoc(), a, b, left_side, lower, false, transpose); -} - -mlir::Value MLIRFunction::DynamicUpdateSliceOp(mlir::Value operand, mlir::Value update, std::vector start_indices) { - auto builder = module_->builder(); - setInsertionPoint(); - return builder->create(builder->getUnknownLoc(), operand, update, mlir::ValueRange(start_indices)); -} - -void MLIRModule::LowerPatterns() { - mlir::ConversionTarget target(*context()); - target.addIllegalDialect(); - target.addIllegalDialect(); - target.addLegalDialect(); - - mlir::stablehlo::StablehloToHloTypeConverter converter; - mlir::RewritePatternSet patterns(context()); - - mlir::stablehlo::registerFuncOpsForTypeConversion(target, patterns, converter); - mlir::stablehlo::populateStablehloToHloPatterns(&patterns, &converter, context()); - - mlir::applyPartialConversion(module(), target, std::move(patterns)); -} - -std::pair> MLIRFunction::InfeedOp(mlir::Value token, std::vector shapes) { - auto builder = module_->builder(); - setInsertionPoint(); - - std::vector types; - for (auto shape : shapes) { - types.push_back(GetMLIRFunctionType(builder, &shape)); - } - types.push_back(token.getType()); - - auto infeed_op = builder->create(builder->getUnknownLoc(), mlir::TypeRange(types), token); - - mlir::Operation::result_range results = infeed_op.getResults(); - - std::vector output(results.begin(), results.end()); - - mlir::Value out_token = output.back(); - output.pop_back(); - - return std::make_pair(out_token, output); -} - -mlir::Value MLIRFunction::OutfeedOp(std::vector inputs, mlir::Value token) { - auto builder = module_->builder(); - setInsertionPoint(); - return builder->create(builder->getUnknownLoc(), mlir::ValueRange(inputs), token); -} - -std::vector MLIRFunction::CallOp(std::vector inputs, MLIRFunction *computation) { - auto builder = module_->builder(); - setInsertionPoint(); - auto call_op = builder->create(builder->getUnknownLoc(), *computation->function(), mlir::ValueRange(inputs)); - - mlir::Operation::result_range results = call_op.getResults(); - return std::vector(results.begin(), results.end()); -} - -void addRegionArguments(mlir::OpBuilder *builder, mlir::Region *region, std::vector args) { - mlir::OpBuilder::InsertionGuard insertionPointGuard(*builder); - mlir::Location loc = region->getLoc(); - mlir::Block *block = builder->createBlock(region); - // Add two arguments for each element type. - for (mlir::Value arg : args) { - block->addArgument(arg.getType(), loc); - } -} - -std::pair, std::pair> MLIRFunction::WhileOp(std::vector initial) { - auto builder = module_->builder(); - setInsertionPoint(); - - auto while_op = builder->create(builder->getUnknownLoc(), mlir::ValueRange(initial)); - - mlir::Region *cond = &while_op.getCond(); - addRegionArguments(builder, cond, initial); - - mlir::Region *body = &while_op.getBody(); - addRegionArguments(builder, body, initial); - - mlir::Operation::result_range results = while_op.getResults(); - std::vector output(results.begin(), results.end()); - - return std::make_pair(output, std::make_pair(cond, body)); -} - -std::vector MLIRFunction::ReturnOp(std::vector operands) { - setInsertionPoint(); - auto ret_op = module_->builder()->create(module_->builder()->getUnknownLoc(), mlir::ValueRange(operands)); - - mlir::Operation::operand_range results = ret_op.getResults(); - return std::vector(results.begin(), results.end()); -} - -void MLIRFunction::setInsertionPoint() { - if (regions.size() == 0) { - module_->builder()->setInsertionPointToEnd(&func_->getBody().back()); - } else { - module_->builder()->setInsertionPointToEnd(®ions.top()->back()); - } -} - -std::pair MLIRFunction::QRCpuCustomCall(mlir::Value operand, std::vector q_shape, std::vector r_shape) { - auto builder = module_->builder(); - setInsertionPoint(); - - mlir::RankedTensorType op_type = llvm::cast(operand.getType()); - - auto op_shape = op_type.getShape(); - - mlir::Value dim_sizes = builder->create(builder->getUnknownLoc(), Int64ToDenseIntElementsAttr(builder, std::vector({static_cast(op_shape.size()), static_cast(q_shape.size()), static_cast(r_shape.size())}))); - mlir::Value operand_dims = builder->create(builder->getUnknownLoc(), Int64ToDenseIntElementsAttr(builder, op_shape)); - mlir::Value q_dims = builder->create(builder->getUnknownLoc(), Int64ToDenseIntElementsAttr(builder, q_shape)); - mlir::Value r_dims = builder->create(builder->getUnknownLoc(), Int64ToDenseIntElementsAttr(builder, r_shape)); - - auto element_type = op_type.getElementType(); - std::string call_target_name = "qr_cpu_custom_call_f32"; - - if (element_type.isF32()) { - XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(call_target_name, qr_cpu_custom_call_f32); - } else if (element_type.isF64()) { - call_target_name = "qr_cpu_custom_call_f64"; - XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(call_target_name, qr_cpu_custom_call_f64); - } else if (element_type.isF16()) { - call_target_name = "qr_cpu_custom_call_f16"; - XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(call_target_name, qr_cpu_custom_call_f16); - } else if (element_type.isBF16()) { - call_target_name = "qr_cpu_custom_call_bf16"; - XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(call_target_name, qr_cpu_custom_call_bf16); - } else { - std::cerr << "Unsupported type for QR decomposition" << std::endl; - exit(1); - } - - auto call_target_name_attr = mlir::NamedAttribute(builder->getStringAttr("call_target_name"), builder->getStringAttr(call_target_name)); - auto backend_config_attr = mlir::NamedAttribute(builder->getStringAttr("backend_config"), builder->getStringAttr("Host")); - auto named_attrs = {call_target_name_attr, backend_config_attr}; - - mlir::Type q_type = mlir::RankedTensorType::get(q_shape, op_type.getElementType()); - mlir::Type r_type = mlir::RankedTensorType::get(r_shape, op_type.getElementType()); - - mlir::TupleType out_tuple_type = mlir::TupleType::get(builder->getContext(), mlir::TypeRange({q_type, r_type})); - - auto custom_call = builder->create( - builder->getUnknownLoc(), - mlir::TypeRange({out_tuple_type}), - mlir::ValueRange({operand, dim_sizes, operand_dims, q_dims, r_dims}), - llvm::ArrayRef(named_attrs)); - - mlir::Value out_tuple = custom_call.getResult(0); - mlir::Value q = this->GetTupleElementOp(out_tuple, 0); - mlir::Value r = this->GetTupleElementOp(out_tuple, 1); - - return std::make_pair(q, r); -} - -} // namespace exla diff --git a/exla/c_src/exla/mlir/builder.h b/exla/c_src/exla/mlir/builder.h deleted file mode 100644 index d3d6585c2c..0000000000 --- a/exla/c_src/exla/mlir/builder.h +++ /dev/null @@ -1,170 +0,0 @@ -#ifndef EXLA_MLIR_BUILDER_H_ -#define EXLA_MLIR_BUILDER_H_ - -#include - -#include "../exla_nif_util.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/OwningOpRef.h" -#include "stablehlo/dialect/StablehloOps.h" -#include "stablehlo/reference/Types.h" -#include "xla/shape.h" -#include "xla/types.h" - -namespace exla { - -class MLIRModule; - -class MLIRFunction { - public: - MLIRFunction(MLIRModule *module, std::unique_ptr func); - - mlir::Value AddOp(mlir::Value lhs, mlir::Value rhs); - mlir::Value SubtractOp(mlir::Value lhs, mlir::Value rhs); - mlir::Value TupleOp(std::vector vals); - mlir::Value GetTupleElementOp(mlir::Value tuple, tsl::int64 index); - mlir::Value MulOp(mlir::Value lhs, mlir::Value rhs); - mlir::Value MinOp(mlir::Value lhs, mlir::Value rhs); - mlir::Value MaxOp(mlir::Value lhs, mlir::Value rhs); - mlir::Value RemOp(mlir::Value lhs, mlir::Value rhs); - mlir::Value PowOp(mlir::Value lhs, mlir::Value rhs); - mlir::Value DivOp(mlir::Value lhs, mlir::Value rhs); - mlir::Value Atan2Op(mlir::Value lhs, mlir::Value rhs); - mlir::Value EqualOp(mlir::Value lhs, mlir::Value rhs); - mlir::Value NotEqualOp(mlir::Value lhs, mlir::Value rhs); - mlir::Value LessOp(mlir::Value lhs, mlir::Value rhs); - mlir::Value LessEqualOp(mlir::Value lhs, mlir::Value rhs); - mlir::Value GreaterOp(mlir::Value lhs, mlir::Value rhs); - mlir::Value GreaterEqualOp(mlir::Value lhs, mlir::Value rhs); - mlir::Value BitwiseAndOp(mlir::Value lhs, mlir::Value rhs); - mlir::Value BitwiseOrOp(mlir::Value lhs, mlir::Value rhs); - mlir::Value BitwiseXorOp(mlir::Value lhs, mlir::Value rhs); - mlir::Value BitwiseNotOp(mlir::Value operand); - mlir::Value ShiftLeftOp(mlir::Value lhs, mlir::Value rhs); - mlir::Value ShiftRightLogicalOp(mlir::Value lhs, mlir::Value rhs); - mlir::Value ShiftRightArithmeticOp(mlir::Value lhs, mlir::Value rhs); - mlir::Value ConvertOp(mlir::Value operand, mlir::Type type); - mlir::Value BitcastConvertOp(mlir::Value operand, xla::Shape shape); - mlir::Value PadOp(mlir::Value op, mlir::Value pad, std::vector padding_low, std::vector padding_high, std::vector padding_mid); - mlir::Value AbsOp(mlir::Value operand); - mlir::Value RealOp(mlir::Value operand); - mlir::Value ImagOp(mlir::Value operand); - mlir::Value ConjOp(mlir::Value operand); - mlir::Value ExpOp(mlir::Value operand); - mlir::Value Expm1Op(mlir::Value operand); - mlir::Value FloorOp(mlir::Value operand); - mlir::Value CeilOp(mlir::Value operand); - mlir::Value RoundOp(mlir::Value operand); - mlir::Value LogOp(mlir::Value operand); - mlir::Value LogisticOp(mlir::Value operand); - mlir::Value Log1pOp(mlir::Value operand); - mlir::Value SignOp(mlir::Value operand); - mlir::Value CosOp(mlir::Value operand); - mlir::Value SinOp(mlir::Value operand); - mlir::Value TanOp(mlir::Value operand); - mlir::Value AcosOp(mlir::Value operand); - mlir::Value AsinOp(mlir::Value operand); - mlir::Value AtanOp(mlir::Value operand); - mlir::Value CoshOp(mlir::Value operand); - mlir::Value SinhOp(mlir::Value operand); - mlir::Value TanhOp(mlir::Value operand); - mlir::Value AcoshOp(mlir::Value operand); - mlir::Value AsinhOp(mlir::Value operand); - mlir::Value AtanhOp(mlir::Value operand); - mlir::Value SqrtOp(mlir::Value operand); - mlir::Value CbrtOp(mlir::Value operand); - mlir::Value NegateOp(mlir::Value operand); - mlir::Value ErfOp(mlir::Value operand); - mlir::Value ErfInvOp(mlir::Value operand); - mlir::Value ErfcOp(mlir::Value operand); - mlir::Value IsFiniteOp(mlir::Value operand); - mlir::Value IsInfOp(mlir::Value operand); - mlir::Value IsNanOp(mlir::Value operand); - mlir::Value RsqrtOp(mlir::Value operand); - mlir::Value ClzOp(mlir::Value operand); - mlir::Value PopulationCountOp(mlir::Value operand); - mlir::Value IotaOp(xla::Shape shape, int64_t dimension); - mlir::Value TransposeOp(mlir::Value operand, std::vector axes); - mlir::Value ReshapeOp(mlir::Value operand, std::vector target_shape); - mlir::Value ReverseOp(mlir::Value operand, std::vector dims); - mlir::Value SliceOp(mlir::Value operand, std::vector starts, std::vector limites, std::vector strides); - std::vector TopKOp(mlir::Value operand, int64_t k); - std::vector SortOp(MLIRFunction *comparator, std::vector operand, int64_t dim, bool status); - mlir::Value DynamicSliceOp(mlir::Value operand, std::vector starts, std::vector lengths); - mlir::Value BroadcastInDimOp(mlir::Value operand, xla::Shape result_shape, std::vector axes); - mlir::Value DotGeneralOp(xla::Shape output_shape, mlir::Value lhs, mlir::Value rhs, xla::DotDimensionNumbers dnums, xla::PrecisionConfig config); - mlir::Value ConcatenateOp(std::vector operands, int64_t dimension); - mlir::Value OptimizationBarrierOp(mlir::Value operand); - mlir::Value ClampOp(mlir::Value min, mlir::Value operand, mlir::Value max); - mlir::Value SelectOp(mlir::Value pred, mlir::Value on_true, mlir::Value on_false); - mlir::Value ScatterOp(mlir::Value target, mlir::Value indices, mlir::Value updates, bool add_or_put, int64_t indices_rank, std::vector update_window_dims, std::vector inserted_window_dims, std::vector index_dims_to_window_dims); - mlir::Value SelectAndScatterOp(mlir::Value target, mlir::Value source, mlir::Value init_value, bool gt_or_lt, std::vector window_dimensions, std::vector window_strides, std::vector padding); - mlir::Value GatherOp(mlir::Value source, mlir::Value indices, std::vector offset_dims, std::vector collapsed_slice_dims, std::vector start_index_map, std::vector slice_sizes, int64_t index_vector_dim); - mlir::Value FFTOp(mlir::Value tensor, bool forward_fft, std::vector fft_length); - mlir::Value ConvOp(mlir::Value tensor, mlir::Value kernel, std::vector window_strides, std::vector padding, std::vector tensor_dilation, std::vector kernel_dilation, xla::ConvolutionDimensionNumbers dimension_numbers, uint64_t feature_group_count, uint64_t batch_group_count, uint64_t precision_config, std::vector output_dims); - mlir::Value CreateTokenOp(); - mlir::Value TriangularSolveOp(mlir::Value a, mlir::Value b, bool left_side, bool lower, bool transpose_a); - mlir::Value DynamicUpdateSliceOp(mlir::Value operand, mlir::Value update, std::vector start_indices); - std::vector ReduceOp(MLIRFunction *function, std::vector init_values, std::vector inputs, std::vector dimensions); - std::vector WindowReduceOp(MLIRFunction *function, std::vector init_values, std::vector inputs, std::vector window_dimensions, std::vector window_strides, std::vector input_dilations, std::vector window_dilations, std::vector> padding); - mlir::Value MapOp(MLIRFunction *function, std::vector inputs, std::vector dimensions); - std::pair, std::pair> IfOp(mlir::Value pred, std::vector output_shape); - ERL_NIF_TERM ConstantOp(mlir::Type type, ErlNifEnv *env, ERL_NIF_TERM value_ptr, std::optional> dims = std::nullopt); - std::pair> InfeedOp(mlir::Value token, std::vector shapes); - mlir::Value OutfeedOp(std::vector inputs, mlir::Value token); - std::vector CallOp(std::vector inputs, MLIRFunction *computation); - std::pair, std::pair> WhileOp(std::vector initial); - std::vector ReturnOp(std::vector values); - int get_mlir_type(ErlNifEnv *env, ERL_NIF_TERM term, mlir::Type *type); - std::vector PushRegion(mlir::Region *region); - std::pair QRCpuCustomCall(mlir::Value operand, std::vector q_shape, std::vector r_shape); - void PopRegion(); - void Build(mlir::Value root); - - llvm::MutableArrayRef get_arguments() { return func_->getBody().front().getArguments(); } - - mlir::func::FuncOp *function() { return func_.get(); } - - private: - std::shared_ptr module_; - std::unique_ptr func_; - - std::stack regions; - - void dump_mlir_module(); - void setInsertionPoint(); -}; - -class MLIRModule { - public: - MLIRModule(mlir::MLIRContext *context); - - MLIRFunction *CreateFunction( - std::string name, - std::vector arg_shapes, - std::vector ret_shape, - bool is_public); - - mlir::ModuleOp module() { return module_.get(); } - mlir::OpBuilder *builder() { return builder_.get(); } - mlir::MLIRContext *context() { return context_; } - - void LowerPatterns(); - - private: - mlir::MLIRContext *context_; - mlir::OwningOpRef module_; - std::unique_ptr builder_; -}; - -mlir::Type -TypeIntToMLIRType(mlir::OpBuilder *builder, xla::PrimitiveType type_int); - -xla::PrimitiveType MLIRTypeToPrimitiveType(mlir::Type); -} // namespace exla - -#endif \ No newline at end of file diff --git a/exla/c_src/exla/mlir/ops.cc b/exla/c_src/exla/mlir/ops.cc deleted file mode 100644 index d53d038a99..0000000000 --- a/exla/c_src/exla/mlir/ops.cc +++ /dev/null @@ -1,1594 +0,0 @@ - -#include "ops.h" - -#include - -#include "../exla_client.h" -#include "../exla_nif_util.h" -#include "mhlo/IR/hlo_ops.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "stablehlo/dialect/ChloOps.h" -#include "stablehlo/dialect/StablehloOps.h" -#include "xla/shape_util.h" - -// MLIR Functions - -ERL_NIF_TERM mlir_compile(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 7) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::ExlaClient** client; - exla::MLIRModule** module; - std::vector argument_layouts; - xla::ExecutableBuildOptions build_options; - int num_replicas; - int num_partitions; - bool use_spmd; - int device_id; - - if (!exla::nif::get(env, argv[0], client)) { - return exla::nif::error(env, "Unable to get client."); - } - if (!exla::nif::get(env, argv[1], module)) { - return exla::nif::error(env, "Unable to get module."); - } - if (!exla::nif::get_list(env, argv[2], argument_layouts)) { - return exla::nif::error(env, "Unable to get argument layouts."); - } - if (!exla::nif::get(env, argv[3], &num_replicas)) { - return exla::nif::error(env, "Unable to get Number of Replicas."); - } - if (!exla::nif::get(env, argv[4], &num_partitions)) { - return exla::nif::error(env, "Unable to get Number of Partitions."); - } - if (!exla::nif::get(env, argv[5], &use_spmd)) { - return exla::nif::error(env, "Unable to get SPMD Partitioning Flag."); - } - if (!exla::nif::get(env, argv[6], &device_id)) { - return exla::nif::error(env, "Unable to get device ID."); - } - - build_options.set_num_replicas(num_replicas); - build_options.set_num_partitions(num_partitions); - build_options.set_use_spmd_partitioning(use_spmd); - - bool compile_portable_executable = false; - if (device_id >= 0) { - compile_portable_executable = true; - build_options.set_device_ordinal(device_id); - } - - EXLA_ASSIGN_OR_RETURN_NIF(exla::ExlaExecutable * executable, - (*client)->Compile((*module)->module(), argument_layouts, build_options, compile_portable_executable), env); - - return exla::nif::ok(env, exla::nif::make(env, executable)); -} - -ERL_NIF_TERM new_mlir_context(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 0) { - return exla::nif::error(env, "Bad argument count."); - } - - mlir::MLIRContext* context = new mlir::MLIRContext(); - context->getOrLoadDialect(); - context->getOrLoadDialect(); - context->getOrLoadDialect(); - context->getOrLoadDialect(); - - auto ret = exla::nif::make(env, context); - return exla::nif::ok(env, ret); -} - -ERL_NIF_TERM new_mlir_module(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 1) { - return exla::nif::error(env, "Bad argument count."); - } - - mlir::MLIRContext** ctx; - - if (!exla::nif::get(env, argv[0], ctx)) { - return exla::nif::error(env, "Unable to get context."); - } - - exla::MLIRModule* module = new exla::MLIRModule(*ctx); - - return exla::nif::ok(env, exla::nif::make(env, module)); -} - -ERL_NIF_TERM create_mlir_function(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 5) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRModule** module; - std::string func_name; - std::vector, xla::PrimitiveType>> arg_types; - std::pair, xla::PrimitiveType> ret_type; - std::vector arg_shapes; - std::vector ret_shapes; - bool is_public; - - if (!exla::nif::get(env, argv[0], module)) { - return exla::nif::error(env, "Unable to get module."); - } - if (!exla::nif::get(env, argv[1], func_name)) { - return exla::nif::error(env, "Unable to get function name."); - } - if (!exla::nif::get_list(env, argv[2], arg_shapes)) { - return exla::nif::error(env, "Unable to get args."); - } - if (!exla::nif::get_list(env, argv[3], ret_shapes)) { - return exla::nif::error(env, "Unable to get return."); - } - if (!exla::nif::get(env, argv[4], &is_public)) { - return exla::nif::error(env, "Unable to get is_public."); - } - - exla::MLIRFunction* func = (*module)->CreateFunction(func_name, arg_shapes, ret_shapes, is_public); - - return exla::nif::ok(env, exla::nif::make(env, func)); -} - -ERL_NIF_TERM get_mlir_function_arguments(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 1) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - - llvm::MutableArrayRef args = (*function)->get_arguments(); - std::vector terms; - terms.reserve(args.size()); - - for (auto arg : args) { - ERL_NIF_TERM term = exla::nif::make(env, arg); - terms.push_back(term); - } - - return exla::nif::ok(env, enif_make_list_from_array(env, terms.data(), terms.size())); -} - -ERL_NIF_TERM mlir_binary_op(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[], std::function op) { - if (argc != 3) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - mlir::Value* lhs; - mlir::Value* rhs; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], lhs)) { - return exla::nif::error(env, "Unable to get lhs."); - } - if (!exla::nif::get(env, argv[2], rhs)) { - return exla::nif::error(env, "Unable to get rhs."); - } - - mlir::Value res = op(*function, lhs, rhs); - - return exla::nif::ok(env, exla::nif::make(env, res)); -} - -#define MLIR_BIN_OP(OP) mlir_binary_op(env, argc, argv, [](exla::MLIRFunction* f, mlir::Value* lhs, mlir::Value* rhs) -> mlir::Value { return f->OP(*lhs, *rhs); }) - -ERL_NIF_TERM mlir_add(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_BIN_OP(AddOp); -} - -ERL_NIF_TERM mlir_subtract(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_BIN_OP(SubtractOp); -} - -ERL_NIF_TERM mlir_multiply(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_BIN_OP(MulOp); -} - -ERL_NIF_TERM mlir_min(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_BIN_OP(MinOp); -} - -ERL_NIF_TERM mlir_max(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_BIN_OP(MaxOp); -} - -ERL_NIF_TERM mlir_remainder(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_BIN_OP(RemOp); -} - -ERL_NIF_TERM mlir_pow(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_BIN_OP(PowOp); -} - -ERL_NIF_TERM mlir_divide(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_BIN_OP(DivOp); -} - -ERL_NIF_TERM mlir_atan2(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_BIN_OP(Atan2Op); -} - -ERL_NIF_TERM mlir_equal(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_BIN_OP(EqualOp); -} - -ERL_NIF_TERM mlir_not_equal(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_BIN_OP(NotEqualOp); -} - -ERL_NIF_TERM mlir_less(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_BIN_OP(LessOp); -} - -ERL_NIF_TERM mlir_less_equal(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_BIN_OP(LessEqualOp); -} - -ERL_NIF_TERM mlir_greater(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_BIN_OP(GreaterOp); -} - -ERL_NIF_TERM mlir_greater_equal(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_BIN_OP(GreaterEqualOp); -} - -ERL_NIF_TERM mlir_bitwise_and(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_BIN_OP(BitwiseAndOp); -} - -ERL_NIF_TERM mlir_bitwise_or(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_BIN_OP(BitwiseOrOp); -} - -ERL_NIF_TERM mlir_bitwise_xor(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_BIN_OP(BitwiseXorOp); -} - -ERL_NIF_TERM mlir_shift_left(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_BIN_OP(ShiftLeftOp); -} - -ERL_NIF_TERM mlir_shift_right_logical(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_BIN_OP(ShiftRightLogicalOp); -} - -ERL_NIF_TERM mlir_shift_right_arithmetic(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_BIN_OP(ShiftRightArithmeticOp); -} - -ERL_NIF_TERM mlir_unary_op(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[], std::function op) { - if (argc != 2) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - mlir::Value* operand; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], operand)) { - return exla::nif::error(env, "Unable to get operand."); - } - - mlir::Value res = op(*function, operand); - - return exla::nif::ok(env, exla::nif::make(env, res)); -} - -#define MLIR_UNARY_OP(OP) mlir_unary_op(env, argc, argv, [](exla::MLIRFunction* f, mlir::Value* operand) -> mlir::Value { return f->OP(*operand); }) - -ERL_NIF_TERM mlir_abs(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(AbsOp); -} -ERL_NIF_TERM mlir_exp(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(ExpOp); -} -ERL_NIF_TERM mlir_expm1(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(Expm1Op); -} -ERL_NIF_TERM mlir_floor(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(FloorOp); -} -ERL_NIF_TERM mlir_ceil(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(CeilOp); -} -ERL_NIF_TERM mlir_round(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(RoundOp); -} -ERL_NIF_TERM mlir_log(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(LogOp); -} -ERL_NIF_TERM mlir_sigmoid(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(LogisticOp); -} -ERL_NIF_TERM mlir_log1p(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(Log1pOp); -} -ERL_NIF_TERM mlir_sign(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(SignOp); -} -ERL_NIF_TERM mlir_cos(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(CosOp); -} -ERL_NIF_TERM mlir_tan(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(TanOp); -} -ERL_NIF_TERM mlir_sin(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(SinOp); -} -ERL_NIF_TERM mlir_acos(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(AcosOp); -} -ERL_NIF_TERM mlir_asin(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(AsinOp); -} -ERL_NIF_TERM mlir_atan(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(AtanOp); -} -ERL_NIF_TERM mlir_cosh(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(CoshOp); -} -ERL_NIF_TERM mlir_sinh(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(SinhOp); -} -ERL_NIF_TERM mlir_tanh(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(TanhOp); -} -ERL_NIF_TERM mlir_acosh(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(AcoshOp); -} -ERL_NIF_TERM mlir_asinh(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(AsinhOp); -} -ERL_NIF_TERM mlir_atanh(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(AtanhOp); -} -ERL_NIF_TERM mlir_sqrt(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(SqrtOp); -} -ERL_NIF_TERM mlir_cbrt(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(CbrtOp); -} - -ERL_NIF_TERM mlir_bitwise_not(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(BitwiseNotOp); -} - -ERL_NIF_TERM mlir_negate(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(NegateOp); -} - -ERL_NIF_TERM mlir_erf(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(ErfOp); -} - -ERL_NIF_TERM mlir_erfc(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(ErfcOp); -} - -ERL_NIF_TERM mlir_erf_inv(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(ErfInvOp); -} - -ERL_NIF_TERM mlir_is_infinity(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(IsInfOp); -} -ERL_NIF_TERM mlir_is_nan(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(IsNanOp); -} -ERL_NIF_TERM mlir_rsqrt(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(RsqrtOp); -} -ERL_NIF_TERM mlir_clz(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(ClzOp); -} -ERL_NIF_TERM mlir_population_count(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(PopulationCountOp); -} -ERL_NIF_TERM mlir_real(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(RealOp); -} -ERL_NIF_TERM mlir_imag(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(ImagOp); -} -ERL_NIF_TERM mlir_conjugate(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - return MLIR_UNARY_OP(ConjOp); -} - -ERL_NIF_TERM mlir_iota(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 3) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - xla::Shape* shape; - exla::int64 dimension; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], shape)) { - return exla::nif::error(env, "Unable to get shape."); - } - if (!exla::nif::get(env, argv[2], &dimension)) { - return exla::nif::error(env, "Unable to get dimension"); - } - - mlir::Value res = (*function)->IotaOp(*shape, dimension); - return exla::nif::ok(env, exla::nif::make(env, res)); -} -ERL_NIF_TERM mlir_reshape(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 3) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - std::vector shape; - mlir::Value* operand; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], operand)) { - return exla::nif::error(env, "Unable to get operand."); - } - if (!exla::nif::get_tuple(env, argv[2], shape)) { - return exla::nif::error(env, "Unable to get shape."); - } - - mlir::Value res = (*function)->ReshapeOp(*operand, shape); - return exla::nif::ok(env, exla::nif::make(env, res)); -} - -ERL_NIF_TERM mlir_reverse(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 3) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - std::vector dims; - mlir::Value* operand; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], operand)) { - return exla::nif::error(env, "Unable to get operand."); - } - if (!exla::nif::get_list(env, argv[2], dims)) { - return exla::nif::error(env, "Unable to get dims."); - } - - mlir::Value res = (*function)->ReverseOp(*operand, dims); - return exla::nif::ok(env, exla::nif::make(env, res)); -} - -ERL_NIF_TERM mlir_transpose(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 3) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - std::vector axes; - mlir::Value* operand; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], operand)) { - return exla::nif::error(env, "Unable to get operand."); - } - if (!exla::nif::get_list(env, argv[2], axes)) { - return exla::nif::error(env, "Unable to get axes."); - } - - mlir::Value res = (*function)->TransposeOp(*operand, axes); - return exla::nif::ok(env, exla::nif::make(env, res)); -} - -ERL_NIF_TERM mlir_slice(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 5) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - std::vector starts, limits, strides; - mlir::Value* operand; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], operand)) { - return exla::nif::error(env, "Unable to get operand."); - } - if (!exla::nif::get_list(env, argv[2], starts)) { - return exla::nif::error(env, "Unable to get starts."); - } - if (!exla::nif::get_list(env, argv[3], limits)) { - return exla::nif::error(env, "Unable to get lengths."); - } - if (!exla::nif::get_list(env, argv[4], strides)) { - return exla::nif::error(env, "Unable to get strides."); - } - - mlir::Value res = (*function)->SliceOp(*operand, starts, limits, strides); - return exla::nif::ok(env, exla::nif::make(env, res)); -} - -ERL_NIF_TERM mlir_dynamic_slice(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 4) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - std::vector starts; - std::vector lengths; - mlir::Value* operand; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], operand)) { - return exla::nif::error(env, "Unable to get operand."); - } - if (!exla::nif::get_list(env, argv[2], starts)) { - return exla::nif::error(env, "Unable to get starts."); - } - if (!exla::nif::get_list(env, argv[3], lengths)) { - return exla::nif::error(env, "Unable to get lengths."); - } - - mlir::Value res = (*function)->DynamicSliceOp(*operand, starts, lengths); - return exla::nif::ok(env, exla::nif::make(env, res)); -} - -ERL_NIF_TERM mlir_constant_r0(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 3) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - mlir::Type type; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if ((*function)->get_mlir_type(env, argv[2], &type)) { - return exla::nif::error(env, "Unable to get type string."); - } - - return (*function)->ConstantOp(type, env, argv[1]); -} -ERL_NIF_TERM mlir_constant_from_binary(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 4) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - mlir::Type type; - std::vector dims; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if ((*function)->get_mlir_type(env, argv[2], &type)) { - return exla::nif::error(env, "Unable to get type string."); - } - if (!exla::nif::get_tuple(env, argv[3], dims)) { - return exla::nif::error(env, "Unable to get dims."); - } - - return (*function)->ConstantOp(type, env, argv[1], dims); -} - -ERL_NIF_TERM mlir_dot_general(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 6) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - xla::Shape* output_shape; - mlir::Value* lhs; - mlir::Value* rhs; - xla::DotDimensionNumbers dnums; - xla::PrecisionConfig config; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], output_shape)) { - return exla::nif::error(env, "Unable to get shape."); - } - if (!exla::nif::get(env, argv[2], lhs)) { - return exla::nif::error(env, "Unable to get lhs."); - } - if (!exla::nif::get(env, argv[3], rhs)) { - return exla::nif::error(env, "Unable to get rhs."); - } - if (!exla::nif::get_dot_dimension_numbers(env, argv[4], &dnums)) { - return exla::nif::error(env, "Unable to get dot dimensions."); - } - if (!exla::nif::get_precision_config(env, argv[5], 2, &config)) { - return exla::nif::error(env, "Unable to get precision configuration."); - } - - mlir::Value res = (*function)->DotGeneralOp(*output_shape, *lhs, *rhs, dnums, config); - return exla::nif::ok(env, exla::nif::make(env, res)); -} - -ERL_NIF_TERM mlir_select(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 4) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - mlir::Value* pred; - mlir::Value* on_true; - mlir::Value* on_false; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], pred)) { - return exla::nif::error(env, "Unable to get pred."); - } - if (!exla::nif::get(env, argv[2], on_true)) { - return exla::nif::error(env, "Unable to get on true."); - } - if (!exla::nif::get(env, argv[3], on_false)) { - return exla::nif::error(env, "Unable to get on false."); - } - - mlir::Value res = (*function)->SelectOp(*pred, *on_true, *on_false); - return exla::nif::ok(env, exla::nif::make(env, res)); -} - -ERL_NIF_TERM mlir_convert(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 3) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - mlir::Value* t; - mlir::Type type; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], t)) { - return exla::nif::error(env, "Unable to get tensor."); - } - if ((*function)->get_mlir_type(env, argv[2], &type)) { - return exla::nif::error(env, "Unable to get type string."); - } - - mlir::Value result = (*function)->ConvertOp(*t, type); - - return exla::nif::ok(env, exla::nif::make(env, result)); -} -ERL_NIF_TERM mlir_top_k(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 3) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - mlir::Value* operand; - int64_t k; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], operand)) { - return exla::nif::error(env, "Unable to get operand."); - } - if (!exla::nif::get(env, argv[2], &k)) { - return exla::nif::error(env, "Unable to get k."); - } - - std::vector result = (*function)->TopKOp(*operand, k); - return exla::nif::ok(env, exla::nif::make_list(env, result)); -} - -ERL_NIF_TERM mlir_sort(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 5) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - std::vector operands; - exla::int64 axis; - exla::MLIRFunction** comparator; - bool stable; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get_list(env, argv[1], operands)) { - return exla::nif::error(env, "Unable to get operands."); - } - if (!exla::nif::get(env, argv[2], &axis)) { - return exla::nif::error(env, "Unable to get axis."); - } - if (!exla::nif::get(env, argv[3], comparator)) { - return exla::nif::error(env, "Unable to get comparator."); - } - if (!exla::nif::get(env, argv[4], &stable)) { - return exla::nif::error(env, "Unable to get stable flag."); - } - - std::vector res = (*function)->SortOp(*comparator, operands, axis, stable); - return exla::nif::ok(env, exla::nif::make_list(env, res)); -} - -ERL_NIF_TERM mlir_reduce(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 5) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - exla::MLIRFunction** reducer; - std::vector init_values; - std::vector inputs; - std::vector dimensions; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], reducer)) { - return exla::nif::error(env, "Unable to get reducer."); - } - if (!exla::nif::get_list(env, argv[2], init_values)) { - return exla::nif::error(env, "Unable to get init_values."); - } - if (!exla::nif::get_list(env, argv[3], inputs)) { - return exla::nif::error(env, "Unable to get inputs."); - } - if (!exla::nif::get_tuple(env, argv[4], dimensions)) { - return exla::nif::error(env, "Unable to get dimensions."); - } - - std::vector res = (*function)->ReduceOp(*reducer, init_values, inputs, dimensions); - return exla::nif::ok(env, exla::nif::make_list(env, res)); -} - -ERL_NIF_TERM mlir_window_reduce(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 9) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - exla::MLIRFunction** reducer; - std::vector init_values; - std::vector inputs; - std::vector window_dimensions; - std::vector window_strides; - std::vector input_dilations; - std::vector window_dilations; - std::vector> padding; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], reducer)) { - return exla::nif::error(env, "Unable to get reducer."); - } - if (!exla::nif::get_list(env, argv[2], init_values)) { - return exla::nif::error(env, "Unable to get init_values."); - } - if (!exla::nif::get_list(env, argv[3], inputs)) { - return exla::nif::error(env, "Unable to get inputs."); - } - if (!exla::nif::get_tuple(env, argv[4], window_dimensions)) { - return exla::nif::error(env, "Unable to get window_dimensions."); - } - if (!exla::nif::get_tuple(env, argv[5], window_strides)) { - return exla::nif::error(env, "Unable to get window_strides."); - } - if (!exla::nif::get_tuple(env, argv[6], input_dilations)) { - return exla::nif::error(env, "Unable to get input_dilations."); - } - if (!exla::nif::get_tuple(env, argv[7], window_dilations)) { - return exla::nif::error(env, "Unable to get window_dilations."); - } - if (!exla::nif::get_general_padding(env, argv[8], padding)) { - return exla::nif::error(env, "Unable to get padding."); - } - - std::vector res = (*function)->WindowReduceOp(*reducer, - init_values, - inputs, - window_dimensions, - window_strides, - input_dilations, - window_dilations, - padding); - - return exla::nif::ok(env, exla::nif::make_list(env, res)); -} - -ERL_NIF_TERM mlir_map(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 4) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - exla::MLIRFunction** mapper; - std::vector inputs; - std::vector dimensions; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], mapper)) { - return exla::nif::error(env, "Unable to get mapper."); - } - if (!exla::nif::get_list(env, argv[2], inputs)) { - return exla::nif::error(env, "Unable to get inputs."); - } - if (!exla::nif::get_tuple(env, argv[3], dimensions)) { - return exla::nif::error(env, "Unable to get dimensions."); - } - - mlir::Value result = (*function)->MapOp(*mapper, inputs, dimensions); - - return exla::nif::ok(env, exla::nif::make(env, result)); -} - -ERL_NIF_TERM mlir_if(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 3) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - mlir::Value* pred; - std::vector output_shapes; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], pred)) { - return exla::nif::error(env, "Unable to get pred."); - } - if (!exla::nif::get_list(env, argv[2], output_shapes)) { - return exla::nif::error(env, "Unable to get output shapes."); - } - - auto result = (*function)->IfOp(*pred, output_shapes); - - ERL_NIF_TERM res = exla::nif::make_list(env, result.first); - ERL_NIF_TERM true_region = exla::nif::make(env, result.second.first); - ERL_NIF_TERM false_region = exla::nif::make(env, result.second.second); - return exla::nif::ok(env, enif_make_tuple3(env, res, true_region, false_region)); -} - -ERL_NIF_TERM mlir_push_region(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 2) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - mlir::Region** region; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], region)) { - return exla::nif::error(env, "Unable to get region."); - } - - std::vector args = (*function)->PushRegion(*region); - return exla::nif::ok(env, exla::nif::make_list(env, args)); -} - -ERL_NIF_TERM -mlir_pop_region(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 1) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - - (*function)->PopRegion(); - return exla::nif::ok(env); -} - -ERL_NIF_TERM mlir_bitcast_convert(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 4) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - mlir::Value* t; - mlir::Type type; - xla::PrimitiveType element_type; - std::vector dims; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], t)) { - return exla::nif::error(env, "Unable to get tensor."); - } - if (!exla::nif::get_primitive_type(env, argv[2], &element_type)) { - return exla::nif::error(env, "Unable to get type."); - } - if (!exla::nif::get_tuple(env, argv[3], dims)) { - return exla::nif::error(env, "Unable to get dimensions."); - } - - xla::Shape shape = xla::ShapeUtil::MakeShape(element_type, dims); - - mlir::Value result = (*function)->BitcastConvertOp(*t, shape); - - return exla::nif::ok(env, exla::nif::make(env, result)); -} - -ERL_NIF_TERM mlir_pad(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 6) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - std::vector padding_high, padding_low, padding_mid; - mlir::Value *operand, *pad_value; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], operand)) { - return exla::nif::error(env, "Unable to get operand."); - } - if (!exla::nif::get(env, argv[2], pad_value)) { - return exla::nif::error(env, "Unable to get pad value."); - } - if (!exla::nif::get_list(env, argv[3], padding_low)) { - return exla::nif::error(env, "Unable to get padding_low."); - } - if (!exla::nif::get_list(env, argv[4], padding_high)) { - return exla::nif::error(env, "Unable to get padding_high."); - } - if (!exla::nif::get_list(env, argv[5], padding_mid)) { - return exla::nif::error(env, "Unable to get padding_mid."); - } - - mlir::Value res = (*function)->PadOp(*operand, *pad_value, padding_low, padding_high, padding_mid); - return exla::nif::ok(env, exla::nif::make(env, res)); -} - -ERL_NIF_TERM mlir_optimization_barrier(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 2) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - mlir::Value* t; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], t)) { - return exla::nif::error(env, "Unable to get tensor."); - } - - mlir::Value result = (*function)->OptimizationBarrierOp(*t); - return exla::nif::ok(env, exla::nif::make(env, result)); -} - -ERL_NIF_TERM mlir_clamp(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 4) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - mlir::Value* operand; - mlir::Value* min; - mlir::Value* max; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], operand)) { - return exla::nif::error(env, "Unable to get operand."); - } - if (!exla::nif::get(env, argv[2], min)) { - return exla::nif::error(env, "Unable to get operand."); - } - if (!exla::nif::get(env, argv[3], max)) { - return exla::nif::error(env, "Unable to get operand."); - } - - mlir::Value result = (*function)->ClampOp(*min, *operand, *max); - - return exla::nif::ok(env, exla::nif::make(env, result)); -} - -xla::Shape mlir_type_to_xla_shape(mlir::Type type) { - if (type.isa()) { - auto tensorType = type.cast(); - // Get the shape (dimensions) of the tensor - std::vector dims = tensorType.getShape(); - auto element_type = tensorType.getElementType(); - return xla::ShapeUtil::MakeShape(exla::MLIRTypeToPrimitiveType(element_type), dims); - } - - if (type.isa()) { - auto tupleType = type.cast(); - std::vector subshapes; - - for (mlir::Type subType : tupleType.getTypes()) { - // Handle each sub-type in the tuple - subshapes.push_back(mlir_type_to_xla_shape(subType)); - } - return xla::ShapeUtil::MakeTupleShape(subshapes); - } - - auto element_type = exla::MLIRTypeToPrimitiveType(type); - - if (element_type == xla::PrimitiveType::TOKEN) { - return xla::ShapeUtil::MakeTokenShape(); - } - - return xla::ShapeUtil::MakeShape(element_type, {}); -} - -ERL_NIF_TERM mlir_get_shape(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 1) { - return exla::nif::error(env, "Bad argument count."); - } - - mlir::Value* t; - - if (!exla::nif::get(env, argv[0], t)) { - return exla::nif::error(env, "Unable to get tensor."); - } - - mlir::Type type = t->getType(); - xla::Shape shape = mlir_type_to_xla_shape(type); - - return exla::nif::ok(env, exla::nif::make(env, shape)); -} - -ERL_NIF_TERM mlir_broadcast_in_dim(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 4) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - std::vector axes; - xla::Shape* output_shape; - mlir::Value* operand; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], output_shape)) { - return exla::nif::error(env, "Unable to get shape."); - } - if (!exla::nif::get(env, argv[2], operand)) { - return exla::nif::error(env, "Unable to get operand."); - } - if (!exla::nif::get_tuple(env, argv[3], axes)) { - return exla::nif::error(env, "Unable to get broadcast dimensions."); - } - - mlir::Value res = (*function)->BroadcastInDimOp(*operand, *output_shape, axes); - - return exla::nif::ok(env, exla::nif::make(env, res)); -} - -ERL_NIF_TERM mlir_concatenate(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 3) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - std::vector vals; - exla::int64 dimension; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get_list(env, argv[1], vals)) { - return exla::nif::error(env, "Unable to get values."); - } - if (!exla::nif::get(env, argv[2], &dimension)) { - return exla::nif::error(env, "Unable to get dimension"); - } - - mlir::Value res = (*function)->ConcatenateOp(vals, dimension); - - return exla::nif::ok(env, exla::nif::make(env, res)); -} - -ERL_NIF_TERM dump_mlir_module(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 1) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRModule** builder; - - if (!exla::nif::get(env, argv[0], builder)) { - return exla::nif::error(env, "Unable to get builder."); - } - - (*builder)->module().dump(); - - return exla::nif::ok(env); -} - -ERL_NIF_TERM mlir_scatter(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 9) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - mlir::Value *target, *indices, *updates; - bool add_or_put; - int64_t indices_rank; - std::vector update_window_dims, inserted_window_dims, index_dims_to_window_dims; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], target)) { - return exla::nif::error(env, "Unable to get target."); - } - if (!exla::nif::get(env, argv[2], indices)) { - return exla::nif::error(env, "Unable to get indices."); - } - if (!exla::nif::get(env, argv[3], updates)) { - return exla::nif::error(env, "Unable to get updates."); - } - if (!exla::nif::get(env, argv[4], &add_or_put)) { - return exla::nif::error(env, "Unable to get add_or_put."); - } - if (!exla::nif::get(env, argv[5], &indices_rank)) { - return exla::nif::error(env, "Unable to get indices_rank."); - } - if (!exla::nif::get_list(env, argv[6], update_window_dims)) { - return exla::nif::error(env, "Unable to get update_window_dims."); - } - if (!exla::nif::get_list(env, argv[7], inserted_window_dims)) { - return exla::nif::error(env, "Unable to get inserted_window_dims."); - } - if (!exla::nif::get_list(env, argv[8], index_dims_to_window_dims)) { - return exla::nif::error(env, "Unable to get index_dims_to_window_dims."); - } - - mlir::Value res = (*function)->ScatterOp( - *target, - *indices, - *updates, - add_or_put, - indices_rank, - update_window_dims, - inserted_window_dims, - index_dims_to_window_dims); - return exla::nif::ok(env, exla::nif::make(env, res)); -} - -ERL_NIF_TERM mlir_select_and_scatter(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 8) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - mlir::Value *target, *source, *init_value; - bool add_or_put, gt_or_lt; - - std::vector window_dimensions, window_strides; - std::vector> padding_config; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], target)) { - return exla::nif::error(env, "Unable to get target."); - } - if (!exla::nif::get(env, argv[2], source)) { - return exla::nif::error(env, "Unable to get source."); - } - if (!exla::nif::get(env, argv[3], init_value)) { - return exla::nif::error(env, "Unable to get init_value."); - } - if (!exla::nif::get(env, argv[4], >_or_lt)) { - return exla::nif::error(env, "Unable to get gt_or_lt."); - } - if (!exla::nif::get_list(env, argv[5], window_dimensions)) { - return exla::nif::error(env, "Unable to get window_dimensions."); - } - if (!exla::nif::get_list(env, argv[6], window_strides)) { - return exla::nif::error(env, "Unable to get window_strides."); - } - if (!exla::nif::get_general_padding(env, argv[7], padding_config)) { - return exla::nif::error(env, "Unable to get padding configuration."); - } - - std::vector padding; - - for (std::pair item : padding_config) { - padding.push_back(item.first); - padding.push_back(item.second); - } - - mlir::Value res = (*function)->SelectAndScatterOp(*target, *source, *init_value, gt_or_lt, window_dimensions, window_strides, padding); - return exla::nif::ok(env, exla::nif::make(env, res)); -} - -ERL_NIF_TERM mlir_gather(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 8) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - mlir::Value *source, *indices; - - int64_t index_vector_dim; - std::vector slice_sizes, offset_dims, collapsed_slice_dims, start_index_map; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], source)) { - return exla::nif::error(env, "Unable to get source."); - } - if (!exla::nif::get(env, argv[2], indices)) { - return exla::nif::error(env, "Unable to get indices."); - } - if (!exla::nif::get_list(env, argv[3], slice_sizes)) { - return exla::nif::error(env, "Unable to get slice_sizes."); - } - if (!exla::nif::get_list(env, argv[4], offset_dims)) { - return exla::nif::error(env, "Unable to get offset_dims."); - } - if (!exla::nif::get_list(env, argv[5], collapsed_slice_dims)) { - return exla::nif::error(env, "Unable to get collapsed_slice_dims."); - } - if (!exla::nif::get_list(env, argv[6], start_index_map)) { - return exla::nif::error(env, "Unable to get start_index_map."); - } - if (!exla::nif::get(env, argv[7], &index_vector_dim)) { - return exla::nif::error(env, "Unable to get index_vector_dim."); - } - - mlir::Value res = (*function)->GatherOp(*source, *indices, offset_dims, collapsed_slice_dims, start_index_map, slice_sizes, index_vector_dim); - return exla::nif::ok(env, exla::nif::make(env, res)); -} - -ERL_NIF_TERM mlir_fft(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 4) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - mlir::Value* operand; - bool forward_fft; - - std::vector fft_length; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], operand)) { - return exla::nif::error(env, "Unable to get operand."); - } - if (!exla::nif::get(env, argv[2], &forward_fft)) { - return exla::nif::error(env, "Unable to get forward_fft."); - } - if (!exla::nif::get_list(env, argv[3], fft_length)) { - return exla::nif::error(env, "Unable to get fft_length."); - } - - mlir::Value res = (*function)->FFTOp(*operand, forward_fft, fft_length); - return exla::nif::ok(env, exla::nif::make(env, res)); -} - -ERL_NIF_TERM mlir_convolution(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 12) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - mlir::Value *tensor, *kernel; - std::vector strides; - std::vector> padding_config; - std::vector tensor_dilation; - std::vector kernel_dilation; - xla::ConvolutionDimensionNumbers dimension_numbers; - uint64_t feature_group_count, batch_group_count, precision_config; - std::vector output_dims; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], tensor)) { - return exla::nif::error(env, "Unable to get operand."); - } - if (!exla::nif::get(env, argv[2], kernel)) { - return exla::nif::error(env, "Unable to get kernel."); - } - if (!exla::nif::get_list(env, argv[3], strides)) { - return exla::nif::error(env, "Unable to get strides."); - } - if (!exla::nif::get_general_padding(env, argv[4], padding_config)) { - return exla::nif::error(env, "Unable to get padding_config."); - } - if (!exla::nif::get_list(env, argv[5], tensor_dilation)) { - return exla::nif::error(env, "Unable to get operand dilation."); - } - if (!exla::nif::get_list(env, argv[6], kernel_dilation)) { - return exla::nif::error(env, "Unable to get kernel dilation."); - } - if (!exla::nif::get_conv_dimension_numbers(env, argv[7], &dimension_numbers)) { - return exla::nif::error(env, "Unable to get conv dimension numbers."); - } - if (!exla::nif::get(env, argv[8], &feature_group_count)) { - return exla::nif::error(env, "Unable to get feature groups."); - } - if (!exla::nif::get(env, argv[9], &batch_group_count)) { - return exla::nif::error(env, "Unable to get batch groups."); - } - if (!exla::nif::get(env, argv[10], &precision_config)) { - return exla::nif::error(env, "Unable to get precision config."); - } - if (!exla::nif::get_list(env, argv[11], output_dims)) { - return exla::nif::error(env, "Unable to get output_dims."); - } - - std::vector padding; - - for (std::pair item : padding_config) { - padding.push_back(item.first); - padding.push_back(item.second); - } - - mlir::Value res = (*function)->ConvOp( - *tensor, - *kernel, - strides, - padding, - tensor_dilation, - kernel_dilation, - dimension_numbers, - feature_group_count, - batch_group_count, - precision_config, - output_dims); - - return exla::nif::ok(env, exla::nif::make(env, res)); -} - -ERL_NIF_TERM mlir_create_token(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 1) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - - mlir::Value token = (*function)->CreateTokenOp(); - - return exla::nif::ok(env, exla::nif::make(env, token)); -} - -ERL_NIF_TERM mlir_triangular_solve(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 6) { - return exla::nif::error(env, "Bad argument count."); - } - // mlir::Value TriangularSolveOp(mlir::Value a, mlir::Value b, bool left_side, bool lower, bool transpose_a); - - exla::MLIRFunction** function; - mlir::Value *a, *b; - bool left_side, lower, transpose_a; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], a)) { - return exla::nif::error(env, "Unable to get a."); - } - if (!exla::nif::get(env, argv[2], b)) { - return exla::nif::error(env, "Unable to get b."); - } - if (!exla::nif::get(env, argv[3], &left_side)) { - return exla::nif::error(env, "Unable to get left_side."); - } - if (!exla::nif::get(env, argv[4], &lower)) { - return exla::nif::error(env, "Unable to get lower."); - } - if (!exla::nif::get(env, argv[5], &transpose_a)) { - return exla::nif::error(env, "Unable to get transpose_a."); - } - - mlir::Value res = (*function)->TriangularSolveOp(*a, *b, left_side, lower, transpose_a); - - return exla::nif::ok(env, exla::nif::make(env, res)); -} - -ERL_NIF_TERM mlir_dynamic_update_slice(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 4) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - mlir::Value *operand, *updates; - std::vector starts; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], operand)) { - return exla::nif::error(env, "Unable to get operand."); - } - if (!exla::nif::get(env, argv[2], updates)) { - return exla::nif::error(env, "Unable to get updates."); - } - if (!exla::nif::get_list(env, argv[3], starts)) { - return exla::nif::error(env, "Unable to get starts."); - } - - mlir::Value res = (*function)->DynamicUpdateSliceOp(*operand, *updates, starts); - - return exla::nif::ok(env, exla::nif::make(env, res)); -} - -ERL_NIF_TERM mlir_infeed(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 3) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - mlir::Value* token; - std::vector shapes; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], token)) { - return exla::nif::error(env, "Unable to get token."); - } - if (!exla::nif::get_list(env, argv[2], shapes)) { - return exla::nif::error(env, "Unable to get shapes."); - } - - std::pair> infeed = (*function)->InfeedOp(*token, shapes); - - ERL_NIF_TERM out_token = exla::nif::make(env, infeed.first); - ERL_NIF_TERM results = exla::nif::make_list(env, infeed.second); - - return exla::nif::ok(env, enif_make_tuple2(env, out_token, results)); -} - -ERL_NIF_TERM mlir_outfeed(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 3) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - mlir::Value* token; - std::vector inputs; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], token)) { - return exla::nif::error(env, "Unable to get token."); - } - if (!exla::nif::get_list(env, argv[2], inputs)) { - return exla::nif::error(env, "Unable to get inputs."); - } - - mlir::Value result = (*function)->OutfeedOp(inputs, *token); - - return exla::nif::ok(env, exla::nif::make(env, result)); -} - -ERL_NIF_TERM mlir_call(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 3) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction **function, **computation; - std::vector arguments; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get_list(env, argv[1], arguments)) { - return exla::nif::error(env, "Unable to get arguments."); - } - if (!exla::nif::get(env, argv[2], computation)) { - return exla::nif::error(env, "Unable to get computation."); - } - - std::vector result = (*function)->CallOp(arguments, *computation); - - return exla::nif::ok(env, exla::nif::make_list(env, result)); -} - -ERL_NIF_TERM mlir_while(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 2) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - std::vector initial; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get_list(env, argv[1], initial)) { - return exla::nif::error(env, "Unable to get initial."); - } - - auto result = (*function)->WhileOp(initial); - - ERL_NIF_TERM res = exla::nif::make_list(env, result.first); - ERL_NIF_TERM pred_region = exla::nif::make(env, result.second.first); - ERL_NIF_TERM body_region = exla::nif::make(env, result.second.second); - return exla::nif::ok(env, enif_make_tuple3(env, res, pred_region, body_region)); -} - -ERL_NIF_TERM mlir_return(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 2) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - std::vector operands; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get_list(env, argv[1], operands)) { - return exla::nif::error(env, "Unable to get operands."); - } - - std::vector res = (*function)->ReturnOp(operands); - return exla::nif::ok(env, exla::nif::make_list(env, res)); -} - -ERL_NIF_TERM mlir_qr(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { - if (argc != 4) { - return exla::nif::error(env, "Bad argument count."); - } - - exla::MLIRFunction** function; - mlir::Value* value; - std::vector q_shape, r_shape; - - if (!exla::nif::get(env, argv[0], function)) { - return exla::nif::error(env, "Unable to get function."); - } - if (!exla::nif::get(env, argv[1], value)) { - return exla::nif::error(env, "Unable to get value."); - } - if (!exla::nif::get_list(env, argv[2], q_shape)) { - return exla::nif::error(env, "Unable to get Q shape."); - } - if (!exla::nif::get_list(env, argv[3], r_shape)) { - return exla::nif::error(env, "Unable to get R shape."); - } - - std::pair result = (*function)->QRCpuCustomCall(*value, q_shape, r_shape); - - ERL_NIF_TERM q = exla::nif::make(env, result.first); - ERL_NIF_TERM r = exla::nif::make(env, result.second); - - return exla::nif::ok(env, enif_make_tuple2(env, q, r)); -} \ No newline at end of file diff --git a/exla/c_src/exla/mlir/ops.h b/exla/c_src/exla/mlir/ops.h deleted file mode 100644 index 68a76e728a..0000000000 --- a/exla/c_src/exla/mlir/ops.h +++ /dev/null @@ -1,116 +0,0 @@ -#pragma once -#include "../exla_nif_util.h" -#include "builder.h" - -#define DEFINE_NIF(FUNCTION_NAME) ERL_NIF_TERM FUNCTION_NAME(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) - -DEFINE_NIF(mlir_compile); -DEFINE_NIF(new_mlir_module); -DEFINE_NIF(new_mlir_context); -DEFINE_NIF(create_mlir_function); -DEFINE_NIF(get_mlir_function_arguments); - -// Binary Ops -DEFINE_NIF(mlir_add); -DEFINE_NIF(mlir_subtract); -DEFINE_NIF(mlir_multiply); -DEFINE_NIF(mlir_min); -DEFINE_NIF(mlir_max); -DEFINE_NIF(mlir_remainder); -DEFINE_NIF(mlir_pow); -DEFINE_NIF(mlir_divide); -DEFINE_NIF(mlir_atan2); -DEFINE_NIF(mlir_equal); -DEFINE_NIF(mlir_not_equal); -DEFINE_NIF(mlir_less); -DEFINE_NIF(mlir_less_equal); -DEFINE_NIF(mlir_greater); -DEFINE_NIF(mlir_greater_equal); -DEFINE_NIF(mlir_bitwise_and); -DEFINE_NIF(mlir_bitwise_or); -DEFINE_NIF(mlir_bitwise_xor); -DEFINE_NIF(mlir_shift_left); -DEFINE_NIF(mlir_shift_right_logical); -DEFINE_NIF(mlir_shift_right_arithmetic); - -// Unary Ops -DEFINE_NIF(mlir_abs); -DEFINE_NIF(mlir_exp); -DEFINE_NIF(mlir_expm1); -DEFINE_NIF(mlir_floor); -DEFINE_NIF(mlir_ceil); -DEFINE_NIF(mlir_round); -DEFINE_NIF(mlir_log); -DEFINE_NIF(mlir_sigmoid); -DEFINE_NIF(mlir_log1p); -DEFINE_NIF(mlir_sign); -DEFINE_NIF(mlir_cos); -DEFINE_NIF(mlir_sin); -DEFINE_NIF(mlir_tan); -DEFINE_NIF(mlir_acos); -DEFINE_NIF(mlir_asin); -DEFINE_NIF(mlir_atan); -DEFINE_NIF(mlir_cosh); -DEFINE_NIF(mlir_sinh); -DEFINE_NIF(mlir_tanh); -DEFINE_NIF(mlir_acosh); -DEFINE_NIF(mlir_asinh); -DEFINE_NIF(mlir_atanh); -DEFINE_NIF(mlir_sqrt); -DEFINE_NIF(mlir_cbrt); -DEFINE_NIF(mlir_bitwise_not); -DEFINE_NIF(mlir_negate); -DEFINE_NIF(mlir_erf); -DEFINE_NIF(mlir_erfc); -DEFINE_NIF(mlir_erf_inv); -DEFINE_NIF(mlir_is_infinity); -DEFINE_NIF(mlir_is_nan); -DEFINE_NIF(mlir_rsqrt); -DEFINE_NIF(mlir_clz); -DEFINE_NIF(mlir_real); -DEFINE_NIF(mlir_imag); -DEFINE_NIF(mlir_conjugate); -DEFINE_NIF(mlir_population_count); -DEFINE_NIF(mlir_convolution); - -// -DEFINE_NIF(mlir_iota); -DEFINE_NIF(mlir_reshape); -DEFINE_NIF(mlir_reverse); -DEFINE_NIF(mlir_transpose); -DEFINE_NIF(mlir_slice); -DEFINE_NIF(mlir_dynamic_slice); -DEFINE_NIF(mlir_constant_r0); -DEFINE_NIF(mlir_constant_from_binary); -DEFINE_NIF(mlir_dot_general); -DEFINE_NIF(mlir_select); -DEFINE_NIF(mlir_convert); -DEFINE_NIF(mlir_top_k); -DEFINE_NIF(mlir_sort); -DEFINE_NIF(mlir_bitcast_convert); -DEFINE_NIF(mlir_pad); -DEFINE_NIF(mlir_optimization_barrier); -DEFINE_NIF(mlir_clamp); -DEFINE_NIF(mlir_get_shape); -DEFINE_NIF(mlir_broadcast_in_dim); -DEFINE_NIF(mlir_concatenate); -DEFINE_NIF(dump_mlir_module); -DEFINE_NIF(mlir_scatter); -DEFINE_NIF(mlir_select_and_scatter); -DEFINE_NIF(mlir_gather); -DEFINE_NIF(mlir_fft); -DEFINE_NIF(mlir_create_token); -DEFINE_NIF(mlir_triangular_solve); -DEFINE_NIF(mlir_dynamic_update_slice); -DEFINE_NIF(mlir_reduce); -DEFINE_NIF(mlir_window_reduce); -DEFINE_NIF(mlir_map); -DEFINE_NIF(mlir_if); -DEFINE_NIF(mlir_push_region); -DEFINE_NIF(mlir_pop_region); -DEFINE_NIF(mlir_infeed); -DEFINE_NIF(mlir_outfeed); -DEFINE_NIF(mlir_call); -DEFINE_NIF(mlir_while); -DEFINE_NIF(mlir_return); -DEFINE_NIF(mlir_qr); \ No newline at end of file diff --git a/exla/lib/exla/backend.ex b/exla/lib/exla/backend.ex index 54bcffa9f7..8e0b87f362 100644 --- a/exla/lib/exla/backend.ex +++ b/exla/lib/exla/backend.ex @@ -42,8 +42,8 @@ defmodule EXLA.Backend do @impl true def from_binary(%T{shape: shape, type: type} = tensor, binary, backend_options) do {client, device_id} = client_and_device_id(backend_options) - shape = EXLA.Shape.make_shape(type, shape) - buffer = EXLA.DeviceBuffer.place_on_device(binary, shape, client, device_id) + typespec = EXLA.Typespec.tensor(type, shape) + buffer = EXLA.DeviceBuffer.place_on_device(binary, typespec, client, device_id) put_in(tensor.data, %B{buffer: buffer}) end @@ -126,20 +126,20 @@ defmodule EXLA.Backend do device_id = backend_opts[:device_id] || client.default_device_id - shape = EXLA.Shape.make_shape(type, dims) + typespec = EXLA.Typespec.tensor(type, dims) result = EXLA.NIF.create_buffer_from_device_pointer( client.ref, pointer, opts[:mode], - shape.ref, + EXLA.Typespec.nif_encode(typespec), device_id ) case result do {:ok, ref} -> - buffer = EXLA.DeviceBuffer.from_ref(ref, client, device_id, shape) + buffer = EXLA.DeviceBuffer.from_ref(ref, client, device_id, typespec) {:ok, %{template | data: %EXLA.Backend{buffer: buffer}}} error -> diff --git a/exla/lib/exla/binary_buffer.ex b/exla/lib/exla/binary_buffer.ex index 6a429bf860..2ed3749b8c 100644 --- a/exla/lib/exla/binary_buffer.ex +++ b/exla/lib/exla/binary_buffer.ex @@ -3,10 +3,10 @@ defmodule EXLA.BinaryBuffer do A buffer where data is kept in a binary. """ - @enforce_keys [:data, :shape] - defstruct [:data, :shape] + @enforce_keys [:data, :typespec] + defstruct [:data, :typespec] - def from_binary(data, shape) do - %EXLA.BinaryBuffer{data: data, shape: shape} + def from_binary(data, typespec) do + %EXLA.BinaryBuffer{data: data, typespec: typespec} end end diff --git a/exla/lib/exla/client.ex b/exla/lib/exla/client.ex index 9bb2084e41..3ad4ae51b6 100644 --- a/exla/lib/exla/client.ex +++ b/exla/lib/exla/client.ex @@ -86,37 +86,33 @@ defmodule EXLA.Client do end @doc """ - Sends `data_and_shapes` to device infeed. + Sends `data_and_typespecs` to device infeed. - `data_and_shapes` must be a list of two element tuples where the + `data_and_typespecs` must be a list of two element tuples where the first element is a binary or a flat list of binaries and the second - element is a `EXLA.Shape`. - - > Note: XLA does not support tuple infeed shapes when running on - > host. Passing one will simply block the operation indefinitely. - > Instead, convert the tuple into multiple infeed operations. + element is a `EXLA.Typespec`. """ - def to_infeed(%EXLA.Client{ref: client}, device_id, data_and_shapes) - when is_list(data_and_shapes) do - data_and_shapes = - Enum.map(data_and_shapes, fn - {binary, %EXLA.Shape{ref: shape}} when is_binary(binary) -> {[binary], shape} - {[binary | _] = data, %EXLA.Shape{ref: shape}} when is_binary(binary) -> {data, shape} + def to_infeed(%EXLA.Client{ref: client}, device_id, data_and_typespecs) + when is_list(data_and_typespecs) do + data_and_typespecs = + Enum.map(data_and_typespecs, fn + {binary, typespec} when is_binary(binary) -> + {[binary], EXLA.Typespec.nif_encode(typespec)} + + {[binary | _] = data, typespec} when is_binary(binary) -> + {data, EXLA.Typespec.nif_encode(typespec)} end) - EXLA.NIF.transfer_to_infeed(client, device_id, data_and_shapes) |> unwrap!() + EXLA.NIF.transfer_to_infeed(client, device_id, data_and_typespecs) |> unwrap!() end @doc """ Sends buffer from device outfeed to the given process tagged by `ref`. - - > Note: XLA does not support tuple outfeed shapes. Passing one will simply - > block the operation indefinitely. Instead, convert the tuple into multiple - > outfeed operations. """ - def from_outfeed(%EXLA.Client{ref: client}, device_id, shapes, pid, ref) when is_list(shapes) do - shape_refs = Enum.map(shapes, fn %EXLA.Shape{ref: shape_ref} -> shape_ref end) - EXLA.NIF.transfer_from_outfeed(client, device_id, shape_refs, pid, ref) |> unwrap!() + def from_outfeed(%EXLA.Client{ref: client}, device_id, typespecs, pid, ref) + when is_list(typespecs) do + typespecs = Enum.map(typespecs, &EXLA.Typespec.nif_encode/1) + EXLA.NIF.transfer_from_outfeed(client, device_id, typespecs, pid, ref) |> unwrap!() end ## Callbacks diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 1772ec17c2..72b75a371f 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -6,6 +6,7 @@ defmodule EXLA.Defn do alias Nx.Defn.{Composite, Expr, Tree} alias Nx.Tensor, as: T + alias EXLA.Typespec alias EXLA.MLIR.Value alias EXLA.MLIR.Function @@ -62,7 +63,7 @@ defmodule EXLA.Defn do comp_fun ) - {input_shape, input_indexes} = extra + {input_typespecs, input_indexes} = extra # Also discard the stream inputs from used inputs, similar to how it is done to buffers # Note we discard all lazy transfers too, as they are not possible with streams @@ -107,7 +108,7 @@ defmodule EXLA.Defn do # The outfeed reader will redirect all outputs with flag 1 to the current # process. Once flag 0 is emitted, we know the stream is done. - {output_shapes, outfeed} = Outfeed.configure_stream_hook(outfeed, self(), lock) + {output_typespecs, outfeed} = Outfeed.configure_stream_hook(outfeed, self(), lock) {:ok, outfeed_pid} = Outfeed.start_child(executable, outfeed, Process.group_leader()) stream = @@ -117,10 +118,10 @@ defmodule EXLA.Defn do runner, outfeed_pid, input, - input_shape, + input_typespecs, input_indexes, output, - output_shapes, + output_typespecs, acc_output ) @@ -140,19 +141,20 @@ defmodule EXLA.Defn do acc_length, %Function{} = builder, expr, - used_shapes, + used_typespecs, outfeed, options ) do %{token: root_token, infeeds: []} = outfeed - {input_shapes, used_shapes} = Enum.split_while(used_shapes, fn {i, _} -> i < input_length end) + {input_typespecs, used_typespecs} = + Enum.split_while(used_typespecs, fn {i, _} -> i < input_length end) # Get all input indexes and shape - input_indexes = Enum.map(input_shapes, &elem(&1, 0)) + input_indexes = Enum.map(input_typespecs, &elem(&1, 0)) - # Drop all accumulator entries from used_shapes as we will handle it separately. - {acc_shapes, used_shapes} = Enum.split(used_shapes, acc_length) + # Drop all accumulator entries from used_typespecs as we will handle it separately. + {acc_typespecs, used_typespecs} = Enum.split(used_typespecs, acc_length) # The stream loop will be a three element tuple: # @@ -161,35 +163,31 @@ defmodule EXLA.Defn do # The looping constants. # # The input will be read as part of the infeed. - acc_shapes_l = Enum.map(acc_shapes, &elem(&1, 1)) - acc_shape = List.to_tuple(acc_shapes_l) + acc_typespecs_l = Enum.map(acc_typespecs, &elem(&1, 1)) + acc_typespec = List.to_tuple(acc_typespecs_l) - flag_shape = EXLA.Shape.make_shape({:pred, 8}, {}) + flag_typespec = Typespec.tensor({:pred, 8}, {}) args = EXLA.MLIR.Function.get_arguments(builder) - {token, [flag]} = Value.infeed(root_token, flag_shape) + {token, [flag]} = Value.infeed(root_token, [flag_typespec]) init = [flag, token | args] - {[_flag, out_token | results], pred_region, body_region} = Value.while(builder, init) - - acc = Enum.take(results, acc_length) - - output = wrap_tuple_result(acc, acc_shape) - - [flag | _] = Function.push_region(builder, pred_region) - r0 = Value.constant_r0(builder, 1, {:pred, 8}) - pred_op = Value.equal(builder, flag, r0) - Value.variadic_return(builder, [pred_op]) + arg_typespecs = Enum.map(init, &Value.get_typespec/1) + {pred_computation, [flag | _]} = Function.push_region(builder, arg_typespecs) + typespec = Typespec.tensor({:pred, 8}, {}) + r0 = Value.constant(builder, [1], typespec) + pred_op = Value.equal(flag, r0, typespec) + Value.return(builder, [pred_op]) Function.pop_region(builder) - [_flag, token | args] = Function.push_region(builder, body_region) + {body_computation, [_flag, token | args]} = Function.push_region(builder, arg_typespecs) {acc, constant} = Enum.split(args, acc_length) - {indices, input_shape} = Enum.unzip(input_shapes) - {token, input} = Value.infeed(token, input_shape) + {indices, input_typespecs} = Enum.unzip(input_typespecs) + {token, input} = Value.infeed(token, input_typespecs) input_params = Enum.zip(indices, input) @@ -197,12 +195,12 @@ defmodule EXLA.Defn do case expr do {output_expr, acc_expr} -> acc_params = - Enum.map(acc_shapes, fn {pos, _shape} -> + Enum.map(acc_typespecs, fn {pos, _typespec} -> {pos, Enum.fetch!(acc, pos - input_length)} end) constant_params = - Enum.with_index(used_shapes, fn {pos, _shape}, index -> + Enum.with_index(used_typespecs, fn {pos, _typespec}, index -> {pos, Enum.fetch!(constant, index)} end) @@ -226,16 +224,21 @@ defmodule EXLA.Defn do end # Emit the stream hook to signal loop output - {token, [flag]} = Value.infeed(token, flag_shape) + {token, [flag]} = Value.infeed(token, [flag_typespec]) - Value.variadic_return(flag.function, [flag, token | acc] ++ List.flatten(constant)) + Value.return(flag.function, [flag, token | acc] ++ List.flatten(constant)) Function.pop_region(builder) + [_flag, out_token | results] = Value.while(builder, pred_computation, body_computation, init) + + acc = Enum.take(results, acc_length) + output = wrap_tuple_result(acc, acc_typespec) + outfeed = outfeed |> Outfeed.with_token(out_token) |> Outfeed.close(builder) - Value.variadic_return(builder, output) + Value.return(builder, output) - {{input_shape, input_indexes}, outfeed} + {{input_typespecs, input_indexes}, outfeed} end @doc false @@ -280,9 +283,9 @@ defmodule EXLA.Defn do end end - defp to_root_computation(%Function{} = function, expr, used_shapes, outfeed, options) do + defp to_root_computation(%Function{} = function, expr, used_typespecs, outfeed, options) do params = - Enum.zip_with(used_shapes, Function.get_arguments(function), fn {pos, _shape}, arg -> + Enum.zip_with(used_typespecs, Function.get_arguments(function), fn {pos, _typespec}, arg -> {pos, arg} end) @@ -303,7 +306,7 @@ defmodule EXLA.Defn do {res, cache} = recur_flatten(expr, state, new_cache(outfeed)) outfeed = cache |> get_outfeed() |> Outfeed.close(function) - Value.variadic_return(function, res) + Value.return(function, res) {:ok, outfeed} end @@ -406,29 +409,29 @@ defmodule EXLA.Defn do {comp_time, {evaled, {xla_time, executable, extra, outfeed}}} = :timer.tc(fn -> comp_cache_fun.(comp_key, fn -> - {reverse_inputs_and_shapes, reverse_infeeds} = + {reverse_inputs_and_typespecs, reverse_infeeds} = reverse_args_identifiers |> Enum.reverse() |> EXLA.Defn.Buffers.split_by_value(used_inputs, fn - {type, shape, _names}, i, nil -> {i, EXLA.Shape.make_shape(type, shape)} - {type, shape, _names}, i, depth -> {i, depth, EXLA.Shape.make_shape(type, shape)} + {type, shape, _names}, i, nil -> {i, Typespec.tensor(type, shape)} + {type, shape, _names}, i, depth -> {i, depth, Typespec.tensor(type, shape)} end) - inputs_and_shapes = Enum.reverse(reverse_inputs_and_shapes) + inputs_and_typespecs = Enum.reverse(reverse_inputs_and_typespecs) - comp_arg_shapes = - for {i, shape} <- inputs_and_shapes, i >= used_buffers, do: shape + comp_arg_typespecs = + for {i, typespec} <- inputs_and_typespecs, i >= used_buffers, do: typespec - out_types = + out_typespecs = [outputs] |> Nx.Defn.Composite.flatten_list() |> Enum.map(fn t -> t |> Nx.devectorize() - |> then(&EXLA.Shape.make_shape(&1.type, &1.shape)) + |> then(&Typespec.tensor(&1.type, &1.shape)) end) - EXLA.MLIR.Module.new(comp_arg_shapes, out_types, fn builder -> + EXLA.MLIR.Module.new(comp_arg_typespecs, out_typespecs, fn builder -> outfeed = outfeed |> Outfeed.with_token(Value.create_token(builder)) @@ -437,17 +440,18 @@ defmodule EXLA.Defn do expr = Nx.Defn.Composite.traverse(expr || fun.(vars), &Nx.devectorize/1) {extra, outfeed} = - to_computation.(builder, expr, inputs_and_shapes, outfeed) + to_computation.(builder, expr, inputs_and_typespecs, outfeed) {xla_time, executable} = :timer.tc(fn -> - shapes = for {i, shape} <- inputs_and_shapes, i >= used_buffers, do: shape + typespecs = + for {i, typespec} <- inputs_and_typespecs, i >= used_buffers, do: typespec EXLA.MLIR.Module.compile( builder.module, client, - shapes, - builder.return_shape, + typespecs, + builder.return_typespecs, options ) end) @@ -518,14 +522,15 @@ defmodule EXLA.Defn do [initial_arg, _arg, pred, body] = args initial_with_token = {get_token(cache), initial_arg} - {initial, cache} = - recur_composite(initial_with_token, &cast_pred_to_u8/1, state, cache) + {initial, cache} = recur_composite(initial_with_token, state, cache) - {[token | results], pred_region, body_region} = Value.while(function, initial) - result = wrap_tuple_result(results, initial_arg) + {pred_computation, cache} = mlir_while_computation(pred, initial, {:pred, 8}, state, cache) + {body_computation, cache} = mlir_while_computation(body, initial, :with_token, state, cache) + + [token | results] = + Value.while(function, pred_computation, body_computation, List.flatten(initial)) - cache = mlir_while_computation(pred_region, pred, {:pred, 8}, state, cache) - cache = mlir_while_computation(body_region, body, :with_token, state, cache) + result = wrap_tuple_result(results, initial_arg) {result, update_token(cache, token)} end @@ -557,8 +562,8 @@ defmodule EXLA.Defn do end defp cached_recur_operator(:fun, %T{data: %Expr{args: args}, type: type}, state, cache) do - [args, expr, {_, name, _}] = args - {fun_computation(name, args, expr, type, state), cache} + [args, expr, {_, _, _}] = args + {fun_computation(args, expr, type, state), cache} end defp cached_recur_operator( @@ -587,20 +592,21 @@ defmodule EXLA.Defn do tensor end - {q, r} = Value.qr(tensor, q_expr.shape, r_expr.shape) + {q, r} = Value.qr(tensor, expr_to_typespec(q_expr), expr_to_typespec(r_expr)) {[q, r], cache} end defp cached_recur_operator( :optional, - %T{data: %Expr{args: [%{data: %{op: :top_k, args: [tensor, opts]}}, _expr, _callback]}} = + %T{data: %Expr{args: [%{data: %{op: :top_k, args: [tensor, opts]}}, expr, _callback]}} = _out, state, cache ) do {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() - - results = Value.top_k(tensor, opts[:k]) + {values, idx} = expr + typespecs = [expr_to_typespec(values), expr_to_typespec(idx)] + results = Value.top_k(tensor, opts[:k], typespecs) {results, cache} end @@ -613,7 +619,7 @@ defmodule EXLA.Defn do ) do {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() - {fft2(&Value.fft(&1, :fft, &2), [tensor, opts], out, state), cache} + {fft2(&Value.fft(&1, :fft, &2, &3), [tensor, opts], out, state), cache} end defp cached_recur_operator( @@ -625,7 +631,7 @@ defmodule EXLA.Defn do ) do {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() - {fft2(&Value.fft(&1, :ifft, &2), [tensor, opts], out, state), cache} + {fft2(&Value.fft(&1, :ifft, &2, &3), [tensor, opts], out, state), cache} end defp cached_recur_operator(:optional, %T{data: %Expr{args: args}}, state, cache) do @@ -647,7 +653,11 @@ defmodule EXLA.Defn do {computation, Map.put(cache, key, computation)} end - [token | result] = Value.call(state.builder, [get_token(cache) | call_args], call_body) + typespecs = [Typespec.token() | container_to_typespecs(expr)] + + [token | result] = + Value.call(state.builder, [get_token(cache) | call_args], call_body, typespecs) + {wrap_tuple_result(result, expr), update_token(cache, token)} end @@ -688,7 +698,7 @@ defmodule EXLA.Defn do if ans.shape == {} do op else - Value.broadcast_in_dim(op, EXLA.Shape.make_shape(ans.type, ans.shape), {}) + Value.broadcast_in_dim(op, [], expr_to_typespec(ans)) end end @@ -700,47 +710,55 @@ defmodule EXLA.Defn do to_constant(state.builder, Nx.to_number(tensor), tensor.type) shape -> - shape = EXLA.Shape.make_shape(tensor.type, shape) - Value.constant_from_binary(state.builder, Nx.to_binary(tensor), shape) + Value.constant( + state.builder, + Nx.to_flat_list(tensor), + Typespec.tensor(tensor.type, shape) + ) end end - defp to_operator(:iota, [axis], %{type: type, shape: shape}, state) do - shape = EXLA.Shape.make_shape(type, shape) - EXLA.Lib.iota(state.builder, shape, axis) + defp to_operator(:iota, [axis], ans, state) do + EXLA.Lib.iota(state.builder, axis, expr_to_typespec(ans)) end defp to_operator(:eye, [], %{type: type, shape: shape}, state) do iota_type = Nx.Type.merge_number({:u, 8}, Tuple.product(shape)) - iota_shape = EXLA.Shape.make_shape(iota_type, shape) + iota_typespec = Typespec.tensor(iota_type, shape) rank = tuple_size(shape) - i0 = Value.iota(state.builder, iota_shape, rank - 2) - i1 = Value.iota(state.builder, iota_shape, rank - 1) - to_type(Value.equal(state.builder, i0, i1), type) + i0 = Value.iota(state.builder, rank - 2, iota_typespec) + i1 = Value.iota(state.builder, rank - 1, iota_typespec) + + typespec = Typespec.tensor({:pred, 8}, shape) + Value.equal(i0, i1, typespec) |> to_type(type) end ## to_operator shape - defp to_operator(:reshape, [%Value{} = op], %{shape: shape}, _state) do - Value.reshape(op, shape) + defp to_operator(:reshape, [%Value{} = op], ans, _state) do + Value.reshape(op, expr_to_typespec(ans)) end - defp to_operator(:pad, [%Value{} = op, %Value{} = value, padding_config], %{type: type}, _state) do - Value.pad(to_type(op, type), to_type(value, type), padding_config) + defp to_operator(:pad, [%Value{} = op, %Value{} = value, padding_config], ans, _state) do + Value.pad( + to_type(op, ans.type), + to_type(value, ans.type), + padding_config, + expr_to_typespec(ans) + ) end defp to_operator(:broadcast, [%Value{} = op, _shape, axes], ans, _state) do - out_shape = EXLA.Shape.make_shape(ans.type, ans.shape) - Value.broadcast_in_dim(to_type(op, ans.type), out_shape, List.to_tuple(axes)) + Value.broadcast_in_dim(to_type(op, ans.type), axes, expr_to_typespec(ans)) end - defp to_operator(:transpose, [%Value{} = op, axes], _ans, _state) do - Value.transpose(op, axes) + defp to_operator(:transpose, [%Value{} = op, axes], ans, _state) do + Value.transpose(op, axes, expr_to_typespec(ans)) end defp to_operator(:squeeze, [%Value{} = op, _axes], ans, _state) do - Value.reshape(op, ans.shape) + Value.reshape(op, expr_to_typespec(ans)) end ## to_operator others @@ -769,18 +787,17 @@ defmodule EXLA.Defn do contract_axes2, batch_axes2 ], - %{type: type, shape: shape}, + ans, state ) do precision = state.precision - output_shape = EXLA.Shape.make_shape(type, shape) Value.dot_general( - output_shape, left, right, {contract_axes1, batch_axes1, contract_axes2, batch_axes2}, - precision + precision, + expr_to_typespec(ans) ) end @@ -799,16 +816,8 @@ defmodule EXLA.Defn do %{type: output_type} = ans - # Build general conv dims - input_permutation = List.to_tuple(opts[:input_permutation]) - [out_features, in_features | spatial_features] = opts[:kernel_permutation] - kernel_permutation = List.to_tuple([in_features, out_features | spatial_features]) - - output_permutation = - opts[:output_permutation] - |> List.to_tuple() - - dimension_numbers = {input_permutation, kernel_permutation, output_permutation} + dimension_numbers = + {opts[:input_permutation], opts[:kernel_permutation], opts[:output_permutation]} # Ensure both types are floating operand = to_type(operand, output_type) @@ -825,52 +834,58 @@ defmodule EXLA.Defn do feature_group_count, batch_group_count, state.precision, - ans.shape + expr_to_typespec(ans) ) end defp to_operator( :select, [%Value{} = pred, %Value{} = on_true, %Value{} = on_false], - %{type: type, shape: shape}, + %{type: type, shape: shape} = ans, _state ) do pred = to_type(pred, {:pred, 8}) - out_shape = EXLA.Shape.make_shape(type, shape) + typespec = expr_to_typespec(ans) on_true = on_true |> to_type(type) - |> Value.broadcast_in_dim(out_shape, broadcast_axes(op_shape(on_true), shape)) + |> Value.broadcast_in_dim(broadcast_axes(op_shape(on_true), shape), typespec) on_false = on_false |> to_type(type) - |> Value.broadcast_in_dim(out_shape, broadcast_axes(op_shape(on_false), shape)) + |> Value.broadcast_in_dim(broadcast_axes(op_shape(on_false), shape), typespec) - Value.select(pred, on_true, on_false) + Value.select(pred, on_true, on_false, typespec) end - defp to_operator(:triangular_solve, [%Value{} = a, b, opts], %{type: type}, _state) do + defp to_operator(:triangular_solve, [%Value{} = a, b, opts], %{type: type} = ans, _state) do left_side = Keyword.fetch!(opts, :left_side) lower = Keyword.fetch!(opts, :lower) transform = Keyword.fetch!(opts, :transform_a) - case Value.get_shape(b).dims do + case Value.get_typespec(b).shape do {_} = b_shape -> + b_shape = Tuple.append(b_shape, 1) + b = b |> to_type(type) - |> Value.reshape(Tuple.append(b_shape, 1)) + |> Value.reshape(Typespec.tensor(type, b_shape)) + + typespec = Typespec.tensor(type, b_shape) to_type(a, type) - |> Value.triangular_solve(b, left_side, lower, transform) - |> Value.reshape(b_shape) + |> Value.triangular_solve(b, left_side, lower, transform, typespec) + |> Value.reshape(Typespec.tensor(type, ans.shape)) _ -> + typespec = Typespec.tensor(type, ans.shape) + to_type(a, type) - |> Value.triangular_solve(to_type(b, type), left_side, lower, transform) + |> Value.triangular_solve(to_type(b, type), left_side, lower, transform, typespec) end end @@ -880,49 +895,42 @@ defmodule EXLA.Defn do ## to_operator element-wise - defp to_operator(:negate, [%Value{} = op], _ans, _state), do: Value.negate(op) - - defp to_operator(:abs, [%Value{} = op], _ans, _state), do: Value.abs(op) + defp to_operator(:negate, [%Value{} = op], ans, _state), + do: Value.negate(op, expr_to_typespec(ans)) - defp to_operator(:sign, [%Value{} = op], %{shape: shape, type: type}, state) do - case type do - {:u, _} -> - ones_shape = Tuple.duplicate(1, tuple_size(shape)) + defp to_operator(:abs, [%Value{} = op], ans, _state), do: Value.abs(op, expr_to_typespec(ans)) - one = Enum.reduce(1..tuple_size(shape), 1, fn _, acc -> [acc] end) + defp to_operator(:sign, [%Value{} = op], ans, state) do + typespec = expr_to_typespec(ans) - one = - one - |> Nx.tensor(type: type, backend: Nx.BinaryBackend) - |> Nx.to_binary() - |> then( - &Value.constant_from_binary(state.builder, &1, %{dtype: type, dims: ones_shape}) - ) + case typespec.type do + {:u, _} -> + one = Value.constant(state.builder, [1], Typespec.to_shape(typespec, {})) one - |> Value.broadcast_in_dim(Value.get_shape(op), List.to_tuple(Nx.axes(shape))) - |> then(&Value.min(state.builder, &1, op)) + |> Value.broadcast_in_dim([], typespec) + |> Value.min(op, typespec) _ -> - Value.sign(op) + Value.sign(op, typespec) end end - defp to_operator(:right_shift, [%Value{} = left, %Value{} = right], out, state) do + defp to_operator(:right_shift, [%Value{} = left, %Value{} = right], out, _state) do op = if match?({:u, _}, out.type), do: :right_shift_logical, else: :right_shift_arithmetic - apply_mlir_broadcasted_bin_op(state.builder, op, out, left, right) + apply_mlir_broadcasted_bin_op(op, out, left, right) end @bin_op [:add, :subtract, :multiply, :min, :max, :remainder, :pow, :divide, :atan2] ++ [:bitwise_and, :bitwise_or, :bitwise_xor, :left_shift] - defp to_operator(op, [%Value{} = left, %Value{} = right], out, state) + defp to_operator(op, [%Value{} = left, %Value{} = right], out, _state) when op in @bin_op do - apply_mlir_broadcasted_bin_op(state.builder, op, out, left, right) + apply_mlir_broadcasted_bin_op(op, out, left, right) end defp to_operator(:quotient, [left, right], ans, state) do @@ -931,17 +939,16 @@ defmodule EXLA.Defn do @bin_comp_op [:equal, :not_equal, :greater, :less, :greater_equal, :less_equal] - defp to_operator(op, [%Value{} = left, %Value{} = right], ans, state) + defp to_operator(op, [%Value{} = left, %Value{} = right], ans, _state) when op in @bin_comp_op do - apply_mlir_broadcasted_bin_op(state.builder, op, ans, left, right) + apply_mlir_broadcasted_bin_op(op, ans, left, right) end @bin_pred_op [logical_and: :bitwise_and, logical_or: :bitwise_or, logical_xor: :bitwise_xor] for {logical, bitwise} <- @bin_pred_op do - defp to_operator(unquote(logical), [%Value{} = left, %Value{} = right], ans, state) do + defp to_operator(unquote(logical), [%Value{} = left, %Value{} = right], ans, _state) do apply_mlir_broadcasted_bin_op( - state.builder, unquote(bitwise), ans, to_mlir_logical(left), @@ -950,58 +957,54 @@ defmodule EXLA.Defn do end end - @unary_op [:exp, :expm1, :log, :log1p, :sigmoid, :cos, :sin, :tanh, :sqrt, :rsqrt, :cbrt] ++ + @unary_op [:exp, :expm1, :log, :log1p, :sigmoid, :cos, :sin, :tan, :tanh, :sqrt, :rsqrt, :cbrt] ++ [:bitwise_not, :count_leading_zeros, :population_count, :cosh, :sinh, :acos] ++ [:asin, :atan, :floor, :ceil, :round, :acosh, :asinh, :atanh, :erf] ++ [:erfc, :erf_inv, :conjugate] - defp to_operator(op, [%Value{} = arg], %{type: type}, _state) when op in @unary_op do - apply(Value, op, [to_type(arg, type)]) + defp to_operator(op, [%Value{} = arg], %{type: type} = ans, _state) + when op in @unary_op do + apply(Value, op, [to_type(arg, type), expr_to_typespec(ans)]) end defp to_operator(:fft, [%Value{} | _] = args, out, state), - do: fft(&Value.fft(&1, :fft, &2), args, out, state) + do: fft(&Value.fft(&1, :fft, &2, &3), args, out, state) defp to_operator(:ifft, [%Value{} | _] = args, out, state), - do: fft(&Value.fft(&1, :ifft, &2), args, out, state) + do: fft(&Value.fft(&1, :ifft, &2, &3), args, out, state) - defp to_operator(:is_nan, [%Value{} = arg], _out, _state), - do: Value.is_nan(arg) + defp to_operator(:is_nan, [%Value{} = arg], out, _state), + do: Value.is_nan(arg, expr_to_typespec(out)) - defp to_operator(:is_infinity, [%Value{} = arg], _out, _state), - do: Value.is_infinity(arg) + defp to_operator(:is_infinity, [%Value{} = arg], out, _state), + do: Value.is_infinity(arg, expr_to_typespec(out)) # These operations do the type conversion implicitly, and so # we cannot mess with the output type (e.g. the to_type conversion) # because it will throw an error @complex_op [:real, :imag] - defp to_operator(op, [%Value{} = arg], %{type: type}, _state) when op in @complex_op do + defp to_operator(op, [%Value{} = arg], ans, _state) + when op in @complex_op do maybe_cast_arg = if Nx.Type.integer?(op_type(arg)) do - to_type(arg, type) + to_type(arg, ans.type) else arg end - apply(Value, op, [maybe_cast_arg]) - end - - @unary_lib_op [:tan] - - defp to_operator(op, [arg], %{type: type}, _state) when op in @unary_lib_op do - apply(EXLA.Lib, op, [to_type(arg, type)]) + apply(Value, op, [maybe_cast_arg, expr_to_typespec(ans)]) end defp to_operator(:as_type, [arg], %{type: type}, _state) do to_type(arg, type) end - defp to_operator(:bitcast, [%Value{} = arg], %{type: type}, _state) do - if op_type(arg) == type do + defp to_operator(:bitcast, [%Value{} = arg], ans, _state) do + if op_type(arg) == ans.type do arg else - Value.bitcast_convert(arg, type) + Value.bitcast_convert(arg, expr_to_typespec(ans)) end end @@ -1049,37 +1052,40 @@ defmodule EXLA.Defn do ) do arg = to_type(arg, type) keep_axes = opts[:keep_axes] - [result] = Value.reduce(fun, [to_type(acc, type)], [arg], reduce_axes(arg, opts[:axes])) + reduce_axes = reduce_axes(arg, opts[:axes]) + + typespec = Typespec.tensor(type, remove_axes(op_shape(arg), reduce_axes)) + [result] = Value.reduce(fun, [to_type(acc, type)], [arg], reduce_axes, [typespec]) if keep_axes do - Value.reshape(result, shape) + Value.reshape(result, Typespec.tensor(type, shape)) else result end end - defp to_operator(:window_sum, [arg, window_dims, opts], %{type: type}, state) do - to_window_aggregate(:add, type, arg, 0, window_dims, opts, state) + defp to_operator(:window_sum, [arg, window_dims, opts], ans, state) do + to_window_aggregate(:add, ans, arg, 0, window_dims, opts, state) end - defp to_operator(:window_max, [arg, window_dims, opts], %{type: type}, state) do + defp to_operator(:window_max, [arg, window_dims, opts], %{type: type} = ans, state) do min_number = EXLA.Lib.min_number(state.builder, type) - to_window_aggregate(:max, type, arg, min_number, window_dims, opts, state) + to_window_aggregate(:max, ans, arg, min_number, window_dims, opts, state) end - defp to_operator(:window_min, [arg, window_dims, opts], %{type: type}, state) do + defp to_operator(:window_min, [arg, window_dims, opts], %{type: type} = ans, state) do max_number = EXLA.Lib.max_number(state.builder, type) - to_window_aggregate(:min, type, arg, max_number, window_dims, opts, state) + to_window_aggregate(:min, ans, arg, max_number, window_dims, opts, state) end - defp to_operator(:window_product, [arg, window_dims, opts], %{type: type}, state) do - to_window_aggregate(:multiply, type, arg, 1, window_dims, opts, state) + defp to_operator(:window_product, [arg, window_dims, opts], ans, state) do + to_window_aggregate(:multiply, ans, arg, 1, window_dims, opts, state) end defp to_operator( :window_reduce, [arg, acc, window_dimensions, opts, fun], - %{type: type}, + %{type: type} = ans, %{builder: %Function{}} ) do padding_config = opts[:padding] @@ -1093,11 +1099,12 @@ defmodule EXLA.Defn do fun, [acc], [arg], - window_dimensions, - List.to_tuple(strides), - Tuple.duplicate(1, tuple_size(op_shape(arg))), - List.to_tuple(window_dilations), - padding_config + Tuple.to_list(window_dimensions), + strides, + List.duplicate(1, tuple_size(op_shape(arg))), + window_dilations, + padding_config, + [expr_to_typespec(ans)] ) result @@ -1106,7 +1113,7 @@ defmodule EXLA.Defn do defp to_operator( :window_scatter_max, [%Value{} = arg, %Value{} = source, %Value{} = init_value, window_dimensions, opts], - %{type: type}, + %{type: type} = ans, _state ) do padding_config = opts[:padding] @@ -1123,14 +1130,15 @@ defmodule EXLA.Defn do :gt, Tuple.to_list(window_dimensions), strides, - padding_config + padding_config, + expr_to_typespec(ans) ) end defp to_operator( :window_scatter_min, [%Value{} = arg, %Value{} = source, %Value{} = init_value, window_dimensions, opts], - %{type: type}, + %{type: type} = ans, _state ) do padding_config = opts[:padding] @@ -1147,7 +1155,8 @@ defmodule EXLA.Defn do :lt, Tuple.to_list(window_dimensions), strides, - padding_config + padding_config, + expr_to_typespec(ans) ) end @@ -1159,10 +1168,9 @@ defmodule EXLA.Defn do mlir_scatter(tensors, out, :put) end - defp to_operator(:map, [%Value{} = arg, _opts, fun], %{shape: shape, type: type}, _state) do - arg = to_type(arg, type) - - Value.map(fun, [arg], Nx.axes(shape) |> List.to_tuple()) + defp to_operator(:map, [%Value{} = arg, _opts, fun], ans, _state) do + arg = to_type(arg, ans.type) + Value.map(fun, [arg], Nx.axes(ans.shape), expr_to_typespec(ans)) end defp to_operator(op, [arg, opts], ans, state) when op in [:argmax, :argmin] do @@ -1174,7 +1182,7 @@ defmodule EXLA.Defn do max = to_type(max, ans.type) operand = to_type(operand, ans.type) - Value.clamp(operand, min, max) + Value.clamp(operand, min, max, expr_to_typespec(ans)) end defp to_operator(:slice, [%Value{} = tensor, start_indices, lengths, strides], ans, _state) do @@ -1182,7 +1190,7 @@ defmodule EXLA.Defn do if all_static? do limit_indices = Enum.zip_with(start_indices, lengths, fn i, len -> i + len end) - Value.slice(tensor, start_indices, limit_indices, strides) + Value.slice(tensor, start_indices, limit_indices, strides, expr_to_typespec(ans)) else sample = Enum.find(start_indices, &(not is_integer(&1))) @@ -1194,12 +1202,14 @@ defmodule EXLA.Defn do start_indices = Enum.map(start_indices, &to_type(&1, type)) zeros = List.duplicate(0, tuple_size(ans.shape)) - slice = Value.dynamic_slice(tensor, start_indices, lengths) + + typespec = Typespec.tensor(ans.type, List.to_tuple(lengths)) + slice = Value.dynamic_slice(tensor, start_indices, lengths, typespec) if Enum.all?(strides, &(&1 == 1)) do slice else - Value.slice(slice, zeros, lengths, strides) + Value.slice(slice, zeros, lengths, strides, expr_to_typespec(ans)) end end end @@ -1207,10 +1217,10 @@ defmodule EXLA.Defn do defp to_operator(:put_slice, [%Value{} = tensor, start_indices, slice], ans, _state) do tensor = to_type(tensor, ans.type) slice = to_type(slice, ans.type) - Value.dynamic_update_slice(tensor, slice, start_indices) + Value.dynamic_update_slice(tensor, slice, start_indices, expr_to_typespec(ans)) end - defp to_operator(:take, [%Value{} = tensor, indices, axis], _ans, _state) do + defp to_operator(:take, [%Value{} = tensor, indices, axis], ans, _state) do tensor_rank = tensor |> op_shape() |> tuple_size() indices_rank = indices |> op_shape() |> tuple_size() result_rank = tensor_rank - 1 + indices_rank @@ -1228,12 +1238,13 @@ defmodule EXLA.Defn do slice_sizes, offset_dims, collapsed_slice_dims, - start_index_map + start_index_map, + expr_to_typespec(ans) ) end - defp to_operator(:take_along_axis, [%Value{} = tensor, indices, axis], _ans, state) do - indices_shape = op_shape(indices) + defp to_operator(:take_along_axis, [%Value{} = tensor, indices, axis], ans, state) do + %{shape: indices_shape} = indices_typespec = Value.get_typespec(indices) indices_rank = tuple_size(indices_shape) axes_range = 0..(indices_rank - 1)//1 @@ -1244,33 +1255,32 @@ defmodule EXLA.Defn do collapsed_slice_dims = Enum.to_list(axes_range) start_index_map = Enum.to_list(axes_range) - indices_exla_shape = Value.get_shape(indices) + new_axis_typespec = Typespec.to_shape(indices_typespec, Tuple.append(indices_shape, 1)) - iotas = - Enum.map(axes_range, fn axis -> - Value.iota(state.builder, indices_exla_shape, axis) - end) - - new_axis_shape = Tuple.append(indices_shape, 1) + full_indices_typespec = + Typespec.to_shape(indices_typespec, Tuple.append(indices_shape, indices_rank)) - indices = - iotas - |> List.replace_at(axis, indices) - |> Enum.map(&Value.reshape(&1, new_axis_shape)) - |> Value.concatenate(indices_rank) + full_indices = + axes_range + |> Enum.map(fn + ^axis -> Value.reshape(indices, new_axis_typespec) + axis -> Value.iota(state.builder, axis, new_axis_typespec) + end) + |> Value.concatenate(indices_rank, full_indices_typespec) Value.gather( tensor, - indices, + full_indices, index_vector_dim, slice_sizes, offset_dims, collapsed_slice_dims, - start_index_map + start_index_map, + expr_to_typespec(ans) ) end - defp to_operator(:gather, [%Value{} = tensor, indices, opts], _ans, _state) do + defp to_operator(:gather, [%Value{} = tensor, indices, opts], ans, _state) do axes = Keyword.fetch!(opts, :axes) tensor_shape = op_shape(tensor) tensor_rank = tuple_size(tensor_shape) @@ -1284,11 +1294,21 @@ defmodule EXLA.Defn do batch_size = tensor_rank - length(axes) offset_dims = count_up(batch_size, batch_size) - Value.gather(tensor, indices, index_vector_dim, slice_sizes, offset_dims, axes, axes) + + Value.gather( + tensor, + indices, + index_vector_dim, + slice_sizes, + offset_dims, + axes, + axes, + expr_to_typespec(ans) + ) end - defp to_operator(:reverse, [%Value{} = tensor, axes], _ans, _state) do - Value.reverse(tensor, axes) + defp to_operator(:reverse, [%Value{} = tensor, axes], ans, _state) do + Value.reverse(tensor, axes, expr_to_typespec(ans)) end defp to_operator(:concatenate, [[%Value{} | _rest] = tensors, axis], ans, _state) do @@ -1296,7 +1316,7 @@ defmodule EXLA.Defn do tensors |> Enum.map(&to_type(&1, ans.type)) - Value.concatenate(tensors, axis) + Value.concatenate(tensors, axis, expr_to_typespec(ans)) end defp to_operator(:sort, [%Value{} = tensor, opts], ans, state) do @@ -1308,11 +1328,12 @@ defmodule EXLA.Defn do :desc -> :greater end - args = [%{type: ans.type, shape: {}}, %{type: ans.type, shape: {}}] + arg_typespec = Typespec.tensor(ans.type, {}) + arg_typespecs = [arg_typespec, arg_typespec] - comp = sort_computation(op, ans.type, args, state) + comp = sort_computation(op, ans.type, arg_typespecs, state) - Value.sort(tensor, comp, dimension, opts[:stable] == true) + Value.sort([tensor], comp, dimension, opts[:stable] == true, [expr_to_typespec(ans)]) |> hd() end defp to_operator(:argsort, [tensor, opts], ans, state) do @@ -1327,19 +1348,16 @@ defmodule EXLA.Defn do type = op_type(tensor) - args = [ - %{type: type, shape: {}}, - %{type: type, shape: {}}, - %{type: ans.type, shape: {}}, - %{type: ans.type, shape: {}} - ] + value_typespec = Typespec.tensor(type, {}) + idx_typespec = Typespec.tensor(ans.type, {}) + arg_typespecs = [value_typespec, value_typespec, idx_typespec, idx_typespec] - comp = sort_computation(op, type, args, state) + comp = sort_computation(op, type, arg_typespecs, state) EXLA.Lib.argsort(state.builder, tensor, dimension, stable, comp, ans.type) end - defp fft(exla_op, [%Value{} = tensor, opts], %{type: type}, state) do + defp fft(exla_op, [%Value{} = tensor, opts], %{type: type} = ans, state) do n = opts[:length] axis = opts[:axis] output_type = Nx.Type.to_complex(type) @@ -1359,18 +1377,20 @@ defmodule EXLA.Defn do ^last_axis -> axis ax -> ax end) - |> List.to_tuple() + + {transposed_shape, _} = Nx.Shape.transpose(ans.shape, permutation, ans.names) + transposed_typespec = Typespec.tensor(ans.type, transposed_shape) tensor - |> Value.transpose(permutation) - |> exla_op.([n]) - |> Value.transpose(permutation) + |> Value.transpose(permutation, transposed_typespec) + |> exla_op.([n], transposed_typespec) + |> Value.transpose(permutation, expr_to_typespec(ans)) else - exla_op.(tensor, [n]) + exla_op.(tensor, [n], expr_to_typespec(ans)) end end - defp fft2(exla_op, [%Value{} = tensor, opts], %{type: type}, state) do + defp fft2(exla_op, [%Value{} = tensor, opts], %{type: type} = ans, state) do [l1, l2] = lengths = opts[:lengths] [ax1, ax2] = axes = opts[:axes] output_type = Nx.Type.to_complex(type) @@ -1396,14 +1416,16 @@ defmodule EXLA.Defn do ^last_axis -> ax2 ax -> ax end) - |> List.to_tuple() + + {transposed_shape, _} = Nx.Shape.transpose(ans.shape, permutation, ans.names) + transposed_typespec = Typespec.tensor(ans.type, transposed_shape) tensor - |> Value.transpose(permutation) - |> exla_op.(lengths) - |> Value.transpose(permutation) + |> Value.transpose(permutation, transposed_typespec) + |> exla_op.(lengths, transposed_typespec) + |> Value.transpose(permutation, expr_to_typespec(ans)) else - exla_op.(tensor, lengths) + exla_op.(tensor, lengths, expr_to_typespec(ans)) end end @@ -1423,22 +1445,27 @@ defmodule EXLA.Defn do strides = List.duplicate(1, tuple_size(shape)) limit_indices = Enum.zip_with(starts, lengths, fn i, len -> i + len end) - Value.slice(tensor, starts, limit_indices, strides) + + {_, shape} = Nx.Shape.slice(shape, starts, limit_indices, strides) + typespec = Typespec.tensor(output_type, shape) + Value.slice(tensor, starts, limit_indices, strides, typespec) m < n -> zero = - Value.constant_r0(state.builder, Complex.new(0), output_type) + Value.constant(state.builder, [Complex.new(0)], Typespec.tensor(output_type, {})) padding_config = {0, 0, 0} |> List.duplicate(tuple_size(shape)) |> List.replace_at(axis, {0, n - m, 0}) - Value.pad(tensor, zero, padding_config) + shape = Nx.Shape.pad(shape, padding_config) + typespec = Typespec.tensor(output_type, shape) + Value.pad(tensor, zero, padding_config, typespec) end end - defp mlir_scatter([target, indices, updates, opts], %{type: type}, kind) + defp mlir_scatter([target, indices, updates, opts], %{type: type} = ans, kind) when kind in [:add, :put] do target = to_type(target, type) updates = to_type(updates, type) @@ -1446,7 +1473,17 @@ defmodule EXLA.Defn do update_axes = tl(axes_for_rank(update_rank)) index_axes = Keyword.fetch!(opts, :axes) - Value.scatter(target, indices, updates, kind, 1, update_axes, index_axes, index_axes) + Value.scatter( + target, + indices, + updates, + kind, + 1, + update_axes, + index_axes, + index_axes, + expr_to_typespec(ans) + ) end ## Cache and hook helpers helpers @@ -1474,68 +1511,48 @@ defmodule EXLA.Defn do ## Computation helpers - defp sort_computation(op, type, args, %{builder: %EXLA.MLIR.Function{} = builder}) do - %{module: module, name: name} = subbuilder(builder, Atom.to_string(op)) + defp sort_computation(op, type, arg_typespecs, %{builder: %EXLA.MLIR.Function{} = function}) do + {region, [lhs, rhs | _]} = Function.push_region(function, arg_typespecs) - arg_shapes = Enum.map(args, &EXLA.Shape.make_shape(&1.type, &1.shape)) - - function = - EXLA.MLIR.Module.add_function(module, name, arg_shapes, [ - EXLA.Shape.make_shape({:pred, 8}, {}) - ]) - - [lhs, rhs | _] = EXLA.MLIR.Function.get_arguments(function) + typespec = Typespec.tensor({:pred, 8}, {}) op = cond do Nx.Type.integer?(type) -> - apply(Value, op, [function, lhs, rhs]) + apply(Value, op, [lhs, rhs, typespec]) op == :less -> - is_nan = Value.is_nan(rhs) - Value.bitwise_or(function, is_nan, Value.less(function, lhs, rhs)) + is_nan = Value.is_nan(rhs, typespec) + Value.bitwise_or(is_nan, Value.less(lhs, rhs, typespec), typespec) op == :greater -> - is_nan = Value.is_nan(lhs) - Value.bitwise_or(function, is_nan, Value.greater(function, lhs, rhs)) + is_nan = Value.is_nan(lhs, typespec) + Value.bitwise_or(is_nan, Value.greater(lhs, rhs, typespec), typespec) end - Value.variadic_return(function, [op]) - function + Value.return(function, [op]) + Function.pop_region(function) + region end defp op_computation( op, - arg_shapes, - out, + arg_typespecs, %{builder: %EXLA.MLIR.Function{} = builder}, prepare_args ) do - %{module: module, name: name} = subbuilder(builder, Atom.to_string(op)) - - function = EXLA.MLIR.Module.add_function(module, name, arg_shapes, out) - - args = EXLA.MLIR.Function.get_arguments(function) - - op = apply(Value, op, [function | prepare_args.(args)]) - Value.variadic_return(function, [op]) - function + {region, args} = Function.push_region(builder, arg_typespecs) + op = apply(Value, op, prepare_args.(args) ++ [hd(arg_typespecs)]) + Value.return(builder, [op]) + Function.pop_region(builder) + region end - defp fun_computation( - name, - args, - expr, - type, - %{builder: %EXLA.MLIR.Function{module: module}} = state - ) do - arg_shapes = - Enum.map(args, fn %{type: type, shape: shape} -> EXLA.Shape.make_shape(type, shape) end) - - out_type = container_to_exla_shape(expr) + defp fun_computation(args, expr, type, %{builder: %Function{} = function} = state) do + arg_typespecs = + Enum.map(args, fn %{type: type, shape: shape} -> Typespec.tensor(type, shape) end) - function = EXLA.MLIR.Module.add_function(module, Atom.to_string(name), arg_shapes, out_type) - mlir_args = EXLA.MLIR.Function.get_arguments(function) + {region, mlir_args} = Function.push_region(function, arg_typespecs) arg_params = Enum.zip(args, mlir_args) @@ -1549,18 +1566,15 @@ defmodule EXLA.Defn do } {res, _} = recur_composite(expr, state, no_token_cache()) - Value.variadic_return(function, Enum.map(res, &to_type(&1, type))) - function + Value.return(function, Enum.map(res, &to_type(&1, type))) + Function.pop_region(function) + region end - defp mlir_while_computation( - region, - expr, - type, - %{builder: %Function{} = function} = state, - cache - ) do - [arg_token | arg_params] = Function.push_region(function, region) + defp mlir_while_computation(expr, initial, type, state, cache) do + arg_typespecs = Enum.map(List.flatten(initial), &Value.get_typespec/1) + + {region, [arg_token | arg_params]} = Function.push_region(state.builder, arg_typespecs) params = Enum.with_index(arg_params, &{&2, &1}) @@ -1586,24 +1600,22 @@ defmodule EXLA.Defn do Enum.map(res, &to_type(&1, type)) end - Value.variadic_return(function, res) - Function.pop_region(function) + Value.return(state.builder, res) + Function.pop_region(state.builder) - merge_outfeed(cache, comp_cache) + {region, merge_outfeed(cache, comp_cache)} end defp token_computation(name, args, expr, %{builder: %Function{}} = state, cache) do %Function{module: module, name: name} = subbuilder(state.builder, name) - token_shape = EXLA.Shape.make_token_shape() - - arg_shapes = Enum.map(args, &Value.get_shape/1) - - out_shapes = container_to_exla_shape(expr) + token_typespec = Typespec.token() + arg_typespecs = Enum.map(args, &Value.get_typespec/1) + out_typespecs = container_to_typespecs(expr) function = - EXLA.MLIR.Module.add_function(module, name, [token_shape | arg_shapes], [ - token_shape | out_shapes + EXLA.MLIR.Module.add_function(module, name, [token_typespec | arg_typespecs], [ + token_typespec | out_typespecs ]) [arg_token | tail] = EXLA.MLIR.Function.get_arguments(function) @@ -1619,7 +1631,7 @@ defmodule EXLA.Defn do {res, comp_cache} = recur_composite(expr, state, reset_token(cache, arg_token)) - Value.variadic_return(function, [get_token(comp_cache) | List.flatten(res)]) + Value.return(function, [get_token(comp_cache) | List.flatten(res)]) {function, merge_outfeed(cache, comp_cache)} end @@ -1629,8 +1641,8 @@ defmodule EXLA.Defn do keys = Enum.map(args, fn %Value{} = op -> - %EXLA.Shape{dims: dims, dtype: dtype} = Value.get_shape(op) - {dims, dtype} + %Typespec{type: type, shape: shape} = Value.get_typespec(op) + {shape, type} opts -> opts @@ -1720,25 +1732,30 @@ defmodule EXLA.Defn do acc = case initial do - %Value{} = initial -> initial - initial when is_number(initial) -> Value.constant_r0(state.builder, initial, type) + %Value{} = initial -> + initial + + initial when is_number(initial) -> + Value.constant(state.builder, [initial], Typespec.tensor(type, {})) end - arg_shape = EXLA.Shape.make_shape(type, {}) - args = [arg_shape, arg_shape] - comp = op_computation(op, args, [EXLA.Shape.make_shape(type, shape)], state, &Enum.reverse/1) + args = [Typespec.tensor(type, {}), Typespec.tensor(type, {})] + comp = op_computation(op, args, state, &Enum.reverse/1) keep_axes = opts[:keep_axes] - [result] = Value.reduce(comp, [acc], [arg], reduce_axes(arg, opts[:axes])) + reduce_axes = reduce_axes(arg, opts[:axes]) + + typespec = Typespec.tensor(type, remove_axes(op_shape(arg), reduce_axes)) + [result] = Value.reduce(comp, [acc], [arg], reduce_axes, [typespec]) if keep_axes do - Value.reshape(result, shape) + Value.reshape(result, Typespec.tensor(type, shape)) else result end end - defp to_window_aggregate(op, type, arg, initial, window_dimensions, opts, state) do + defp to_window_aggregate(op, %{type: type} = ans, arg, initial, window_dimensions, opts, state) do arg = to_type(arg, type) acc = @@ -1747,23 +1764,15 @@ defmodule EXLA.Defn do initial initial when is_number(initial) -> - Value.constant_r0(state.builder, initial, type) + Value.constant(state.builder, [initial], Typespec.tensor(type, {})) end - arg_shape = EXLA.Shape.make_shape(type, {}) - args = [arg_shape, arg_shape] + args = [Typespec.tensor(type, {}), Typespec.tensor(type, {})] # We reverse the argument order because :nan + :infinity # returns :nan but :infinity + :nan returns :infinity. # So we want to keep the current value as first argument # to preserve such properties. - comp = - op_computation( - op, - args, - [arg_shape], - state, - &Enum.reverse/1 - ) + comp = op_computation(op, args, state, &Enum.reverse/1) strides = opts[:strides] padding = opts[:padding] @@ -1774,11 +1783,12 @@ defmodule EXLA.Defn do comp, [acc], [arg], - window_dimensions, - List.to_tuple(strides), - Tuple.duplicate(1, tuple_size(op_shape(arg))), - List.to_tuple(window_dilations), - padding + Tuple.to_list(window_dimensions), + strides, + List.duplicate(1, tuple_size(op_shape(arg))), + window_dilations, + padding, + [expr_to_typespec(ans)] ) result @@ -1795,22 +1805,20 @@ defmodule EXLA.Defn do cache = recur_shared_ids(on_true, false_ids, state, cache) cache = recur_shared_ids(on_false, true_ids, state, cache) - out_shape = container_to_exla_shape(on_true) + out_typespecs = container_to_typespecs(on_true) in_token = get_token(cache) - result_shapes = + result_typespecs = if in_token do - [EXLA.Shape.make_token_shape() | out_shape] + [Typespec.token() | out_typespecs] else - out_shape + out_typespecs end - {if_results, true_region, false_region} = Value.if_op(pred_op, result_shapes) - - cache = to_mlir_if_branch(true_region, on_true, true_ids, state, cache) - - cache = to_mlir_if_branch(false_region, on_false, false_ids, state, cache) + {true_computation, cache} = to_mlir_if_branch(on_true, true_ids, state, cache) + {false_computation, cache} = to_mlir_if_branch(on_false, false_ids, state, cache) + if_results = Value.if_op(pred_op, true_computation, false_computation, result_typespecs) if in_token do [token | results] = if_results @@ -1873,20 +1881,22 @@ defmodule EXLA.Defn do end end - defp to_mlir_if_branch(region, expr, current_ids, state, cache) do + defp to_mlir_if_branch(expr, current_ids, state, cache) do + {region, []} = Function.push_region(state.builder, []) + comp_state = %{state | scope_ids: current_ids} - [] = Function.push_region(state.builder, region) {res, res_cache} = recur_composite(expr, & &1, comp_state, cache) if token = get_token(cache) do - Value.variadic_return(state.builder, [token | List.flatten(res)]) + Value.return(state.builder, [token | List.flatten(res)]) else - Value.variadic_return(state.builder, List.flatten(res)) + Value.return(state.builder, List.flatten(res)) end Function.pop_region(state.builder) - merge_outfeed(cache, res_cache) + + {region, merge_outfeed(cache, res_cache)} end ## Axes helpers @@ -1897,16 +1907,14 @@ defmodule EXLA.Defn do max_size = tuple_size(max) # To reproduce Nx broadcast, we simply match the lower dimensions to the highest ones. - List.to_tuple(count_up(min_size, max_size - min_size)) + count_up(min_size, max_size - min_size) end defp reduce_axes(op, axes) do if axes do - axes - |> Enum.sort() - |> List.to_tuple() + Enum.sort(axes) else - List.to_tuple(Nx.axes(op_shape(op))) + Nx.axes(op_shape(op)) end end @@ -1921,26 +1929,20 @@ defmodule EXLA.Defn do ## Op Helpers - defp op_type(%Value{} = op), do: Value.get_shape(op).dtype + defp op_type(%Value{} = op), do: Value.get_typespec(op).type - defp op_shape(%Value{} = op), do: Value.get_shape(op).dims + defp op_shape(%Value{} = op), do: Value.get_typespec(op).shape defp to_type(%Value{} = op, type) do - if op_type(op) == type do + typespec = Value.get_typespec(op) + + if typespec.type == type do op else - Value.convert(op, type) + Value.convert(op, Typespec.to_type(typespec, type)) end end - # Inside cond/while, we need to convert pred to u8. - # We could do so lazily by comparing the versions of - # the branches, but that gets tricky with cond/if, - # so we always perform the operation. - defp cast_pred_to_u8(%Value{} = op) do - op - end - defp merge_type({:pred, 8}, {:pred, 8}), do: {:pred, 8} defp merge_type(left, right), do: Nx.Type.merge(to_nx_type(left), to_nx_type(right)) @@ -1948,12 +1950,12 @@ defmodule EXLA.Defn do defp to_nx_type(type), do: type defp to_constant(%EXLA.MLIR.Function{} = function, constant, type) do - Value.constant_r0(function, constant, type) + Value.constant(function, [constant], Typespec.tensor(type, {})) end defp subbuilder(%EXLA.MLIR.Function{name: name} = function, description) do suffix = System.unique_integer([:positive]) - %{function | name: name <> "-" <> description <> "-" <> Integer.to_string(suffix)} + %{function | name: name <> "_" <> description <> "_" <> Integer.to_string(suffix)} end # Helpers @@ -1963,38 +1965,35 @@ defmodule EXLA.Defn do left ++ Enum.drop(right, length) end - defp apply_mlir_broadcasted_bin_op(function, op, out, left, right) do - left_shape = Value.get_shape(left) - right_shape = Value.get_shape(right) - out_shape = EXLA.Shape.make_shape(out.type, out.shape) - left_dims = broadcast_axes(left_shape.dims, out_shape.dims) - right_dims = broadcast_axes(right_shape.dims, out_shape.dims) + defp apply_mlir_broadcasted_bin_op(op, out, left, right) do + left_typespec = Value.get_typespec(left) + right_typespec = Value.get_typespec(right) + left_dims = broadcast_axes(left_typespec.shape, out.shape) + right_dims = broadcast_axes(right_typespec.shape, out.shape) - type = merge_type(left_shape.dtype, right_shape.dtype) + type = merge_type(left_typespec.type, right_typespec.type) type = merge_type(type, out.type) - broadcast_shape = EXLA.Shape.make_shape(type, out_shape.dims) - left = to_type(left, type) left = - if left_shape.dims == broadcast_shape.dims do + if left_typespec.shape == out.shape do left else - Value.broadcast_in_dim(left, broadcast_shape, left_dims) + Value.broadcast_in_dim(left, left_dims, Typespec.tensor(type, out.shape)) end right = to_type(right, type) right = - if right_shape.dims == broadcast_shape.dims do + if right_typespec.shape == out.shape do right else - Value.broadcast_in_dim(right, broadcast_shape, right_dims) + Value.broadcast_in_dim(right, right_dims, Typespec.tensor(type, out.shape)) end Value - |> apply(op, [function, left, right]) + |> apply(op, [left, right, Typespec.tensor(type, out.shape)]) |> to_type(out.type) end @@ -2002,15 +2001,15 @@ defmodule EXLA.Defn do to_type(value, {:pred, 8}) end - defp container_to_exla_shape(container) do + defp container_to_typespecs(container) do [container] |> Nx.Defn.Composite.flatten_list() |> Enum.flat_map(fn %Nx.Tensor{type: {:tuple, _}, data: %{args: values}} -> - Enum.flat_map(values, &container_to_exla_shape/1) + Enum.flat_map(values, &container_to_typespecs/1) t -> - [EXLA.Shape.make_shape(t.type, t.shape)] + [Typespec.tensor(t.type, t.shape)] end) end @@ -2026,4 +2025,14 @@ defmodule EXLA.Defn do defp unwrap_single_tensor!({[%Value{} = op], cache}), do: {op, cache} defp unwrap_single_tensor!({%Value{} = op, cache}), do: {op, cache} + + defp remove_axes(shape, axes) do + axes + |> Enum.reverse() + |> Enum.reduce(shape, &Tuple.delete_at(&2, &1)) + end + + defp expr_to_typespec(expr) do + Typespec.tensor(expr.type, expr.shape) + end end diff --git a/exla/lib/exla/defn/buffers.ex b/exla/lib/exla/defn/buffers.ex index 2394d59d0b..d8bce0a4a5 100644 --- a/exla/lib/exla/defn/buffers.ex +++ b/exla/lib/exla/defn/buffers.ex @@ -70,21 +70,21 @@ defmodule EXLA.Defn.Buffers do %Nx.BinaryBackend{state: buffer} end - defp buffer_to_data(tensor, %EXLA.DeviceBuffer{shape: exla_shape} = buffer) do - validate_shape!(tensor, exla_shape) + defp buffer_to_data(tensor, %EXLA.DeviceBuffer{typespec: typespec} = buffer) do + validate_shape!(tensor, typespec) %EXLA.Backend{buffer: buffer} end - defp buffer_to_data(tensor, %EXLA.BinaryBuffer{data: data, shape: exla_shape}) do - validate_shape!(tensor, exla_shape) + defp buffer_to_data(tensor, %EXLA.BinaryBuffer{data: data, typespec: typespec}) do + validate_shape!(tensor, typespec) %Nx.BinaryBackend{state: data} end - defp validate_shape!(%Nx.Tensor{} = t, exla_shape) do + defp validate_shape!(%Nx.Tensor{} = t, typespec) do %{type: type, shape: shape} = Nx.devectorize(t) - nx_type = to_nx_type(exla_shape.dtype) - nx_shape = exla_shape.dims + nx_type = to_nx_type(typespec.type) + nx_shape = typespec.shape if type != nx_type do raise "internal bug! Nx.Defn expected a tensor with type #{inspect(type)} " <> @@ -110,7 +110,7 @@ defmodule EXLA.Defn.Buffers do %EXLA.Backend{buffer: %EXLA.DeviceBuffer{ref: ref} = buffer} when node(ref) != node() -> binary = :erpc.call(node(ref), EXLA.DeviceBuffer, :read, [buffer]) - EXLA.BinaryBuffer.from_binary(binary, to_exla_shape(tensor)) + EXLA.BinaryBuffer.from_binary(binary, to_typespec(tensor)) %EXLA.Backend{buffer: %EXLA.DeviceBuffer{} = buffer} when transfer? and buffer.client_name != executable.client.name @@ -151,9 +151,9 @@ defmodule EXLA.Defn.Buffers do "cannot pass a tensor expression as argument to defn, got: #{inspect(tensor)}" _ -> - EXLA.BinaryBuffer.from_binary(Nx.to_binary(tensor), to_exla_shape(tensor)) + EXLA.BinaryBuffer.from_binary(Nx.to_binary(tensor), to_typespec(tensor)) end end - defp to_exla_shape(%Nx.Tensor{type: type, shape: shape}), do: EXLA.Shape.make_shape(type, shape) + defp to_typespec(%Nx.Tensor{type: type, shape: shape}), do: EXLA.Typespec.tensor(type, shape) end diff --git a/exla/lib/exla/defn/outfeed.ex b/exla/lib/exla/defn/outfeed.ex index 8b9fec1897..101946800b 100644 --- a/exla/lib/exla/defn/outfeed.ex +++ b/exla/lib/exla/defn/outfeed.ex @@ -120,12 +120,13 @@ defmodule EXLA.Defn.Outfeed do {infeeds, {compiled_hooks, token}} = entries |> List.keysort(1, :desc) - |> Enum.map_reduce({compiled_hooks, token}, fn {pos, _, shape}, {compiled_hooks, token} -> + |> Enum.map_reduce({compiled_hooks, token}, fn {pos, _, typespec}, + {compiled_hooks, token} -> next_flag = next_hook(compiled_hooks) - compiled_hooks = Map.put(compiled_hooks, next_flag, {:infeed, pos, shape}) + compiled_hooks = Map.put(compiled_hooks, next_flag, {:infeed, pos, typespec}) - token = Value.outfeed(Value.constant_r0(builder, next_flag, {:u, 16}), token) - {token, [input]} = Value.infeed(token, shape) + token = Value.outfeed(Value.constant(builder, [next_flag], flag_typespec()), token) + {token, [input]} = Value.infeed(token, [typespec]) {{pos, input}, {compiled_hooks, token}} end) @@ -133,14 +134,16 @@ defmodule EXLA.Defn.Outfeed do %{outfeed | compiled_hooks: compiled_hooks, token: token, infeeds: infeeds} end + defp flag_typespec(), do: EXLA.Typespec.tensor({:u, 16}, {}) + @doc """ Adds a function hook if it has a callback defined for it. """ def maybe_add_function_hook(%Outfeed{} = outfeed, builder, tuple, name, expr) do cond do name in outfeed.used_hooks -> - {outfeed, flag, shapes} = outfeed_flat_tuple(outfeed, builder, tuple) - put_in(outfeed.compiled_hooks[flag], {:function, shapes, name, Nx.to_template(expr)}) + {outfeed, flag, typespecs} = outfeed_flat_tuple(outfeed, builder, tuple) + put_in(outfeed.compiled_hooks[flag], {:function, typespecs, name, Nx.to_template(expr)}) outfeed.token -> outfeed @@ -156,15 +159,15 @@ defmodule EXLA.Defn.Outfeed do Used by streams. Only one is allowed. Requires configuration. """ def add_stream_hook(%Outfeed{} = outfeed, builder, tuple) do - {outfeed, flag, shapes} = outfeed_flat_tuple(outfeed, builder, tuple) + {outfeed, flag, typespecs} = outfeed_flat_tuple(outfeed, builder, tuple) # We don't know the pid+ref pair for the stream, so we store it # under a special key called :stream and revert to the flag once configured - put_in(outfeed.compiled_hooks[:stream], {flag, shapes}) + put_in(outfeed.compiled_hooks[:stream], {flag, typespecs}) end def configure_stream_hook(%Outfeed{} = outfeed, pid, ref) when is_pid(pid) do - {{flag, shapes}, outfeed} = pop_in(outfeed.compiled_hooks[:stream]) - {shapes, put_in(outfeed.compiled_hooks[flag], {:stream, shapes, pid, ref})} + {{flag, typespecs}, outfeed} = pop_in(outfeed.compiled_hooks[:stream]) + {typespecs, put_in(outfeed.compiled_hooks[flag], {:stream, typespecs, pid, ref})} end @doc """ @@ -175,22 +178,23 @@ defmodule EXLA.Defn.Outfeed do def close(outfeed, builder) def close(%Outfeed{} = outfeed, %Function{} = builder) when will_outfeed(outfeed), - do: update_in(outfeed.token, &Value.outfeed(Value.constant_r0(builder, 0, {:u, 16}), &1)) + do: + update_in(outfeed.token, &Value.outfeed(Value.constant(builder, [0], flag_typespec()), &1)) def close(%Outfeed{} = outfeed, _builder), do: outfeed defp outfeed_flat_tuple(%Outfeed{token: token, compiled_hooks: ch} = outfeed, builder, tuple) do flag = next_hook(ch) - token = Value.outfeed(Value.constant_r0(builder, flag, {:u, 16}), token) - shapes = Enum.map(tuple, &Value.get_shape/1) + token = Value.outfeed(Value.constant(builder, [flag], flag_typespec()), token) + typespecs = Enum.map(tuple, &Value.get_typespec/1) token = Enum.reduce(tuple, token, fn elem, token -> Value.outfeed(elem, token) end) - {%{outfeed | token: token}, flag, shapes} + {%{outfeed | token: token}, flag, typespecs} end # The index 0 is served for closing streams @@ -200,7 +204,7 @@ defmodule EXLA.Defn.Outfeed do @doc """ Receives a client, device_id, and mappings of u16 to - `{shapes, {pid, ref} | {fun, template}}` pairs to + `{typespecs, {pid, ref} | {fun, template}}` pairs to deliver/execute the outputs. The computation must emit a 0 flag on exit. """ @@ -227,12 +231,12 @@ defmodule EXLA.Defn.Outfeed do # Copy the group leader so we report to the proper device Process.group_leader(self(), group_leader) ref = make_ref() - shape = EXLA.Shape.make_shape({:u, 16}, {}) - loop(client, device_id, ref, shape, hooks, compiled_hooks, infeeds) + typespec = EXLA.Typespec.tensor({:u, 16}, {}) + loop(client, device_id, ref, typespec, hooks, compiled_hooks, infeeds) end - defp loop(client, device_id, ref, shape, hooks, compiled_hooks, infeeds) do - :ok = EXLA.Client.from_outfeed(client, device_id, [shape], self(), ref) + defp loop(client, device_id, ref, typespec, hooks, compiled_hooks, infeeds) do + :ok = EXLA.Client.from_outfeed(client, device_id, [typespec], self(), ref) receive do {^ref, <<0::native-unsigned-16>>} -> @@ -240,30 +244,30 @@ defmodule EXLA.Defn.Outfeed do {^ref, <>} -> case Map.fetch!(compiled_hooks, flag) do - {:infeed, index, data_shape} -> + {:infeed, index, data_typespec} -> data = case Map.fetch!(infeeds, index) do %EXLA.DeviceBuffer{} = buffer -> EXLA.DeviceBuffer.read(buffer) %EXLA.BinaryBuffer{data: data} -> data end - EXLA.Client.to_infeed(client, device_id, [{data, data_shape}]) - loop(client, device_id, ref, shape, hooks, compiled_hooks, infeeds) + EXLA.Client.to_infeed(client, device_id, [{data, data_typespec}]) + loop(client, device_id, ref, typespec, hooks, compiled_hooks, infeeds) - {:stream, shapes, recv_pid, recv_ref} -> - :ok = EXLA.Client.from_outfeed(client, device_id, shapes, recv_pid, recv_ref) - loop(client, device_id, ref, shape, hooks, compiled_hooks, infeeds) + {:stream, typespecs, recv_pid, recv_ref} -> + :ok = EXLA.Client.from_outfeed(client, device_id, typespecs, recv_pid, recv_ref) + loop(client, device_id, ref, typespec, hooks, compiled_hooks, infeeds) - {:function, shapes, name, template} -> + {:function, typespecs, name, template} -> fun = Map.fetch!(hooks, name) - length = length(shapes) + length = length(typespecs) parent = self() ref = make_ref() pid = spawn(fn -> apply_hook(parent, ref, length, fun, template) end) - :ok = EXLA.Client.from_outfeed(client, device_id, shapes, pid, ref) + :ok = EXLA.Client.from_outfeed(client, device_id, typespecs, pid, ref) receive do - ^ref -> loop(client, device_id, ref, shape, hooks, compiled_hooks, infeeds) + ^ref -> loop(client, device_id, ref, typespec, hooks, compiled_hooks, infeeds) end end end diff --git a/exla/lib/exla/defn/stream.ex b/exla/lib/exla/defn/stream.ex index a341b27095..801baecaf8 100644 --- a/exla/lib/exla/defn/stream.ex +++ b/exla/lib/exla/defn/stream.ex @@ -2,7 +2,7 @@ defmodule EXLA.Defn.Stream do @moduledoc false keys = - [:lock, :outfeed, :pid, :runner, :send, :send_shape, :send_indexes] ++ + [:lock, :outfeed, :pid, :runner, :send, :send_typespec, :send_indexes] ++ [:recv, :recv_length, :done, :client, :device_id] @derive {Inspect, only: [:pid, :client, :device_id, :send, :recv]} @@ -15,10 +15,10 @@ defmodule EXLA.Defn.Stream do runner, outfeed, send, - send_shape, + send_typespec, send_indexes, recv, - recv_shapes, + recv_typespecs, done ) do %{client: client, device_id: device_id} = executable @@ -39,10 +39,10 @@ defmodule EXLA.Defn.Stream do outfeed: outfeed, lock: lock, send: send, - send_shape: send_shape, + send_typespec: send_typespec, send_indexes: send_indexes, recv: recv, - recv_length: length(recv_shapes), + recv_length: length(recv_typespecs), client: client, device_id: device_id, done: done @@ -52,7 +52,7 @@ defmodule EXLA.Defn.Stream do # It is time to halt the stream, we do it by sending 0 for the loop infeed. # Then we wait for the outfeed process to read all. defp halt_stream(client, device_id, outfeed) do - pred = EXLA.Shape.make_shape({:pred, 8}, {}) + pred = EXLA.Typespec.tensor({:pred, 8}, {}) :ok = EXLA.Client.to_infeed(client, device_id, [{<<0::8-native>>, pred}]) {:transfer, outfeed} end @@ -64,7 +64,7 @@ defmodule EXLA.Defn.Stream do client: client, device_id: device_id, send: send, - send_shape: send_shape, + send_typespec: send_typespec, send_indexes: send_indexes } = stream @@ -86,15 +86,17 @@ defmodule EXLA.Defn.Stream do """ end - data_and_shapes = + data_and_typespecs = if client.platform == :host do - Enum.zip(buffers, send_shape) + Enum.zip(buffers, send_typespec) else - [{buffers, send_shape}] + [{buffers, send_typespec}] end - pred = EXLA.Shape.make_shape({:pred, 8}, {}) - :ok = EXLA.Client.to_infeed(client, device_id, [{<<1::8-native>>, pred} | data_and_shapes]) + pred = EXLA.Typespec.tensor({:pred, 8}, {}) + + :ok = + EXLA.Client.to_infeed(client, device_id, [{<<1::8-native>>, pred} | data_and_typespecs]) end defp nx_to_io(container, indexes) do diff --git a/exla/lib/exla/device_buffer.ex b/exla/lib/exla/device_buffer.ex index 8fe70d7a55..dc2944a927 100644 --- a/exla/lib/exla/device_buffer.ex +++ b/exla/lib/exla/device_buffer.ex @@ -4,32 +4,40 @@ defmodule EXLA.DeviceBuffer do """ alias __MODULE__ - alias EXLA.{Client, Shape} + alias EXLA.Client - @enforce_keys [:ref, :client_name, :device_id, :shape] - defstruct [:ref, :client_name, :device_id, :shape] + @enforce_keys [:ref, :client_name, :device_id, :typespec] + defstruct [:ref, :client_name, :device_id, :typespec] @doc false - def from_ref(ref, %Client{name: name}, device_id, shape) when is_reference(ref) do - %DeviceBuffer{ref: ref, client_name: name, device_id: device_id, shape: shape} + def from_ref(ref, %Client{name: name}, device_id, typespec) when is_reference(ref) do + %DeviceBuffer{ref: ref, client_name: name, device_id: device_id, typespec: typespec} end @doc """ Places the given binary `buffer` on the given `device` using `client`. """ - def place_on_device(data, %Shape{} = shape, client = %Client{}, device_id) + def place_on_device(data, %EXLA.Typespec{} = typespec, client = %Client{}, device_id) when is_integer(device_id) and is_binary(data) do - ref = client.ref |> EXLA.NIF.binary_to_device_mem(data, shape.ref, device_id) |> unwrap!() - %DeviceBuffer{ref: ref, client_name: client.name, device_id: device_id, shape: shape} + ref = + client.ref + |> EXLA.NIF.binary_to_device_mem(data, EXLA.Typespec.nif_encode(typespec), device_id) + |> unwrap!() + + %DeviceBuffer{ref: ref, client_name: client.name, device_id: device_id, typespec: typespec} end @doc """ Copies buffer to device with given device ID. """ - def copy_to_device(%DeviceBuffer{ref: buffer, shape: shape}, %Client{} = client, device_id) + def copy_to_device( + %DeviceBuffer{ref: buffer, typespec: typespec}, + %Client{} = client, + device_id + ) when is_integer(device_id) do ref = client.ref |> EXLA.NIF.copy_buffer_to_device(buffer, device_id) |> unwrap!() - %DeviceBuffer{ref: ref, client_name: client.name, device_id: device_id, shape: shape} + %DeviceBuffer{ref: ref, client_name: client.name, device_id: device_id, typespec: typespec} end @doc """ diff --git a/exla/lib/exla/executable.ex b/exla/lib/exla/executable.ex index 0ad98ea14d..a6a0c8cbdf 100644 --- a/exla/lib/exla/executable.ex +++ b/exla/lib/exla/executable.ex @@ -4,26 +4,27 @@ defmodule EXLA.Executable do """ alias __MODULE__ - alias EXLA.{BinaryBuffer, DeviceBuffer, Shape} + alias EXLA.{BinaryBuffer, DeviceBuffer} - @enforce_keys [:client, :ref, :output_shape, :num_replicas, :num_partitions, :device_id] - defstruct [:client, :ref, :output_shape, :num_replicas, :num_partitions, :device_id] + @enforce_keys [:client, :ref, :output_typespecs, :num_replicas, :num_partitions, :device_id] + defstruct [:client, :ref, :output_typespecs, :num_replicas, :num_partitions, :device_id] @doc """ Runs the given executable with a list of lists as inputs and the given options. """ def run(%Executable{} = executable, [subinputs | _] = inputs, options \\ []) when is_list(subinputs) do - %{client: client, device_id: device_id, output_shape: output_shape, ref: ref} = executable + %{client: client, device_id: device_id, output_typespecs: output_typespecs, ref: ref} = + executable for data_and_device_id <- run(client, ref, device_id, inputs, options) do - decompose_output(data_and_device_id, output_shape, client) + decompose_output(data_and_device_id, output_typespecs, client) end end def serialize(%Executable{ ref: executable, - output_shape: out_shape, + output_typespecs: output_typespecs, num_replicas: num_replicas, num_partitions: num_partitions, device_id: device_id @@ -34,11 +35,9 @@ defmodule EXLA.Executable do |> unwrap!() |> IO.iodata_to_binary() - stripped_shape = strip_shape(out_shape) - %{ serialized: serialized_exec, - output_shape: stripped_shape, + output_typespecs: output_typespecs, num_replicas: num_replicas, num_partitions: num_partitions, device_id: device_id @@ -57,7 +56,6 @@ defmodule EXLA.Executable do exec_data |> Map.put(:ref, ref) |> Map.put(:client, client) - |> Map.update!(:output_shape, &reconstruct_shapes/1) |> then(&struct(__MODULE__, &1)) _other -> @@ -69,8 +67,11 @@ defmodule EXLA.Executable do inputs = for subinputs <- inputs do Enum.map(subinputs, fn - %DeviceBuffer{ref: ref} -> ref - %BinaryBuffer{data: data, shape: shape} -> {data, shape.ref} + %DeviceBuffer{ref: ref} -> + ref + + %BinaryBuffer{data: data, typespec: typespec} -> + {data, EXLA.Typespec.nif_encode(typespec)} end) end @@ -83,28 +84,14 @@ defmodule EXLA.Executable do unwrap!(data) end - defp decompose_output({data, device_id}, shapes, client) do - shapes = - Enum.flat_map(List.wrap(shapes), fn shape -> - case shape do - shapes when is_list(shapes) -> shapes - %Shape{} -> [shape] - end - end) - - Enum.zip_with(data, shapes, fn - buf, subshape when is_reference(buf) -> - DeviceBuffer.from_ref(buf, client, device_id, subshape) - - buf, subshape when is_binary(buf) -> - BinaryBuffer.from_binary(buf, subshape) - end) - end - - defp strip_shape(%Shape{dtype: dtype, dims: dims}), do: %{dtype: dtype, dims: dims} + defp decompose_output({data, device_id}, output_typespecs, client) do + Enum.zip_with(data, output_typespecs, fn + buf, typespec when is_reference(buf) -> + DeviceBuffer.from_ref(buf, client, device_id, typespec) - defp reconstruct_shapes(%{dtype: dtype, dims: dims}) do - EXLA.Shape.make_shape(dtype, dims) + buf, typespec when is_binary(buf) -> + BinaryBuffer.from_binary(buf, typespec) + end) end defp unwrap!(:ok), do: :ok diff --git a/exla/lib/exla/lib.ex b/exla/lib/exla/lib.ex index e248d421e9..2b54e3a5da 100644 --- a/exla/lib/exla/lib.ex +++ b/exla/lib/exla/lib.ex @@ -2,32 +2,24 @@ defmodule EXLA.Lib do @moduledoc false # High-level operations built on top of `EXLA.MLIR.Value`. - alias EXLA.Shape - + alias EXLA.Typespec alias EXLA.MLIR.Function alias EXLA.MLIR.Value - @doc """ - Element-wise tangent function. - """ - def tan(%Value{} = op) do - Value.tan(op) - end - @doc """ Builds iota along axis. """ - def iota(%EXLA.MLIR.Function{} = function, shape, nil) do - total_elems = Nx.size(shape.dims) + def iota(%EXLA.MLIR.Function{} = function, nil, typespec) do + total_elems = Nx.size(typespec.shape) Value.reshape( - Value.iota(function, EXLA.Shape.make_shape(shape.dtype, {total_elems}), 0), - shape.dims + Value.iota(function, 0, Typespec.to_shape(typespec, {total_elems})), + typespec ) end - def iota(%EXLA.MLIR.Function{} = function, shape, axis) do - Value.iota(function, shape, axis) + def iota(%EXLA.MLIR.Function{} = function, axis, typespec) do + Value.iota(function, axis, typespec) end @doc """ @@ -63,76 +55,84 @@ defmodule EXLA.Lib do defp argmin_or_max(builder, %Value{} = op, is_min?, type, opts) do tie_break = opts[:tie_break] || :low keep_axis = opts[:keep_axis] || false - op_shape = Value.get_shape(op) + + op_typespec = Value.get_typespec(op) init_value = if is_min?, - do: max_number(builder, op_shape.dtype), - else: min_number(builder, op_shape.dtype) + do: max_number(builder, op_typespec.type), + else: min_number(builder, op_typespec.type) axis = opts[:axis] - index_init_value = Value.constant_r0(builder, 0, type) - iota = iota(builder, Shape.make_shape(type, op_shape.dims), axis) - reduction = create_min_max_computation(builder, op_shape.dtype, type, is_min?, tie_break) + index_init_value = Value.constant(builder, [0], Typespec.tensor(type, {})) + iota = iota(builder, axis, Typespec.to_type(op_typespec, type)) + reduction = create_min_max_computation(builder, op_typespec.type, type, is_min?, tie_break) dims = if axis do - {axis} + [axis] else - List.to_tuple(Nx.axes(op_shape.dims)) + Nx.axes(op_typespec.shape) end + shape = remove_axes(op_typespec.shape, dims) + typespecs = [Typespec.tensor(op_typespec.type, shape), Typespec.tensor(type, shape)] + [_, result] = - Value.reduce(reduction, [init_value, index_init_value], [op, iota], dims) + Value.reduce(reduction, [init_value, index_init_value], [op, iota], dims, typespecs) if keep_axis do - Value.reshape(result, put_elem(op_shape.dims, axis, 1)) + Value.reshape(result, Typespec.tensor(type, put_elem(op_typespec.shape, axis, 1))) else result end end - defp create_min_max_computation(%Function{} = builder, type, index_type, is_min?, tie_break) do - %{module: module, name: name} = subbuilder(builder, "min-max") + defp remove_axes(shape, axes) do + axes + |> Enum.reverse() + |> Enum.reduce(shape, &Tuple.delete_at(&2, &1)) + end + + defp create_min_max_computation(%Function{} = function, type, index_type, is_min?, tie_break) do + arg_typespecs = [ + Typespec.tensor(type, {}), + Typespec.tensor(index_type, {}), + Typespec.tensor(type, {}), + Typespec.tensor(index_type, {}) + ] - function = - EXLA.MLIR.Module.add_function( - module, - name, - [ - EXLA.Shape.make_shape(type, {}), - EXLA.Shape.make_shape(index_type, {}), - EXLA.Shape.make_shape(type, {}), - EXLA.Shape.make_shape(index_type, {}) - ], - [EXLA.Shape.make_shape(type, {}), EXLA.Shape.make_shape(index_type, {})] - ) + {region, args} = Function.push_region(function, arg_typespecs) + [lhs_value, lhs_index, rhs_value, rhs_index] = args - [lhs_value, lhs_index, rhs_value, rhs_index] = EXLA.MLIR.Function.get_arguments(function) + pred_typespec = Typespec.tensor({:pred, 8}, {}) + value_typespec = Typespec.tensor(type, {}) + idx_typespec = Typespec.tensor(index_type, {}) cmp = if is_min?, - do: Value.less_equal(function, lhs_value, rhs_value), - else: Value.greater_equal(function, lhs_value, rhs_value) + do: Value.less_equal(lhs_value, rhs_value, pred_typespec), + else: Value.greater_equal(lhs_value, rhs_value, pred_typespec) - max = Value.select(cmp, lhs_value, rhs_value) - arg_max = Value.select(cmp, lhs_index, rhs_index) + max = Value.select(cmp, lhs_value, rhs_value, value_typespec) + arg_max = Value.select(cmp, lhs_index, rhs_index, idx_typespec) arg_max = case tie_break do :low -> - eq? = Value.equal(function, lhs_value, rhs_value) - id = Value.min(function, lhs_index, rhs_index) - Value.select(eq?, id, arg_max) + eq? = Value.equal(lhs_value, rhs_value, pred_typespec) + id = Value.min(lhs_index, rhs_index, idx_typespec) + Value.select(eq?, id, arg_max, idx_typespec) :high -> - eq? = Value.equal(function, lhs_value, rhs_value) - id = Value.max(function, lhs_index, rhs_index) - Value.select(eq?, id, arg_max) + eq? = Value.equal(lhs_value, rhs_value, pred_typespec) + id = Value.max(lhs_index, rhs_index, idx_typespec) + Value.select(eq?, id, arg_max, idx_typespec) end - Value.variadic_return(function, [max, arg_max]) - function + Value.return(function, [max, arg_max]) + Function.pop_region(function) + region end @doc """ @@ -141,7 +141,18 @@ defmodule EXLA.Lib do It will be negative infinity for floating point types. """ def min_number(%Function{} = builder, type) do - Value.constant_from_binary(builder, min_binary(type), Shape.make_shape(type, {})) + number = + case type do + {:pred, 8} -> + 0 + + type -> + type + |> Nx.Constants.min(backend: Nx.BinaryBackend) + |> Nx.to_number() + end + + Value.constant(builder, [number], Typespec.tensor(type, {})) end @doc """ @@ -150,27 +161,30 @@ defmodule EXLA.Lib do Maximum values are defined in `Nx.Type.max_finite_binary/1`. """ def max_number(builder, type) do - Value.constant_from_binary(builder, max_binary(type), Shape.make_shape(type, {})) - end + number = + case type do + {:pred, 8} -> + 1 + + type -> + type + |> Nx.Constants.max(backend: Nx.BinaryBackend) + |> Nx.to_number() + end - defp subbuilder(%EXLA.MLIR.Function{name: name} = function, description) do - suffix = System.unique_integer([:positive]) - %{function | name: name <> "-" <> description <> "-" <> Integer.to_string(suffix)} + Value.constant(builder, [number], Typespec.tensor(type, {})) end - defp min_binary({:pred, 8}), do: <<0>> - defp min_binary(type), do: Nx.Type.min_binary(type) - defp max_binary({:pred, 8}), do: <<1>> - defp max_binary(type), do: Nx.Type.max_binary(type) - @doc """ Sorts a tensor and returns the corresponding indices in the new positions. """ def argsort(builder, %Value{} = operand, dimension, stable, comparator, iota_type) do - shape = Value.get_shape(operand) - iota = iota(builder, Shape.make_shape(iota_type, shape.dims), dimension) + typespec = Value.get_typespec(operand) + iota_typespec = Typespec.to_type(typespec, iota_type) + iota = iota(builder, dimension, iota_typespec) - [_, result] = Value.sort([operand, iota], comparator, dimension, stable) + typespecs = [typespec, iota_typespec] + [_, result] = Value.sort([operand, iota], comparator, dimension, stable, typespecs) result end diff --git a/exla/lib/exla/mlir/context_pool.ex b/exla/lib/exla/mlir/context_pool.ex index 751af9088a..14cf11429f 100644 --- a/exla/lib/exla/mlir/context_pool.ex +++ b/exla/lib/exla/mlir/context_pool.ex @@ -14,7 +14,7 @@ defmodule EXLA.MLIR.ContextPool do @impl NimblePool def init_worker(pool_state) do - {:ok, context} = EXLA.NIF.new_mlir_context() + {:ok, context} = EXLA.NIF.mlir_new_context() {:ok, context, pool_state} end diff --git a/exla/lib/exla/mlir/function.ex b/exla/lib/exla/mlir/function.ex index d7718fbf6b..d9aaa4f7b1 100644 --- a/exla/lib/exla/mlir/function.ex +++ b/exla/lib/exla/mlir/function.ex @@ -2,7 +2,7 @@ defmodule EXLA.MLIR.Function do @moduledoc false # Representation of an MLIR Function or `func.func` type. - defstruct [:module, :ref, :name, :return_shape] + defstruct [:module, :ref, :name, :return_typespecs] alias __MODULE__, as: Function alias EXLA.MLIR.Value @@ -13,17 +13,26 @@ defmodule EXLA.MLIR.Function do which can be used in MLIR operations. """ def get_arguments(%Function{ref: ref} = function) do - arg_refs = EXLA.NIF.get_mlir_function_arguments(ref) |> unwrap!() + arg_refs = EXLA.NIF.mlir_get_function_arguments(ref) |> unwrap!() Enum.map(arg_refs, fn arg -> %Value{ref: arg, function: function} end) end - def push_region(%Function{ref: ref} = function, %Region{ref: region}) do - ref - |> EXLA.NIF.mlir_push_region(region) - |> unwrap!() - |> Enum.map(&%Value{function: function, ref: &1}) + @doc """ + Creates a new region within the current function and sets it as the + insertion point for subsequent operations. + + Returns `{region, args}`, where args is a list of values referencing + the block arguments. + """ + def push_region(%Function{ref: ref} = function, arg_typespecs) do + arg_mlir_types = Value.typespecs_to_mlir_types(arg_typespecs) + {region, args} = EXLA.NIF.mlir_push_region(ref, arg_mlir_types) |> unwrap!() + {%Region{ref: region}, Enum.map(args, &%Value{function: function, ref: &1})} end + @doc """ + Pops region created with `push_region/2`. + """ def pop_region(%Function{ref: ref}) do EXLA.NIF.mlir_pop_region(ref) |> unwrap!() end diff --git a/exla/lib/exla/mlir/module.ex b/exla/lib/exla/mlir/module.ex index 8bb190dca9..3c91024c32 100644 --- a/exla/lib/exla/mlir/module.ex +++ b/exla/lib/exla/mlir/module.ex @@ -9,17 +9,16 @@ defmodule EXLA.MLIR.Module do alias EXLA.Client alias EXLA.Executable - alias EXLA.Shape @doc """ Creates a new MLIR module. """ - def new(arg_shapes, return_shape, fun) when is_function(fun, 1) do + def new(arg_typespecs, return_typespecs, fun) when is_function(fun, 1) do ContextPool.checkout(fn context -> - ref = context |> EXLA.NIF.new_mlir_module() |> unwrap!() + ref = context |> EXLA.NIF.mlir_new_module() |> unwrap!() %__MODULE__{ref: ref} - |> create_function("main", arg_shapes, return_shape, true) + |> create_function("main", arg_typespecs, return_typespecs, true) |> fun.() end) end @@ -27,32 +26,32 @@ defmodule EXLA.MLIR.Module do @doc """ Adds a new function to the MLIR module. """ - def add_function(module, name, arg_shapes, return_shapes) do - create_function(module, name, arg_shapes, return_shapes, false) + def add_function(module, name, arg_typespecs, return_typespecs) do + create_function(module, name, arg_typespecs, return_typespecs, false) end defp create_function( %__MODULE__{ref: module_ref} = module, name, - arg_shapes, - return_shapes, + arg_typespecs, + return_typespecs, is_public ) when is_binary(name) do - arg_shape_refs = Enum.map(arg_shapes, fn %Shape{ref: ref} -> ref end) - return_shape_refs = Enum.map(return_shapes, fn %Shape{ref: ref} -> ref end) + arg_types = EXLA.MLIR.Value.typespecs_to_mlir_types(arg_typespecs) + return_types = EXLA.MLIR.Value.typespecs_to_mlir_types(return_typespecs) ref = - EXLA.NIF.create_mlir_function( + EXLA.NIF.mlir_create_function( module_ref, name, - arg_shape_refs, - return_shape_refs, + arg_types, + return_types, if(is_public, do: 1, else: 0) ) |> unwrap!() - %Function{module: module, ref: ref, name: name, return_shape: return_shapes} + %Function{module: module, ref: ref, name: name, return_typespecs: return_typespecs} end @doc """ @@ -83,8 +82,8 @@ defmodule EXLA.MLIR.Module do def compile( module = %__MODULE__{}, client = %Client{}, - argument_shapes, - return_shape, + argument_typespecs, + return_typespecs, options \\ [] ) do num_replicas = Keyword.get(options, :num_replicas, 1) @@ -103,7 +102,7 @@ defmodule EXLA.MLIR.Module do EXLA.NIF.mlir_compile( client.ref, module.ref, - Enum.map(argument_shapes, & &1.ref), + Enum.map(argument_typespecs, &EXLA.Typespec.nif_encode/1), num_replicas, num_partitions, use_spmd, @@ -114,13 +113,21 @@ defmodule EXLA.MLIR.Module do %Executable{ client: client, ref: ref, - output_shape: return_shape, + output_typespecs: return_typespecs, num_replicas: num_replicas, num_partitions: num_partitions, device_id: device_id } end + @doc """ + Returns a human-readable representation of the module using MLIR + syntax. + """ + def to_string(module = %__MODULE__{}) do + EXLA.NIF.mlir_module_to_string(module.ref) |> unwrap!() + end + defp unwrap!(:ok), do: :ok defp unwrap!({:ok, ref}), do: ref defp unwrap!({:error, error}), do: raise(List.to_string(error)) diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index a84a90eb01..262c8a02f1 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -1,320 +1,384 @@ defmodule EXLA.MLIR.Value do @moduledoc false # Representation of an MLIR Value. + # # MLIR Values are SSA and generally are either operations or # block arguments. This module is used to construct most of the # MLIR operations. + # + # See the full specification of the stablehlo MLIR dialect [1]. Note + # that the URL points to the exact stablehlo revision that we depend + # on via elixir-nx/xla. + # + # [1]: https://github.com/openxla/stablehlo/blob/04291aea6b50d9573e6f4de184938d83b9564cd0/docs/spec.md defstruct [:ref, :function] - alias __MODULE__, as: Value + alias __MODULE__ + alias EXLA.Typespec alias EXLA.MLIR.Region alias EXLA.MLIR.Function - @bin_ops [:add, :subtract, :multiply, :divide, :pow, :min] ++ - [:max, :remainder, :atan2, :equal, :less, :less_equal] ++ - [:greater, :greater_equal, :not_equal, :bitwise_and] ++ - [:bitwise_or, :bitwise_xor] ++ - [:left_shift, :right_shift_arithmetic, :right_shift_logical] - - for op <- @bin_ops do - mlir_op = :"mlir_#{op}" - - def unquote(op)( - func, - %Value{ref: lhs}, - %Value{ref: rhs} - ) do - ref = EXLA.NIF.unquote(mlir_op)(func.ref, lhs, rhs) |> unwrap!() - %Value{ref: ref, function: func} + @bin_ops %{ + add: "stablehlo.add", + subtract: "stablehlo.subtract", + multiply: "stablehlo.multiply", + divide: "stablehlo.divide", + pow: "stablehlo.power", + min: "stablehlo.minimum", + max: "stablehlo.maximum", + remainder: "stablehlo.remainder", + atan2: "stablehlo.atan2", + bitwise_and: "stablehlo.and", + bitwise_or: "stablehlo.or", + bitwise_xor: "stablehlo.xor", + left_shift: "stablehlo.shift_left", + right_shift_arithmetic: "stablehlo.shift_right_arithmetic", + right_shift_logical: "stablehlo.shift_right_logical" + } + + for {op, op_name} <- @bin_ops do + def unquote(op)(%Value{function: func} = lhs, %Value{function: func} = rhs, typespec) do + result_types = typespecs_to_mlir_types([typespec]) + op(func, unquote(op_name), [lhs, rhs], result_types) |> one!() end end - @unary_ops [:abs, :exp, :expm1, :floor, :ceil, :round] ++ - [:log, :log1p, :sigmoid, :sign, :cos] ++ - [:sin, :tan, :acos, :asin, :atan, :cosh, :sinh] ++ - [:tanh, :acosh, :asinh, :atanh, :sqrt, :cbrt] ++ - [:bitwise_not, :erf, :erfc, :erf_inv] ++ - [:is_infinity, :is_nan, :rsqrt, :negate, :count_leading_zeros] ++ - [:population_count, :real, :imag, :conjugate] + @bin_comparison_ops %{ + equal: :eq, + less: :lt, + less_equal: :le, + greater: :gt, + greater_equal: :ge, + not_equal: :ne + } - for op <- @unary_ops do - mlir_op = :"mlir_#{op}" - - def unquote(op)(%Value{ref: operand, function: %Function{} = func}) do - ref = EXLA.NIF.unquote(mlir_op)(func.ref, operand) |> unwrap!() - %Value{ref: ref, function: func} + for {op, direction} <- @bin_comparison_ops do + def unquote(op)(%Value{function: func} = lhs, %Value{function: func} = rhs, typespec) do + compare_and_return_bool(func, lhs, rhs, typespec, unquote(direction)) end end - def reshape(%Value{function: %Function{} = func} = op, shape_tuple) do - ref = EXLA.NIF.mlir_reshape(func.ref, op.ref, shape_tuple) |> unwrap!() - %Value{op | ref: ref} - end - - def reverse(%Value{function: %Function{} = func} = op, dims) do - ref = EXLA.NIF.mlir_reverse(func.ref, op.ref, dims) |> unwrap!() - %Value{op | ref: ref} - end + defp compare_and_return_bool(func, lhs, rhs, typespec, direction) do + %{type: lhs_type} = get_typespec(lhs) + %{type: rhs_type} = get_typespec(rhs) - def transpose(%Value{} = op, axes) when is_tuple(axes) do - transpose(op, Tuple.to_list(axes)) - end + comparison_type = + if Nx.Type.float?(lhs_type) or Nx.Type.float?(rhs_type) do + attr_comparison_type(:totalorder) + else + attr_comparison_type(:notype) + end - def transpose(%Value{function: %Function{} = func} = op, axes) do - ref = EXLA.NIF.mlir_transpose(func.ref, op.ref, axes) |> unwrap!() - %Value{op | ref: ref} + attributes = [ + comparison_direction: attr_comparison_direction(direction), + comparison_type: comparison_type + ] + + result_types = typespecs_to_mlir_types([Typespec.to_type(typespec, {:pred, 8})]) + + op(func, "stablehlo.compare", [lhs, rhs], result_types, attributes: attributes) |> one!() + end + + @unary_ops %{ + abs: "stablehlo.abs", + exp: "stablehlo.exponential", + expm1: "stablehlo.exponential_minus_one", + floor: "stablehlo.floor", + ceil: "stablehlo.ceil", + round: "stablehlo.round_nearest_afz", + log: "stablehlo.log", + log1p: "stablehlo.log_plus_one", + sigmoid: "stablehlo.logistic", + sign: "stablehlo.sign", + cos: "stablehlo.cosine", + sin: "stablehlo.sine", + tan: "chlo.tan", + acos: "chlo.acos", + asin: "chlo.asin", + atan: "chlo.atan", + cosh: "chlo.cosh", + sinh: "chlo.sinh", + tanh: "stablehlo.tanh", + acosh: "chlo.acosh", + asinh: "chlo.asinh", + atanh: "chlo.atanh", + sqrt: "stablehlo.sqrt", + cbrt: "stablehlo.cbrt", + bitwise_not: "stablehlo.not", + erf: "chlo.erf", + erfc: "chlo.erfc", + erf_inv: "chlo.erf_inv", + rsqrt: "stablehlo.rsqrt", + negate: "stablehlo.negate", + count_leading_zeros: "stablehlo.count_leading_zeros", + population_count: "stablehlo.popcnt", + real: "stablehlo.real", + imag: "stablehlo.imag", + conjugate: "chlo.conj" + } + + for {op, op_name} <- @unary_ops do + def unquote(op)(%Value{function: func} = operand, typespec) do + result_types = typespecs_to_mlir_types([typespec]) + op(func, unquote(op_name), [operand], result_types, []) |> one!() + end end - def slice(%Value{function: %Function{} = func} = op, starts, limits, strides) do - ref = EXLA.NIF.mlir_slice(func.ref, op.ref, starts, limits, strides) |> unwrap!() - %Value{op | ref: ref} - end + def is_infinity(%Value{function: func} = operand, typespec) do + %{type: type} = get_typespec(operand) - def dynamic_slice(%Value{function: %Function{} = func} = op, starts, lengths) do - starts = Enum.map(starts, fn %Value{ref: ref} -> ref end) - ref = EXLA.NIF.mlir_dynamic_slice(func.ref, op.ref, starts, lengths) |> unwrap!() - %Value{op | ref: ref} - end + typespec = Typespec.to_type(typespec, {:pred, 8}) - def get_shape(%Value{ref: ref}) do - shape_ref = EXLA.NIF.mlir_get_shape(ref) |> unwrap!() - EXLA.Shape.get_shape_info(shape_ref) - end + cond do + Nx.Type.complex?(type) -> + float_typespec = Typespec.to_type(typespec, complex_part_type(type)) + real = real(operand, float_typespec) + imag = imag(operand, float_typespec) + is_inf_real = is_infinity(real, typespec) + is_inf_imag = is_infinity(imag, typespec) + bitwise_or(is_inf_real, is_inf_imag, typespec) - def convert(%Value{ref: in_ref, function: %Function{} = func} = value, dtype) do - out_ref = - EXLA.NIF.mlir_convert(func.ref, in_ref, EXLA.Shape.dtype_to_charlist(dtype)) |> unwrap!() + Nx.Type.integer?(type) -> + # Integers are never infinity. We use inequality to make sure + # the operand is still a part of the computation + not_equal(operand, operand, typespec) - %Value{value | ref: out_ref} + true -> + result_types = typespecs_to_mlir_types([typespec]) + op(func, "chlo.is_inf", [operand], result_types) |> one!() + end end - def bitcast_convert(%Value{ref: in_ref, function: %Function{} = func} = value, dtype) do - shape = get_shape(value) - - out_ref = - EXLA.NIF.mlir_bitcast_convert( - func.ref, - in_ref, - EXLA.Shape.dtype_to_charlist(dtype), - shape.dims - ) - |> unwrap!() + def is_nan(%Value{function: func} = operand, typespec) do + %{type: type} = get_typespec(operand) + + typespec = Typespec.to_type(typespec, {:pred, 8}) + + cond do + Nx.Type.complex?(type) -> + float_typespec = Typespec.to_type(typespec, complex_part_type(type)) + real = real(operand, float_typespec) + imag = imag(operand, float_typespec) + is_nan_real = is_nan(real, typespec) + is_nan_imag = is_nan(imag, typespec) + bitwise_or(is_nan_real, is_nan_imag, typespec) + + Nx.Type.integer?(type) -> + # Integers are never nan. We use inequality to make sure + # the operand is still a part of the computation + not_equal(operand, operand, typespec) + + true -> + result_types = typespecs_to_mlir_types([typespec]) + is_inf = op(func, "chlo.is_inf", [operand], result_types) |> one!() + is_finite = op(func, "stablehlo.is_finite", [operand], result_types) |> one!() + is_not_inf = bitwise_not(is_inf, typespec) + is_not_finite = bitwise_not(is_finite, typespec) + bitwise_and(is_not_inf, is_not_finite, typespec) + end + end - %Value{value | ref: out_ref} + def reshape(%Value{function: func} = operand, typespec) do + result_types = typespecs_to_mlir_types([typespec]) + op(func, "stablehlo.reshape", [operand], result_types) |> one!() end - def top_k(%Value{function: %Function{ref: func_ref}, ref: op_ref} = val, k) do - [val_ref, idx_ref] = EXLA.NIF.mlir_top_k(func_ref, op_ref, k) |> unwrap!() - [%Value{val | ref: val_ref}, %Value{val | ref: idx_ref}] + def reverse(%Value{function: func} = operand, dims, typespec) do + result_types = typespecs_to_mlir_types([typespec]) + attributes = [dimensions: attr_dense_i64_elements(dims)] + op(func, "stablehlo.reverse", [operand], result_types, attributes: attributes) |> one!() end - def sort(%Value{} = value, comparator_fun, axis, stable) do - [result] = sort([value], comparator_fun, axis, stable) - result + def transpose(%Value{function: func} = operand, axes, typespec) do + result_types = typespecs_to_mlir_types([typespec]) + attributes = [permutation: attr_dense_i64_elements(axes)] + op(func, "stablehlo.transpose", [operand], result_types, attributes: attributes) |> one!() end - def sort( - [%Value{function: %Function{ref: func_ref}} | _] = values, - %Function{ref: comparator_fun}, - axis, - stable - ) - when is_integer(axis) and is_boolean(stable) do - stable = if stable, do: 1, else: 0 + def slice(%Value{function: func} = operand, starts, limits, strides, typespec) do + result_types = typespecs_to_mlir_types([typespec]) - in_refs = - Enum.map(values, fn %Value{ref: ref, function: %Function{ref: ^func_ref}} -> ref end) + attributes = [ + start_indices: attr_dense_i64_elements(starts), + limit_indices: attr_dense_i64_elements(limits), + strides: attr_dense_i64_elements(strides) + ] - out_refs = - EXLA.NIF.mlir_sort(func_ref, in_refs, axis, comparator_fun, stable) |> unwrap!() + op(func, "stablehlo.slice", [operand], result_types, attributes: attributes) |> one!() + end - Enum.zip_with(values, out_refs, fn value, out_ref -> %Value{value | ref: out_ref} end) + def dynamic_slice(%Value{function: func} = operand, starts, lengths, typespec) do + result_types = typespecs_to_mlir_types([typespec]) + operands = [operand] ++ starts + attributes = [slice_sizes: attr_dense_i64_elements(lengths)] + op(func, "stablehlo.dynamic_slice", operands, result_types, attributes: attributes) |> one!() end - def iota(%Function{} = func, shape, dim) do - ref = EXLA.NIF.mlir_iota(func.ref, shape.ref, dim) |> unwrap!() - %Value{ref: ref, function: func} + def convert(%Value{function: func} = operand, typespec) do + result_types = typespecs_to_mlir_types([typespec]) + op(func, "stablehlo.convert", [operand], result_types) |> one!() end - def constant_r0(%Function{} = func, value, {:c, width} = type) - when type in [{:c, 64}, {:c, 128}] do - {re, im} = - case value do - %Complex{re: re, im: im} -> {re, im} - n when is_float(n) -> {n, 0.0} - n when is_integer(n) -> {n * 1.0, 0.0} - true -> {1.0, 0.0} - false -> {0.0, 0.0} - end + def bitcast_convert(%Value{function: func} = operand, typespec) do + result_types = typespecs_to_mlir_types([typespec]) + op(func, "stablehlo.bitcast_convert", [operand], result_types) |> one!() + end - width = div(width, 2) + def top_k(%Value{function: func} = operand, k, typespecs) do + [typespec, index_typespec] = typespecs + result_types = typespecs_to_mlir_types([typespec, Typespec.to_type(index_typespec, {:s, 32})]) - data = <> + attributes = [k: attr_i64(k)] + [result, idx] = op(func, "chlo.top_k", [operand], result_types, attributes: attributes) - ref = - EXLA.NIF.mlir_constant_from_binary( - func.ref, - data, - EXLA.Shape.dtype_to_charlist(type), - {1} - ) - |> unwrap!() + idx = convert(idx, index_typespec) - reshape(%Value{ref: ref, function: func}, {}) + [result, idx] end - def constant_r0(%Function{} = func, value, type) - when value in [:infinity, :nan, :neg_infinity] do - data = - value - |> Nx.tensor(backend: Nx.BinaryBackend, type: type) - |> Nx.to_binary() - - ref = - EXLA.NIF.mlir_constant_from_binary( - func.ref, - data, - EXLA.Shape.dtype_to_charlist(type), - {} + def sort( + [%Value{function: func} | _] = operands, + %Region{ref: comparator}, + axis, + stable, + typespecs ) - |> unwrap!() - - %Value{ref: ref, function: func} - end + when is_integer(axis) and is_boolean(stable) do + result_types = typespecs_to_mlir_types(typespecs) - def constant_r0(%Function{} = func, value, type) do - value = - if Nx.Type.float?(type) and not is_float(value) do - value * 1.0 - else - value - end + attributes = [ + dimension: attr_i64(axis), + is_stable: attr_boolean(stable) + ] - ref = - EXLA.NIF.mlir_constant_r0( - func.ref, - value, - EXLA.Shape.dtype_to_charlist(type) - ) - |> unwrap!() + regions = [comparator] - %Value{ref: ref, function: func} + op(func, "stablehlo.sort", operands, result_types, attributes: attributes, regions: regions) end - def constant_from_binary(%Function{} = func, data, shape) do - ref = - EXLA.NIF.mlir_constant_from_binary( - func.ref, - data, - EXLA.Shape.dtype_to_charlist(shape.dtype), - shape.dims - ) - |> unwrap!() + def iota(%Function{} = func, dim, typespec) do + result_types = typespecs_to_mlir_types([typespec]) + attributes = [iota_dimension: attr_i64(dim)] + op(func, "stablehlo.iota", [], result_types, attributes: attributes) |> one!() + end - %Value{ref: ref, function: func} + def constant(%Function{} = func, data, typespec) do + result_types = typespecs_to_mlir_types([typespec]) + value = attr_dense_elements(data, typespec.type, typespec.shape) + attributes = [value: value] + op(func, "stablehlo.constant", [], result_types, attributes: attributes) |> one!() end def dot_general( - output_shape, %Value{function: func} = lhs, %Value{function: func} = rhs, dnums, - precision_config + precision_config, + typespec ) do - config = get_precision_config_int(precision_config) + result_types = typespecs_to_mlir_types([typespec]) - ref = - EXLA.NIF.mlir_dot_general(func.ref, output_shape.ref, lhs.ref, rhs.ref, dnums, config) - |> unwrap!() + attr_precision_config = attr_precision_config(precision_config) - %Value{ref: ref, function: func} - end + {contract_axes1, batch_axes1, contract_axes2, batch_axes2} = dnums - def broadcast_in_dim(%Value{function: func} = operand, output_shape, axes) do - ref = - EXLA.NIF.mlir_broadcast_in_dim(func.ref, output_shape.ref, operand.ref, axes) - |> unwrap!() + dot_dimension_numbers = + attr_struct("stablehlo.dot", + lhs_batching_dimensions: join_list(batch_axes1), + rhs_batching_dimensions: join_list(batch_axes2), + lhs_contracting_dimensions: join_list(contract_axes1), + rhs_contracting_dimensions: join_list(contract_axes2) + ) - %Value{function: func, ref: ref} + attributes = [ + dot_dimension_numbers: dot_dimension_numbers, + precision_config: "[#{attr_precision_config}, #{attr_precision_config}]" + ] + + op(func, "stablehlo.dot_general", [lhs, rhs], result_types, attributes: attributes) |> one!() end - def concatenate([%Value{function: func} | _rest] = operands, dimension) do - refs = Enum.map(operands, & &1.ref) + def broadcast_in_dim(%Value{function: func} = operand, axes, typespec) do + result_types = typespecs_to_mlir_types([typespec]) - ref = - EXLA.NIF.mlir_concatenate(func.ref, refs, dimension) - |> unwrap!() + attributes = [ + broadcast_dimensions: attr_dense_i64_elements(axes) + ] - %Value{ref: ref, function: func} + op(func, "stablehlo.broadcast_in_dim", [operand], result_types, attributes: attributes) + |> one!() end - def optimization_barrier(%Value{function: func} = operand) do - ref = - EXLA.NIF.mlir_optimization_barrier(func.ref, operand.ref) - |> unwrap!() - - %Value{ref: ref, function: func} + def concatenate([%Value{function: func} | _rest] = operands, dimension, typespec) do + result_types = typespecs_to_mlir_types([typespec]) + attributes = [dimension: attr_i64(dimension)] + op(func, "stablehlo.concatenate", operands, result_types, attributes: attributes) |> one!() end def clamp( %Value{function: func} = operand, %Value{function: func} = min, - %Value{function: func} = max + %Value{function: func} = max, + typespec ) do - ref = - EXLA.NIF.mlir_clamp(func.ref, operand.ref, min.ref, max.ref) - |> unwrap!() - - %Value{ref: ref, function: func} + result_types = typespecs_to_mlir_types([typespec]) + op(func, "stablehlo.clamp", [min, operand, max], result_types) |> one!() end def select( %Value{function: func} = pred, %Value{function: func} = on_true, - %Value{function: func} = on_false + %Value{function: func} = on_false, + typespec ) do - ref = - EXLA.NIF.mlir_select(func.ref, pred.ref, on_true.ref, on_false.ref) - |> unwrap!() + result_types = typespecs_to_mlir_types([typespec]) + op(func, "stablehlo.select", [pred, on_true, on_false], result_types) |> one!() + end + + def pad( + %Value{function: func} = operand, + %Value{function: func} = pad, + padding_config, + typespec + ) do + result_types = typespecs_to_mlir_types([typespec]) + + {padding_low, padding_high, padding_mid} = unzip_padding_config(padding_config) + + attributes = [ + edge_padding_low: attr_dense_i64_elements(padding_low), + edge_padding_high: attr_dense_i64_elements(padding_high), + interior_padding: attr_dense_i64_elements(padding_mid) + ] - %Value{ref: ref, function: func} + op(func, "stablehlo.pad", [operand, pad], result_types, attributes: attributes) |> one!() end - def pad(%Value{function: func} = operand, %Value{function: func} = pad, padding_config) do - {padding_low, padding_high, padding_mid} = - Enum.reduce(padding_config, {[], [], []}, fn {low, high, mid}, - {low_acc, high_acc, mid_acc} -> - {[low | low_acc], [high | high_acc], [mid | mid_acc]} - end) + defp unzip_padding_config(padding_config), + do: unzip_padding_config(padding_config, {[], [], []}) - ref = - EXLA.NIF.mlir_pad( - func.ref, - operand.ref, - pad.ref, - Enum.reverse(padding_low), - Enum.reverse(padding_high), - Enum.reverse(padding_mid) - ) - |> unwrap!() + defp unzip_padding_config([], {low_acc, high_acc, mid_acc}) do + {Enum.reverse(low_acc), Enum.reverse(high_acc), Enum.reverse(mid_acc)} + end - %Value{ref: ref, function: func} + defp unzip_padding_config([{low, high, mid} | rest], {low_acc, high_acc, mid_acc}) do + unzip_padding_config(rest, {[low | low_acc], [high | high_acc], [mid | mid_acc]}) end - def fft(%Value{function: func} = value, fft_kind, fft_length) + def fft(%Value{function: func} = value, fft_kind, fft_length, typespec) when fft_kind in [:fft, :ifft] when is_list(fft_length) or is_integer(fft_length) do - ref = - EXLA.NIF.mlir_fft( - func.ref, - value.ref, - if(fft_kind == :fft, do: 1, else: 0), - List.wrap(fft_length) - ) - |> unwrap!() + result_types = typespecs_to_mlir_types([typespec]) + + fft_type = attr_fft_type(fft_kind) - %Value{value | ref: ref} + attributes = [ + fft_type: fft_type, + fft_length: attr_dense_i64_elements(List.wrap(fft_length)) + ] + + op(func, "stablehlo.fft", [value], result_types, attributes: attributes) |> one!() end def scatter( @@ -325,27 +389,50 @@ defmodule EXLA.MLIR.Value do indices_rank, update_window_dims, inserted_window_dims, - index_dims_to_window_dims + index_dims_to_window_dims, + typespec ) when kind in [:add, :put] and is_integer(indices_rank) and is_list(update_window_dims) and is_list(inserted_window_dims) and is_list(index_dims_to_window_dims) do - add_or_put = if(kind == :add, do: 1, else: 0) - - ref = - EXLA.NIF.mlir_scatter( - func.ref, - target.ref, - indices.ref, - updates.ref, - add_or_put, - indices_rank, - update_window_dims, - inserted_window_dims, - index_dims_to_window_dims + result_types = typespecs_to_mlir_types([typespec]) + + operands = [target, indices, updates] + + scatter_dimension_numbers = + attr_struct("stablehlo.scatter", + update_window_dims: join_list(update_window_dims), + inserted_window_dims: join_list(inserted_window_dims), + scatter_dims_to_operand_dims: join_list(index_dims_to_window_dims), + index_vector_dim: Integer.to_string(indices_rank) ) - |> unwrap!() - %Value{target | ref: ref} + attributes = [scatter_dimension_numbers: scatter_dimension_numbers] + + scatter_computation = scatter_computation(func, kind, typespec) + regions = [scatter_computation.ref] + + op(func, "stablehlo.scatter", operands, result_types, + attributes: attributes, + regions: regions + ) + |> one!() + end + + defp scatter_computation(%Function{} = function, kind, typespec) do + arg_typespec = Typespec.to_shape(typespec, {}) + {region, [value, update]} = Function.push_region(function, [arg_typespec, arg_typespec]) + + res = + case kind do + :add -> add(value, update, arg_typespec) + :put -> update + end + + return(function, [res]) + + Function.pop_region(function) + + region end def select_and_scatter( @@ -355,25 +442,41 @@ defmodule EXLA.MLIR.Value do comparison, window_dimensions, window_strides, - padding + padding, + typespec ) when comparison in [:gt, :lt] do - gt_or_lt = if(comparison == :gt, do: 1, else: 0) - - ref = - EXLA.NIF.mlir_select_and_scatter( - func.ref, - target.ref, - source.ref, - init_value.ref, - gt_or_lt, - window_dimensions, - window_strides, - padding - ) - |> unwrap!() + operands = [target, source, init_value] + + result_types = typespecs_to_mlir_types([typespec]) + + attributes = [ + window_dimensions: attr_dense_i64_elements(window_dimensions), + window_strides: attr_dense_i64_elements(window_strides), + padding: attr_padding(padding) + ] + + select_computation = select_computation(func, comparison, typespec) + scatter_computation = scatter_computation(func, :add, typespec) + regions = [select_computation.ref, scatter_computation.ref] - %Value{target | ref: ref} + op(func, "stablehlo.select_and_scatter", operands, result_types, + attributes: attributes, + regions: regions + ) + |> one!() + end + + defp select_computation(function, direction, typespec) do + arg_typespec = Typespec.to_shape(typespec, {}) + {region, [arg0, arg1]} = Function.push_region(function, [arg_typespec, arg_typespec]) + + res = compare_and_return_bool(function, arg0, arg1, arg_typespec, direction) + return(function, [res]) + + Function.pop_region(function) + + region end def gather( @@ -383,42 +486,44 @@ defmodule EXLA.MLIR.Value do slice_sizes, offset_dims, collapsed_slice_dims, - start_index_map - ) do - ref = - EXLA.NIF.mlir_gather( - func.ref, - source.ref, - indices.ref, - slice_sizes, - offset_dims, - collapsed_slice_dims, start_index_map, - index_vector_dim + typespec + ) do + result_types = typespecs_to_mlir_types([typespec]) + + dimension_numbers = + attr_struct("stablehlo.gather", + offset_dims: join_list(offset_dims), + collapsed_slice_dims: join_list(collapsed_slice_dims), + start_index_map: join_list(start_index_map), + index_vector_dim: Integer.to_string(index_vector_dim) ) - |> unwrap!() - %Value{source | ref: ref} + attributes = [ + dimension_numbers: dimension_numbers, + slice_sizes: attr_dense_i64_elements(slice_sizes), + indices_are_sorted: attr_boolean(false) + ] + + op(func, "stablehlo.gather", [source, indices], result_types, attributes: attributes) + |> one!() end - defp get_precision_config_int(precision_config) do + defp attr_precision_config(precision_config) do case precision_config do :default -> - 0 + attr_precision(:default) :high -> - 1 + attr_precision(:high) :highest -> - 2 - - :packed_nibble -> - 3 + attr_precision(:highest) _ -> raise ArgumentError, "expected precision configuration to be one of" <> - " :default, :high, :highest, or :packed_nibble," <> + " :default, :high, or :highest," <> " got: #{inspect(precision_config)}" end end @@ -434,206 +539,428 @@ defmodule EXLA.MLIR.Value do feature_group_count, batch_group_count, precision_config, - output_shape + typespec ) do - precision_config = get_precision_config_int(precision_config) + result_types = typespecs_to_mlir_types([typespec]) - ref = - EXLA.NIF.mlir_convolution( - func.ref, - tensor.ref, - kernel.ref, - strides, - padding, - input_dilation, - kernel_dilation, - dimension_numbers, - feature_group_count, - batch_group_count, - precision_config, - Tuple.to_list(output_shape) - ) - |> unwrap!() + attr_precision_config = attr_precision_config(precision_config) + + attributes = [ + window_strides: attr_dense_i64_elements(strides), + padding: attr_padding(padding), + lhs_dilation: attr_dense_i64_elements(input_dilation), + rhs_dilation: attr_dense_i64_elements(kernel_dilation), + dimension_numbers: attr_conv_dimension_numbers(dimension_numbers), + feature_group_count: attr_i64(feature_group_count), + batch_group_count: attr_i64(batch_group_count), + precision_config: "[#{attr_precision_config}, #{attr_precision_config}]" + ] - %{tensor | ref: ref} + op(func, "stablehlo.convolution", [tensor, kernel], result_types, attributes: attributes) + |> one!() end - def triangular_solve(a, b, left_side, lower, transform) do - ref = - EXLA.NIF.mlir_triangular_solve( - a.function.ref, - a.ref, - b.ref, - if(left_side, do: 1, else: 0), - if(lower, do: 1, else: 0), - if(transform == :transpose, do: 1, else: 0) - ) - |> unwrap!() + defp attr_conv_dimension_numbers(dimension_numbers) do + {input_permutation, kernel_permutation, output_permutation} = dimension_numbers + input_string = convolution_dims_permutation(input_permutation, "b", "f") + kernel_string = convolution_dims_permutation(kernel_permutation, "o", "i") + output_string = convolution_dims_permutation(output_permutation, "b", "f") + "#stablehlo.conv<[#{input_string}]x[#{kernel_string}]->[#{output_string}]>" + end - %{a | ref: ref} + defp convolution_dims_permutation(permutation, dim1_mark, dim2_mark) do + [dim1, dim2 | spatial_dims] = permutation + + dims_with_marks = + [{dim1, dim1_mark}, {dim2, dim2_mark}] ++ + Enum.with_index(spatial_dims, fn dim, idx -> {dim, Integer.to_string(idx)} end) + + dims_with_marks + |> Enum.sort() + |> Enum.map_join(",", fn {_dim, mark} -> mark end) end - def dynamic_update_slice(operand, updates, starts) do - ref = - EXLA.NIF.mlir_dynamic_update_slice( - operand.function.ref, - operand.ref, - updates.ref, - Enum.map(starts, & &1.ref) - ) - |> unwrap!() + def triangular_solve( + %Value{function: func} = a, + %Value{function: func} = b, + left_side, + lower, + transform, + typespec + ) do + result_types = typespecs_to_mlir_types([typespec]) + + complex? = Nx.Type.complex?(typespec.type) + + transpose_a = + case transform do + :transpose when complex? -> attr_transpose(:adjoint) + :transpose -> attr_transpose(:transpose) + :none -> attr_transpose(:no_transpose) + end + + attributes = [ + left_side: attr_boolean(left_side), + lower: attr_boolean(lower), + unit_diagonal: attr_boolean(false), + transpose_a: transpose_a + ] - %{operand | ref: ref} + op(func, "stablehlo.triangular_solve", [a, b], result_types, attributes: attributes) |> one!() + end + + def dynamic_update_slice(%Value{function: func} = operand, updates, starts, typespec) do + result_types = typespecs_to_mlir_types([typespec]) + + op(func, "stablehlo.dynamic_update_slice", [operand, updates] ++ starts, result_types) + |> one!() end def reduce( - %Function{ref: reducer}, + %Region{ref: reducer}, [%Value{function: func} | _] = init_values, [%Value{function: func} | _] = inputs, - dimensions + dimensions, + typespecs ) do - init_value_refs = Enum.map(init_values, & &1.ref) - input_refs = Enum.map(inputs, & &1.ref) - - refs = - EXLA.NIF.mlir_reduce(func.ref, reducer, init_value_refs, input_refs, dimensions) - |> unwrap!() - - Enum.map(refs, &%Value{ref: &1, function: func}) + operands = inputs ++ init_values + result_types = typespecs_to_mlir_types(typespecs) + attributes = [dimensions: attr_dense_i64_elements(dimensions)] + regions = [reducer] + op(func, "stablehlo.reduce", operands, result_types, attributes: attributes, regions: regions) end def window_reduce( - %Function{ref: reducer}, + %Region{ref: reducer}, [%Value{function: func} | _] = init_values, [%Value{function: func} | _] = inputs, window_dimensions, window_strides, input_dilations, window_dilations, - padding + padding, + typespecs ) do - init_value_refs = Enum.map(init_values, & &1.ref) - input_refs = Enum.map(inputs, & &1.ref) + operands = inputs ++ init_values + result_types = typespecs_to_mlir_types(typespecs) - refs = - EXLA.NIF.mlir_window_reduce( - func.ref, - reducer, - init_value_refs, - input_refs, - window_dimensions, - window_strides, - input_dilations, - window_dilations, - padding - ) - |> unwrap!() + attributes = [ + window_dimensions: attr_dense_i64_elements(window_dimensions), + window_strides: attr_dense_i64_elements(window_strides), + base_dilations: attr_dense_i64_elements(input_dilations), + window_dilations: attr_dense_i64_elements(window_dilations), + padding: attr_padding(padding) + ] + + regions = [reducer] - Enum.map(refs, &%Value{ref: &1, function: func}) + op(func, "stablehlo.reduce_window", operands, result_types, + attributes: attributes, + regions: regions + ) end def map( - %Function{ref: mapper}, + %Region{ref: mapper}, [%Value{function: func} | _] = inputs, - dimensions + dimensions, + typespec ) do - input_refs = Enum.map(inputs, & &1.ref) + result_types = typespecs_to_mlir_types([typespec]) - ref = - EXLA.NIF.mlir_map(func.ref, mapper, input_refs, dimensions) - |> unwrap!() + attributes = [ + dimensions: attr_dense_i64_elements(dimensions) + ] - %Value{ref: ref, function: func} - end + regions = [mapper] - def if_op(%Value{} = pred, [%EXLA.Shape{} | _] = output_shapes) do - {refs, true_region, false_region} = - EXLA.NIF.mlir_if( - pred.function.ref, - pred.ref, - flatten_shapes(output_shapes) - ) - |> unwrap!() + op(func, "stablehlo.map", inputs, result_types, attributes: attributes, regions: regions) + |> one!() + end - results = Enum.map(refs, &%Value{ref: &1, function: pred.function}) + def if_op( + %Value{function: func} = pred, + %Region{ref: on_true}, + %Region{ref: on_false}, + typespecs + ) do + result_types = typespecs_to_mlir_types(typespecs) + regions = [on_true, on_false] + pred = convert(pred, Typespec.tensor({:pred, 8}, {})) + op(func, "stablehlo.if", [pred], result_types, regions: regions) + end - {results, %Region{ref: true_region}, %Region{ref: false_region}} + def infeed(%Value{function: func} = token, typespecs) do + result_types = typespecs_to_mlir_types(typespecs ++ [Typespec.token()]) + results = op(func, "stablehlo.infeed", [token], result_types) + {results, [token]} = Enum.split(results, -1) + {token, results} end - def infeed(%Value{} = token, %EXLA.Shape{} = shape) do - infeed(token, [shape]) + def outfeed(%Value{} = input, token), do: outfeed([input], token) + + def outfeed(inputs, %Value{function: func} = token) do + result_types = [type_token()] + op(func, "stablehlo.outfeed", inputs ++ [token], result_types) |> one!() end - def infeed(%Value{function: function} = token, [%EXLA.Shape{} | _] = shapes) do - {token_ref, result_refs} = - EXLA.NIF.mlir_infeed(function.ref, token.ref, Enum.map(shapes, & &1.ref)) |> unwrap!() + def create_token(%Function{} = func) do + result_types = [type_token()] + op(func, "stablehlo.create_token", [], result_types) |> one!() + end - {%Value{token | ref: token_ref}, Enum.map(result_refs, &%Value{token | ref: &1})} + def call(%Function{} = func, args, %Function{} = computation, typespecs) do + result_types = typespecs_to_mlir_types(typespecs) + attributes = [callee: attr_symbol_reference(computation.name)] + op(func, "func.call", args, result_types, attributes: attributes) end - def outfeed(%Value{} = input, token), do: outfeed([input], token) + def while(%Function{} = func, %Region{ref: pred}, %Region{ref: body}, initial) do + typespecs = Enum.map(initial, &get_typespec/1) + result_types = typespecs_to_mlir_types(typespecs) - def outfeed(inputs, %Value{function: function} = token) do - input_refs = Enum.map(inputs, & &1.ref) - ref = EXLA.NIF.mlir_outfeed(function.ref, token.ref, input_refs) |> unwrap!() - %{token | ref: ref} + regions = [pred, body] + + op(func, "stablehlo.while", initial, result_types, regions: regions) end - def create_token(%Function{ref: ref} = function) do - ref = EXLA.NIF.mlir_create_token(ref) |> unwrap!() - %Value{ref: ref, function: function} + def return(func, values) when is_list(values) do + op(func, "stablehlo.return", values, []) end - def call(%Function{ref: fun_ref} = function, args, %Function{ref: computation_ref}) do - arg_refs = Enum.map(args, & &1.ref) - refs = EXLA.NIF.mlir_call(fun_ref, arg_refs, computation_ref) |> unwrap!() - Enum.map(refs, &%Value{ref: &1, function: function}) + def qr(%Value{function: func} = value, q_typespec, r_typespec) do + %{type: op_type, shape: op_shape} = get_typespec(value) + %{type: q_type, shape: q_shape} = q_typespec + %{type: r_type, shape: r_shape} = r_typespec + + dim_sizes = [tuple_size(op_shape), tuple_size(q_shape), tuple_size(r_shape)] + operand_dims = Tuple.to_list(op_shape) + q_dims = Tuple.to_list(q_shape) + r_dims = Tuple.to_list(r_shape) + + dim_sizes = constant(func, dim_sizes, Typespec.tensor({:s, 64}, {length(dim_sizes)})) + operand_dims = constant(func, operand_dims, Typespec.tensor({:s, 64}, {length(operand_dims)})) + q_dims = constant(func, q_dims, Typespec.tensor({:s, 64}, {length(q_dims)})) + r_dims = constant(func, r_dims, Typespec.tensor({:s, 64}, {length(r_dims)})) + operands = [value, dim_sizes, operand_dims, q_dims, r_dims] + + q_result_type = type_tensor(q_type, q_shape) + r_result_type = type_tensor(r_type, r_shape) + result_types = [type_tuple([q_result_type, r_result_type])] + + call_target_name = + case op_type do + {:f, 32} -> "qr_cpu_custom_call_f32" + {:f, 64} -> "qr_cpu_custom_call_f64" + {:f, 16} -> "qr_cpu_custom_call_f16" + {:bf, 16} -> "qr_cpu_custom_call_bf16" + type -> raise "QR decomposition not supported for type #{inspect(type)}" + end + + attributes = [ + call_target_name: attr_string(call_target_name), + backend_config: attr_string("Host") + ] + + result = + op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) |> one!() + + q = get_tuple_element(result, 0, q_typespec) + r = get_tuple_element(result, 1, r_typespec) + + {q, r} end - def while(function, initial) do - {result_refs, pred_ref, body_ref} = - EXLA.NIF.mlir_while(function.ref, flatten_tuples(initial)) |> unwrap!() + def get_tuple_element(%Value{function: func} = operand, index, typespec) do + result_types = typespecs_to_mlir_types([typespec]) + attributes = [index: attr_i32(index)] + + op(func, "stablehlo.get_tuple_element", [operand], result_types, attributes: attributes) + |> one!() + end - results = Enum.map(result_refs, &%Value{function: function, ref: &1}) + def get_typespec(value) do + EXLA.NIF.mlir_get_typespec(value.ref) + |> unwrap!() + |> Typespec.nif_decode() + end - {results, %Region{ref: pred_ref}, %Region{ref: body_ref}} + def typespecs_to_mlir_types(shapes) do + Enum.map(shapes, &typespec_to_mlir_type/1) end - def variadic_return(function, values) when is_list(values) do - refs = Enum.map(values, & &1.ref) + defp typespec_to_mlir_type(%{type: :token}), do: type_token() + defp typespec_to_mlir_type(%{type: type, shape: shape}), do: type_tensor(type, shape) - refs = EXLA.NIF.mlir_return(function.ref, refs) |> unwrap!() + defp unwrap!(:ok), do: :ok + defp unwrap!({:ok, value}), do: value + defp unwrap!(other), do: raise("#{inspect(other)}") + + defp one!([value]), do: value - Enum.map(refs, fn ref -> %Value{function: function, ref: ref} end) + defp one!(other) do + raise "expected a list with single element, got: #{inspect(other)}" end - def qr(%Value{function: function, ref: ref}, q_shape, r_shape) - when is_tuple(q_shape) and is_tuple(r_shape) do - {q_ref, r_ref} = - EXLA.NIF.mlir_qr(function.ref, ref, Tuple.to_list(q_shape), Tuple.to_list(r_shape)) + defp complex_part_type({:c, size}), do: {:f, div(size, 2)} + + defp op(%Function{} = function, op_name, operands, result_types, opts \\ []) do + opts = Keyword.validate!(opts, attributes: [], regions: []) + + %{ref: function_ref} = function + + refs = + Enum.map(operands, fn + %Value{ref: ref, function: %Function{ref: ^function_ref}} -> ref + end) + + refs = + EXLA.NIF.mlir_op( + function.ref, + op_name, + refs, + result_types, + opts[:attributes], + opts[:regions] + ) |> unwrap!() - { - %Value{function: function, ref: q_ref}, - %Value{function: function, ref: r_ref} - } + Enum.map(refs, &%Value{function: function, ref: &1}) end - defp flatten_shapes(val) when is_list(val) do - Enum.flat_map(val, &flatten_shapes/1) + defp type_tensor(type, shape) do + shape_sequence = shape |> Tuple.to_list() |> Enum.map_join("", &"#{&1}x") + "tensor<#{shape_sequence}#{type_number(type)}>" end - defp flatten_shapes(val) do - [val.ref] + defp type_number({:pred, 8}), do: "i1" + defp type_number({:s, width}), do: "i#{width}" + defp type_number({:u, width}), do: "ui#{width}" + defp type_number({:f, width}), do: "f#{width}" + defp type_number({:bf, width}), do: "bf#{width}" + defp type_number({:c, 64}), do: "complex" + defp type_number({:c, 128}), do: "complex" + + defp type_token(), do: "!stablehlo.token" + + defp type_tuple(children) do + "tuple<#{Enum.join(children, ", ")}>" end - defp flatten_tuples(val) when is_list(val) do - Enum.flat_map(val, &flatten_tuples/1) + defp number_literal(value, type) do + cond do + Nx.Type.complex?(type) -> + {re, im} = + case value do + %Complex{re: re, im: im} -> {re, im} + true -> {1, 0} + false -> {0, 0} + n -> {n, 0} + end + + subtype = complex_part_type(type) + "(#{number_literal(re, subtype)}, #{number_literal(im, subtype)})" + + Nx.Type.float?(type) -> + # We pass floats using binary representation, because that is + # likely more robust and not a subject to formatting limits and + # rounding. Based on the examples in the docs, the hexadecimal + # representation is always big-endian. + # + # See https://mlir.llvm.org/docs/Dialects/Builtin/#floatattr + hex_data = float_hex(value, type) + "0x#{hex_data}" + + true -> + "#{value}" + end end - defp flatten_tuples(val), do: [val.ref] + defp float_hex(value, {_, size} = type) do + data = + case value do + :nan -> type |> Nx.Type.nan_binary() |> native_to_big() + :infinity -> type |> Nx.Type.infinity_binary() |> native_to_big() + :neg_infinity -> type |> Nx.Type.neg_infinity_binary() |> native_to_big() + value -> <> + end + + Base.encode16(data) + end - defp unwrap!(:ok), do: :ok - defp unwrap!({:ok, value}), do: value - defp unwrap!(other), do: raise("#{inspect(other)}") + defp native_to_big(binary) do + size = byte_size(binary) * 8 + <> = binary + <> + end + + defp attr_dense_i64_elements(list) do + attr_dense_elements(list, {:s, 64}, {length(list)}) + end + + defp attr_dense_elements([], type, {0} = shape) do + "dense<[]> : #{type_tensor(type, shape)}" + end + + defp attr_dense_elements(list, type, shape) do + literals = Enum.map(list, &number_literal(&1, type)) + + list_literal = + shape + |> Tuple.to_list() + |> List.foldr(literals, fn size, acc -> + acc + |> Enum.chunk_every(size) + |> Enum.map(fn chunk -> + ["[", Enum.intersperse(chunk, ", "), "]"] + end) + end) + |> IO.iodata_to_binary() + + "dense<#{list_literal}> : #{type_tensor(type, shape)}" + end + + defp attr_string(string), do: ~s["#{string}"] + + defp attr_symbol_reference(id), do: "@#{id}" + + defp attr_boolean(true), do: "true" + defp attr_boolean(false), do: "false" + + defp attr_i32(number), do: "#{number} : i32" + defp attr_i64(number), do: "#{number} : i64" + + defp attr_padding(padding) do + list = Enum.flat_map(padding, &Tuple.to_list/1) + attr_dense_elements(list, {:s, 64}, {length(padding), 2}) + end + + defp attr_comparison_direction(value) when value in [:eq, :lt, :le, :gt, :ge, :ne], + do: attr_enum("stablehlo", "comparison_direction", value) + + defp attr_comparison_type(value) when value in [:totalorder, :notype], + do: attr_enum("stablehlo", "comparison_type", value) + + defp attr_precision(value) when value in [:default, :high, :highest], + do: attr_enum("stablehlo", "precision", value) + + defp attr_transpose(value) when value in [:adjoint, :transpose, :no_transpose], + do: attr_enum("stablehlo", "transpose", value) + + defp attr_fft_type(value) when value in [:fft, :ifft], + do: attr_enum("stablehlo", "fft_type", value) + + defp attr_enum(dialect, enum_name, value) do + value = value |> Atom.to_string() |> String.upcase() + "##{dialect}<#{enum_name} #{value}>" + end + + defp attr_struct(name, keyword_list) do + content = Enum.map_join(keyword_list, ", ", fn {key, value} -> "#{key} = #{value}" end) + "##{name}<#{content}>" + end + + defp join_list(list) do + "[" <> Enum.join(list, ", ") <> "]" + end end diff --git a/exla/lib/exla/nif.ex b/exla/lib/exla/nif.ex index 0e83d77b5a..6830df726c 100644 --- a/exla/lib/exla/nif.ex +++ b/exla/lib/exla/nif.ex @@ -7,70 +7,21 @@ defmodule EXLA.NIF do :erlang.load_nif(path, 0) end - def new_mlir_context, do: :erlang.nif_error(:undef) + def mlir_new_context, do: :erlang.nif_error(:undef) - def new_mlir_module(_context), do: :erlang.nif_error(:undef) + def mlir_new_module(_context), do: :erlang.nif_error(:undef) - def create_mlir_function(_module, _name, _arg_types, _ret_type, _is_public), + def mlir_create_function(_module, _name, _arg_types, _ret_type, _is_public), do: :erlang.nif_error(:undef) - def get_mlir_function_arguments(_function), do: :erlang.nif_error(:undef) + def mlir_get_function_arguments(_function), do: :erlang.nif_error(:undef) - @bin_ops [:add, :subtract, :multiply, :divide, :pow, :min] ++ - [:max, :remainder, :atan2, :equal, :not_equal] ++ - [:less, :less_equal, :greater, :greater_equal] ++ - [:bitwise_and, :bitwise_or, :bitwise_xor] ++ - [:left_shift, :right_shift_arithmetic, :right_shift_logical] - - for op <- @bin_ops do - mlir_op = :"mlir_#{op}" - def unquote(mlir_op)(_function, _lhs, _rhs), do: :erlang.nif_error(:undef) - end - - @unary_ops [:abs, :exp, :expm1, :floor, :ceil, :round] ++ - [:log, :log1p, :sigmoid, :sign, :cos] ++ - [:sin, :tan, :acos, :asin, :atan, :cosh, :sinh] ++ - [:tanh, :acosh, :asinh, :atanh, :sqrt, :cbrt] ++ - [:bitwise_not, :erf, :erfc, :erf_inv] ++ - [:is_infinity, :is_nan, :rsqrt, :negate, :count_leading_zeros] ++ - [:population_count, :real, :imag, :conjugate] - - for op <- @unary_ops do - mlir_op = :"mlir_#{op}" - def unquote(mlir_op)(_function, _operand), do: :erlang.nif_error(:undef) - end - - def mlir_reshape(_function, _operand, _shape), do: :erlang.nif_error(:undef) - def mlir_reverse(_function, _operand, _shape), do: :erlang.nif_error(:undef) - def mlir_transpose(_function, _operand, _shape), do: :erlang.nif_error(:undef) - def mlir_slice(_function, _operand, _starts, _limits, _strides), do: :erlang.nif_error(:undef) - def mlir_dynamic_slice(_function, _operand, _starts, _lengths), do: :erlang.nif_error(:undef) - def mlir_pad(_function, _tensor, _pad, _low, _high, _mid), do: :erlang.nif_error(:undef) - - def mlir_reduce(_function, _reducer, _init_values, _inputs, _dimensions), - do: :erlang.nif_error(:undef) - - def mlir_window_reduce( - _function, - _reducer, - _init_values, - _inputs, - _window_dimensions, - _window_strides, - _input_dilations, - _window_dilations, - _padding - ), - do: :erlang.nif_error(:undef) - - def mlir_map(_function, _mapper, _inputs, _dimensions), + def mlir_op(_function, _op_name, _operands, _result_type, _attributes, _blocks), do: :erlang.nif_error(:undef) - def mlir_if(_function, _pred, _output_shape), + def mlir_push_region(_function, _arg_types), do: :erlang.nif_error(:undef) - def mlir_push_region(_function, _region), do: :erlang.nif_error(:undef) - def mlir_pop_region(_function), do: :erlang.nif_error(:undef) @@ -87,109 +38,9 @@ defmodule EXLA.NIF do ), do: :erlang.nif_error(:undef) - def mlir_convert(_function, _tensor, _type), do: :erlang.nif_error(:undef) - def mlir_bitcast_convert(_function, _tensor, _type, _dims), do: :erlang.nif_error(:undef) - def mlir_top_k(_function, _tensor, _k), do: :erlang.nif_error(:undef) - def mlir_sort(_function, _tensors, _dim, _comparator, _stable), do: :erlang.nif_error(:undef) - - def mlir_get_shape(_tensor), do: :erlang.nif_error(:undef) - - def dump_mlir_module(_builder), do: :erlang.nif_error(:undef) - - def mlir_iota(_function, _shape, _dim), do: :erlang.nif_error(:undef) - def mlir_constant_r0(_function, _value, _type), do: :erlang.nif_error(:undef) - def mlir_constant_from_binary(_function, _data, _type, _dims), do: :erlang.nif_error(:undef) - - def mlir_dot_general(_function, _shape, _lhs, _rhs, _dims, _precision), - do: :erlang.nif_error(:undef) - - def mlir_broadcast_in_dim(_function, _shape, _operand, _axes), do: :erlang.nif_error(:undef) - def mlir_concatenate(_function, _operands, _dimension), do: :erlang.nif_error(:undef) - def mlir_optimization_barrier(_function, _operand), do: :erlang.nif_error(:undef) - def mlir_clamp(_function, _operand, _min, _max), do: :erlang.nif_error(:undef) - - def mlir_select(_function, _pred, _on_true, _on_false), - do: :erlang.nif_error(:undef) - - def mlir_scatter( - _function, - _target, - _indices, - _updates, - _add_or_put, - _indices_rank, - _update_window_dims, - _inserted_window_dims, - _index_axes_to_target_axes - ), - do: :erlang.nif_error(:undef) - - def mlir_select_and_scatter( - _function, - _target, - _source, - _init_value, - _gt_or_lt, - _window_dimensions, - _window_strides, - _padding - ), - do: :erlang.nif_error(:undef) - - def mlir_gather( - _function, - _sorce, - _indices, - _slice_sizes, - _offset_dims, - _collapsed_slice_dims, - _start_index_map, - _index_vector_dim - ), - do: :erlang.nif_error(:undef) - - def mlir_fft(_function, _tensor, _forward_fft, _fft_lenght), do: :erlang.nif_error(:undef) - - def mlir_convolution( - _function, - _tensor, - _kernel, - _strides, - _padding_config, - _tensor_dilation, - _kernel_dilation, - _dimension_numbers, - _feature_group_count, - _batch_group_count, - _precision_config, - _output_dims - ), - do: :erlang.nif_error(:undef) - - def mlir_create_token(_function), do: :erlang.nif_error(:undef) + def mlir_get_typespec(_tensor), do: :erlang.nif_error(:undef) - def mlir_triangular_solve(_function, _a, _b, _left_side, _lower, _transpose_a), - do: :erlang.nif_error(:undef) - - def mlir_dynamic_update_slice(_function, _operand, _updates, _starts), - do: :erlang.nif_error(:undef) - - def mlir_infeed(_function, _token, _shape), do: :erlang.nif_error(:undef) - def mlir_outfeed(_function, _token, _inputs), do: :erlang.nif_error(:undef) - - def mlir_call(_function, _args, _computation), do: :erlang.nif_error(:undef) - def mlir_while(_function, _initial), do: :erlang.nif_error(:undef) - def mlir_return(_function, _operands), do: :erlang.nif_error(:undef) - - def mlir_qr(_function, _operand, _q_shape, _r_shape), do: :erlang.nif_error(:undef) - - def get_shape_info(_ref), do: :erlang.nif_error(:undef) - - def make_shape(_type, _dims), - do: :erlang.nif_error(:undef) - - def make_token_shape(), - do: :erlang.nif_error(:undef) + def mlir_module_to_string(_builder), do: :erlang.nif_error(:undef) def get_host_client(), do: :erlang.nif_error(:undef) @@ -232,12 +83,12 @@ defmodule EXLA.NIF do _client, _opaque_pointer, _pointer_kind, - _shape, + _typespec, _device_id ), do: :erlang.nif_error(:undef) - def binary_to_device_mem(_client, _binary, _shape, _device_ordinal), + def binary_to_device_mem(_client, _binary, _typespec, _device_ordinal), do: :erlang.nif_error(:undef) def read_device_mem(_buffer, _size), @@ -246,10 +97,10 @@ defmodule EXLA.NIF do def deallocate_device_mem(_buffer), do: :erlang.nif_error(:undef) - def transfer_to_infeed(_client, _device, _data_shapes), + def transfer_to_infeed(_client, _device, _data_typespecs), do: :erlang.nif_error(:undef) - def transfer_from_outfeed(_client, _device, _shapes, _pid, _ref), + def transfer_from_outfeed(_client, _device, _typespecs, _pid, _ref), do: :erlang.nif_error(:undef) def copy_buffer_to_device(_client, _buffer, _device), diff --git a/exla/lib/exla/shape.ex b/exla/lib/exla/shape.ex deleted file mode 100644 index bb142f9aba..0000000000 --- a/exla/lib/exla/shape.ex +++ /dev/null @@ -1,68 +0,0 @@ -defmodule EXLA.Shape do - @moduledoc """ - Wrapper around XLA's shape. - """ - - alias __MODULE__ - import Kernel, except: [byte_size: 1] - - @enforce_keys [:ref, :dims, :dtype] - defstruct [:ref, :dims, :dtype] - - @doc false - def get_shape_info(ref) when is_reference(ref) do - {dims_term, type_str} = EXLA.NIF.get_shape_info(ref) |> unwrap!() - %Shape{dims: dims_term, dtype: charlist_to_dtype(type_str), ref: ref} - end - - @doc """ - Creates a shape with the given type-size tuple and dimensions. - """ - def make_shape({type, size}, dims) when is_tuple(dims) do - validate_dims!(dims, tuple_size(dims)) - ref = EXLA.NIF.make_shape(dtype_to_charlist({type, size}), dims) |> unwrap!() - %Shape{ref: ref, dtype: {type, size}, dims: dims} - end - - @doc """ - Creates a token shape. - """ - def make_token_shape() do - ref = EXLA.NIF.make_token_shape() |> unwrap!() - %Shape{dims: {}, dtype: :token, ref: ref} - end - - defp validate_dims!(_dims, 0), do: :ok - - defp validate_dims!(dims, i) - when is_integer(:erlang.element(i, dims)), - do: validate_dims!(dims, i - 1) - - defp validate_dims!(dims, _i) do - raise ArgumentError, "dimensions must be a tuple of integers, got: #{inspect(dims)}" - end - - @doc """ - Returns the shape size in bytes. - """ - def byte_size(%EXLA.Shape{dtype: {_, bit_size}, dims: dims}) do - Tuple.product(dims) * div(bit_size, 8) - end - - @doc """ - Converts a charlist type into Nx' tuple format. - """ - def charlist_to_dtype(~c"token"), do: :token - def charlist_to_dtype(~c"bf16"), do: {:bf, 16} - def charlist_to_dtype(~c"pred"), do: {:pred, 8} - def charlist_to_dtype([letter | int]), do: {List.to_atom([letter]), List.to_integer(int)} - - @doc """ - Converts Nx's tuple format into charlist. - """ - def dtype_to_charlist({:pred, _}), do: ~c"pred" - def dtype_to_charlist({type, size}), do: Atom.to_charlist(type) ++ Integer.to_charlist(size) - - defp unwrap!({:ok, ref}), do: ref - defp unwrap!({:error, error}), do: raise(List.to_string(error)) -end diff --git a/exla/lib/exla/typespec.ex b/exla/lib/exla/typespec.ex new file mode 100644 index 0000000000..471a25aace --- /dev/null +++ b/exla/lib/exla/typespec.ex @@ -0,0 +1,76 @@ +defmodule EXLA.Typespec do + @moduledoc """ + Combined type and shape information about tensors. + + In addition to the Nx types, also supports `{:pred, 8}` and `:token`, + which are used internally within the compiler. + + This struct corresponds to the `xla::Shape` class in the XLA compiler, + but is also meant as a lightweight data structure for passing the + information around. + + Note: the name "typespec" has been chosen intentionally to distinguish + it from both "type" and "shape". + """ + + @enforce_keys [:type, :shape] + defstruct [:type, :shape] + + @doc """ + Builds a tensor typespec. + """ + def tensor(type, shape) do + %__MODULE__{type: type, shape: shape} + end + + @doc """ + Builds a token typespec. + """ + def token() do + %__MODULE__{type: :token, shape: {}} + end + + @doc """ + Returns an updated typespec with the given type. + """ + def to_type(typespec, type), do: %{typespec | type: type} + + @doc """ + Returns an updated typespec with the given shape. + """ + def to_shape(typespec, shape), do: %{typespec | shape: shape} + + @doc false + def nif_encode(typespec) do + {type_to_charlist(typespec.type), typespec.shape} + end + + @doc false + def nif_decode({type_charlist, shape}) do + %__MODULE__{shape: shape, type: charlist_to_type(type_charlist)} + end + + type_to_charlist = %{ + :token => ~c"token", + {:pred, 8} => ~c"pred", + {:s, 8} => ~c"s8", + {:s, 16} => ~c"s16", + {:s, 32} => ~c"s32", + {:s, 64} => ~c"s64", + {:u, 8} => ~c"u8", + {:u, 16} => ~c"u16", + {:u, 32} => ~c"u32", + {:u, 64} => ~c"u64", + {:f, 16} => ~c"f16", + {:f, 32} => ~c"f32", + {:f, 64} => ~c"f64", + {:bf, 16} => ~c"bf16", + {:c, 64} => ~c"c64", + {:c, 128} => ~c"c128" + } + + for {type, charlist} <- type_to_charlist do + defp charlist_to_type(unquote(charlist)), do: unquote(type) + defp type_to_charlist(unquote(type)), do: unquote(charlist) + end +end diff --git a/exla/mix.exs b/exla/mix.exs index 23ca14ccb7..517e051772 100644 --- a/exla/mix.exs +++ b/exla/mix.exs @@ -88,7 +88,7 @@ defmodule EXLA.MixProject do EXLA.Client, EXLA.DeviceBuffer, EXLA.Executable, - EXLA.Shape + EXLA.Typespec ] ] ] diff --git a/exla/test/exla/defn/expr_test.exs b/exla/test/exla/defn/expr_test.exs index 96bf48b566..c540784c67 100644 --- a/exla/test/exla/defn/expr_test.exs +++ b/exla/test/exla/defn/expr_test.exs @@ -4053,7 +4053,7 @@ defmodule EXLA.Defn.ExprTest do test "raises on bad precision" do valid_precision = - ":default, :high, :highest, or :packed_nibble" + ":default, :high, or :highest" assert_raise ArgumentError, "expected precision configuration to be one of" <> diff --git a/exla/test/exla/device_buffer_test.exs b/exla/test/exla/device_buffer_test.exs index 863ab9c23e..db47039485 100644 --- a/exla/test/exla/device_buffer_test.exs +++ b/exla/test/exla/device_buffer_test.exs @@ -1,13 +1,13 @@ defmodule EXLA.DeviceBufferTest do use ExUnit.Case, async: true - alias EXLA.{DeviceBuffer, Shape} + alias EXLA.{DeviceBuffer, Typespec} import EXLAHelpers describe "buffer" do test "place_on_device/4" do - b1 = DeviceBuffer.place_on_device(<<1::32>>, Shape.make_shape({:s, 32}, {}), client(), 0) + b1 = DeviceBuffer.place_on_device(<<1::32>>, Typespec.tensor({:s, 32}, {}), client(), 0) assert is_reference(b1.ref) end @@ -15,7 +15,7 @@ defmodule EXLA.DeviceBufferTest do b1 = DeviceBuffer.place_on_device( <<1::32, 2::32, 3::32, 4::32>>, - Shape.make_shape({:s, 32}, {4}), + Typespec.tensor({:s, 32}, {4}), client(), 0 ) @@ -35,7 +35,7 @@ defmodule EXLA.DeviceBufferTest do end test "deallocate/1" do - b1 = DeviceBuffer.place_on_device(<<1::32>>, Shape.make_shape({:s, 32}, {}), client(), 0) + b1 = DeviceBuffer.place_on_device(<<1::32>>, Typespec.tensor({:s, 32}, {}), client(), 0) assert :ok = DeviceBuffer.deallocate(b1) assert :already_deallocated = DeviceBuffer.deallocate(b1) diff --git a/exla/test/exla/executable_test.exs b/exla/test/exla/executable_test.exs index c74d5c14e7..28e276edfc 100644 --- a/exla/test/exla/executable_test.exs +++ b/exla/test/exla/executable_test.exs @@ -4,27 +4,27 @@ defmodule EXLA.ExecutableTest do alias EXLA.BinaryBuffer alias EXLA.DeviceBuffer alias EXLA.Executable - alias EXLA.Shape + alias EXLA.Typespec alias EXLA.MLIR.Value import EXLAHelpers describe "run" do test "with no inputs and default options" do assert [a = %DeviceBuffer{}] = - run_one([], [], Shape.make_shape({:s, 32}, {}), fn b -> - [Value.constant_r0(b, 1, {:s, 32})] + run_one([], [], Typespec.tensor({:s, 32}, {}), fn b -> + [Value.constant(b, [1], s32_typespec())] end) assert <<1::32-native>> == DeviceBuffer.read(a) end test "with 2 inputs and default options" do - t1 = BinaryBuffer.from_binary(<<1::32-native>>, Shape.make_shape({:s, 32}, {})) - t2 = BinaryBuffer.from_binary(<<1::32-native>>, Shape.make_shape({:s, 32}, {})) + t1 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {})) + t2 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {})) assert [a = %DeviceBuffer{}] = - run_one([t1, t2], [], [t1.shape], fn b, x, y -> - [Value.add(b, x, y)] + run_one([t1, t2], [], [t1.typespec], fn _b, x, y -> + [Value.add(x, y, s32_typespec())] end) assert <<2::32-native>> == DeviceBuffer.read(a) @@ -34,7 +34,7 @@ defmodule EXLA.ExecutableTest do t1 = DeviceBuffer.place_on_device( <<1::32-native>>, - Shape.make_shape({:s, 32}, {}), + Typespec.tensor({:s, 32}, {}), client(), 0 ) @@ -42,19 +42,19 @@ defmodule EXLA.ExecutableTest do t2 = DeviceBuffer.place_on_device( <<1::32-native>>, - Shape.make_shape({:s, 32}, {}), + Typespec.tensor({:s, 32}, {}), client(), 0 ) assert [%DeviceBuffer{}] = - run_one([t1, t2], [], t1.shape, fn b, x, y -> - [Value.add(b, x, y)] + run_one([t1, t2], [], t1.typespec, fn _b, x, y -> + [Value.add(x, y, s32_typespec())] end) assert [%DeviceBuffer{}] = - run_one([t1, t2], [], [t1.shape], fn b, x, y -> - [Value.add(b, x, y)] + run_one([t1, t2], [], [t1.typespec], fn _b, x, y -> + [Value.add(x, y, s32_typespec())] end) assert DeviceBuffer.read(t1) == <<1::32-native>> @@ -62,12 +62,12 @@ defmodule EXLA.ExecutableTest do end test "with data from a previous run" do - t1 = BinaryBuffer.from_binary(<<1::32-native>>, Shape.make_shape({:s, 32}, {})) - t2 = BinaryBuffer.from_binary(<<1::32-native>>, Shape.make_shape({:s, 32}, {})) + t1 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {})) + t2 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {})) exec = - compile([t1.shape, t2.shape], [], [t1.shape], fn b, x, y -> - [Value.add(b, x, y)] + compile([t1.typespec, t2.typespec], [], [t1.typespec], fn _b, x, y -> + [Value.add(x, y, s32_typespec())] end) assert [[t3 = %DeviceBuffer{}]] = Executable.run(exec, [[t1, t2]]) @@ -80,27 +80,27 @@ defmodule EXLA.ExecutableTest do t1 = DeviceBuffer.place_on_device( <<1::32-native>>, - Shape.make_shape({:s, 32}, {}), + Typespec.tensor({:s, 32}, {}), client(), 0 ) - t2 = BinaryBuffer.from_binary(<<2::32-native>>, Shape.make_shape({:s, 32}, {})) + t2 = BinaryBuffer.from_binary(<<2::32-native>>, Typespec.tensor({:s, 32}, {})) assert [a = %DeviceBuffer{}] = - run_one([t1, t2], [], [t1.shape], fn b, x, y -> - [Value.add(b, x, y)] + run_one([t1, t2], [], [t1.typespec], fn _b, x, y -> + [Value.add(x, y, s32_typespec())] end) assert <<3::32-native>> == DeviceBuffer.read(a) end test "with tuple return" do - t1 = BinaryBuffer.from_binary(<<1::32-native>>, Shape.make_shape({:s, 32}, {})) - t2 = BinaryBuffer.from_binary(<<2::32-native>>, Shape.make_shape({:s, 32}, {})) + t1 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {})) + t2 = BinaryBuffer.from_binary(<<2::32-native>>, Typespec.tensor({:s, 32}, {})) assert [a = %DeviceBuffer{}, b = %DeviceBuffer{}] = - run_one([t1, t2], [], [t1.shape, t2.shape], fn _b, x, y -> + run_one([t1, t2], [], [t1.typespec, t2.typespec], fn _b, x, y -> [x, y] end) @@ -110,16 +110,16 @@ defmodule EXLA.ExecutableTest do @tag :multi_device test "runs on a specific device" do - t1 = BinaryBuffer.from_binary(<<1::32-native>>, Shape.make_shape({:s, 32}, {})) - t2 = BinaryBuffer.from_binary(<<2::32-native>>, Shape.make_shape({:s, 32}, {})) + t1 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {})) + t2 = BinaryBuffer.from_binary(<<2::32-native>>, Typespec.tensor({:s, 32}, {})) assert [a = %DeviceBuffer{}, b = %DeviceBuffer{}, c = %DeviceBuffer{}] = run_one( [t1, t2], [device_id: 1], - [t1.shape, t2.shape, t1.shape], - fn b, x, y -> - [x, y, Value.add(b, x, y)] + [t1.typespec, t2.typespec, t1.typespec], + fn _b, x, y -> + [x, y, Value.add(x, y, s32_typespec())] end ) @@ -131,12 +131,14 @@ defmodule EXLA.ExecutableTest do assert c.device_id == 1 assert_raise RuntimeError, ~r"Expected buffer to be placed on device 0", fn -> - run_one([a, b], [device_id: 0], t1.shape, fn b, x, y -> - [Value.add(b, x, y)] + run_one([a, b], [device_id: 0], t1.typespec, fn _b, x, y -> + [Value.add(x, y, s32_typespec())] end) end end end + + defp s32_typespec(), do: Typespec.tensor({:s, 32}, {}) end defmodule EXLA.ExecutableFeedTest do @@ -147,83 +149,85 @@ defmodule EXLA.ExecutableFeedTest do alias EXLA.BinaryBuffer alias EXLA.DeviceBuffer alias EXLA.Client - alias EXLA.Shape + alias EXLA.Typespec alias EXLA.MLIR.Function alias EXLA.MLIR.Value import EXLAHelpers describe "infeed/outfeed" do test "successfully sends to/from device asynchronously" do - t = BinaryBuffer.from_binary(<<1::32-native>>, Shape.make_shape({:s, 32}, {})) + t = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {})) assert res = Task.async(fn -> - run_one([], [], [Shape.make_token_shape()], fn b -> + run_one([], [], [Typespec.token()], fn b -> token = Value.create_token(b) - {new_token, [val]} = Value.infeed(token, t.shape) + {new_token, [val]} = Value.infeed(token, [t.typespec]) - outfeed_val = Value.add(b, val, val) + outfeed_val = Value.add(val, val, s32_typespec()) _outfeed_token = Value.outfeed(outfeed_val, new_token) - [Value.add(b, outfeed_val, val)] + [Value.add(outfeed_val, val, s32_typespec())] end) end) - assert :ok = Client.to_infeed(client(), 0, [{t.data, t.shape}]) - assert from_outfeed(client(), 0, Shape.make_shape({:s, 32}, {})) == <<2::32-native>> + assert :ok = Client.to_infeed(client(), 0, [{t.data, t.typespec}]) + assert from_outfeed(client(), 0, Typespec.tensor({:s, 32}, {})) == <<2::32-native>> assert [a = %DeviceBuffer{}] = Task.await(res) assert DeviceBuffer.read(a) == <<3::32-native>> end test "successfully sends to/from device asynchronously in a loop" do - t = BinaryBuffer.from_binary(<<1::32-native>>, Shape.make_shape({:s, 32}, {})) + t = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {})) - token_shape = Shape.make_token_shape() + token_shape = Typespec.token() assert res = Task.async(fn -> - run_one([], [], [token_shape, t.shape], fn b -> + run_one([], [], [token_shape, t.typespec], fn b -> token = Value.create_token(b) - {token, [val]} = Value.infeed(token, t.shape) - - {[_token, result], condition_region, body_region} = - Value.while(b, [token, val]) + arg_shapes = [token_shape, t.typespec] - [_token, val] = Function.push_region(b, condition_region) - zero = Value.constant_r0(b, 0, {:s, 32}) - Value.variadic_return(b, [Value.not_equal(b, val, zero)]) + {condition_region, [_token, val]} = Function.push_region(b, arg_shapes) + zero = Value.constant(b, [0], s32_typespec()) + Value.return(b, [Value.not_equal(val, zero, Typespec.tensor({:u, 8}, {}))]) Function.pop_region(b) - [body_token, val] = Function.push_region(b, body_region) + {body_region, [body_token, val]} = Function.push_region(b, arg_shapes) - body_token = Value.outfeed(Value.add(b, val, val), body_token) - {body_token, [input]} = Value.infeed(body_token, t.shape) + body_token = Value.outfeed(Value.add(val, val, s32_typespec()), body_token) + {body_token, [input]} = Value.infeed(body_token, [t.typespec]) - Value.variadic_return(b, [body_token, input]) + Value.return(b, [body_token, input]) Function.pop_region(b) + {token, [val]} = Value.infeed(token, [t.typespec]) + [_token, result] = Value.while(b, condition_region, body_region, [token, val]) + [result] end) end) - assert :ok = Client.to_infeed(client(), 0, [{<<1::32-native>>, t.shape}]) - assert from_outfeed(client(), 0, Shape.make_shape({:s, 32}, {})) == <<2::32-native>> + assert :ok = Client.to_infeed(client(), 0, [{<<1::32-native>>, t.typespec}]) + assert from_outfeed(client(), 0, Typespec.tensor({:s, 32}, {})) == <<2::32-native>> - assert :ok = Client.to_infeed(client(), 0, [{<<2::32-native>>, t.shape}]) - assert from_outfeed(client(), 0, Shape.make_shape({:s, 32}, {})) == <<4::32-native>> + assert :ok = Client.to_infeed(client(), 0, [{<<2::32-native>>, t.typespec}]) + assert from_outfeed(client(), 0, Typespec.tensor({:s, 32}, {})) == <<4::32-native>> - assert :ok = Client.to_infeed(client(), 0, [{<<0::32-native>>, t.shape}]) + assert :ok = Client.to_infeed(client(), 0, [{<<0::32-native>>, t.typespec}]) assert [a = %DeviceBuffer{}] = Task.await(res) assert DeviceBuffer.read(a) == <<0::32-native>> end end - defp from_outfeed(client, device_id, shape) do + defp s32_typespec(), do: Typespec.tensor({:s, 32}, {}) + + defp from_outfeed(client, device_id, typespec) do ref = make_ref() - Client.from_outfeed(client, device_id, [shape], self(), ref) + Client.from_outfeed(client, device_id, [typespec], self(), ref) receive do {^ref, msg} -> msg diff --git a/exla/test/exla/shape_test.exs b/exla/test/exla/shape_test.exs deleted file mode 100644 index 4f20c5b95f..0000000000 --- a/exla/test/exla/shape_test.exs +++ /dev/null @@ -1,19 +0,0 @@ -defmodule EXLA.ShapeTest do - use ExUnit.Case, async: true - - alias EXLA.Shape - - describe "make_shape/2" do - test "creates shape" do - shape = Shape.make_shape({:s, 32}, {1, 1}) - assert %Shape{dtype: {:s, 32}, dims: {1, 1}, ref: _} = shape - assert Shape.byte_size(shape) == 4 - end - - test "creates bf16 shape" do - shape = Shape.make_shape({:bf, 16}, {}) - assert %Shape{dtype: {:bf, 16}, dims: {}, ref: _} = shape - assert Shape.byte_size(shape) == 2 - end - end -end diff --git a/exla/test/support/exla_helpers.ex b/exla/test/support/exla_helpers.ex index 68ad12cdc4..db7689d144 100644 --- a/exla/test/support/exla_helpers.ex +++ b/exla/test/support/exla_helpers.ex @@ -9,26 +9,24 @@ defmodule EXLAHelpers do It expects a list of shapes which will be given as parameters. """ - def compile(shapes, opts \\ [], output \\ nil, fun) do + def compile(typespecs, opts \\ [], output \\ nil, fun) do compile_fn = fn builder -> params = EXLA.MLIR.Function.get_arguments(builder) fun |> apply([builder | params]) - |> then(&EXLA.MLIR.Value.variadic_return(builder, List.wrap(&1))) + |> then(&EXLA.MLIR.Value.return(builder, List.wrap(&1))) EXLA.MLIR.Module.compile( builder.module, client(), - Enum.map(params, &EXLA.MLIR.Value.get_shape/1), - builder.return_shape, + Enum.map(params, &EXLA.MLIR.Value.get_typespec/1), + builder.return_typespecs, opts ) end - shapes = exla_shape(shapes) - output = exla_shape(output) - EXLA.MLIR.Module.new(List.wrap(shapes), List.wrap(output), compile_fn) + EXLA.MLIR.Module.new(List.wrap(typespecs), List.wrap(output), compile_fn) end @doc """ @@ -38,28 +36,8 @@ defmodule EXLAHelpers do used for compilation and then given on execution. """ def run_one(args, opts \\ [], output \\ nil, fun) do - exec = compile(Enum.map(args, & &1.shape), opts, output, fun) + exec = compile(Enum.map(args, & &1.typespec), opts, output, fun) [result] = EXLA.Executable.run(exec, [args], opts) result end - - defp exla_shape(tensors) when is_list(tensors) do - Enum.flat_map(tensors, &exla_shape/1) - end - - defp exla_shape(%{type: :token}) do - [EXLA.Shape.make_token_shape()] - end - - defp exla_shape(%{shape: shape, type: type}) do - [EXLA.Shape.make_shape(type, shape)] - end - - defp exla_shape(%EXLA.Shape{} = shape) do - [shape] - end - - defp exla_shape(%EXLA.MLIR.Value{} = value) do - [EXLA.MLIR.Value.get_shape(value)] - end end diff --git a/torchx/c_src/nx_nif_utils.hpp b/torchx/c_src/nx_nif_utils.hpp index 4f92734ced..79b2d947a8 100644 --- a/torchx/c_src/nx_nif_utils.hpp +++ b/torchx/c_src/nx_nif_utils.hpp @@ -301,4 +301,4 @@ namespace nx return 1; } } -} \ No newline at end of file +}