Skip to content

Commit

Permalink
Merge branch 'main' into bp-make-take-along-axis-optional
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjamin-Philip committed May 11, 2024
2 parents 7c3ec92 + fba9efe commit 5a56cce
Show file tree
Hide file tree
Showing 91 changed files with 4,482 additions and 11,404 deletions.
9 changes: 2 additions & 7 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,17 @@ on:
pull_request:
jobs:
main:
name: Linux (${{ matrix.working_directory }}, ${{ matrix.elixir }}, ${{ matrix.otp }}${{ (matrix.use_mlir && ', mlir' || '') }})
name: Linux (${{ matrix.working_directory }}, ${{ matrix.elixir }}, ${{ matrix.otp }})
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
working_directory: ["nx", "exla", "torchx"]
elixir: ["1.14.5", "1.15.4"]
otp: ["25.3"]
use_mlir: [false]
include:
- elixir: "1.15.4"
lint: true
- elixir: "1.15.4"
working_directory: "exla"
otp: "25.3"
use_mlir: true
defaults:
run:
working-directory: ${{ matrix.working_directory }}
Expand Down Expand Up @@ -53,7 +48,7 @@ jobs:
- name: Run epmd for distributed tests
run: epmd -daemon
- name: Run tests
run: ${{(matrix.use_mlir && 'EXLA_COMPILER_MODE=mlir') || ''}} mix test
run: mix test

win:
name: Windows (${{ matrix.working_directory }}, ${{ matrix.elixir }}, ${{ matrix.otp }})
Expand Down
11 changes: 11 additions & 0 deletions exla/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
# Changelog

## v0.7.1 (2024-02-27)

* Add CustomCallOp for QR decomposition
* Minor improvements to the MLIR modules generated
* MLIR Context pooling for better concurrency

## v0.7.0 (2024-02-22)

* Update to latest Nx
* Introduce a `:mlir` based compiler and use it by default. The previous `:xla` based compiler is deprecatead. You can temporarily revert to the previous compiler by setting `config :exla, :compiler_mode, :xla`

## v0.6.4 (2023-11-13)

* Update to latest Nx
Expand Down
37 changes: 33 additions & 4 deletions exla/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ XLA_INCLUDE_PATH = $(XLA_EXTENSION_DIR)/include

# Cache configuration
EXLA_CACHE_SO = cache/libexla.so
EXLA_CACHE_OBJ_DIR = cache/objs

# Private configuration
EXLA_DIR = c_src/exla
Expand All @@ -26,15 +27,18 @@ EXLA_CACHE_SO_LINK_PATH = $(CWD_RELATIVE_TO_PRIV_PATH)/$(EXLA_CACHE_SO)
# Note: this is on :xla 0.5.0 -- things can change with later versions
CFLAGS = -fPIC -I$(ERTS_INCLUDE_DIR) -I$(XLA_INCLUDE_PATH) -Wall -Wno-sign-compare \
-Wno-unused-parameter -Wno-missing-field-initializers -Wno-comment \
-shared -std=c++17 -w -DLLVM_VERSION_STRING=
-std=c++17 -w -DLLVM_VERSION_STRING=

NVCCFLAGS = -shared -Xcompiler -fPIC

ifdef DEBUG
CFLAGS += -g
NVCCFLAGS += -g
else
CFLAGS += -O3
endif

LDFLAGS = -L$(XLA_EXTENSION_LIB) -lxla_extension
LDFLAGS = -L$(XLA_EXTENSION_LIB) -lxla_extension -shared

ifeq ($(shell uname -s), Darwin)
LDFLAGS += -flat_namespace -undefined suppress -rpath @loader_path/xla_extension/lib
Expand All @@ -57,8 +61,33 @@ $(EXLA_SO): $(EXLA_CACHE_SO)
ln -sf $(EXLA_CACHE_SO_LINK_PATH) $(EXLA_SO) ; \
fi

