diff --git a/torchx/c_src/torchx.cpp b/torchx/c_src/torchx.cpp index 6ad2c18d34..29b34bd1f0 100644 --- a/torchx/c_src/torchx.cpp +++ b/torchx/c_src/torchx.cpp @@ -533,13 +533,27 @@ NIF(unfold) TENSOR(at::native::unfold(*input, dim, size, step)); } -NIF(put) -{ +NIF(put) { TENSOR_PARAM(0, input); - TENSOR_PARAM(1, index); + LIST_PARAM(1, std::vector, indices); TENSOR_PARAM(2, source); - TENSOR(at::put(*input, *index, *source)); + torch::Tensor output = input->clone(); + torch::Tensor destination = output; + + auto source_shape = source->sizes(); + + size_t dim = 0; + for (dim = 0; dim < indices.size() - 1; dim++) { + auto start = indices[dim]; + // arguments are dimension, start index and NON-INCLUSIVE end index + destination = destination.slice(dim, start, start + source_shape[dim]); + } + + auto start = indices[dim]; + destination.slice(dim, start, start + source_shape[dim]) = *source; + + TENSOR(output); } NIF(permute) diff --git a/torchx/lib/torchx.ex b/torchx/lib/torchx.ex index 17c721d89d..59e53bcc7c 100644 --- a/torchx/lib/torchx.ex +++ b/torchx/lib/torchx.ex @@ -248,7 +248,7 @@ defmodule Torchx do deftensor argsort(tensor, axis, is_descending, stable) deftensor flip(tensor, axis) deftensor unfold(tensor, dimension, size, step) - deftensor put(tensor_input, tensor_index, tensor_source) + deftensor put(tensor_input, index, tensor_source) deftensor where(tensorA, tensorB, tensorC) ## Aggregation diff --git a/torchx/lib/torchx/backend.ex b/torchx/lib/torchx/backend.ex index 0d06086f30..8f877f0440 100644 --- a/torchx/lib/torchx/backend.ex +++ b/torchx/lib/torchx/backend.ex @@ -339,7 +339,7 @@ defmodule Torchx.Backend do @impl true def put_slice(out, input, start_indices_unbounded, slice) do - {device, _} = input_tx = from_nx(input) + input_tx = from_nx(input) slice_shape_list = Tuple.to_list(slice.shape) @@ -351,30 +351,11 @@ defmodule Torchx.Backend do min(max(idx, 0), dim_size - len) end) - range_or_ranges = - [start_indices, slice_shape_list] - |> Enum.zip_with(fn [s, l] -> s..(s + l - 1)//1 end) - |> Enum.reverse() - |> Enum.reduce(fn range, acc -> for x <- range, y <- acc, do: List.flatten([x, y]) end) - - # if below is needed for when the reduce receives a single-element list - linear_indices_tx = - if is_list(range_or_ranges) do - range_or_ranges - |> Nx.tensor(backend: {__MODULE__, device: device}) - |> then(&as_torchx_linear_indices(input.shape, &1)) - else - range_or_ranges - |> Enum.to_list() - |> Nx.tensor(backend: {__MODULE__, device: device}) - |> Torchx.from_nx() - end - slice_tx = slice |> from_nx() |> Torchx.to_type(to_torch_type(out.type)) input_tx |> Torchx.to_type(to_torch_type(out.type)) - |> Torchx.put(linear_indices_tx, slice_tx) + |> Torchx.put(start_indices, slice_tx) |> to_nx(out) end @@ -534,71 +515,6 @@ defmodule Torchx.Backend do |> to_nx(out) end - defp as_torchx_linear_indices(shape, idx) do - # Nx provides indices as a tensor of shape {*, input_dims} - # However, torch expects indices to be a tensor of indices along a given axis. - # As such, we need to convert the indices tensor to linear indices. - # See the `linear_indices_offsets` function for an explanation on the offsets calculation. - - # Index limit validation - - ndims = tuple_size(shape) - - flattened_idx = Nx.reshape(idx, {div(Nx.size(idx), ndims), ndims}) - shape_tensor = shape |> Tuple.to_list() |> Nx.tensor() - - upper_clamped_idx = - flattened_idx - |> Nx.greater_equal(shape_tensor) - |> Nx.select(Nx.subtract(shape_tensor, 1), flattened_idx) - - lower_clamp_selector = Nx.less(upper_clamped_idx, 0) - - fully_clamped_idx = - lower_clamp_selector |> Nx.select(0, upper_clamped_idx) |> Nx.reshape(idx.shape) - - # Actual conversion algorithm - - linear_indices_offsets = - shape - |> linear_indices_offsets() - |> from_nx() - - lin_idx_num_elements = - idx.shape |> Tuple.delete_at(tuple_size(idx.shape) - 1) |> Tuple.product() - - fully_clamped_idx - |> from_nx() - |> Torchx.tensordot(linear_indices_offsets, [tuple_size(idx.shape) - 1], [0]) - |> Torchx.reshape({lin_idx_num_elements}) - end - - defp linear_indices_offsets(shape) do - # The offsets tensor calculated below follows a formula in which we - # multiply the index along each axis by the number of elements contained in all following axes - # For example, for a {3, 5, 7, 2} tensor, the offsets tensor is [70, 14, 2, 1] - - # This offsets tensor is then applied to the indices tensor through matrix multiplication: - # indices = [[0, 2, 1, 0], [0, 0, 0, 1], [1, 4, 3, 2]] - # offsets = [70, 14, 2, 1] - # linear_indices = [14 * 2 + 2 * 1, 1 * 1, 70 * 1 + 14 * 4 + 2 * 3 + 1 * 2] = [30, 1, 134] - - # By linear indices, we refer to the indices of a row-major representation of a tensor - # it's easy to see the expected values using Nx.iota(tensor), which will output a tensor - # which counts in exactly the same way, when provided no arguments. In effect, Nx.iota outputs - # the corresponding linear indices for a given tensor shape. - - {offsets_list, _} = - shape - |> Tuple.to_list() - |> Enum.reverse() - |> Enum.reduce({[], 1}, fn x, {acc, multiplier} -> - {[multiplier | acc], multiplier * x} - end) - - Nx.tensor(offsets_list, backend: __MODULE__) - end - @impl true def take_along_axis(out, tensor, idx, axis) do idx_tx = idx |> from_nx() |> Torchx.to_type(:long)