Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: initial groundwork for using the IREE compiler/runtime #1489

Closed
wants to merge 44 commits into from
Closed
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
6cb3fa1
chore: pick makefile
polvalente May 8, 2024
f1830a0
chore: import more changes
polvalente May 8, 2024
5a4522b
wip
polvalente May 8, 2024
5f1e257
feat: export module and pass it to the runtime
polvalente May 8, 2024
0c36ef5
feat: lay scaffolding out for runtime
polvalente May 8, 2024
d261ff6
feat: jit workflow mostly working (sans outputs)
polvalente May 9, 2024
44a0e18
wip
polvalente May 9, 2024
3245113
feat: float32 working
polvalente May 9, 2024
0fca437
feat: more support for other types
polvalente May 10, 2024
bf34d35
feat: make all types work (with 64->32 demotion for ints and floats)
polvalente May 10, 2024
b223c59
test: get tests to run without beam crashes
polvalente May 11, 2024
23a468d
test: skip broken tests
polvalente May 11, 2024
2960496
fix: shape allocation issues
polvalente May 11, 2024
859aab0
fix: input and shape vector handling
polvalente May 11, 2024
0c007c8
test: unskip shape error tests
polvalente May 11, 2024
77b9096
refactor: use single iree instance
polvalente May 13, 2024
f52a6e6
wip
polvalente May 13, 2024
f9324dc
wip
polvalente May 14, 2024
f7e20c9
wip
polvalente May 14, 2024
5f02ed7
feat: invoke not segfaulting (yet?)
polvalente May 15, 2024
20ac627
feat: output retrieval working
polvalente May 16, 2024
2c650b5
feat: instance resource pooling
polvalente May 16, 2024
d619e30
test: skip 64 bit tests
polvalente May 16, 2024
4564756
feat: lazy read from device
polvalente May 16, 2024
6f18f76
fix: use func.return only on the higher level
polvalente May 16, 2024
8859178
chore: format
polvalente May 16, 2024
b76b717
feat: more coverage
polvalente May 19, 2024
c02d839
chore: update iree
polvalente May 21, 2024
39c408d
Merge remote-tracking branch 'origin/main' into pv-feat/iree-compiler
polvalente May 21, 2024
0fef901
feat: update reduce
polvalente May 21, 2024
7e14d03
feat: update all dense i64 to array i64
polvalente May 21, 2024
cc9a601
Merge branch 'pv-feat/update-mlir-notation' into pv-feat/iree-compiler
polvalente May 21, 2024
7d0f6f4
feat: update stablehlo
polvalente May 21, 2024
f198bb9
feat: parameterize device on application setup
polvalente May 21, 2024
9ca3acb
fix: return ui8 from is_nan and is_infinity
polvalente May 21, 2024
7d288ed
chore: remove type mismatch tag
polvalente May 21, 2024
bdd34c4
chore: use metal again
polvalente May 21, 2024
1883636
feat: all green tests (some skipped)
polvalente May 22, 2024
4e05bd9
wip: trace debugging
polvalente May 23, 2024
8930724
wip: enable ios runtime compilation
polvalente May 23, 2024
28b4bbd
Merge remote-tracking branch 'origin/main' into pv-feat/iree-compiler
polvalente May 23, 2024
2f4e879
Merge remote-tracking branch 'origin/main' into pv-feat/iree-compiler
polvalente May 23, 2024
0144cf7
chore: xla token bug
polvalente May 23, 2024
bee7309
wip: split runtime
polvalente May 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 68 additions & 9 deletions exla/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,31 @@ XLA_EXTENSION_DIR = cache/xla_extension
XLA_EXTENSION_LIB = $(XLA_EXTENSION_DIR)/lib
XLA_INCLUDE_PATH = $(XLA_EXTENSION_DIR)/include

IREE_COMPILER_DIR = iree/build/lib
IREE_COMPILER_LIB = cache/$(IREE_COMPILER_DIR)
IREE_COMPILER_INCLUDE_PATH = cache/iree/compiler/bindings/c
IREE_RUNTIME_INCLUDE_PATH = cache/iree/runtime/src

IREE_INSTALL_PREFIX = $(abspath cache/iree/build/)

# Cache configuration
EXLA_CACHE_SO = cache/libexla.so
EXLA_CACHE_OBJ_DIR = cache/objs
EXLA_CACHE_IREE_COMPILER_SO = cache/libireecompiler.so

# Private configuration
EXLA_DIR = c_src/exla
PRIV_DIR = $(MIX_APP_PATH)/priv
EXLA_SO = $(PRIV_DIR)/libexla.so
EXLA_IREE_COMPILER_SO = $(PRIV_DIR)/libireecompiler.so
EXLA_LIB_DIR = $(PRIV_DIR)/xla_extension/lib
EXLA_IREE_LIB_DIR = $(PRIV_DIR)/$(IREE_COMPILER_DIR)