$(EXLA_CACHE_SO): $(XLA_EXTENSION_DIR) $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_ops.cc $(EXLA_DIR)/exla_ops.h $(EXLA_DIR)/mlir/ops.cc $(EXLA_DIR)/mlir/ops.h $(EXLA_DIR)/mlir/builder.cc $(EXLA_DIR)/mlir/builder.h $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.cc $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h
$(CXX) $(CFLAGS) $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_nif_util.cc $(EXLA_DIR)/mlir/ops.cc $(EXLA_DIR)/mlir/builder.cc $(EXLA_DIR)/exla_ops.cc $(EXLA_DIR)/exla_client.cc -o $(EXLA_CACHE_SO) $(LDFLAGS)
SOURCES = $(EXLA_DIR)/exla.cc $(EXLA_DIR)/exla_mlir.cc $(EXLA_DIR)/custom_calls.cc $(EXLA_DIR)/exla_client.cc $(EXLA_DIR)/exla_nif_util.cc
HEADERS = $(EXLA_DIR)/exla_mlir.h $(EXLA_DIR)/custom_calls.h $(EXLA_DIR)/exla_client.h $(EXLA_DIR)/exla_nif_util.h $(EXLA_DIR)/exla_log_sink.h
OBJECTS = $(patsubst $(EXLA_DIR)/%.cc,$(EXLA_CACHE_OBJ_DIR)/%.o,$(SOURCES)) $(EXLA_CACHE_OBJ_DIR)/exla_cuda.o


NVCC_RESULT := $(shell which nvcc 2> /dev/null)
NVCC_TEST := $(notdir $(NVCC_RESULT))

ifeq ($(NVCC_TEST),nvcc)
NVCC := nvcc
NVCCFLAGS += -DCUDA_ENABLED
else
NVCC := $(CXX)
NVCCFLAGS = $(CFLAGS)
endif

$(EXLA_CACHE_OBJ_DIR)/exla_cuda.o: $(EXLA_DIR)/exla_cuda.cc $(EXLA_DIR)/exla_cuda.h
@ mkdir -p $(EXLA_CACHE_OBJ_DIR)
$(NVCC) $(NVCCFLAGS) -c $< -o $@

$(EXLA_CACHE_OBJ_DIR)/%.o: $(EXLA_DIR)/%.cc $(HEADERS)
@ mkdir -p $(EXLA_CACHE_OBJ_DIR)
@ mkdir -p $(EXLA_CACHE_OBJ_DIR)/mlir
$(CXX) $(CFLAGS) -c $< -o $@

$(EXLA_CACHE_SO): $(XLA_EXTENSION_DIR) $(OBJECTS)
$(CXX) $(OBJECTS) -o $(EXLA_CACHE_SO) $(LDFLAGS)

