From 41ae49fd695752b139e060139f733ca3f69132a7 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Fri, 13 Oct 2023 01:05:25 -0300 Subject: [PATCH] feat(mlir): gather --- exla/c_src/exla/exla.cc | 1 + exla/c_src/exla/mlir/builder.cc | 8 +++++ exla/c_src/exla/mlir/builder.h | 1 + exla/c_src/exla/mlir/ops.cc | 40 +++++++++++++++++++++++++ exla/c_src/exla/mlir/ops.h | 1 + exla/lib/exla/defn.ex | 21 +++++++++++++ exla/lib/exla/mlir/value.ex | 25 ++++++++++++++++ exla/lib/exla/nif.ex | 12 ++++++++ exla/test/exla/mlir/executable_test.exs | 16 ++++++++++ 9 files changed, 125 insertions(+) diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index 9a8d597cad..d4d0abfe3c 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -680,6 +680,7 @@ static ErlNifFunc exla_funcs[] = { {"mlir_sort", 4, mlir_sort}, {"mlir_scatter", 5, mlir_scatter}, {"mlir_select_and_scatter", 8, mlir_select_and_scatter}, + {"mlir_gather", 8, mlir_gather}, {"mlir_reshape", 3, mlir_reshape}, {"mlir_reverse", 3, mlir_reverse}, {"mlir_transpose", 3, mlir_transpose}, diff --git a/exla/c_src/exla/mlir/builder.cc b/exla/c_src/exla/mlir/builder.cc index 7b866c288d..a973a45401 100644 --- a/exla/c_src/exla/mlir/builder.cc +++ b/exla/c_src/exla/mlir/builder.cc @@ -815,6 +815,14 @@ mlir::Value MLIRFunction::SelectAndScatterOp( return op.getResult(); } +mlir::Value MLIRFunction::GatherOp(mlir::Value source, mlir::Value indices, std::vector offset_dims, std::vector collapsed_slice_dims, std::vector start_index_map, std::vector slice_sizes, int64_t index_vector_dim) { + auto builder = module_->builder(); + builder->setInsertionPointToEnd(&func_->getBody().back()); + auto gather_dimension_numbers = mlir::mhlo::GatherDimensionNumbersAttr::get(builder->getContext(), offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim); + auto slice_sizes_attr = Int64ToDenseIntElementsAttr(module_->builder(), slice_sizes); + return builder->create(builder->getUnknownLoc(), source, indices, gather_dimension_numbers, slice_sizes_attr, false); +} + mlir::Value MLIRFunction::FFTOp(mlir::Value tensor, bool forward_fft, std::vector fft_length) { auto builder = module_->builder(); builder->setInsertionPointToEnd(&func_->getBody().back()); diff --git a/exla/c_src/exla/mlir/builder.h b/exla/c_src/exla/mlir/builder.h index 60c6e280ff..58deacae58 100644 --- a/exla/c_src/exla/mlir/builder.h +++ b/exla/c_src/exla/mlir/builder.h @@ -98,6 +98,7 @@ class MLIRFunction { mlir::Value SelectOp(mlir::Value pred, mlir::Value on_true, mlir::Value on_false); mlir::Value ScatterOp(mlir::Value target, mlir::Value indices, mlir::Value updates, bool add_or_put); mlir::Value SelectAndScatterOp(mlir::Value target, mlir::Value source, mlir::Value init_value, bool gt_or_lt, std::vector window_dimensions, std::vector window_strides, std::vector padding); + mlir::Value GatherOp(mlir::Value source, mlir::Value indices, std::vector offset_dims, std::vector collapsed_slice_dims, std::vector start_index_map, std::vector slice_sizes, int64_t index_vector_dim); mlir::Value FFTOp(mlir::Value tensor, bool forward_fft, std::vector fft_length); mlir::Value ConvOp(mlir::Value tensor, mlir::Value kernel, std::vector window_strides, std::vector padding, std::vector tensor_dilation, std::vector kernel_dilation, xla::ConvolutionDimensionNumbers dimension_numbers, uint64_t feature_group_count, uint64_t batch_group_count, uint64_t precision_config, std::vector output_dims); mlir::Value CreateTokenOp(); diff --git a/exla/c_src/exla/mlir/ops.cc b/exla/c_src/exla/mlir/ops.cc index 0dbb665eea..e91d746e70 100644 --- a/exla/c_src/exla/mlir/ops.cc +++ b/exla/c_src/exla/mlir/ops.cc @@ -1023,6 +1023,46 @@ ERL_NIF_TERM mlir_select_and_scatter(ErlNifEnv* env, int argc, const ERL_NIF_TER return exla::nif::ok(env, exla::nif::make(env, res)); } +ERL_NIF_TERM mlir_gather(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { + if (argc != 8) { + return exla::nif::error(env, "Bad argument count."); + } + + exla::MLIRFunction** function; + mlir::Value *source, *indices; + + int64_t index_vector_dim; + std::vector slice_sizes, offset_dims, collapsed_slice_dims, start_index_map; + + if (!exla::nif::get(env, argv[0], function)) { + return exla::nif::error(env, "Unable to get function."); + } + if (!exla::nif::get(env, argv[1], source)) { + return exla::nif::error(env, "Unable to get source."); + } + if (!exla::nif::get(env, argv[2], indices)) { + return exla::nif::error(env, "Unable to get indices."); + } + if (!exla::nif::get_list(env, argv[3], slice_sizes)) { + return exla::nif::error(env, "Unable to get slice_sizes."); + } + if (!exla::nif::get_list(env, argv[4], offset_dims)) { + return exla::nif::error(env, "Unable to get offset_dims."); + } + if (!exla::nif::get_list(env, argv[5], collapsed_slice_dims)) { + return exla::nif::error(env, "Unable to get collapsed_slice_dims."); + } + if (!exla::nif::get_list(env, argv[6], start_index_map)) { + return exla::nif::error(env, "Unable to get start_index_map."); + } + if (!exla::nif::get(env, argv[7], &index_vector_dim)) { + return exla::nif::error(env, "Unable to get index_vector_dim."); + } + + mlir::Value res = (*function)->GatherOp(*source, *indices, offset_dims, collapsed_slice_dims, start_index_map, slice_sizes, index_vector_dim); + return exla::nif::ok(env, exla::nif::make(env, res)); +} + ERL_NIF_TERM mlir_fft(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { if (argc != 4) { return exla::nif::error(env, "Bad argument count."); diff --git a/exla/c_src/exla/mlir/ops.h b/exla/c_src/exla/mlir/ops.h index 6beab5a6df..8cab3fc95c 100644 --- a/exla/c_src/exla/mlir/ops.h +++ b/exla/c_src/exla/mlir/ops.h @@ -95,6 +95,7 @@ ERL_NIF_TERM mlir_concatenate(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[ ERL_NIF_TERM dump_mlir_module(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); ERL_NIF_TERM mlir_scatter(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); ERL_NIF_TERM mlir_select_and_scatter(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); +ERL_NIF_TERM mlir_gather(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); ERL_NIF_TERM mlir_fft(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); ERL_NIF_TERM mlir_create_token(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); ERL_NIF_TERM mlir_triangular_solve(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]); diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index cf0798b72d..f2dbc7d4eb 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -1490,6 +1490,27 @@ defmodule EXLA.Defn do ) end + defp to_operator(:gather, [%Value{} = tensor, indices], _ans, _state) do + tensor_rank = tensor |> op_shape() |> tuple_size() + indices_rank = indices |> op_shape() |> tuple_size() + + index_vector_dim = indices_rank - 1 + slice_sizes = List.duplicate(1, tensor_rank) + offset_dims = [] + collapsed_slice_dims = axes_for_rank(tensor_rank) + start_index_map = axes_for_rank(tensor_rank) + + Value.gather( + tensor, + indices, + slice_sizes, + offset_dims, + collapsed_slice_dims, + start_index_map, + index_vector_dim + ) + end + defp to_operator(:gather, [tensor, indices], _ans, _state) do tensor_rank = tensor |> op_shape() |> tuple_size() indices_rank = indices |> op_shape() |> tuple_size() diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 28cc39db1e..63ff1d26ab 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -338,6 +338,31 @@ defmodule EXLA.MLIR.Value do %Value{target | ref: ref} end + def gather( + %Value{function: func} = source, + %Value{function: func} = indices, + slice_sizes, + offset_dims, + collapsed_slice_dims, + start_index_map, + index_vector_dim + ) do + ref = + EXLA.NIF.mlir_gather( + func.ref, + source.ref, + indices.ref, + slice_sizes, + offset_dims, + collapsed_slice_dims, + start_index_map, + index_vector_dim + ) + |> unwrap!() + + %Value{source | ref: ref} + end + defp get_precision_config_int(precision_config) do case precision_config do :default -> diff --git a/exla/lib/exla/nif.ex b/exla/lib/exla/nif.ex index 420d69b6b7..47171de398 100644 --- a/exla/lib/exla/nif.ex +++ b/exla/lib/exla/nif.ex @@ -97,6 +97,18 @@ defmodule EXLA.NIF do ), do: :erlang.nif_error(:undef) + def mlir_gather( + _function, + _sorce, + _indices, + _slice_sizes, + _offset_dims, + _collapsed_slice_dims, + _start_index_map, + _index_vector_dim + ), + do: :erlang.nif_error(:undef) + def mlir_fft(_function, _tensor, _forward_fft, _fft_lenght), do: :erlang.nif_error(:undef) def mlir_convolution( diff --git a/exla/test/exla/mlir/executable_test.exs b/exla/test/exla/mlir/executable_test.exs index 1ce245284c..d1343643c5 100644 --- a/exla/test/exla/mlir/executable_test.exs +++ b/exla/test/exla/mlir/executable_test.exs @@ -897,4 +897,20 @@ defmodule EXLA.MLIR.ExecutableTest do ) end end + + describe "gather" do + test "works" do + t = Nx.tensor([[1, 2], [3, 4]]) + idx = Nx.tensor([[[1, 1], [0, 0]], [[1, 0], [0, 1]]]) + result = EXLA.jit(fn t, idx -> Nx.gather(t, idx) end, compiler_mode: :mlir).(t, idx) + + assert_equal( + result, + Nx.tensor([ + [4, 1], + [3, 2] + ]) + ) + end + end end