# Link paths
XLA_EXTENSION_LIB_LINK_PATH = ../$(CWD_RELATIVE_TO_PRIV_PATH)/$(XLA_EXTENSION_LIB)
XLA_EXTENSION_LIB_LINK_PATH = ../$(CWD_RELATIVE_TO_PRIV_PATH)/$(XLA_EXTENSION_DIR)/$(XLA_EXTENSION_LIB)
IREE_COMPILER_LIB_LINK_PATH = ../../$(CWD_RELATIVE_TO_PRIV_PATH)/$(IREE_COMPILER_LIB)
EXLA_CACHE_SO_LINK_PATH = $(CWD_RELATIVE_TO_PRIV_PATH)/$(EXLA_CACHE_SO)
EXLA_CACHE_IREE_COMPILER_SO_LINK_PATH = $(CWD_RELATIVE_TO_PRIV_PATH)/$(EXLA_CACHE_IREE_COMPILER_SO)

# Build flags
# c++17 is needed, otherwise xla headers
Expand All @@ -29,6 +41,8 @@ CFLAGS = -fPIC -I$(ERTS_INCLUDE_DIR) -I$(XLA_INCLUDE_PATH) -Wall -Wno-sign-compa
-Wno-unused-parameter -Wno-missing-field-initializers -Wno-comment \
-std=c++17 -w -DLLVM_VERSION_STRING=

IREE_CFLAGS = $(CFLAGS) -I$(IREE_COMPILER_INCLUDE_PATH) -I$(IREE_RUNTIME_INCLUDE_PATH)

NVCCFLAGS = -shared -Xcompiler -fPIC

ifdef DEBUG
Expand All @@ -39,32 +53,41 @@ else
endif

LDFLAGS = -L$(XLA_EXTENSION_LIB) -lxla_extension -shared
IREE_LDFLAGS = $(LDFLAGS) -lIREECompiler

ifeq ($(shell uname -s), Darwin)
LDFLAGS += -flat_namespace -undefined dynamic_lookup -rpath @loader_path/xla_extension/lib
LDFLAGS += -flat_namespace -undefined dynamic_lookup -rpath @loader_path/xla_extension/lib -rpath @loader_path/$(IREE_COMPILER_DIR)
else
# Use a relative RPATH, so at runtime libexla.so looks for libxla_extension.so
# in ./lib regardless of the absolute location. This way priv can be safely
# packed into an Elixir release. Also, we use $$ to escape Makefile variable
# and single quotes to escape shell variable
LDFLAGS += -Wl,-rpath,'$$ORIGIN/xla_extension/lib'
LDFLAGS += -Wl,-rpath,'$$ORIGIN/xla_extension/lib' -Wl,-rpath,'$$ORIGIN/$(IREE_COMPILER_DIR)'
endif

$(EXLA_SO): $(EXLA_CACHE_SO)
$(EXLA_SO): $(EXLA_CACHE_SO) $(EXLA_CACHE_IREE_COMPILER_SO)
@ mkdir -p $(PRIV_DIR)
@ mkdir -p $(PRIV_DIR)/xla_extension
@ mkdir -p $(PRIV_DIR)/iree/build
@ if [ "${MIX_BUILD_EMBEDDED}" = "true" ]; then \
cp -a $(abspath $(XLA_EXTENSION_LIB)) $(EXLA_LIB_DIR) ; \
cp -a $(abspath $(IREE_COMPILER_LIB)) $(EXLA_IREE_LIB_DIR) ; \
cp -a $(abspath $(EXLA_CACHE_SO)) $(EXLA_SO) ; \
cp -a $(abspath $(EXLA_CACHE_IREE_COMPILER_SO)) $(EXLA_IREE_COMPILER_SO) ; \
else \
ln -sf $(XLA_EXTENSION_LIB_LINK_PATH) $(EXLA_LIB_DIR) ; \
ln -sf $(IREE_COMPILER_LIB_LINK_PATH) $(EXLA_IREE_LIB_DIR) ; \
ln -sf $(EXLA_CACHE_SO_LINK_PATH) $(EXLA_SO) ; \
ln -sf $(EXLA_CACHE_IREE_COMPILER_SO_LINK_PATH) $(EXLA_IREE_COMPILER_SO) ; \
fi

SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/custom_calls.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_nif_util.cc
HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h
SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/custom_calls.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_nif_util.cc $(EXLA_DIR)/exla_mlir_nif_util.cc
HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h $(EXLA_DIR)/exla_mlir_nif_util.h
OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(SOURCES)) $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o