clean:
rm -rf cache
2 changes: 2 additions & 0 deletions exla/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ mix deps.get
mix test
```

By default, EXLA passes `["-jN"]` as a Make argument, where `N` is `System.schedulers_online() - 2`, capped at `1`. `config :exla, :make_args, ...` can be used to override this default setting.

In order to run tests on a specific device, use the `EXLA_TARGET` environment variable, which is a dev-only variable for this project (it has no effect when using EXLA as a dependency). For example, `EXLA_TARGET=cuda` or `EXLA_TARGET=rocm`. Make sure to also specify `XLA_TARGET` to fetch or compile a proper version of the XLA binary.

### Building with Docker
Expand Down
111 changes: 111 additions & 0 deletions exla/c_src/exla/custom_calls.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#include "custom_calls.h"
#include "exla_nif_util.h"

#include "xla/service/custom_call_target_registry.h"

#include "Eigen/Dense"
#include "Eigen/QR"

template <typename DataType>
void single_matrix_qr_cpu_custom_call(DataType *q_out, DataType *r_out, DataType *in, int64_t m, int64_t k, int64_t n, bool complete) {
typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> RowMajorMatrix;

Eigen::Map<RowMajorMatrix> input(in, m, n);
Eigen::HouseholderQR<RowMajorMatrix> qr = input.householderQr();

RowMajorMatrix Q, R;
size_t num_bytes_q, num_bytes_r;

if (complete) {
Q = qr.householderQ() * RowMajorMatrix::Identity(m, m);
R = qr.matrixQR();

num_bytes_q = m * m * sizeof(DataType);

for (int64_t i = 0; i < m; ++i) {
for (int64_t j = 0; j < n; ++j) {
r_out[i * n + j] = (j >= i) ? R(i, j) : static_cast<DataType>(0.0);
}
}
} else {
Q = qr.householderQ() * RowMajorMatrix::Identity(m, k);
R = qr.matrixQR().topRows(k);

num_bytes_q = m * k * sizeof(DataType);

for (int64_t i = 0; i < k; ++i) {
for (int64_t j = 0; j < n; ++j) {
r_out[i * n + j] = (j >= i) ? R(i, j) : static_cast<DataType>(0.0);
}
}
}

memcpy(q_out, Q.data(), num_bytes_q);
}

template <typename DataType>
void qr_cpu_custom_call(void *out[], const void *in[]) {
DataType *operand = (DataType *)in[0];

int64_t *dim_sizes = (int64_t *)in[1];
int64_t num_operand_dims = dim_sizes[0];
int64_t num_q_dims = dim_sizes[1];
int64_t num_r_dims = dim_sizes[2];

int64_t *operand_dims_ptr = (int64_t *)in[2];
std::vector<int64_t> operand_dims(operand_dims_ptr, operand_dims_ptr + num_operand_dims);

int64_t *q_dims_ptr = (int64_t *)in[3];
std::vector<int64_t> q_dims(q_dims_ptr, q_dims_ptr + num_q_dims);

int64_t *r_dims_ptr = (int64_t *)in[4];
std::vector<int64_t> r_dims(r_dims_ptr, r_dims_ptr + num_r_dims);

int64_t m = q_dims[q_dims.size() - 2];
int64_t k = q_dims[q_dims.size() - 1];
int64_t n = r_dims[r_dims.size() - 1];
bool complete = r_dims[r_dims.size() - 2] == m;

auto leading_dimensions = std::vector<int64_t>(operand_dims.begin(), operand_dims.end() - 2);

int64_t batch_items = 1;
for (int64_t i = 0; i < leading_dimensions.size(); i++) {
batch_items *= leading_dimensions[i];
}

DataType *q = (DataType *)out[0];
DataType *r = (DataType *)out[1];

int64_t r_stride = r_dims[r_dims.size() - 1] * r_dims[r_dims.size() - 2] * sizeof(DataType);
int64_t q_stride = q_dims[q_dims.size() - 1] * q_dims[q_dims.size() - 2] * sizeof(DataType);
int64_t inner_stride = m * n * sizeof(DataType);

for (int64_t i = 0; i < batch_items; i++) {
single_matrix_qr_cpu_custom_call<DataType>(
(DataType *)out[0] + i * q_stride,
(DataType *)out[1] + i * r_stride,
operand + i * inner_stride * sizeof(DataType),
m, k, n, complete);
}
}

void qr_cpu_custom_call_bf16(void *out[], const void *in[]) {
qr_cpu_custom_call<exla::bfloat16>(out, in);
}

void qr_cpu_custom_call_f16(void *out[], const void *in[]) {
qr_cpu_custom_call<exla::float16>(out, in);
}

void qr_cpu_custom_call_f32(void *out[], const void *in[]) {
qr_cpu_custom_call<float>(out, in);
}

void qr_cpu_custom_call_f64(void *out[], const void *in[]) {
qr_cpu_custom_call<double>(out, in);
}

XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f32", qr_cpu_custom_call_f32);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f64", qr_cpu_custom_call_f64);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f16", qr_cpu_custom_call_f16);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_bf16", qr_cpu_custom_call_bf16);
9 changes: 9 additions & 0 deletions exla/c_src/exla/custom_calls.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#ifndef EXLA_MLIR_CUSTOM_CALLS_H_
#define EXLA_MLIR_CUSTOM_CALLS_H_

void qr_cpu_custom_call_bf16(void *out[], const void *in[]);
void qr_cpu_custom_call_f16(void *out[], const void *in[]);
void qr_cpu_custom_call_f32(void *out[], const void *in[]);
void qr_cpu_custom_call_f64(void *out[], const void *in[]);

#endif
Loading

0 comments on commit 5a56cce

Please sign in to comment.