Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: gather #44

Merged
merged 1 commit into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion c_src/emlx_nif.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,8 @@ NIF(pad) {
TENSOR_PARAM(4, pad_value);
DEVICE_PARAM(5, device);

TENSOR(mlx::core::pad(*t, axes, low_pad_size, high_pad_size, *pad_value, "constant", device))
TENSOR(mlx::core::pad(*t, axes, low_pad_size, high_pad_size, *pad_value,
"constant", device))
};

NIF(sort) {
Expand Down Expand Up @@ -505,6 +506,16 @@ NIF(take) {
TENSOR(mlx::core::take(*t, *indices, axis, device));
}

NIF(gather) {
TENSOR_PARAM(0, t);
LIST_PARAM(1, std::vector<mlx::core::array>, indices);
LIST_PARAM(2, std::vector<int>, axes);
LIST_PARAM(3, std::vector<int>, slice_sizes);
DEVICE_PARAM(4, device);

TENSOR(mlx::core::gather(*t, indices, axes, slice_sizes, device));
}

/* Reduction Ops */

#define REDUCTION_AXES_OP(OP) REDUCTION_AXES_OP2(OP, OP)
Expand Down Expand Up @@ -877,6 +888,7 @@ static ErlNifFunc nif_funcs[] = {{"strides", 1, strides},
{"concatenate", 3, concatenate},
{"take_along_axis", 4, take_along_axis},
{"take", 4, take},
{"gather", 5, gather},
{"slice", 5, slice},
{"slice_update", 5, slice_update},
{"squeeze", 3, squeeze},
Expand Down
1 change: 1 addition & 0 deletions lib/emlx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ defmodule EMLX do
deftensor concatenate(tensors, axis)
deftensor take_along_axis(tensor, tensorIndices, axis)
deftensor take(tensor, tensorIndices, axis)
deftensor gather(tensor, indices, axes, slice_sizes)
deftensor max(tensor, axes, keep_axes)
deftensor min(tensor, axes, keep_axes)
deftensor clip(tensor, tensor_min, tensor_max)
Expand Down
34 changes: 33 additions & 1 deletion lib/emlx/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,39 @@ defmodule EMLX.Backend do
end)
end

@impl true
def gather(out, tensor, indices, opts) do
axes = opts[:axes]

num_axes = Nx.axis_size(indices, -1)

slice_sizes =
Enum.map(Nx.axes(tensor), fn axis ->
if axis in axes do
1
else
Nx.axis_size(tensor, axis)
end
end)

indices_list =
Enum.map(0..(num_axes - 1), fn entry ->
{_device, ref} =
indices
|> Nx.slice_along_axis(entry, 1, axis: -1)
|> Nx.squeeze(axes: [-1])
|> from_nx()

ref
end)

tensor
|> from_nx()
|> EMLX.gather(indices_list, axes, slice_sizes)
|> EMLX.reshape(out.shape)
|> to_nx(out)
end

for {op, arity} <- [
reduce: 5,
window_reduce: 6,
Expand All @@ -1135,7 +1168,6 @@ defmodule EMLX.Backend do
to_pointer: 2,
indexed_put: 5,
indexed_add: 5,
gather: 4,
from_pointer: 5
] do
@impl true
Expand Down
4 changes: 4 additions & 0 deletions lib/emlx/nif.ex
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ defmodule EMLX.NIF do
:erlang.nif_error(:nif_not_loaded)
end

def gather(_tensor, _indices, _axes, _slice_sizes, _device) do
:erlang.nif_error(:nif_not_loaded)
end

def abs(_tensor, _device) do
:erlang.nif_error(:nif_not_loaded)
end
Expand Down
4 changes: 0 additions & 4 deletions test/emlx/nx_doctest_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ defmodule EMLX.Nx.DoctestTest do
indexed_put: 4,
# put_diagonal depends on indexed_put
put_diagonal: 3,
# take_diagonal depends on gather
take_diagonal: 2,
# make_diagonal depends on indexed_put
make_diagonal: 2,
# mode depends on indexed_add
Expand All @@ -26,9 +24,7 @@ defmodule EMLX.Nx.DoctestTest do
window_scatter_min: 5,
window_scatter_max: 5,
reverse: 2,
take: 3,
slice: 4,
gather: 3,
reflect: 2
]

Expand Down