Skip to content

Commit

Permalink
Merge branch 'main' into pv-feat/add-to-mlir-module
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente committed Jul 11, 2024
2 parents 9ba0721 + bae1249 commit 59cd78f
Show file tree
Hide file tree
Showing 22 changed files with 440 additions and 189 deletions.
88 changes: 85 additions & 3 deletions exla/c_src/exla/custom_calls.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,36 @@
#include "custom_calls.h"
#include "exla_nif_util.h"

#include "xla/service/custom_call_target_registry.h"

#include "Eigen/Dense"
#include "Eigen/Eigenvalues"
#include "Eigen/QR"
#include "exla_nif_util.h"
#include "xla/service/custom_call_target_registry.h"

template <typename DataType>
void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out, DataType *eigenvectors_out, DataType *in, int64_t m, int64_t n) {
typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> RowMajorMatrix;

// Map the input matrix
Eigen::Map<RowMajorMatrix> input(in, m, n);

// Compute the Eigenvalue decomposition
Eigen::SelfAdjointEigenSolver<RowMajorMatrix> eigensolver(input);

if (eigensolver.info() != Eigen::Success) {
std::cerr << "Eigenvalue decomposition failed!" << std::endl;
return;
}

// Get the eigenvalues and eigenvectors
Eigen::Matrix<DataType, Eigen::Dynamic, 1> eigenvalues = eigensolver.eigenvalues();
RowMajorMatrix eigenvectors = eigensolver.eigenvectors();

// Copy the eigenvalues to the output
std::memcpy(eigenvalues_out, eigenvalues.data(), m * sizeof(DataType));

// Copy the eigenvectors to the output
std::memcpy(eigenvectors_out, eigenvectors.data(), m * n * sizeof(DataType));
}

template <typename DataType>
void single_matrix_qr_cpu_custom_call(DataType *q_out, DataType *r_out, DataType *in, int64_t m, int64_t k, int64_t n, bool complete) {
Expand Down Expand Up @@ -89,6 +115,50 @@ void qr_cpu_custom_call(void *out[], const void *in[]) {
}
}

template <typename DataType>
void eigh_cpu_custom_call(void *out[], const void *in[]) {
DataType *operand = (DataType *)in[0];

int64_t *dim_sizes = (int64_t *)in[1];
int64_t num_operand_dims = dim_sizes[0];
int64_t num_eigenvalues_dims = dim_sizes[1];
int64_t num_eigenvectors_dims = dim_sizes[2];

int64_t *operand_dims_ptr = (int64_t *)in[2];
std::vector<int64_t> operand_dims(operand_dims_ptr, operand_dims_ptr + num_operand_dims);

int64_t *eigenvalues_dims_ptr = (int64_t *)in[3];
std::vector<int64_t> eigenvalues_dims(eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims);

int64_t *eigenvectors_dims_ptr = (int64_t *)in[4];
std::vector<int64_t> eigenvectors_dims(eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims);

int64_t m = eigenvectors_dims[eigenvectors_dims.size() - 2];
int64_t n = eigenvectors_dims[eigenvectors_dims.size() - 1];

auto leading_dimensions = std::vector<int64_t>(operand_dims.begin(), operand_dims.end() - 2);

int64_t batch_items = 1;
for (int64_t i = 0; i < leading_dimensions.size(); i++) {
batch_items *= leading_dimensions[i];
}

DataType *eigenvalues = (DataType *)out[0];
DataType *eigenvectors = (DataType *)out[1];

int64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1] * sizeof(DataType);
int64_t eigenvectors_stride = eigenvectors_dims[eigenvectors_dims.size() - 1] * eigenvectors_dims[eigenvectors_dims.size() - 2] * sizeof(DataType);
int64_t inner_stride = m * n * sizeof(DataType);

for (int64_t i = 0; i < batch_items; i++) {
single_matrix_eigh_cpu_custom_call<DataType>(
eigenvalues + i * eigenvalues_stride,
eigenvectors + i * eigenvectors_stride,
operand + i * inner_stride / sizeof(DataType),
m, n);
}
}