IREE_SOURCES = $(EXLA_DIR)/iree/iree.cc $(EXLA_DIR)/iree/compiler.cc $(EXLA_DIR)/iree/runtime.cc
IREE_HEADERS = $(EXLA_DIR)/iree/compiler.h $(EXLA_DIR)/iree/runtime.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_mlir.h
IREE_OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(IREE_SOURCES))

NVCC_RESULT := $(shell which nvcc 2> /dev/null)
NVCC_TEST := $(notdir $(NVCC_RESULT))
Expand All @@ -82,12 +105,48 @@ $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o: $(EXLA_DIR)/exla_cuda.cc $(EXLA_DIR)/exla_cud
$(NVCC) $(NVCCFLAGS) -c $< -o $@

$(EXLA_CACHE_OBJ_DIR)/%.o: $(EXLA_DIR)/%.cc $(HEADERS)
@ mkdir -p $(EXLA_CACHE_OBJ_DIR)
@ mkdir -p $(EXLA_CACHE_OBJ_DIR)/mlir
@ mkdir -p $(EXLA_CACHE_OBJ_DIR)/iree
$(CXX) $(CFLAGS) -c $< -o $@

$(EXLA_CACHE_SO): $(XLA_EXTENSION_DIR) $(OBJECTS)
$(EXLA_CACHE_SO): $(XLA_EXTENSION_DIR) cache/iree $(OBJECTS)
$(CXX) $(OBJECTS) -o $(EXLA_CACHE_SO) $(LDFLAGS)


IREE_CMAKE_BUILD_DIR = cache/objs/exla_iree_cmake

ifdef DEBUG
IREE_CMAKE_CONFIG = RelWithDebInfo
else
IREE_CMAKE_CONFIG = Release
endif

# This is gonna be extracted out to a library
# For now, we're doing it here to get things working
IREE_COMMIT := d4aa8491a755e31d590f00a507e6c3859dfa662d
cache/iree:
@mkdir -p cache
@git clone https://github.com/iree-org/iree cache/iree
@cd cache/iree && git checkout $(IREE_COMMIT)
@cd cache/iree && git submodule update --init --recursive
@mkdir -p cache/iree/build
cmake -G Ninja -B cache/iree/build -DIREE_BUILD_TESTS=OFF -DIREE_BUILD_SAMPLES=OFF -DIREE_ENABLE_ASSERTIONS=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_CXX_FLAGS="-fvisibility=hidden" cache/iree
cmake --build cache/iree/build

$(EXLA_CACHE_IREE_COMPILER_SO): cache/iree
@mkdir -p $(IREE_CMAKE_BUILD_DIR)
@mkdir -p cache/objs/iree_cmake_out
@mkdir -p cache/objs/mlir_cmake_out
@mkdir -p cache/objs/llvm_cmake_out
cmake -S c_src/iree_runtime -B $(IREE_CMAKE_BUILD_DIR) \
-DIREE_COMPILER_INCLUDE_PATH=$(IREE_COMPILER_INCLUDE_PATH) \
-DIREE_COMPILER_DIR=$(IREE_COMPILER_DIR) \
-DXLA_INCLUDE_PATH=$(abspath $(XLA_INCLUDE_PATH)) \
-DIREE_INSTALL_PREFIX=$(IREE_INSTALL_PREFIX) \
-DCACHE_DIR=cache\
-DXLA_EXTENSION_LIB=$(abspath $(XLA_EXTENSION_LIB))\
-DCMAKE_BUILD_TYPE=$(IREE_CMAKE_CONFIG)
cmake --build $(IREE_CMAKE_BUILD_DIR) --config $(IREE_CMAKE_CONFIG) --verbose
cmake --install $(IREE_CMAKE_BUILD_DIR) --config $(IREE_CMAKE_CONFIG) --prefix cache

clean:
rm -rf cache
30 changes: 14 additions & 16 deletions exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
#include <string>

#include "exla_mlir.h"
#include "exla_client.h"
#include "exla_cuda.h"
#include "exla_log_sink.h"
#include "exla_mlir.h"
#include "exla_mlir_nif_util.h"
#include "exla_nif_util.h"

#include "xla/pjrt/pjrt_api.h"
#include "xla/service/platform_util.h"

#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/pjrt/pjrt_api.h"
#include "xla/service/platform_util.h"

