Skip to content
This repository has been archived by the owner on Mar 6, 2023. It is now read-only.

Commit

Permalink
Merge pull request #9607 from jakevdp:scatter-apply
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 429797642
  • Loading branch information
jax authors committed Feb 19, 2022
2 parents 3290dd3 + e13c847 commit 660616c
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 2 deletions.
64 changes: 63 additions & 1 deletion jax/_src/lax/slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

import enum
from functools import partial
from typing import Any, NamedTuple, Optional, Sequence, Union
from typing import Any, Callable, NamedTuple, Optional, Sequence, Union
import weakref

import numpy as np

Expand Down Expand Up @@ -490,6 +491,67 @@ def scatter_max(
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
mode=GatherScatterMode.from_any(mode))

# To avoid recompilation, we store a dict of weak references to funcs.
_scatter_apply_cache: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()

def scatter_apply(
operand: Array, scatter_indices: Array,
func: Callable[[Array], Array],
dimension_numbers: ScatterDimensionNumbers, *,
indices_are_sorted: bool = False, unique_indices: bool = False,
mode: Optional[Union[str, GatherScatterMode]] = None) -> Array:
"""Scatter-apply operator.
Wraps `XLA's Scatter operator
<https://www.tensorflow.org/xla/operation_semantics#scatter>`_, where values
from ``operand`` are replaced with ``func(operand)``, with duplicate indices
resulting in multiple applications of ``func``.
The semantics of scatter are complicated, and its API might change in the
future. For most use cases, you should prefer the
:attr:`jax.numpy.ndarray.at` property on JAX arrays which uses
the familiar NumPy indexing syntax.
Note that in the current implementation, ``scatter_apply`` is not compatible
with automatic differentiation.
Args:
operand: an array to which the scatter should be applied
scatter_indices: an array that gives the indices in `operand` to which each
update in `updates` should be applied.
func: unary function that will be applied at each index.
dimension_numbers: a `lax.ScatterDimensionNumbers` object that describes
how dimensions of `operand`, `start_indices`, `updates` and the output
relate.
indices_are_sorted: whether `scatter_indices` is known to be sorted. If
true, may improve performance on some backends.
unique_indices: whether the indices to be updated in ``operand`` are
guaranteed to not overlap with each other. If true, may improve performance on
some backends.
mode: how to handle indices that are out of bounds: when set to 'clip',
indices are clamped so that the slice is within bounds, and when
set to 'fill' or 'drop' out-of-bounds updates are dropped. The behavior
for out-of-bounds indices when set to 'promise_in_bounds' is
implementation-defined.
Returns:
An array containing the result of applying `func` to `operand` at the given indices.
"""
# TODO: can we implement this without a placeholder?
unused = lax.full(scatter_indices.shape[:1], 0, operand.dtype)
_apply = lambda x, _: func(x)
try:
_apply = _scatter_apply_cache.setdefault(func, _apply)
except TypeError: # func is not weak referenceable
pass
jaxpr, consts = lax._reduction_jaxpr(_apply, lax._abstractify(lax._zero(operand)))
# TODO: implement this via its own primitive so we can define appropriate autodiff rules.
return scatter_p.bind(
operand, scatter_indices, unused, update_jaxpr=jaxpr,
update_consts=consts, dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
mode=GatherScatterMode.from_any(mode))

# Define this outside of scatter to ensure cache hits.
_scatter_reduction_computation = lambda x, y: y

Expand Down
25 changes: 24 additions & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6931,6 +6931,7 @@ class _IndexUpdateHelper:
``x = x.at[idx].power(y)`` ``x[idx] **= y``
``x = x.at[idx].min(y)`` ``x[idx] = minimum(x[idx], y)``
``x = x.at[idx].max(y)`` ``x[idx] = maximum(x[idx], y)``
``x = x.at[idx].apply(ufunc)`` ``ufunc.at(x, idx)``
``x = x.at[idx].get()`` ``x = x[idx]``
============================== ================================
Expand Down Expand Up @@ -7047,14 +7048,36 @@ def set(self, values, indices_are_sorted=False, unique_indices=False,
"""Pure equivalent of ``x[idx] = y``.
Returns the value of ``x`` that would result from the NumPy-style
:mod:indexed assignment <numpy.doc.indexing>` ``x[idx] = y``.
:mod:`indexed assignment <numpy.doc.indexing>` ``x[idx] = y``.
See :mod:`jax.ops` for details.
"""
return scatter._scatter_update(self.array, self.index, values, lax.scatter,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode)

def apply(self, func, indices_are_sorted=False, unique_indices=False,
mode=None):
"""Pure equivalent of ``func.at(x, idx)`` for a unary ufunc ``func``.
Returns the value of ``x`` that would result from applying the unary
function ``func`` to ``x`` at the given indices. This is similar to
``x.at[idx].set(func(x[idx]))``, but differs in the case of repeated indices:
in ``x.at[idx].apply(func)``, repeated indices result in the function being
applied multiple times.
Note that in the current implementation, ``scatter_apply`` is not compatible
with automatic differentiation.
See :mod:`jax.ops` for details.
"""
def _scatter_apply(x, indices, _, dims, **kwargs):
return lax.scatter_apply(x, indices, func, dims, **kwargs)
return scatter._scatter_update(self.array, self.index, lax._zero(self.array.dtype),
_scatter_apply,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode)

def add(self, values, indices_are_sorted=False, unique_indices=False,
mode=None):
"""Pure equivalent of ``x[idx] += y``.
Expand Down
1 change: 1 addition & 0 deletions jax/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@
index_in_dim as index_in_dim,
index_take as index_take,
scatter as scatter,
scatter_apply as scatter_apply,
scatter_add as scatter_add,
scatter_add_p as scatter_add_p,
scatter_max as scatter_max,
Expand Down
21 changes: 21 additions & 0 deletions tests/lax_numpy_indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,27 @@ def testStaticIndexing(self, shape, dtype, indexer):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)


@parameterized.named_parameters(jtu.cases_from_list({
"testcase_name": f"_{funcname}", "funcname": funcname}
for funcname in ["negative", "sin", "cos", "square", "sqrt", "log", "exp"]))
def testIndexApply(self, funcname, size=10, dtype='float32'):
rng = jtu.rand_default(self.rng())
idx_rng = jtu.rand_int(self.rng(), -size, size)
np_func = getattr(np, funcname)
jnp_func = getattr(jnp, funcname)
@jtu.ignore_warning(category=RuntimeWarning)
def np_op(x, idx):
y = x.copy()
np_func.at(y, idx)
return y
def jnp_op(x, idx):
return jnp.asarray(x).at[idx].apply(jnp_func)
args_maker = lambda: [rng(size, dtype), idx_rng(size, int)]
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)


@parameterized.named_parameters({
"testcase_name":
f"{jtu.format_shape_dtype_string(shape, dtype)}_inshape={name}"
Expand Down

0 comments on commit 660616c

Please sign in to comment.