From 9e7e6de6ff529e83f1f316f2cec6cf725f90f92a Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Thu, 4 Jul 2024 12:57:05 -0300 Subject: [PATCH] feat(exla): add custom callback for eigh --- exla/c_src/exla/custom_calls.cc | 88 +++++++++++++++++++++++++++++++-- exla/c_src/exla/custom_calls.h | 3 ++ exla/lib/exla/defn.ex | 41 +++++++++++++++ exla/lib/exla/mlir/value.ex | 58 ++++++++++++++++++++++ 4 files changed, 187 insertions(+), 3 deletions(-) diff --git a/exla/c_src/exla/custom_calls.cc b/exla/c_src/exla/custom_calls.cc index eb2ce90d27..36abcec04a 100644 --- a/exla/c_src/exla/custom_calls.cc +++ b/exla/c_src/exla/custom_calls.cc @@ -1,10 +1,36 @@ #include "custom_calls.h" -#include "exla_nif_util.h" - -#include "xla/service/custom_call_target_registry.h" #include "Eigen/Dense" +#include "Eigen/Eigenvalues" #include "Eigen/QR" +#include "exla_nif_util.h" +#include "xla/service/custom_call_target_registry.h" + +template +void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out, DataType *eigenvectors_out, DataType *in, int64_t m, int64_t n) { + typedef Eigen::Matrix RowMajorMatrix; + + // Map the input matrix + Eigen::Map input(in, m, n); + + // Compute the Eigenvalue decomposition + Eigen::SelfAdjointEigenSolver eigensolver(input); + + if (eigensolver.info() != Eigen::Success) { + std::cerr << "Eigenvalue decomposition failed!" << std::endl; + return; + } + + // Get the eigenvalues and eigenvectors + Eigen::Matrix eigenvalues = eigensolver.eigenvalues(); + RowMajorMatrix eigenvectors = eigensolver.eigenvectors(); + + // Copy the eigenvalues to the output + std::memcpy(eigenvalues_out, eigenvalues.data(), m * sizeof(DataType)); + + // Copy the eigenvectors to the output + std::memcpy(eigenvectors_out, eigenvectors.data(), m * n * sizeof(DataType)); +} template void single_matrix_qr_cpu_custom_call(DataType *q_out, DataType *r_out, DataType *in, int64_t m, int64_t k, int64_t n, bool complete) { @@ -89,6 +115,50 @@ void qr_cpu_custom_call(void *out[], const void *in[]) { } } +template +void eigh_cpu_custom_call(void *out[], const void *in[]) { + DataType *operand = (DataType *)in[0]; + + int64_t *dim_sizes = (int64_t *)in[1]; + int64_t num_operand_dims = dim_sizes[0]; + int64_t num_eigenvalues_dims = dim_sizes[1]; + int64_t num_eigenvectors_dims = dim_sizes[2]; + + int64_t *operand_dims_ptr = (int64_t *)in[2]; + std::vector operand_dims(operand_dims_ptr, operand_dims_ptr + num_operand_dims); + + int64_t *eigenvalues_dims_ptr = (int64_t *)in[3]; + std::vector eigenvalues_dims(eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims); + + int64_t *eigenvectors_dims_ptr = (int64_t *)in[4]; + std::vector eigenvectors_dims(eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims); + + int64_t m = eigenvectors_dims[eigenvectors_dims.size() - 2]; + int64_t n = eigenvectors_dims[eigenvectors_dims.size() - 1]; + + auto leading_dimensions = std::vector(operand_dims.begin(), operand_dims.end() - 2); + + int64_t batch_items = 1; + for (int64_t i = 0; i < leading_dimensions.size(); i++) { + batch_items *= leading_dimensions[i]; + } + + DataType *eigenvalues = (DataType *)out[0]; + DataType *eigenvectors = (DataType *)out[1]; + + int64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1] * sizeof(DataType); + int64_t eigenvectors_stride = eigenvectors_dims[eigenvectors_dims.size() - 1] * eigenvectors_dims[eigenvectors_dims.size() - 2] * sizeof(DataType); + int64_t inner_stride = m * n * sizeof(DataType); + + for (int64_t i = 0; i < batch_items; i++) { + single_matrix_eigh_cpu_custom_call( + eigenvalues + i * eigenvalues_stride, + eigenvectors + i * eigenvectors_stride, + operand + i * inner_stride / sizeof(DataType), + m, n); + } +} + void qr_cpu_custom_call_bf16(void *out[], const void *in[]) { qr_cpu_custom_call(out, in); } @@ -105,7 +175,19 @@ void qr_cpu_custom_call_f64(void *out[], const void *in[]) { qr_cpu_custom_call(out, in); } +void eigh_cpu_custom_call_f32(void *out[], const void *in[]) { + eigh_cpu_custom_call(out, in); +} + +void eigh_cpu_custom_call_f64(void *out[], const void *in[]) { + eigh_cpu_custom_call(out, in); +} + XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f32", qr_cpu_custom_call_f32); XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f64", qr_cpu_custom_call_f64); XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f16", qr_cpu_custom_call_f16); XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_bf16", qr_cpu_custom_call_bf16); + + +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f32", eigh_cpu_custom_call_f32); +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f64", eigh_cpu_custom_call_f64); \ No newline at end of file diff --git a/exla/c_src/exla/custom_calls.h b/exla/c_src/exla/custom_calls.h index f00834d411..36cd7d9ac5 100644 --- a/exla/c_src/exla/custom_calls.h +++ b/exla/c_src/exla/custom_calls.h @@ -6,4 +6,7 @@ void qr_cpu_custom_call_f16(void *out[], const void *in[]); void qr_cpu_custom_call_f32(void *out[], const void *in[]); void qr_cpu_custom_call_f64(void *out[], const void *in[]); +void eigh_cpu_custom_call_f32(void *out[], const void *in[]); +void eigh_cpu_custom_call_f64(void *out[], const void *in[]); + #endif diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 482ad540df..b06f139418 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -633,6 +633,47 @@ defmodule EXLA.Defn do {[q, r], cache} end + defp cached_recur_operator( + :optional, + %T{ + data: %Expr{ + args: [ + %{data: %{op: :eigh, args: [tensor, _opts]}}, + {%{type: {type_kind, _}} = eigenvecs_expr, eigenvals_expr}, + _callback + ] + } + }, + %{client: %EXLA.Client{platform: :host}, builder: %Function{}} = state, + cache + ) do + dbg({type_kind}) + # We match only on platform: :host for MLIR, as we want to support + # eigh-on-cpu as a custom call only in this case + {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() + + # convert to float and ensure that we're either using f32 or f64 + out_type = Nx.Type.merge(Nx.Type.to_floating(eigenvecs_expr.type), {:f, 32}) + + tensor = + if op_type(tensor) != out_type do + to_type(tensor, out_type) + else + tensor + end + + {eigenvecs, eigenvals} = + Value.eigh( + tensor, + expr_to_typespec(%{eigenvecs_expr | type: out_type}), + expr_to_typespec(%{eigenvals_expr | type: out_type}) + ) + + dbg(eigenvecs) + + {[to_type(eigenvecs, eigenvecs_expr.type), to_type(eigenvals, eigenvals_expr.type)], cache} + end + defp cached_recur_operator( :optional, %T{ diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 10b9c95f52..63a9aa7626 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -710,6 +710,64 @@ defmodule EXLA.MLIR.Value do op(func, "stablehlo.return", values, []) end + def eigh(%Value{function: func} = value, eigenvecs_typespec, eigenvals_typespec) do + %{type: op_type, shape: op_shape} = get_typespec(value) + %{type: eigenvecs_type, shape: eigenvecs_shape} = eigenvecs_typespec + %{type: eigenvals_type, shape: eigenvals_shape} = eigenvals_typespec + + dim_sizes = [tuple_size(op_shape), tuple_size(eigenvecs_shape), tuple_size(eigenvals_shape)] + operand_dims = Tuple.to_list(op_shape) + eigenvecs_dims = Tuple.to_list(eigenvecs_shape) + eigenvals_dims = Tuple.to_list(eigenvals_shape) + + dim_sizes = constant(func, dim_sizes, Typespec.tensor({:s, 64}, {length(dim_sizes)})) + operand_dims = constant(func, operand_dims, Typespec.tensor({:s, 64}, {length(operand_dims)})) + + eigenvecs_dims = + constant(func, eigenvecs_dims, Typespec.tensor({:s, 64}, {length(eigenvecs_dims)})) + + eigenvals_dims = + constant(func, eigenvals_dims, Typespec.tensor({:s, 64}, {length(eigenvals_dims)})) + + operands = [value, dim_sizes, operand_dims, eigenvecs_dims, eigenvals_dims] + + eigenvecs_result_type = type_tensor(eigenvecs_type, eigenvecs_shape) + eigenvals_result_type = type_tensor(eigenvals_type, eigenvals_shape) + result_types = [type_tuple([eigenvecs_result_type, eigenvals_result_type])] + + call_target_name = + case op_type do + {:bf, 16} -> + "eigh_cpu_custom_call_bf16" + + {:f, 16} -> + "eigh_cpu_custom_call_f16" + + {:f, 32} -> + "eigh_cpu_custom_call_f32" + + {:f, 64} -> + "eigh_cpu_custom_call_f64" + + type -> + # Due to matching on EXLA.Defn, we are sure that the device here is always :host + raise "Eigh decomposition not supported on :host device for type #{inspect(type)}" + end + + attributes = [ + call_target_name: attr_string(call_target_name), + backend_config: attr_string("Host") + ] + + result = + op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) |> one!() + + eigenvecs = get_tuple_element(result, 0, eigenvecs_typespec) + eigenvals = get_tuple_element(result, 1, eigenvals_typespec) + + {eigenvecs, eigenvals} + end + def qr(%Value{function: func} = value, q_typespec, r_typespec) do %{type: op_type, shape: op_shape} = get_typespec(value) %{type: q_type, shape: q_shape} = q_typespec