Skip to content

Commit

Permalink
feat(exla): add custom callback for eigh
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente committed Jul 4, 2024
1 parent b5d4c4c commit 9e7e6de
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 3 deletions.
88 changes: 85 additions & 3 deletions exla/c_src/exla/custom_calls.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,36 @@
#include "custom_calls.h"
#include "exla_nif_util.h"

#include "xla/service/custom_call_target_registry.h"

#include "Eigen/Dense"
#include "Eigen/Eigenvalues"
#include "Eigen/QR"
#include "exla_nif_util.h"
#include "xla/service/custom_call_target_registry.h"

template <typename DataType>
void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out, DataType *eigenvectors_out, DataType *in, int64_t m, int64_t n) {
typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> RowMajorMatrix;

// Map the input matrix
Eigen::Map<RowMajorMatrix> input(in, m, n);

// Compute the Eigenvalue decomposition
Eigen::SelfAdjointEigenSolver<RowMajorMatrix> eigensolver(input);

if (eigensolver.info() != Eigen::Success) {
std::cerr << "Eigenvalue decomposition failed!" << std::endl;
return;
}

// Get the eigenvalues and eigenvectors
Eigen::Matrix<DataType, Eigen::Dynamic, 1> eigenvalues = eigensolver.eigenvalues();
RowMajorMatrix eigenvectors = eigensolver.eigenvectors();

// Copy the eigenvalues to the output
std::memcpy(eigenvalues_out, eigenvalues.data(), m * sizeof(DataType));

// Copy the eigenvectors to the output
std::memcpy(eigenvectors_out, eigenvectors.data(), m * n * sizeof(DataType));
}

template <typename DataType>
void single_matrix_qr_cpu_custom_call(DataType *q_out, DataType *r_out, DataType *in, int64_t m, int64_t k, int64_t n, bool complete) {
Expand Down Expand Up @@ -89,6 +115,50 @@ void qr_cpu_custom_call(void *out[], const void *in[]) {
}
}

template <typename DataType>
void eigh_cpu_custom_call(void *out[], const void *in[]) {
DataType *operand = (DataType *)in[0];

int64_t *dim_sizes = (int64_t *)in[1];
int64_t num_operand_dims = dim_sizes[0];
int64_t num_eigenvalues_dims = dim_sizes[1];
int64_t num_eigenvectors_dims = dim_sizes[2];

int64_t *operand_dims_ptr = (int64_t *)in[2];
std::vector<int64_t> operand_dims(operand_dims_ptr, operand_dims_ptr + num_operand_dims);

int64_t *eigenvalues_dims_ptr = (int64_t *)in[3];
std::vector<int64_t> eigenvalues_dims(eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims);

int64_t *eigenvectors_dims_ptr = (int64_t *)in[4];
std::vector<int64_t> eigenvectors_dims(eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims);

int64_t m = eigenvectors_dims[eigenvectors_dims.size() - 2];
int64_t n = eigenvectors_dims[eigenvectors_dims.size() - 1];

auto leading_dimensions = std::vector<int64_t>(operand_dims.begin(), operand_dims.end() - 2);

int64_t batch_items = 1;
for (int64_t i = 0; i < leading_dimensions.size(); i++) {
batch_items *= leading_dimensions[i];
}

DataType *eigenvalues = (DataType *)out[0];
DataType *eigenvectors = (DataType *)out[1];

int64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1] * sizeof(DataType);
int64_t eigenvectors_stride = eigenvectors_dims[eigenvectors_dims.size() - 1] * eigenvectors_dims[eigenvectors_dims.size() - 2] * sizeof(DataType);
int64_t inner_stride = m * n * sizeof(DataType);

for (int64_t i = 0; i < batch_items; i++) {
single_matrix_eigh_cpu_custom_call<DataType>(
eigenvalues + i * eigenvalues_stride,
eigenvectors + i * eigenvectors_stride,
operand + i * inner_stride / sizeof(DataType),
m, n);
}
}

void qr_cpu_custom_call_bf16(void *out[], const void *in[]) {
qr_cpu_custom_call<exla::bfloat16>(out, in);
}
Expand All @@ -105,7 +175,19 @@ void qr_cpu_custom_call_f64(void *out[], const void *in[]) {
qr_cpu_custom_call<double>(out, in);
}

void eigh_cpu_custom_call_f32(void *out[], const void *in[]) {
eigh_cpu_custom_call<float>(out, in);
}

void eigh_cpu_custom_call_f64(void *out[], const void *in[]) {
eigh_cpu_custom_call<double>(out, in);
}

XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f32", qr_cpu_custom_call_f32);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f64", qr_cpu_custom_call_f64);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f16", qr_cpu_custom_call_f16);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_bf16", qr_cpu_custom_call_bf16);


XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f32", eigh_cpu_custom_call_f32);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f64", eigh_cpu_custom_call_f64);
3 changes: 3 additions & 0 deletions exla/c_src/exla/custom_calls.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@ void qr_cpu_custom_call_f16(void *out[], const void *in[]);
void qr_cpu_custom_call_f32(void *out[], const void *in[]);
void qr_cpu_custom_call_f64(void *out[], const void *in[]);

void eigh_cpu_custom_call_f32(void *out[], const void *in[]);
void eigh_cpu_custom_call_f64(void *out[], const void *in[]);

#endif
41 changes: 41 additions & 0 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,47 @@ defmodule EXLA.Defn do
{[q, r], cache}
end

defp cached_recur_operator(
:optional,
%T{
data: %Expr{
args: [
%{data: %{op: :eigh, args: [tensor, _opts]}},
{%{type: {type_kind, _}} = eigenvecs_expr, eigenvals_expr},
_callback
]
}
},
%{client: %EXLA.Client{platform: :host}, builder: %Function{}} = state,
cache
) do
dbg({type_kind})
# We match only on platform: :host for MLIR, as we want to support
# eigh-on-cpu as a custom call only in this case
{tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!()

# convert to float and ensure that we're either using f32 or f64
out_type = Nx.Type.merge(Nx.Type.to_floating(eigenvecs_expr.type), {:f, 32})

tensor =
if op_type(tensor) != out_type do
to_type(tensor, out_type)
else
tensor
end

{eigenvecs, eigenvals} =
Value.eigh(
tensor,
expr_to_typespec(%{eigenvecs_expr | type: out_type}),
expr_to_typespec(%{eigenvals_expr | type: out_type})
)

dbg(eigenvecs)

{[to_type(eigenvecs, eigenvecs_expr.type), to_type(eigenvals, eigenvals_expr.type)], cache}
end

defp cached_recur_operator(
:optional,
%T{
Expand Down
58 changes: 58 additions & 0 deletions exla/lib/exla/mlir/value.ex
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,64 @@ defmodule EXLA.MLIR.Value do
op(func, "stablehlo.return", values, [])
end

def eigh(%Value{function: func} = value, eigenvecs_typespec, eigenvals_typespec) do
%{type: op_type, shape: op_shape} = get_typespec(value)
%{type: eigenvecs_type, shape: eigenvecs_shape} = eigenvecs_typespec
%{type: eigenvals_type, shape: eigenvals_shape} = eigenvals_typespec

dim_sizes = [tuple_size(op_shape), tuple_size(eigenvecs_shape), tuple_size(eigenvals_shape)]
operand_dims = Tuple.to_list(op_shape)
eigenvecs_dims = Tuple.to_list(eigenvecs_shape)
eigenvals_dims = Tuple.to_list(eigenvals_shape)

dim_sizes = constant(func, dim_sizes, Typespec.tensor({:s, 64}, {length(dim_sizes)}))
operand_dims = constant(func, operand_dims, Typespec.tensor({:s, 64}, {length(operand_dims)}))

eigenvecs_dims =
constant(func, eigenvecs_dims, Typespec.tensor({:s, 64}, {length(eigenvecs_dims)}))

eigenvals_dims =
constant(func, eigenvals_dims, Typespec.tensor({:s, 64}, {length(eigenvals_dims)}))

operands = [value, dim_sizes, operand_dims, eigenvecs_dims, eigenvals_dims]

eigenvecs_result_type = type_tensor(eigenvecs_type, eigenvecs_shape)
eigenvals_result_type = type_tensor(eigenvals_type, eigenvals_shape)
result_types = [type_tuple([eigenvecs_result_type, eigenvals_result_type])]

call_target_name =
case op_type do
{:bf, 16} ->
"eigh_cpu_custom_call_bf16"

{:f, 16} ->
"eigh_cpu_custom_call_f16"

{:f, 32} ->
"eigh_cpu_custom_call_f32"

{:f, 64} ->
"eigh_cpu_custom_call_f64"

type ->
# Due to matching on EXLA.Defn, we are sure that the device here is always :host
raise "Eigh decomposition not supported on :host device for type #{inspect(type)}"
end

attributes = [
call_target_name: attr_string(call_target_name),
backend_config: attr_string("Host")
]

result =
op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) |> one!()

eigenvecs = get_tuple_element(result, 0, eigenvecs_typespec)
eigenvals = get_tuple_element(result, 1, eigenvals_typespec)

{eigenvecs, eigenvals}
end

def qr(%Value{function: func} = value, q_typespec, r_typespec) do
%{type: op_type, shape: op_shape} = get_typespec(value)
%{type: q_type, shape: q_shape} = q_typespec
Expand Down

0 comments on commit 9e7e6de

Please sign in to comment.