Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(mlir): gather #1340

Merged
merged 1 commit into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,7 @@ static ErlNifFunc exla_funcs[] = {
{"mlir_sort", 4, mlir_sort},
{"mlir_scatter", 5, mlir_scatter},
{"mlir_select_and_scatter", 8, mlir_select_and_scatter},
{"mlir_gather", 8, mlir_gather},
{"mlir_reshape", 3, mlir_reshape},
{"mlir_reverse", 3, mlir_reverse},
{"mlir_transpose", 3, mlir_transpose},
Expand Down
8 changes: 8 additions & 0 deletions exla/c_src/exla/mlir/builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,14 @@ mlir::Value MLIRFunction::SelectAndScatterOp(
return op.getResult();
}

mlir::Value MLIRFunction::GatherOp(mlir::Value source, mlir::Value indices, std::vector<int64_t> offset_dims, std::vector<int64_t> collapsed_slice_dims, std::vector<int64_t> start_index_map, std::vector<int64_t> slice_sizes, int64_t index_vector_dim) {
auto builder = module_->builder();
builder->setInsertionPointToEnd(&func_->getBody().back());
auto gather_dimension_numbers = mlir::mhlo::GatherDimensionNumbersAttr::get(builder->getContext(), offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim);
auto slice_sizes_attr = Int64ToDenseIntElementsAttr(module_->builder(), slice_sizes);
return builder->create<mlir::mhlo::GatherOp>(builder->getUnknownLoc(), source, indices, gather_dimension_numbers, slice_sizes_attr, false);
}

mlir::Value MLIRFunction::FFTOp(mlir::Value tensor, bool forward_fft, std::vector<int64_t> fft_length) {
auto builder = module_->builder();
builder->setInsertionPointToEnd(&func_->getBody().back());
Expand Down
1 change: 1 addition & 0 deletions exla/c_src/exla/mlir/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class MLIRFunction {
mlir::Value SelectOp(mlir::Value pred, mlir::Value on_true, mlir::Value on_false);
mlir::Value ScatterOp(mlir::Value target, mlir::Value indices, mlir::Value updates, bool add_or_put);
mlir::Value SelectAndScatterOp(mlir::Value target, mlir::Value source, mlir::Value init_value, bool gt_or_lt, std::vector<int64_t> window_dimensions, std::vector<int64_t> window_strides, std::vector<int64_t> padding);
mlir::Value GatherOp(mlir::Value source, mlir::Value indices, std::vector<int64_t> offset_dims, std::vector<int64_t> collapsed_slice_dims, std::vector<int64_t> start_index_map, std::vector<int64_t> slice_sizes, int64_t index_vector_dim);
mlir::Value FFTOp(mlir::Value tensor, bool forward_fft, std::vector<int64_t> fft_length);
mlir::Value ConvOp(mlir::Value tensor, mlir::Value kernel, std::vector<int64_t> window_strides, std::vector<int64_t> padding, std::vector<int64_t> tensor_dilation, std::vector<int64_t> kernel_dilation, xla::ConvolutionDimensionNumbers dimension_numbers, uint64_t feature_group_count, uint64_t batch_group_count, uint64_t precision_config, std::vector<int64_t> output_dims);
mlir::Value CreateTokenOp();
Expand Down
40 changes: 40 additions & 0 deletions exla/c_src/exla/mlir/ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,46 @@ ERL_NIF_TERM mlir_select_and_scatter(ErlNifEnv* env, int argc, const ERL_NIF_TER
return exla::nif::ok(env, exla::nif::make<mlir::Value>(env, res));
}

ERL_NIF_TERM mlir_gather(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 8) {
return exla::nif::error(env, "Bad argument count.");
}

exla::MLIRFunction** function;
mlir::Value *source, *indices;

int64_t index_vector_dim;
std::vector<int64_t> slice_sizes, offset_dims, collapsed_slice_dims, start_index_map;

if (!exla::nif::get<exla::MLIRFunction*>(env, argv[0], function)) {
return exla::nif::error(env, "Unable to get function.");
}
if (!exla::nif::get<mlir::Value>(env, argv[1], source)) {
return exla::nif::error(env, "Unable to get source.");
}
if (!exla::nif::get<mlir::Value>(env, argv[2], indices)) {
return exla::nif::error(env, "Unable to get indices.");
}
if (!exla::nif::get_list(env, argv[3], slice_sizes)) {
return exla::nif::error(env, "Unable to get slice_sizes.");
}
if (!exla::nif::get_list(env, argv[4], offset_dims)) {
return exla::nif::error(env, "Unable to get offset_dims.");
}
if (!exla::nif::get_list(env, argv[5], collapsed_slice_dims)) {
return exla::nif::error(env, "Unable to get collapsed_slice_dims.");
}
if (!exla::nif::get_list(env, argv[6], start_index_map)) {
return exla::nif::error(env, "Unable to get start_index_map.");
}
if (!exla::nif::get(env, argv[7], &index_vector_dim)) {
return exla::nif::error(env, "Unable to get index_vector_dim.");
}

mlir::Value res = (*function)->GatherOp(*source, *indices, offset_dims, collapsed_slice_dims, start_index_map, slice_sizes, index_vector_dim);
return exla::nif::ok(env, exla::nif::make<mlir::Value>(env, res));
}

ERL_NIF_TERM mlir_fft(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 4) {
return exla::nif::error(env, "Bad argument count.");
Expand Down
1 change: 1 addition & 0 deletions exla/c_src/exla/mlir/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ ERL_NIF_TERM mlir_concatenate(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[
ERL_NIF_TERM dump_mlir_module(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]);
ERL_NIF_TERM mlir_scatter(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]);
ERL_NIF_TERM mlir_select_and_scatter(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]);
ERL_NIF_TERM mlir_gather(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]);
ERL_NIF_TERM mlir_fft(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]);
ERL_NIF_TERM mlir_create_token(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]);
ERL_NIF_TERM mlir_triangular_solve(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]);
Expand Down
21 changes: 21 additions & 0 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1490,6 +1490,27 @@ defmodule EXLA.Defn do
)
end