void qr_cpu_custom_call_bf16(void *out[], const void *in[]) {
qr_cpu_custom_call<exla::bfloat16>(out, in);
}
Expand All @@ -105,7 +175,19 @@ void qr_cpu_custom_call_f64(void *out[], const void *in[]) {
qr_cpu_custom_call<double>(out, in);
}

void eigh_cpu_custom_call_f32(void *out[], const void *in[]) {
eigh_cpu_custom_call<float>(out, in);
}

void eigh_cpu_custom_call_f64(void *out[], const void *in[]) {
eigh_cpu_custom_call<double>(out, in);
}

XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f32", qr_cpu_custom_call_f32);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f64", qr_cpu_custom_call_f64);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f16", qr_cpu_custom_call_f16);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_bf16", qr_cpu_custom_call_bf16);


XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f32", eigh_cpu_custom_call_f32);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f64", eigh_cpu_custom_call_f64);
3 changes: 3 additions & 0 deletions exla/c_src/exla/custom_calls.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@ void qr_cpu_custom_call_f16(void *out[], const void *in[]);
void qr_cpu_custom_call_f32(void *out[], const void *in[]);
void qr_cpu_custom_call_f64(void *out[], const void *in[]);

void eigh_cpu_custom_call_f32(void *out[], const void *in[]);
void eigh_cpu_custom_call_f64(void *out[], const void *in[]);

#endif
26 changes: 11 additions & 15 deletions exla/lib/exla/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -234,24 +234,16 @@ defmodule EXLA.Backend do

@impl true
def concatenate(out, tensors, axis) do
out = Nx.to_template(out)

expr_fun = fn tensors ->
Nx.Defn.Expr.concatenate(out, Tuple.to_list(tensors), axis)
end

jit([], expr_fun, tensors, [List.to_tuple(tensors)])
copied = Enum.map(tensors, &Nx.backend_copy(&1, Nx.BinaryBackend))
result = Nx.BinaryBackend.concatenate(out, copied, axis)
Nx.backend_transfer(result, {EXLA.Backend, jit_opts([], tensors)})
end

@impl true
def stack(out, tensors, axis) do
out = Nx.to_template(out)

expr_fun = fn tensors ->
Nx.Defn.Expr.stack(out, Tuple.to_list(tensors), axis)
end

jit([], expr_fun, tensors, [List.to_tuple(tensors)])
copied = Enum.map(tensors, &Nx.backend_copy(&1, Nx.BinaryBackend))
result = Nx.BinaryBackend.stack(out, copied, axis)
Nx.backend_transfer(result, {EXLA.Backend, jit_opts([], tensors)})
end

@impl true
Expand Down Expand Up @@ -390,6 +382,10 @@ defmodule EXLA.Backend do
defp jit(opts, fun, args), do: jit(opts, fun, args, args)

defp jit(opts, fun, tensors, args) do
EXLA.jit_apply(fun, args, [on_conflict: :force] ++ jit_opts(tensors, opts))
end

defp jit_opts(opts, tensors) do
{priority_client, priority_did, backup_client, backup_did} =
for %T{data: %B{buffer: %EXLA.DeviceBuffer{client_name: client_name, device_id: device_id}}} <-
tensors,
Expand Down Expand Up @@ -418,6 +414,6 @@ defmodule EXLA.Backend do
opts[:device_id] || priority_did || backup_did ||
EXLA.Client.fetch!(client).default_device_id

EXLA.jit_apply(fun, args, on_conflict: :force, client: client, device_id: device_id)
[client: client, device_id: device_id]
end
end
56 changes: 52 additions & 4 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ defmodule EXLA.Defn do
output = wrap_tuple_result(acc, acc_typespec)

outfeed = outfeed |> Outfeed.with_token(out_token) |> Outfeed.close(builder)
Value.return(builder, output)
Value.func_return(builder, output)

{{input_typespecs, input_indexes}, outfeed}
end
Expand Down Expand Up @@ -307,7 +307,7 @@ defmodule EXLA.Defn do
{res, cache} = recur_flatten(expr, state, new_cache(outfeed))
outfeed = cache |> get_outfeed() |> Outfeed.close(function)

