Skip to content

Commit

Permalink
feat(exla): add LU custom_call (#1549)
Browse files Browse the repository at this point in the history
Co-authored-by: José Valim <[email protected]>
  • Loading branch information
polvalente and josevalim authored Oct 29, 2024
1 parent 9d73de2 commit 7af065e
Show file tree
Hide file tree
Showing 11 changed files with 258 additions and 17 deletions.
10 changes: 9 additions & 1 deletion exla/c_src/exla/custom_calls.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ void qr_cpu_custom_call_f32(void *out[], const void *in[]);
void qr_cpu_custom_call_f64(void *out[], const void *in[]);
void qr_cpu_custom_call_f16(void *out[], const void *in[]);
void qr_cpu_custom_call_bf16(void *out[], const void *in[]);
void lu_cpu_custom_call_f32(void *out[], const void *in[]);
void lu_cpu_custom_call_f64(void *out[], const void *in[]);
void lu_cpu_custom_call_f16(void *out[], const void *in[]);
void lu_cpu_custom_call_bf16(void *out[], const void *in[]);
void eigh_cpu_custom_call_f32(void *out[], const void *in[]);
void eigh_cpu_custom_call_f64(void *out[], const void *in[]);

Expand All @@ -12,4 +16,8 @@ XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f32", qr_cpu_cu
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f16", qr_cpu_custom_call_f16);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_bf16", qr_cpu_custom_call_bf16);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f64", eigh_cpu_custom_call_f64);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f32", eigh_cpu_custom_call_f32);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f32", eigh_cpu_custom_call_f32);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lu_cpu_custom_call_f64", lu_cpu_custom_call_f64);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lu_cpu_custom_call_f32", lu_cpu_custom_call_f32);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lu_cpu_custom_call_f16", lu_cpu_custom_call_f16);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lu_cpu_custom_call_bf16", lu_cpu_custom_call_bf16);
95 changes: 95 additions & 0 deletions exla/c_src/exla/custom_calls/lu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#pragma once

#include "Eigen/LU";

template <typename DataType>
void single_matrix_lu_cpu_custom_call(uint8_t *p_out, DataType *l_out, DataType *u_out, DataType *in, uint64_t n) {
typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> RowMajorMatrix;

Eigen::Map<RowMajorMatrix> input(in, n, n);
Eigen::PartialPivLU<RowMajorMatrix> lu = input.partialPivLu();

// Get the permutation matrix P and convert to indices
Eigen::PermutationMatrix<Eigen::Dynamic, Eigen::Dynamic> P = lu.permutationP();
for (uint64_t i = 0; i < n; i++) {
for (uint64_t j = 0; j < n; j++) {
p_out[i * n + j] = static_cast<uint8_t>(P.indices()[i] == j ? 1 : 0);
}
}

// Get L and U matrices
RowMajorMatrix L = lu.matrixLU().template triangularView<Eigen::UnitLower>();
RowMajorMatrix U = lu.matrixLU().template triangularView<Eigen::Upper>();

// Copy L matrix
for (uint64_t i = 0; i < n; i++) {
for (uint64_t j = 0; j < n; j++) {

if (j < i) {
l_out[i * n + j] = static_cast<DataType>(L(i, j));
} else if (j == i) {
l_out[i * n + j] = static_cast<DataType>(1.0);
} else {
l_out[i * n + j] = static_cast<DataType>(0.0);
}
}
}

// Copy U matrix
for (uint64_t i = 0; i < n; i++) {
for (uint64_t j = 0; j < n; j++) {
if (j >= i) {
u_out[i * n + j] = static_cast<DataType>(U(i, j));
} else {
u_out[i * n + j] = static_cast<DataType>(0.0);
}
}
}
}

