Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ivy.index_add #26934

Closed
wants to merge 46 commits into from
Closed
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
eb4ba51
ivy.index_add issue #26801
imsoumya18 Oct 12, 2023
32796cd
Merge branch 'unifyai:main' into ivy.index_add#26801
imsoumya18 Oct 13, 2023
20c0565
3. jax -> exp
imsoumya18 Oct 13, 2023
e49b8bc
3. jax -> exp
imsoumya18 Oct 13, 2023
005e825
5. np -> exp
imsoumya18 Oct 13, 2023
1e63433
6. paddle -> exp
imsoumya18 Oct 13, 2023
649a650
7. tf -> exp
imsoumya18 Oct 13, 2023
473d31b
8. torch -> exp
imsoumya18 Oct 13, 2023
bd71306
Merge branch 'unifyai:main' into ivy.index_add#26801
imsoumya18 Oct 14, 2023
f6e4540
4. mxnet -> exp
imsoumya18 Oct 15, 2023
e08573c
9. ivy -> exp
imsoumya18 Oct 15, 2023
72a7fad
Merge branch 'unifyai:main' into ivy.index_add#26801
imsoumya18 Oct 15, 2023
9b89eb1
Merge branch 'unifyai:main' into ivy.index_add#26801
imsoumya18 Oct 16, 2023
23583c1
fix
imsoumya18 Oct 16, 2023
fd568aa
fix
imsoumya18 Oct 17, 2023
1669c25
fix
imsoumya18 Oct 17, 2023
ab0ca56
fix
imsoumya18 Oct 17, 2023
afa668c
fix
imsoumya18 Oct 17, 2023
831b184
fix
imsoumya18 Oct 18, 2023
a719687
Merge branch 'unifyai:main' into ivy.index_add#26801
imsoumya18 Oct 18, 2023
369fe24
🤖 Lint code
ivy-branch Oct 18, 2023
600f829
fix
imsoumya18 Oct 19, 2023
a77c2ed
Merge branch 'unifyai:main' into ivy.index_add#26801
imsoumya18 Oct 19, 2023
6be63a5
Merge branch 'unifyai:main' into ivy.index_add#26801
imsoumya18 Oct 22, 2023
5ff3369
Merge branch 'unifyai:main' into ivy.index_add#26801
imsoumya18 Oct 25, 2023
203784a
Merge branch 'unifyai:main' into ivy.index_add#26801
imsoumya18 Nov 8, 2023
257a128
Merge branch 'unifyai:main' into ivy.index_add#26801
imsoumya18 Nov 9, 2023
6423c6f
Merge branch 'unifyai:main' into ivy.index_add#26801
imsoumya18 Nov 10, 2023
b89ba5b
Merge branch 'unifyai:main' into ivy.index_add#26801
imsoumya18 Nov 19, 2023
6976dba
fix
imsoumya18 Nov 19, 2023
5093bee
Merge branch 'unifyai:main' into ivy.index_add#26801
imsoumya18 Nov 21, 2023
11d3d53
Merge branch 'unifyai:main' into ivy.index_add#26801
imsoumya18 Nov 30, 2023
dbf7d73
fix
imsoumya18 Nov 30, 2023
b0f196f
fix
imsoumya18 Dec 1, 2023
973f73a
Merge branch 'unifyai:main' into ivy.index_add#26801
imsoumya18 Dec 1, 2023
4e6aba7
fix
imsoumya18 Dec 4, 2023
ce39700
Merge remote-tracking branch 'origin/ivy.index_add#26801' into ivy.in…
imsoumya18 Dec 4, 2023
2900851
Merge branch 'unifyai:main' into ivy.index_add#26801
imsoumya18 Dec 4, 2023
e2e10f0
Merge branch 'unifyai:main' into ivy.index_add#26801
imsoumya18 Dec 5, 2023
d372965
Merge branch 'unifyai:main' into ivy.index_add#26801
imsoumya18 Dec 7, 2023
8126015
Merge branch 'unifyai:main' into ivy.index_add#26801
imsoumya18 Dec 7, 2023
5445b29
fix
imsoumya18 Dec 7, 2023
962a30e
Merge branch 'unifyai:main' into ivy.index_add#26801
imsoumya18 Dec 9, 2023
977a9b3
fix
imsoumya18 Dec 9, 2023
61dd412
Merge branch 'unifyai:main' into ivy.index_add#26801
imsoumya18 Dec 11, 2023
6a0edb5
🤖 Lint code
ivy-branch Dec 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions ivy/data_classes/array/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1512,3 +1512,49 @@ def put_along_axis(
changes.
"""
return ivy.put_along_axis(self._data, indices, values, axis, mode=mode, out=out)

def index_add(
self: ivy.Array,
index: ivy.Array,
axis: int,
value: ivy.Array,
/,
*,
name: Optional[str] = None,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
Add the elements of the input tensor with value tensor by selecting the indices
in the order given in index.

Parameters
----------
self : Array
The Destination Array.
index : Array
The 1-D array containing the indices to index.
axis : int
The dimension in which we index.
value : Array
The tensor used to add the elements along the target axis.
name : str, optional
Output array where the output is to be stored. Default value is 'none'.
out : Array
Output array.

Returns
-------
Array
Same dimention and dtype with x.

Examples
--------
>>> x = ivy.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
>>> idx = ivy.array([0, 2])
>>> val = ivy.array([[1, 1, 1], [1, 1, 1]])
>>> ans = ivy.index_add(x, idx, 0, val)
array([[2, 2, 2],
[1, 1, 1],
[2, 2, 2]])
"""
return ivy.index_add(self._data, index, axis, value, out=out)
69 changes: 69 additions & 0 deletions ivy/data_classes/container/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4137,6 +4137,75 @@ def trim_zeros(
"""
return self._static_trim_zeros(self, trim=trim)

@staticmethod
def static_index_add(
x: ivy.Array,
index: ivy.Array,
axis: int,
value: ivy.Array,
/,
*,
name: Optional[str] = None,
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
) -> ivy.Container:
return ContainerBase.cont_multi_map_in_function(
"index_add",
x,
index,
axis,
value,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
)

def index_add(
self: ivy.Container,
index: ivy.Container,
axis: Union[int, ivy.Container],
value: ivy.Container,
/,
*,
name: Optional[str] = None,
) -> ivy.Container:
"""
Add the elements of the input tensor with value tensor by selecting the indices
in the order given in index.

Parameters
----------
x : Array
The Destination Array.
index : Array
The 1-D array containing the indices to index.
axis : int
The dimension in which we index.
value : Array
The tensor used to add the elements along the target axis.
name : str, optional
Output array where the output is to be stored. Default value is 'none'.

Returns
-------
Array
Same dimention and dtype with x.

Examples
--------
>>> x1 = ivy.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
>>> index1 = ivy.array([0, 2])
>>> value1 = ivy.array([[1, 1, 1], [1, 1, 1]])
>>> ret1 = ivy.index_add(x1, index1, 0, value1)
array([[2, 2, 2],
[1, 1, 1],
[2, 2, 2]])
"""
return self.static_index_add(self, index, axis, value)


def concat_from_sequence(
self: ivy.Container,
Expand Down
35 changes: 35 additions & 0 deletions ivy/functional/backends/jax/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,3 +468,38 @@ def take(

def trim_zeros(a: JaxArray, /, *, trim: Optional[str] = "bf") -> JaxArray:
return jnp.trim_zeros(a, trim=trim)


def index_add(
x: JaxArray,
index: JaxArray,
axis: int,
value: JaxArray,
/,
*,
name: Optional[str] = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
x = jnp.swapaxes(x, axis, 0)
value = jnp.swapaxes(value, axis, 0)
_to_adds = []
index = sorted(zip(index.tolist(), range(len(index))), key=(lambda i: i[0]))
while index:
_curr_idx = index[0][0]
while len(_to_adds) < _curr_idx:
_to_adds.append(jnp.zeros_like(value[0]))
_to_add_cum = value[index[0][1]]
while len(index) > 1 and (index[0][0] == index[1][0]):
_to_add_cum = _to_add_cum + value[index.pop(1)[1]]
index.pop(0)
_to_adds.append(_to_add_cum)
while len(_to_adds) < x.shape[0]:
_to_adds.append(jnp.zeros_like(value[0]))
_to_adds = jnp.stack(_to_adds)
if len(x.shape) < 2:
# Added this line due to the paddle backend treating scalars as 1-d arrays
_to_adds = jnp.flatten(_to_adds)

ret = jnp.add(x, _to_adds)
ret = jnp.swapaxes(ret, axis, 0)
return ret
37 changes: 37 additions & 0 deletions ivy/functional/backends/mxnet/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,40 @@ def concat_from_sequence(
out: Optional[Union[(None, mx.ndarray.NDArray)]] = None,
) -> Union[(None, mx.ndarray.NDArray)]:
raise IvyNotImplementedException()


def index_add(
x: Union[(None, mx.ndarray.NDArray)],
index: Union[(None, mx.ndarray.NDArray)],
axis: int,
value: Union[(None, mx.ndarray.NDArray)],
/,
*,
name: Optional[str] = None,
out: Optional[Union[(None, mx.ndarray.NDArray)]] = None,
) -> Union[(None, mx.ndarray.NDArray)]:
x = mx.nd.swapaxes(x, axis, 0)
value = mx.nd.swapaxes(value, axis, 0)
_to_adds = []
index = sorted(
zip(index.asnumpy().tolist(), range(len(index))), key=(lambda i: i[0])
)
while index:
_curr_idx = index[0][0]
while len(_to_adds) < _curr_idx:
_to_adds.append(mx.nd.zeros_like(value[0]))
_to_add_cum = value[index[0][1]]
while len(index) > 1 and (index[0][0] == index[1][0]):
_to_add_cum = _to_add_cum + value[index.pop(1)[1]]
index.pop(0)
_to_adds.append(_to_add_cum)
while len(_to_adds) < x.shape[0]:
_to_adds.append(mx.nd.zeros_like(value[0]))
_to_adds = mx.nd.stack(*_to_adds)
if len(x.shape) < 2:
# Added this line due to the paddle backend treating scalars as 1-d arrays
_to_adds = mx.nd.flatten(_to_adds)

ret = mx.nd.add(x, _to_adds)
ret = mx.nd.swapaxes(ret, axis, 0)
return ret
35 changes: 35 additions & 0 deletions ivy/functional/backends/numpy/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,3 +598,38 @@ def put_along_axis(
put_along_axis.partial_mixed_handler = lambda *args, mode=None, **kwargs: mode in [
"replace",
]


def index_add(
x: np.ndarray,
index: np.ndarray,
axis: int,
value: np.ndarray,
/,
*,
name: Optional[str] = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
x = np.swapaxes(x, axis, 0)
value = np.swapaxes(value, axis, 0)
_to_adds = []
index = sorted(zip(index.tolist(), range(len(index))), key=(lambda i: i[0]))
while index:
_curr_idx = index[0][0]
while len(_to_adds) < _curr_idx:
_to_adds.append(np.zeros_like(value[0]))
_to_add_cum = value[index[0][1]]
while len(index) > 1 and (index[0][0] == index[1][0]):
_to_add_cum = _to_add_cum + value[index.pop(1)[1]]
index.pop(0)
_to_adds.append(_to_add_cum)
while len(_to_adds) < x.shape[0]:
_to_adds.append(np.zeros_like(value[0]))
_to_adds = np.stack(_to_adds)
if len(x.shape) < 2:
# Added this line due to the paddle backend treating scalars as 1-d arrays
_to_adds = _to_adds.flatten()

ret = np.add(x, _to_adds)
ret = np.swapaxes(ret, axis, 0)
return ret
13 changes: 13 additions & 0 deletions ivy/functional/backends/paddle/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,3 +905,16 @@ def put_along_axis(
"sum",
"mul",
]


def index_add(
x: paddle.Tensor,
index: paddle.Tensor,
axis: int,
value: paddle.Tensor,
/,
*,
name: Optional[str] = None,
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
return paddle.index_add(x, index, axis, value)
35 changes: 35 additions & 0 deletions ivy/functional/backends/tensorflow/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,3 +561,38 @@ def trim_zeros(a: tf.Tensor, /, *, trim: Optional[str] = "bf") -> tf.Tensor:
last = tf.minimum(last, tf.cast(tf.shape(a)[0], tf.int64))

return a[first:last]


def index_add(
x: tf.Tensor,
index: tf.Tensor,
axis: int,
value: tf.Tensor,
/,
*,
name: Optional[str] = None,
out: Union[tf.Tensor, tf.Variable] = None,
) -> tf.Tensor:
x = tf.experimental.numpy.swapaxes(x, axis, 0)
value = tf.experimental.numpy.swapaxes(value, axis, 0)
_to_adds = []
index = sorted(zip(index.numpy().tolist(), range(len(index))), key=(lambda i: i[0]))
while index:
_curr_idx = index[0][0]
while len(_to_adds) < _curr_idx:
_to_adds.append(tf.zeros_like(value[0]))
_to_add_cum = value[index[0][1]]
while len(index) > 1 and (index[0][0] == index[1][0]):
_to_add_cum = _to_add_cum + value[index.pop(1)[1]]
index.pop(0)
_to_adds.append(_to_add_cum)
while len(_to_adds) < x.shape[0]:
_to_adds.append(tf.zeros_like(value[0]))
_to_adds = tf.stack(_to_adds)
if len(x.shape) < 2:
# Added this line due to the paddle backend treating scalars as 1-d arrays
_to_adds = tf.nest.flatten(_to_adds)

ret = tf.add(x, _to_adds)
ret = tf.experimental.numpy.swapaxes(ret, axis, 0)
return ret
35 changes: 35 additions & 0 deletions ivy/functional/backends/torch/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,3 +639,38 @@ def trim_zeros(a: torch.Tensor, /, *, trim: Optional[str] = "bf") -> torch.Tenso
else:
last = last - 1
return a[first:last]


def index_add(
x: torch.Tensor,
index: torch.Tensor,
axis: int,
value: torch.Tensor,
/,
*,
name: Optional[str] = None,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x = torch.swapaxes(x, axis, 0)
value = torch.swapaxes(value, axis, 0)
_to_adds = []
index = sorted(zip(index.tolist(), range(len(index))), key=(lambda i: i[0]))
while index:
_curr_idx = index[0][0]
while len(_to_adds) < _curr_idx:
_to_adds.append(torch.zeros_like(value[0]))
_to_add_cum = value[index[0][1]]
while len(index) > 1 and (index[0][0] == index[1][0]):
_to_add_cum = _to_add_cum + value[index.pop(1)[1]]
index.pop(0)
_to_adds.append(_to_add_cum)
while len(_to_adds) < x.shape[0]:
_to_adds.append(torch.zeros_like(value[0]))
_to_adds = torch.stack(_to_adds)
if len(x.shape) < 2:
# Added this line due to the paddle backend treating scalars as 1-d arrays
_to_adds = torch.flatten(_to_adds)

ret = torch.add(x, _to_adds)
ret = torch.swapaxes(ret, axis, 0)
return ret
51 changes: 51 additions & 0 deletions ivy/functional/ivy/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2900,3 +2900,54 @@ def trim_zeros(
),
"to_skip": ("inputs_to_ivy_arrays",),
}


@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
@to_native_arrays_and_back
@handle_array_function
@handle_device
def index_add(
x: ivy.Array,
index: ivy.Array,
axis: int,
value: ivy.Array,
/,
*,
name: Optional[str] = None,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
Add the elements of the input tensor with value tensor by selecting the indices in
the order given in index.

Parameters
----------
x : Array
The Destination Array.
index : Array
The 1-D array containing the indices to index.
axis : int
The dimension in which we index.
value : Array
The tensor used to add the elements along the target axis.
name : str, optional
Output array where the output is to be stored. Default value is 'none'.

Returns
-------
Array
Same dimention and dtype with x.

Examples
--------
>>> x1 = ivy.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
>>> index1 = ivy.array([0, 2])
>>> value1 = ivy.array([[1, 1, 1], [1, 1, 1]])
>>> ret1 = ivy.index_add(x1, index1, 0, value1)
array([[2, 2, 2],
[1, 1, 1],
[2, 2, 2]])
"""
return ivy.current_backend().index_add(x, index, axis, value)
Loading
Loading