Value.return(function, res)
Value.func_return(function, res)

{:ok, outfeed}
end
Expand Down Expand Up @@ -433,6 +433,15 @@ defmodule EXLA.Defn do
comp_arg_typespecs =
for {i, typespec} <- inputs_and_typespecs, i >= used_buffers, do: typespec

outputs =
if stream? do
# The computation returns the final accumulator value
{_chunk_result, acc} = outputs
acc
else
outputs
end

out_typespecs =
[outputs]
|> Nx.Defn.Composite.flatten_list()
Expand Down Expand Up @@ -624,6 +633,45 @@ defmodule EXLA.Defn do
{[q, r], cache}
end

defp cached_recur_operator(
:optional,
%T{
data: %Expr{
args: [
%{data: %{op: :eigh, args: [tensor, _opts]}},
{eigenvecs_expr, eigenvals_expr},
_callback
]
}
},
%{client: %EXLA.Client{platform: :host}, builder: %Function{}} = state,
cache
) do
# We match only on platform: :host for MLIR, as we want to support
# eigh-on-cpu as a custom call only in this case
{tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!()

# convert to float and ensure that we're either using f32 or f64, because Eigen
# only supports f32 and f64 easily.
out_type = Nx.Type.merge(Nx.Type.to_floating(eigenvecs_expr.type), {:f, 32})

tensor =
if op_type(tensor) != out_type do
to_type(tensor, out_type)
else
tensor
end

{eigenvecs, eigenvals} =
Value.eigh(
tensor,
expr_to_typespec(%{eigenvecs_expr | type: out_type}),
expr_to_typespec(%{eigenvals_expr | type: out_type})
)

{[to_type(eigenvecs, eigenvecs_expr.type), to_type(eigenvals, eigenvals_expr.type)], cache}
end

defp cached_recur_operator(
:optional,
%T{
Expand Down Expand Up @@ -1669,9 +1717,9 @@ defmodule EXLA.Defn do
{res, comp_cache} = recur_composite(expr, state, reset_token(cache, inner_token))

if outer_token do
Value.return(function, [get_token(comp_cache) | List.flatten(res)])
Value.func_return(function, [get_token(comp_cache) | List.flatten(res)])
else
Value.return(function, List.flatten(res))
Value.func_return(function, List.flatten(res))
end

{function, merge_outfeed(cache, comp_cache)}
Expand Down
32 changes: 25 additions & 7 deletions exla/lib/exla/executable.ex
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ defmodule EXLA.Executable do
end
end

@doc """
Serializes the executable to a binary.
"""
def serialize(%Executable{
ref: executable,
output_typespecs: output_typespecs,
Expand All @@ -36,6 +39,7 @@ defmodule EXLA.Executable do
|> IO.iodata_to_binary()

%{
version: 1,
serialized: serialized_exec,
output_typespecs: output_typespecs,
num_replicas: num_replicas,
Expand All @@ -45,21 +49,35 @@ defmodule EXLA.Executable do
|> :erlang.term_to_binary()
end

@doc """
Deserializes a previous serialized executable.
"""
def deserialize(client, binary) do
case :erlang.binary_to_term(binary) do
%{serialized: serialized_exec} = exec_data ->
%{version: 1, serialized: serialized} = data ->
%{
output_typespecs: output_typespecs,
num_replicas: num_replicas,
num_partitions: num_partitions,
device_id: device_id
} = data

ref =
serialized_exec
serialized
|> then(&EXLA.NIF.deserialize_executable(client.ref, &1))
|> unwrap!()

exec_data
|> Map.put(:ref, ref)
|> Map.put(:client, client)
|> then(&struct(__MODULE__, &1))
%EXLA.Executable{
output_typespecs: output_typespecs,
num_replicas: num_replicas,
num_partitions: num_partitions,
device_id: device_id,
ref: ref,
client: client
}

_other ->
raise "invalid serialized executable"
raise ArgumentError, "invalid serialized executable"
end
end

Expand Down
Loading

0 comments on commit 59cd78f

Please sign in to comment.