template <typename DataType>
void lu_cpu_custom_call(void *out[], const void *in[]) {
DataType *operand = (DataType *)in[0];

uint64_t *dim_sizes = (uint64_t *)in[1];
uint64_t num_operand_dims = dim_sizes[0];
uint64_t num_p_dims = dim_sizes[1];
uint64_t num_l_dims = dim_sizes[2];
uint64_t num_u_dims = dim_sizes[3];

uint64_t *operand_dims_ptr = (uint64_t *)in[2];
std::vector<uint64_t> operand_dims(operand_dims_ptr, operand_dims_ptr + num_operand_dims);

uint64_t *p_dims_ptr = (uint64_t *)in[3];
std::vector<uint64_t> p_dims(p_dims_ptr, p_dims_ptr + num_p_dims);

uint64_t *l_dims_ptr = (uint64_t *)in[4];
std::vector<uint64_t> l_dims(l_dims_ptr, l_dims_ptr + num_l_dims);

uint64_t *u_dims_ptr = (uint64_t *)in[5];
std::vector<uint64_t> u_dims(u_dims_ptr, u_dims_ptr + num_u_dims);

uint64_t n = l_dims[l_dims.size() - 1];

auto leading_dimensions = std::vector<uint64_t>(operand_dims.begin(), operand_dims.end() - 2);

uint64_t batch_items = 1;
for (uint64_t i = 0; i < leading_dimensions.size(); i++) {
batch_items *= leading_dimensions[i];
}

uint8_t *p = (uint8_t *)out[0];
DataType *l = (DataType *)out[1];
DataType *u = (DataType *)out[2];

uint64_t stride = n * n;

for (uint64_t i = 0; i < batch_items; i++) {
single_matrix_lu_cpu_custom_call<DataType>(
p + i * stride,
l + i * stride,
u + i * stride,
operand + i * stride,
n);
}
}
6 changes: 6 additions & 0 deletions exla/c_src/exla/custom_calls/lu_bf16.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#include "lu.h"
#include "../exla_types.h"

void lu_cpu_custom_call_bf16(void *out[], const void *in[]) {
lu_cpu_custom_call<exla::bfloat16>(out, in);
}
6 changes: 6 additions & 0 deletions exla/c_src/exla/custom_calls/lu_f16.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#include "lu.h"
#include "../exla_types.h"

void lu_cpu_custom_call_f16(void *out[], const void *in[]) {
lu_cpu_custom_call<exla::float16>(out, in);
}
5 changes: 5 additions & 0 deletions exla/c_src/exla/custom_calls/lu_f32.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "lu.h"

void lu_cpu_custom_call_f32(void *out[], const void *in[]) {
lu_cpu_custom_call<float>(out, in);
}
5 changes: 5 additions & 0 deletions exla/c_src/exla/custom_calls/lu_f64.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "lu.h"

void lu_cpu_custom_call_f64(void *out[], const void *in[]) {
lu_cpu_custom_call<double>(out, in);
}
8 changes: 4 additions & 4 deletions exla/c_src/exla/custom_calls/qr.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,15 @@ void qr_cpu_custom_call(void *out[], const void *in[]) {
DataType *q = (DataType *)out[0];
DataType *r = (DataType *)out[1];

uint64_t r_stride = r_dims[r_dims.size() - 1] * r_dims[r_dims.size() - 2] * sizeof(DataType);
uint64_t q_stride = q_dims[q_dims.size() - 1] * q_dims[q_dims.size() - 2] * sizeof(DataType);
uint64_t inner_stride = m * n * sizeof(DataType);
uint64_t r_stride = r_dims[r_dims.size() - 1] * r_dims[r_dims.size() - 2];
uint64_t q_stride = q_dims[q_dims.size() - 1] * q_dims[q_dims.size() - 2];
uint64_t inner_stride = m * n;

for (uint64_t i = 0; i < batch_items; i++) {
single_matrix_qr_cpu_custom_call<DataType>(
(DataType *)out[0] + i * q_stride,
(DataType *)out[1] + i * r_stride,
operand + i * inner_stride * sizeof(DataType),
operand + i * inner_stride,
m, k, n, complete);
}
}
41 changes: 37 additions & 4 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,43 @@ defmodule EXLA.Defn do
end
end

defp cached_recur_operator(
:lu,
%T{data: %Expr{args: [{p_expr, l_expr, u_expr}, tensor, _opts]}},
state,
cache
) do
%{type: {p_type_kind, _}} = p_expr
%{type: {out_type_kind, _}} = l_expr

if state.client.platform != :host do
raise ArgumentError, "XLA does not currently support the LU operation on non-host devices"
end

if p_type_kind == :c or out_type_kind == :c do
raise ArgumentError, "XLA does not currently support the LU operation for complex inputs"
end

{tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!()

tensor =
if op_type(tensor) != u_expr.type do
to_type(tensor, u_expr.type)
else
tensor
end

{p, l, u} =
Value.lu(
tensor,
expr_to_typespec(p_expr),
expr_to_typespec(l_expr),
expr_to_typespec(u_expr)
)

{[p, l, u], cache}
end

defp cached_recur_operator(:attach_token, %T{data: %Expr{args: [token, expr]}}, state, cache) do
{op, cache} = recur_operator(expr, state, cache)
{_, cache} = recur_operator(token, state, cache)
Expand Down Expand Up @@ -772,10 +809,6 @@ defmodule EXLA.Defn do
end
end

defp to_operator(:lu, [{_, _, _}, _tensor, _opts], _ans, _state) do
raise ArgumentError, "XLA does not currently support the LU operation"
end

## to_operator element-wise

defp to_operator(:negate, [%Value{} = op], ans, _state),
Expand Down
75 changes: 75 additions & 0 deletions exla/lib/exla/mlir/value.ex
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,81 @@ defmodule EXLA.MLIR.Value do
{q, r}
end

def lu(%Value{function: func} = value, p_typespec, l_typespec, u_typespec) do
%{type: op_type, shape: op_shape} = get_typespec(value)
%{type: _p_type, shape: p_shape} = p_typespec
%{type: l_type, shape: l_shape} = l_typespec
%{type: u_type, shape: u_shape} = u_typespec

dim_sizes = [
tuple_size(op_shape),
tuple_size(p_shape),
tuple_size(l_shape),
tuple_size(u_shape)
]

operand_dims = Tuple.to_list(op_shape)
p_dims = Tuple.to_list(p_shape)
l_dims = Tuple.to_list(l_shape)
u_dims = Tuple.to_list(u_shape)

dim_sizes = constant(func, dim_sizes, Typespec.tensor({:u, 64}, {length(dim_sizes)}))
operand_dims = constant(func, operand_dims, Typespec.tensor({:u, 64}, {length(operand_dims)}))
p_dims = constant(func, p_dims, Typespec.tensor({:u, 64}, {length(p_dims)}))
l_dims = constant(func, l_dims, Typespec.tensor({:u, 64}, {length(l_dims)}))
u_dims = constant(func, u_dims, Typespec.tensor({:u, 64}, {length(u_dims)}))
operands = [value, dim_sizes, operand_dims, p_dims, l_dims, u_dims]

# Force P to always b u8 to avoid requiring too many template instances during custom_call registration
p_result_type = type_tensor({:u, 8}, p_shape)
l_result_type = type_tensor(l_type, l_shape)
u_result_type = type_tensor(u_type, u_shape)
result_types = [type_tuple([p_result_type, l_result_type, u_result_type])]

call_target_name =
case op_type do
{:f, 32} ->
"lu_cpu_custom_call_f32"

{:f, 64} ->
"lu_cpu_custom_call_f64"

{:f, 16} ->
"lu_cpu_custom_call_f16"

{:bf, 16} ->
"lu_cpu_custom_call_bf16"

type ->
# Due to matching on EXLA.Defn, we are sure that the device here is always :host
raise "LU decomposition not supported on :host device for type #{inspect(type)}"
end

attributes = [
call_target_name: attr_string(call_target_name),
backend_config: attr_string("Host")
]

result =
op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) |> one!()

# This is not the best approach, but the alternative would require many more template instances
u8_typespec = Typespec.to_type(p_typespec, {:u, 8})
p = get_tuple_element(result, 0, u8_typespec)

p =
if u8_typespec != p_typespec do
convert(p, p_typespec)
else
p
end

l = get_tuple_element(result, 1, l_typespec)
u = get_tuple_element(result, 2, u_typespec)

{p, l, u}
end

def get_tuple_element(%Value{function: func} = operand, index, typespec) do
result_types = typespecs_to_mlir_types([typespec])
attributes = [index: attr_i32(index)]
Expand Down
2 changes: 1 addition & 1 deletion exla/mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ defmodule EXLA.MixProject do
File.rm_rf!("cache/#{@version}/libexla.so")

Mix.shell().info("Removing libexla.so cache at #{cached_so}")
File.rm!(cached_so)
File.rm_rf!(cached_so)
end

if cached? do
Expand Down
22 changes: 15 additions & 7 deletions exla/test/exla/nx_linalg_doctest_test.exs
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
defmodule EXLA.MLIR.NxLinAlgDoctestTest do
use EXLA.Case, async: true

@invalid_type_error_doctests [svd: 2, pinv: 2, matrix_rank: 2]
@invalid_type_error_doctests [
svd: 2,
pinv: 2
]

@function_clause_error_doctests [
norm: 2,
lu: 2,
solve: 2,
solve: 2
]

@rounding_error_doctests [
triangular_solve: 3,
eigh: 2,
cholesky: 1,
least_squares: 3,
determinant: 1,
invert: 1,
matrix_power: 2
matrix_power: 2,
lu: 2
]
@rounding_error_doctests [triangular_solve: 3, eigh: 2, cholesky: 1, least_squares: 3]

@excluded_doctests @function_clause_error_doctests ++
@rounding_error_doctests ++
Expand Down

0 comments on commit 7af065e

Please sign in to comment.