defp to_operator(:gather, [%Value{} = tensor, indices], _ans, _state) do
tensor_rank = tensor |> op_shape() |> tuple_size()
indices_rank = indices |> op_shape() |> tuple_size()

index_vector_dim = indices_rank - 1
slice_sizes = List.duplicate(1, tensor_rank)
offset_dims = []
collapsed_slice_dims = axes_for_rank(tensor_rank)
start_index_map = axes_for_rank(tensor_rank)

Value.gather(
tensor,
indices,
slice_sizes,
offset_dims,
collapsed_slice_dims,
start_index_map,
index_vector_dim
)
end

defp to_operator(:gather, [tensor, indices], _ans, _state) do
tensor_rank = tensor |> op_shape() |> tuple_size()
indices_rank = indices |> op_shape() |> tuple_size()
Expand Down
25 changes: 25 additions & 0 deletions exla/lib/exla/mlir/value.ex
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,31 @@ defmodule EXLA.MLIR.Value do
%Value{target | ref: ref}
end

def gather(
%Value{function: func} = source,
%Value{function: func} = indices,
slice_sizes,
offset_dims,
collapsed_slice_dims,
start_index_map,
index_vector_dim
) do
ref =
EXLA.NIF.mlir_gather(
func.ref,
source.ref,
indices.ref,
slice_sizes,
offset_dims,
collapsed_slice_dims,
start_index_map,
index_vector_dim
)
|> unwrap!()

%Value{source | ref: ref}
end

defp get_precision_config_int(precision_config) do
case precision_config do
:default ->
Expand Down
12 changes: 12 additions & 0 deletions exla/lib/exla/nif.ex
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,18 @@ defmodule EXLA.NIF do
),
do: :erlang.nif_error(:undef)

def mlir_gather(
_function,
_sorce,
_indices,
_slice_sizes,
_offset_dims,
_collapsed_slice_dims,
_start_index_map,
_index_vector_dim
),
do: :erlang.nif_error(:undef)

def mlir_fft(_function, _tensor, _forward_fft, _fft_lenght), do: :erlang.nif_error(:undef)

def mlir_convolution(
Expand Down
16 changes: 16 additions & 0 deletions exla/test/exla/mlir/executable_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -897,4 +897,20 @@ defmodule EXLA.MLIR.ExecutableTest do
)
end
end

describe "gather" do
test "works" do
t = Nx.tensor([[1, 2], [3, 4]])
idx = Nx.tensor([[[1, 1], [0, 0]], [[1, 0], [0, 1]]])
result = EXLA.jit(fn t, idx -> Nx.gather(t, idx) end, compiler_mode: :mlir).(t, idx)

assert_equal(
result,
Nx.tensor([
[4, 1],
[3, 2]
])
)
end
end
end