// All of these are created with calls to `new` and subsequently
// passed to the VM as pointers-to-pointers so we balance it out
Expand Down Expand Up @@ -202,19 +201,19 @@ ERL_NIF_TERM mlir_create_function(ErlNifEnv* env, int argc, const ERL_NIF_TERM a

auto arg_types = std::vector<mlir::Type>{};

for (auto const & type_string : arg_type_strings) {
for (auto const& type_string : arg_type_strings) {
auto type = (*module)->ParseType(type_string);
if(type == nullptr) {
if (type == nullptr) {
return type_parsing_error(env, type_string);
}
arg_types.push_back(type);
}

auto ret_types = std::vector<mlir::Type>{};

for (auto const & type_string : ret_type_strings) {
for (auto const& type_string : ret_type_strings) {
auto type = (*module)->ParseType(type_string);
if(type == nullptr) {
if (type == nullptr) {
return type_parsing_error(env, type_string);
}
ret_types.push_back(type);
Expand Down Expand Up @@ -281,19 +280,19 @@ ERL_NIF_TERM mlir_op(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {

auto result_types = std::vector<mlir::Type>{};

for (auto const & type_string : result_type_strings) {
for (auto const& type_string : result_type_strings) {
auto type = (*function)->module()->ParseType(type_string);
if(type == nullptr) {
if (type == nullptr) {
return type_parsing_error(env, type_string);
}
result_types.push_back(type);
}

auto attributes = std::vector<std::pair<std::string, mlir::Attribute>>{};

for (auto const & pair : attributes_kwlist) {
for (auto const& pair : attributes_kwlist) {
auto attribute_value = (*function)->module()->ParseAttribute(pair.second);
if(attribute_value == nullptr) {
if (attribute_value == nullptr) {
return attribute_parsing_error(env, pair.second);
}
attributes.push_back(std::pair{pair.first, attribute_value});
Expand All @@ -304,7 +303,6 @@ ERL_NIF_TERM mlir_op(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
return exla::nif::ok(env, exla::nif::make_list<mlir::Value>(env, results));
}


ERL_NIF_TERM mlir_push_region(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 2) {
return exla::nif::error(env, "Bad argument count.");
Expand All @@ -322,9 +320,9 @@ ERL_NIF_TERM mlir_push_region(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[

auto types = std::vector<mlir::Type>{};

for (auto const & type_string : arg_types) {
for (auto const& type_string : arg_types) {
auto type = (*function)->module()->ParseType(type_string);
if(type == nullptr) {
if (type == nullptr) {
return type_parsing_error(env, type_string);
}
types.push_back(type);
Expand Down
65 changes: 65 additions & 0 deletions exla/c_src/exla/exla_mlir_nif_util.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#include "exla_mlir_nif_util.h"

#include "mlir/IR/Builders.h"
#include "stablehlo/dialect/StablehloOps.h"

namespace exla {
namespace nif {

std::string mlir_numeric_type_to_string(mlir::Type type) {
if (type.isSignlessInteger(1)) {
return "pred";
}
if (auto integer_type = type.dyn_cast<mlir::IntegerType>()) {
if (integer_type.isUnsigned()) {
return "u" + std::to_string(integer_type.getWidth());
} else {
return "s" + std::to_string(integer_type.getWidth());
}
}
if (type.isBF16()) {
return "bf16";
}
if (auto float_type = type.dyn_cast<mlir::FloatType>()) {
return "f" + std::to_string(float_type.getWidth());
}
if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
auto element_type = complex_type.getElementType();
return "c" + std::to_string(element_type.cast<mlir::FloatType>().getWidth() * 2);
}

std::cerr << "Unexpected mlir type" << std::endl;
exit(1);
}

ERL_NIF_TERM make_typespec(ErlNifEnv* env, mlir::Type type) {
if (type.isa<mlir::stablehlo::TokenType>()) {
auto type_term = make(env, "token");
auto shape_term = enif_make_tuple(env, 0);

return enif_make_tuple(env, 2, type_term, shape_term);
}

if (type.isa<mlir::RankedTensorType>()) {
auto tensor_type = type.cast<mlir::RankedTensorType>();
auto dims = tensor_type.getShape();
auto element_type = tensor_type.getElementType();

auto dims_array = std::vector<ERL_NIF_TERM>{};
dims_array.reserve(dims.size());

for (auto dim : dims) {
dims_array.push_back(enif_make_int(env, dim));
}

auto type_term = make(env, mlir_numeric_type_to_string(element_type));
auto shape_term = enif_make_tuple_from_array(env, dims_array.data(), dims_array.size());

return enif_make_tuple(env, 2, type_term, shape_term);
}

std::cerr << "Unexpected mlir type" << std::endl;
exit(1);
}
} // namespace nif
} // namespace exla
10 changes: 10 additions & 0 deletions exla/c_src/exla/exla_mlir_nif_util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#pragma once
#include "exla_nif_util.h"
#include "mlir/IR/BuiltinTypes.h"

namespace exla {
namespace nif {
// Extracts information from `GetShape` into a usable term.
ERL_NIF_TERM make_typespec(ErlNifEnv* env, mlir::Type type);
} // namespace nif
} // namespace exla
Loading