Best viewed here.
For the changes specific to the experimental Pallas APIs,
see {ref}pallas-changelog
.
-
Breaking Changes
-
This release lands "stackless", an internal change to JAX's tracing machinery. We made trace dispatch purely a function of context rather than a function of both context and data. This let us delete a lot of machinery for managing data-dependent tracing: levels, sublevels,
post_process_call
,new_base_main
,custom_bind
, and so on. The change should only affect users that use JAX internals.If you do use JAX internals then you may need to update your code (see https://github.com/jax-ml/jax/commit/c36e1f7c1ad4782060cbc8e8c596d85dfb83986f for clues about how to do this). There might also be version skew issues with JAX libraries that do this. If you find this change breaks your non-JAX-internals-using code then try the
config.jax_data_dependent_tracing_fallback
flag as a workaround, and if you need help updating your code then please file a bug. -
{func}
jax.experimental.jax2tf.convert
withnative_serialization=False
or withenable_xla=False
have been deprecated since July 2024, with JAX version 0.4.31. Now we removed support for these use cases.jax2tf
with native serialization will still be supported. -
In
jax.interpreters.xla
, thexb
,xc
, andxe
symbols have been removed after being deprecated in JAX v0.4.31. Instead usexb = jax.lib.xla_bridge
,xc = jax.lib.xla_client
, andxe = jax.lib.xla_extension
. -
The deprecated module
jax.experimental.export
has been removed. It was replaced by {mod}jax.export
in JAX v0.4.30. See the migration guide for information on migrating to the new API. -
The
initial
argument to {func}jax.nn.softmax
and {func}jax.nn.log_softmax
has been removed, after being deprecated in v0.4.27. -
Calling
np.asarray
on typed PRNG keys (i.e. keys produced by :func:jax.random.key
) now raises an error. Previously, this returned a scalar object array. -
The following deprecated methods and functions in {mod}
jax.export
have been removed:jax.export.DisabledSafetyCheck.shape_assertions
: it had no effect already.jax.export.Exported.lowering_platforms
: useplatforms
.jax.export.Exported.mlir_module_serialization_version
: usecalling_convention_version
.jax.export.Exported.uses_shape_polymorphism
: useuses_global_constants
.- the
lowering_platforms
kwarg for {func}jax.export.export
: useplatforms
instead.
-
Hashing of tracers, which has been deprecated since version 0.4.30, now results in a
TypeError
. -
Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and replaces previous build.py usage. Run
python build/build.py --help
for more details. Brief overview of the new subcommand options:build
: Builds JAX wheel packages. For e.g.,python build/build.py build --wheels=jaxlib,jax-cuda-pjrt
requirements_update
: Updates requirements_lock.txt files.
-
{func}
jax.scipy.linalg.toeplitz
now does implicit batching on multi-dimensional inputs. To recover the previous behavior, you can call {func}jax.numpy.ravel
on the function inputs. -
{func}
jax.scipy.special.gamma
and {func}jax.scipy.special.gammasgn
now return NaN for negative integer inputs, to match the behavior of SciPy from scipy/scipy#21827. -
jax.clear_backends
was removed after being deprecated in v0.4.26.
-
-
New Features
- {func}
jax.jit
got a newcompiler_options: dict[str, Any]
argument, for passing compilation options to XLA. For the moment it's undocumented and may be in flux. - {func}
jax.tree_util.register_dataclass
now allows metadata fields to be declared inline via {func}dataclasses.field
. See the function documentation for examples. - Added {func}
jax.numpy.put_along_axis
. - {func}
jax.lax.linalg.eig
and the relatedjax.numpy
functions ({func}jax.numpy.linalg.eig
and {func}jax.numpy.linalg.eigvals
) are now supported on GPU. See {jax-issue}#24663
for more details.
- {func}
-
Bug fixes
- Fixed a bug where the GPU implementations of LU and QR decomposition would
result in an indexing overflow for batch sizes close to int32 max. See
{jax-issue}
#24843
for more details.
- Fixed a bug where the GPU implementations of LU and QR decomposition would
result in an indexing overflow for batch sizes close to int32 max. See
{jax-issue}
-
Deprecations
jax.lib.xla_extension.ArrayImpl
andjax.lib.xla_client.ArrayImpl
are deprecated; usejax.Array
instead.jax.lib.xla_extension.XlaRuntimeError
is deprecated; usejax.errors.JaxRuntimeError
instead.
-
Breaking Changes
- {func}
jax.numpy.isscalar
now returns True for any array-like object with zero dimensions. Previously it only returned True for zero-dimensional array-like objects with a weak dtype. jax.experimental.host_callback
has been deprecated since March 2024, with JAX version 0.4.26. Now we removed it. See {jax-issue}#20385
for a discussion of alternatives.
- {func}
-
Changes:
jax.lax.FftType
was introduced as a public name for the enum of FFT operations. The semi-public APIjax.lib.xla_client.FftType
has been deprecated.- TPU: JAX now installs TPU support from the
libtpu
package rather thanlibtpu-nightly
. For the next few releases JAX will pin an empty version oflibtpu-nightly
as well aslibtpu
to ease the transition; that dependency will be removed in Q1 2025.
-
Deprecations:
- The semi-public API
jax.lib.xla_client.PaddingType
has been deprecated. No JAX APIs consume this type, so there is no replacement. - The default behavior of {func}
jax.pure_callback
and {func}jax.extend.ffi.ffi_call
undervmap
has been deprecated and so has thevectorized
parameter to those functions. Thevmap_method
parameter should be used instead for better defined behavior. See the discussion in {jax-issue}#23881
for more details. - The semi-public API
jax.lib.xla_client.register_custom_call_target
has been deprecated. Use the JAX FFI instead. - The semi-public APIs
jax.lib.xla_client.dtype_to_etype
,jax.lib.xla_client.ops
,jax.lib.xla_client.shape_from_pyval
,jax.lib.xla_client.PrimitiveType
,jax.lib.xla_client.Shape
,jax.lib.xla_client.XlaBuilder
, andjax.lib.xla_client.XlaComputation
have been deprecated. Use StableHLO instead.
- The semi-public API
-
New Functionality
- This release includes wheels for Python 3.13. Free-threading mode is not yet supported.
jax.errors.JaxRuntimeError
has been added as a public alias for the formerly privateXlaRuntimeError
type.
-
Breaking changes
jax_pmap_no_rank_reduction
flag is set toTrue
by default.- array[0] on a pmap result now introduces a reshape (use array[0:1] instead).
- The per-shard shape (accessable via jax_array.addressable_shards or jax_array.addressable_data(0)) now has a leading (1, ...). Update code that directly accesses shards accordingly. The rank of the per-shard-shape now matches that of the global shape which is the same behavior as jit. This avoids costly reshapes when passing results from pmap into jit.
jax.experimental.host_callback
has been deprecated since March 2024, with JAX version 0.4.26. Now we set the default value of the--jax_host_callback_legacy
configuration value toTrue
, which means that if your code usesjax.experimental.host_callback
APIs, those API calls will be implemented in terms of the newjax.experimental.io_callback
API. If this breaks your code, for a very limited time, you can set the--jax_host_callback_legacy
toTrue
. Soon we will remove that configuration option, so you should instead transition to using the new JAX callback APIs. See {jax-issue}#20385
for a discussion.
-
Deprecations
- In {func}
jax.numpy.trim_zeros
, non-arraylike arguments or arraylike arguments withndim != 1
are now deprecated, and in the future will result in an error. - Internal pretty-printing tools
jax.core.pp_*
have been removed, after being deprecated in JAX v0.4.30. jax.lib.xla_client.Device
is deprecated; usejax.Device
instead.jax.lib.xla_client.XlaRuntimeError
has been deprecated. Usejax.errors.JaxRuntimeError
instead.
- In {func}
-
Deletion:
jax.xla_computation
is deleted. It's been 3 months since it's deprecation in 0.4.30 JAX release. Please use the AOT APIs to get the same functionality asjax.xla_computation
.jax.xla_computation(fn)(*args, **kwargs)
can be replaced withjax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')
.- You can also use
.out_info
property ofjax.stages.Lowered
to get the output information (like tree structure, shape and dtype). - For cross-backend lowering, you can replace
jax.xla_computation(fn, backend='tpu')(*args, **kwargs)
withjax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')
.
- {class}
jax.ShapeDtypeStruct
no longer accepts thenamed_shape
argument. The argument was only used byxmap
which was removed in 0.4.31. jax.tree.map(f, None, non-None)
, which previously emitted aDeprecationWarning
, now raises an error in a future version of jax.None
is only a tree-prefix of itself. To preserve the current behavior, you can askjax.tree.map
to treatNone
as a leaf value by writing:jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)
.jax.sharding.XLACompatibleSharding
has been removed. Please usejax.sharding.Sharding
.
-
Bug fixes
- Fixed a bug where {func}
jax.numpy.cumsum
would produce incorrect outputs if a non-boolean input was provided anddtype=bool
was specified. - Edit implementation of {func}
jax.numpy.ldexp
to get correct gradient.
- Fixed a bug where {func}
This is a patch release on top of jax 0.4.32, that fixes two bugs found in that release.
A TPU-only data corruption bug was found in the version of libtpu pinned by
JAX 0.4.32, which manifested only if multiple TPU slices were present in the
same job, for example, if training on multiple v5e slices.
This release fixes that issue by pinning a fixed version of libtpu
.
This release fixes an inaccurate result for F64 tanh on CPU (#23590).
Note: This release was yanked from PyPi because of a data corruption bug on TPU. See the 0.4.33 release notes for more details.
-
New Functionality
- Added {func}
jax.extend.ffi.ffi_call
and {func}jax.extend.ffi.ffi_lowering
to support the use of the new {ref}ffi-tutorial
to interface with custom C++ and CUDA code from JAX.
- Added {func}
-
Changes
jax_enable_memories
flag is set toTrue
by default.- {mod}
jax.numpy
now supports v2023.12 of the Python Array API Standard. See {ref}python-array-api
for more information. - Computations on the CPU backend may now be dispatched asynchronously in
more cases. Previously non-parallel computations were always dispatched
synchronously. You can recover the old behavior by setting
jax.config.update('jax_cpu_enable_async_dispatch', False)
. - Added new {func}
jax.process_indices
function to replace thejax.host_ids()
function that was deprecated in JAX v0.2.13. - To align with the behavior of
numpy.fabs
,jax.numpy.fabs
has been modified to no longer supportcomplex dtypes
. jax.tree_util.register_dataclass
now checks thatdata_fields
andmeta_fields
includes all dataclass fields withinit=True
and only them, ifnodetype
is a dataclass.- Several {mod}
jax.numpy
functions now have full {class}~jax.numpy.ufunc
interfaces, including {obj}~jax.numpy.add
, {obj}~jax.numpy.multiply
, {obj}~jax.numpy.bitwise_and
, {obj}~jax.numpy.bitwise_or
, {obj}~jax.numpy.bitwise_xor
, {obj}~jax.numpy.logical_and
, {obj}~jax.numpy.logical_and
, and {obj}~jax.numpy.logical_and
. In future releases we plan to expand these to other ufuncs. - Added {func}
jax.lax.optimization_barrier
, which allows users to prevent compiler optimizations such as common-subexpression elimination and to control scheduling.
-
Breaking changes
- The MHLO MLIR dialect (
jax.extend.mlir.mhlo
) has been removed. Use thestablehlo
dialect instead.
- The MHLO MLIR dialect (
-
Deprecations
- Complex inputs to {func}
jax.numpy.clip
and {func}jax.numpy.hypot
are no longer allowed, after being deprecated since JAX v0.4.27. - Deprecated the following APIs:
jax.lib.xla_bridge.xla_client
: use {mod}jax.lib.xla_client
directly.jax.lib.xla_bridge.get_backend
: use {func}jax.extend.backend.get_backend
.jax.lib.xla_bridge.default_backend
: use {func}jax.extend.backend.default_backend
.
- The
jax.experimental.array_api
module is deprecated, and importing it is no longer required to use the Array API.jax.numpy
supports the array API directly; see {ref}python-array-api
for more information. - The internal utilities
jax.core.check_eqn
,jax.core.check_type
, andjax.core.check_valid_jaxtype
are now deprecated, and will be removed in the future. jax.numpy.round_
has been deprecated, following removal of the corresponding API in NumPy 2.0. Use {func}jax.numpy.round
instead.- Passing a DLPack capsule to {func}
jax.dlpack.from_dlpack
is deprecated. The argument to {func}jax.dlpack.from_dlpack
should be an array from another framework that implements the__dlpack__
protocol.
- Complex inputs to {func}
Note: This release was yanked from PyPi because of a data corruption bug on TPU. See the 0.4.33 release notes for more details.
-
Breaking changes
- This release of jaxlib switched to a new version of the CPU backend, which
should compile faster and leverage parallelism better. If you experience
any problems due to this change, you can temporarily enable the old CPU
backend by setting the environment variable
XLA_FLAGS=--xla_cpu_use_thunk_runtime=false
. If you need to do this, please file a JAX bug with instructions to reproduce. - Hermetic CUDA support is added. Hermetic CUDA uses a specific downloadable version of CUDA instead of the user’s locally installed CUDA. Bazel will download CUDA, CUDNN and NCCL distributions, and then use CUDA libraries and tools as dependencies in various Bazel targets. This enables more reproducible builds for JAX and its supported CUDA versions.
- This release of jaxlib switched to a new version of the CPU backend, which
should compile faster and leverage parallelism better. If you experience
any problems due to this change, you can temporarily enable the old CPU
backend by setting the environment variable
-
Changes
- SparseCore profiling is added.
- JAX now supports profiling SparseCore on TPUv5p chips. These traces will be viewable in Tensorboard Profiler's TraceViewer.
- SparseCore profiling is added.
-
Deletion
- xmap has been deleted. Please use {func}
shard_map
as the replacement.
- xmap has been deleted. Please use {func}
-
Changes
- The minimum CuDNN version is v9.1. This was true in previous releases also, but we now declare this version constraint formally.
- The minimum Python version is now 3.10. 3.10 will remain the minimum supported version until July 2025.
- The minimum NumPy version is now 1.24. NumPy 1.24 will remain the minimum supported version until December 2024.
- The minimum SciPy version is now 1.10. SciPy 1.10 will remain the minimum supported version until January 2025.
- {func}
jax.numpy.ceil
, {func}jax.numpy.floor
and {func}jax.numpy.trunc
now return the output of the same dtype as the input, i.e. no longer upcast integer or boolean inputs to floating point. libdevice.10.bc
is no longer bundled with CUDA wheels. It must be installed either as a part of local CUDA installation, or via NVIDIA's CUDA pip wheels.- {class}
jax.experimental.pallas.BlockSpec
now expectsblock_shape
to be passed beforeindex_map
. The old argument order is deprecated and will be removed in a future release. - Updated the repr of gpu devices to be more consistent
with TPUs/CPUs. For example,
cuda(id=0)
will now beCudaDevice(id=0)
. - Added the
device
property andto_device
method to {class}jax.Array
, as part of JAX's Array API support.
-
Deprecations
- Removed a number of previously-deprecated internal APIs related to
polymorphic shapes. From {mod}
jax.core
: removedcanonicalize_shape
,dimension_as_value
,definitely_equal
, andsymbolic_equal_dim
. - HLO lowering rules should no longer wrap singleton ir.Values in tuples. Instead, return singleton ir.Values unwrapped. Support for wrapped values will be removed in a future version of JAX.
- {func}
jax.experimental.jax2tf.convert
withnative_serialization=False
orenable_xla=False
is now deprecated and this support will be removed in a future version. Native serialization has been the default since JAX 0.4.16 (September 2023). - The previously-deprecated function
jax.random.shuffle
has been removed; instead usejax.random.permutation
withindependent=True
.
- Removed a number of previously-deprecated internal APIs related to
polymorphic shapes. From {mod}
- Bug fixes
- Fixed a bug that meant that negative static_argnums to a jit were mishandled by the jit dispatch fast path.
- Fixed a bug that meant triangular solves of batches of singular matrices produce nonsensical finite values, instead of inf or nan (#3589, #15429).
-
Changes
- JAX supports ml_dtypes >= 0.2. In 0.4.29 release, the ml_dtypes version was bumped to 0.4.0 but this has been rolled back in this release to give users of both TensorFlow and JAX more time to migrate to a newer TensorFlow release.
jax.experimental.mesh_utils
can now create an efficient mesh for TPU v5e.- jax now depends on jaxlib directly. This change was enabled by the CUDA
plugin switch: there are no longer multiple jaxlib variants. You can install
a CPU-only jax with
pip install jax
, no extras required. - Added an API for exporting and serializing JAX functions. This used
to exist in
jax.experimental.export
(which is being deprecated), and will now live injax.export
. See the documentation.
-
Deprecations
- Internal pretty-printing tools
jax.core.pp_*
are deprecated, and will be removed in a future release. - Hashing of tracers is deprecated, and will lead to a
TypeError
in a future JAX release. This previously was the case, but there was an inadvertent regression in the last several JAX releases. jax.experimental.export
is deprecated. Use {mod}jax.export
instead. See the migration guide.- Passing an array in place of a dtype is now deprecated in most cases; e.g. for arrays
x
andy
,x.astype(y)
will raise a warning. To silence it usex.astype(y.dtype)
. jax.xla_computation
is deprecated and will be removed in a future release. Please use the AOT APIs to get the same functionality asjax.xla_computation
.jax.xla_computation(fn)(*args, **kwargs)
can be replaced withjax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')
.- You can also use
.out_info
property ofjax.stages.Lowered
to get the output information (like tree structure, shape and dtype). - For cross-backend lowering, you can replace
jax.xla_computation(fn, backend='tpu')(*args, **kwargs)
withjax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')
.
- Internal pretty-printing tools
- Support for monolithic CUDA jaxlibs has been dropped. You must use the
plugin-based installation (
pip install jax[cuda12]
orpip install jax[cuda12_local]
).
-
Changes
- We anticipate that this will be the last release of JAX and jaxlib
supporting a monolithic CUDA jaxlib. Future releases will use the CUDA
plugin jaxlib (e.g.
pip install jax[cuda12]
). - JAX now requires ml_dtypes version 0.4.0 or newer.
- Removed backwards-compatibility support for old usage of the
jax.experimental.export
API. It is not possible anymore to usefrom jax.experimental.export import export
, and instead you should usefrom jax.experimental import export
. The removed functionality has been deprecated since 0.4.24. - Added
is_leaf
argument to {func}jax.tree.all
& {func}jax.tree_util.tree_all
.
- We anticipate that this will be the last release of JAX and jaxlib
supporting a monolithic CUDA jaxlib. Future releases will use the CUDA
plugin jaxlib (e.g.
-
Deprecations
jax.sharding.XLACompatibleSharding
is deprecated. Please usejax.sharding.Sharding
.jax.experimental.Exported.in_shardings
has been renamed asjax.experimental.Exported.in_shardings_hlo
. Same forout_shardings
. The old names will be removed after 3 months.- Removed a number of previously-deprecated APIs:
- from {mod}
jax.core
:non_negative_dim
,DimSize
,Shape
- from {mod}
jax.lax
:tie_in
- from {mod}
jax.nn
:normalize
- from {mod}
jax.interpreters.xla
:backend_specific_translations
,translations
,register_translation
,xla_destructure
,TranslationRule
,TranslationContext
,XlaOp
.
- from {mod}
- The
tol
argument of {func}jax.numpy.linalg.matrix_rank
is being deprecated and will soon be removed. Usertol
instead. - The
rcond
argument of {func}jax.numpy.linalg.pinv
is being deprecated and will soon be removed. Usertol
instead. - The deprecated
jax.config
submodule has been removed. To configure JAX useimport jax
and then reference the config object viajax.config
. - {mod}
jax.random
APIs no longer accept batched keys, where previously some did unintentionally. Going forward, we recommend explicit use of {func}jax.vmap
in such cases. - In {func}
jax.scipy.special.beta
, thex
andy
parameters have been renamed toa
andb
for consistency with otherbeta
APIs.
-
New Functionality
- Added {func}
jax.experimental.Exported.in_shardings_jax
to construct shardings that can be used with the JAX APIs from the HloShardings that are stored in theExported
objects.
- Added {func}
-
Bug fixes
- Fixed a bug where XLA sharded some concatenation operations incorrectly, which manifested as an incorrect output for cumulative reductions (#21403).
- Fixed a bug where XLA:CPU miscompiled certain matmul fusions (openxla/xla#13301).
- Fixes a compiler crash on GPU (jax-ml#21396).
-
Deprecations
jax.tree.map(f, None, non-None)
now emits aDeprecationWarning
, and will raise an error in a future version of jax.None
is only a tree-prefix of itself. To preserve the current behavior, you can askjax.tree.map
to treatNone
as a leaf value by writing:jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)
.
-
Bug fixes
- Reverted a change to
make_jaxpr
that was breaking Equinox (#21116).
- Reverted a change to
-
Deprecations & removals
- The
kind
argument to {func}jax.numpy.sort
and {func}jax.numpy.argsort
is now removed. Usestable=True
orstable=False
instead. - Removed
get_compute_capability
from thejax.experimental.pallas.gpu
module. Use thecompute_capability
attribute of a GPU device, returned by {func}jax.devices
or {func}jax.local_devices
, instead. - The
newshape
argument to {func}jax.numpy.reshape
is being deprecated and will soon be removed. Useshape
instead.
- The
-
Changes
- The minimum jaxlib version of this release is 0.4.27.
-
Bug fixes
- Fixes a memory corruption bug in the type name of Array and JIT Python objects in Python 3.10 or earlier.
- Fixed a warning
'+ptx84' is not a recognized feature for this target
under CUDA 12.4. - Fixed a slow compilation problem on CPU.
-
Changes
- The Windows build is now built with Clang instead of MSVC.
-
New Functionality
- Added {func}
jax.numpy.unstack
and {func}jax.numpy.cumulative_sum
, following their addition in the array API 2023 standard, soon to be adopted by NumPy. - Added a new config option
jax_cpu_collectives_implementation
to select the implementation of cross-process collective operations used by the CPU backend. Choices available are'none'
(default),'gloo'
and'mpi'
(requires jaxlib 0.4.26). If set to'none'
, cross-process collective operations are disabled.
- Added {func}
-
Changes
- {func}
jax.pure_callback
, {func}jax.experimental.io_callback
and {func}jax.debug.callback
now use {class}jax.Array
instead of {class}np.ndarray
. You can recover the old behavior by transforming the arguments viajax.tree.map(np.asarray, args)
before passing them to the callback. complex_arr.astype(bool)
now follows the same semantics as NumPy, returning False wherecomplex_arr
is equal to0 + 0j
, and True otherwise.core.Token
now is a non-trivial class which wraps ajax.Array
. It could be created and threaded in and out of computations to build up dependency. The singleton objectcore.token
has been removed, users now should create and use freshcore.Token
objects instead.- On GPU, the Threefry PRNG implementation no longer lowers to a kernel call
by default. This choice can improve runtime memory usage at a compile-time
cost. Prior behavior, which produces a kernel call, can be recovered with
jax.config.update('jax_threefry_gpu_kernel_lowering', True)
. If the new default causes issues, please file a bug. Otherwise, we intend to remove this flag in a future release.
- {func}
-
Deprecations & Removals
- Pallas now exclusively uses XLA for compiling kernels on GPU. The old
lowering pass via Triton Python APIs has been removed and the
JAX_TRITON_COMPILE_VIA_XLA
environment variable no longer has any effect. - {func}
jax.numpy.clip
has a new argument signature:a
,a_min
, anda_max
are deprecated in favor ofx
(positional only),min
, andmax
({jax-issue}20550
). - The
device()
method of JAX arrays has been removed, after being deprecated since JAX v0.4.21. Usearr.devices()
instead. - The
initial
argument to {func}jax.nn.softmax
and {func}jax.nn.log_softmax
is deprecated; empty inputs to softmax are now supported without setting this. - In {func}
jax.jit
, passing invalidstatic_argnums
orstatic_argnames
now leads to an error rather than a warning. - The minimum jaxlib version is now 0.4.23.
- The {func}
jax.numpy.hypot
function now issues a deprecation warning when passing complex-valued inputs to it. This will raise an error when the deprecation is completed. - Scalar arguments to {func}
jax.numpy.nonzero
, {func}jax.numpy.where
, and related functions now raise an error, following a similar change in NumPy. - The config option
jax_cpu_enable_gloo_collectives
is deprecated. Usejax.config.update('jax_cpu_collectives_implementation', 'gloo')
instead. - The
jax.Array.device_buffer
andjax.Array.device_buffers
methods have been removed after being deprecated in JAX v0.4.22. Instead use {attr}jax.Array.addressable_shards
and {meth}jax.Array.addressable_data
. - The
condition
,x
, andy
parameters ofjax.numpy.where
are now positional-only, following deprecation of the keywords in JAX v0.4.21. - Non-array arguments to functions in {mod}
jax.lax.linalg
now must be specified by keyword. Previously, this raised a DeprecationWarning. - Array-like arguments are now required in several :func:
jax.numpy
APIs, including {func}~jax.numpy.apply_along_axis
, {func}~jax.numpy.apply_over_axes
, {func}~jax.numpy.inner
, {func}~jax.numpy.outer
, {func}~jax.numpy.cross
, {func}~jax.numpy.kron
, and {func}~jax.numpy.lexsort
.
- Pallas now exclusively uses XLA for compiling kernels on GPU. The old
lowering pass via Triton Python APIs has been removed and the
-
Bug fixes
- {func}
jax.numpy.astype
will now always return a copy whencopy=True
. Previously, no copy would be made when the output array would have the same dtype as the input array. This may result in some increased memory usage. The default value is set tocopy=False
to preserve backwards compatibility.
- {func}
-
New Functionality
- Added {func}
jax.numpy.trapezoid
, following the addition of this function in NumPy 2.0.
- Added {func}
-
Changes
- Complex-valued {func}
jax.numpy.geomspace
now chooses the logarithmic spiral branch consistent with that of NumPy 2.0. - The behavior of
lax.rng_bit_generator
, and in turn the'rbg'
and'unsafe_rbg'
PRNG implementations, underjax.vmap
has changed so that mapping over keys results in random generation only from the first key in the batch. - Docs now use
jax.random.key
for construction of PRNG key arrays rather thanjax.random.PRNGKey
.
- Complex-valued {func}
-
Deprecations & Removals
- {func}
jax.tree_map
is deprecated; usejax.tree.map
instead, or for backward compatibility with older JAX versions, use {func}jax.tree_util.tree_map
. - {func}
jax.clear_backends
is deprecated as it does not necessarily do what its name suggests and can lead to unexpected consequences, e.g., it will not destroy existing backends and release corresponding owned resources. Use {func}jax.clear_caches
if you only want to clean up compilation caches. For backward compatibility or you really need to switch/reinitialize the default backend, use {func}jax.extend.backend.clear_backends
. - The
jax.experimental.maps
module andjax.experimental.maps.xmap
are deprecated. Usejax.experimental.shard_map
orjax.vmap
with thespmd_axis_name
argument for expressing SPMD device-parallel computations. - The
jax.experimental.host_callback
module is deprecated. Use instead the new JAX external callbacks. AddedJAX_HOST_CALLBACK_LEGACY
flag to assist in the transition to the new callbacks. See {jax-issue}#20385
for a discussion. - Passing arguments to {func}
jax.numpy.array_equal
and {func}jax.numpy.array_equiv
that cannot be converted to a JAX array now results in an exception. - The deprecated flag
jax_parallel_functions_output_gda
has been removed. This flag was long deprecated and did nothing; its use was a no-op. - The previously-deprecated imports
jax.interpreters.ad.config
andjax.interpreters.ad.source_info_util
have now been removed. Usejax.config
andjax.extend.source_info_util
instead. - JAX export does not support older serialization versions anymore. Version 9 has been supported since October 27th, 2023 and has become the default since February 1, 2024. See a description of the versions. This change could break clients that set a specific JAX serialization version lower than 9.
- {func}
- Changes
- JAX now supports CUDA 12.1 or newer only. Support for CUDA 11.8 has been dropped.
- JAX now supports NumPy 2.0.
-
New Features
- Added CUDA Array Interface import support (requires jaxlib 0.4.24).
- JAX arrays now support NumPy-style scalar boolean indexing, e.g.
x[True]
orx[False]
. - Added {mod}
jax.tree
module, with a more convenient interface for referencing functions in {mod}jax.tree_util
. - {func}
jax.tree.transpose
(i.e. {func}jax.tree_util.tree_transpose
) now acceptsinner_treedef=None
, in which case the inner treedef will be automatically inferred.
-
Changes
- Pallas now uses XLA instead of the Triton Python APIs to compile Triton
kernels. You can revert to the old behavior by setting the
JAX_TRITON_COMPILE_VIA_XLA
environment variable to"0"
. - Several deprecated APIs in {mod}
jax.interpreters.xla
that were removed in v0.4.24 have been re-added in v0.4.25, includingbackend_specific_translations
,translations
,register_translation
,xla_destructure
,TranslationRule
,TranslationContext
, andXLAOp
. These are still considered deprecated, and will be removed again in the future when better replacements are available. Refer to {jax-issue}#19816
for discussion.
- Pallas now uses XLA instead of the Triton Python APIs to compile Triton
kernels. You can revert to the old behavior by setting the
-
Deprecations & Removals
- {func}
jax.numpy.linalg.solve
now shows a deprecation warning for batched 1D solves withb.ndim > 1
. In the future these will be treated as batched 2D solves. - Conversion of a non-scalar array to a Python scalar now raises an error, regardless of the size of the array. Previously a deprecation warning was raised in the case of non-scalar arrays of size 1. This follows a similar deprecation in NumPy.
- The previously deprecated configuration APIs have been removed
following a standard 3 months deprecation cycle (see {ref}
api-compatibility
). These include- the
jax.config.config
object and - the
define_*_state
andDEFINE_*
methods of {data}jax.config
.
- the
- Importing the
jax.config
submodule viaimport jax.config
is deprecated. To configure JAX useimport jax
and then reference the config object viajax.config
. - The minimum jaxlib version is now 0.4.20.
- {func}
-
Changes
- JAX lowering to StableHLO does not depend on physical devices anymore.
If your primitive wraps custom_partitioning or JAX callbacks in the lowering
rule i.e. function passed to
rule
parameter ofmlir.register_lowering
then add your primitive tojax._src.dispatch.prim_requires_devices_during_lowering
set. This is needed because custom_partitioning and JAX callbacks need physical devices to createSharding
s during lowering. This is a temporary state until we can createSharding
s without physical devices. - {func}
jax.numpy.argsort
and {func}jax.numpy.sort
now support thestable
anddescending
arguments. - Several changes to the handling of shape polymorphism (used in
{mod}
jax.experimental.jax2tf
and {mod}jax.experimental.export
):- cleaner pretty-printing of symbolic expressions ({jax-issue}
#19227
) - added the ability to specify symbolic constraints on the dimension variables. This makes shape polymorphism more expressive, and gives a way to workaround limitations in the reasoning about inequalities. See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints.
- with the addition of symbolic constraints ({jax-issue}
#19235
) we now consider dimension variables from different scopes to be different, even if they have the same name. Symbolic expressions from different scopes cannot interact, e.g., in arithmetic operations. Scopes are introduced by {func}jax.experimental.jax2tf.convert
, {func}jax.experimental.export.symbolic_shape
, {func}jax.experimental.export.symbolic_args_specs
. The scope of a symbolic expressione
can be read withe.scope
and passed into the above functions to direct them to construct symbolic expressions in a given scope. See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. - simplified and faster equality comparisons, where we consider two symbolic dimensions
to be equal if the normalized form of their difference reduces to 0
({jax-issue}
#19231
; note that this may result in user-visible behavior changes) - improved the error messages for inconclusive inequality comparisons
({jax-issue}
#19235
). - the
core.non_negative_dim
API (introduced recently) was deprecated andcore.max_dim
andcore.min_dim
were introduced ({jax-issue}#18953
) to expressmax
andmin
for symbolic dimensions. You can usecore.max_dim(d, 0)
instead ofcore.non_negative_dim(d)
. - the
shape_poly.is_poly_dim
is deprecated in favor ofexport.is_symbolic_dim
({jax-issue}#19282
). - the
export.args_specs
is deprecated in favor ofexport.symbolic_args_specs ({jax-issue}
#19283`). - the
shape_poly.PolyShape
andjax2tf.PolyShape
are deprecated, use strings for polymorphic shapes specifications ({jax-issue}#19284
). - JAX default native serialization version is now 9. This is relevant
for {mod}
jax.experimental.jax2tf
and {mod}jax.experimental.export
. See description of version numbers.
- cleaner pretty-printing of symbolic expressions ({jax-issue}
- Refactored the API for
jax.experimental.export
. Instead offrom jax.experimental.export import export
you should use nowfrom jax.experimental import export
. The old way of importing will continue to work for a deprecation period of 3 months. - Added {func}
jax.scipy.stats.sem
. - {func}
jax.numpy.unique
withreturn_inverse = True
returns inverse indices reshaped to the dimension of the input, following a similar change to {func}numpy.unique
in NumPy 2.0. - {func}
jax.numpy.sign
now returnsx / abs(x)
for nonzero complex inputs. This is consistent with the behavior of {func}numpy.sign
in NumPy version 2.0. - {func}
jax.scipy.special.logsumexp
withreturn_sign=True
now uses the NumPy 2.0 convention for the complex sign,x / abs(x)
. This is consistent with the behavior of {func}scipy.special.logsumexp
in SciPy v1.13. - JAX now supports the bool DLPack type for both import and export. Previously bool values could not be imported and were exported as integers.
- JAX lowering to StableHLO does not depend on physical devices anymore.
If your primitive wraps custom_partitioning or JAX callbacks in the lowering
rule i.e. function passed to
-
Deprecations & Removals
- A number of previously deprecated functions have been removed, following a
standard 3+ month deprecation cycle (see {ref}
api-compatibility
). This includes:- From {mod}
jax.core
:TracerArrayConversionError
,TracerIntegerConversionError
,UnexpectedTracerError
,as_hashable_function
,collections
,dtypes
,lu
,map
,namedtuple
,partial
,pp
,ref
,safe_zip
,safe_map
,source_info_util
,total_ordering
,traceback_util
,tuple_delete
,tuple_insert
, andzip
. - From {mod}
jax.lax
:dtypes
,itertools
,naryop
,naryop_dtype_rule
,standard_abstract_eval
,standard_naryop
,standard_primitive
,standard_unop
,unop
, andunop_dtype_rule
. - The
jax.linear_util
submodule and all its contents. - The
jax.prng
submodule and all its contents. - From {mod}
jax.random
:PRNGKeyArray
,KeyArray
,default_prng_impl
,threefry_2x32
,threefry2x32_key
,threefry2x32_p
,rbg_key
, andunsafe_rbg_key
. - From {mod}
jax.tree_util
:register_keypaths
,AttributeKeyPathEntry
, andGetItemKeyPathEntry
. - from {mod}
jax.interpreters.xla
:backend_specific_translations
,translations
,register_translation
,xla_destructure
,TranslationRule
,TranslationContext
,axis_groups
,ShapedArray
,ConcreteArray
,AxisEnv
,backend_compile
, andXLAOp
. - from {mod}
jax.numpy
:NINF
,NZERO
,PZERO
,row_stack
,issubsctype
,trapz
, andin1d
. - from {mod}
jax.scipy.linalg
:tril
andtriu
.
- From {mod}
- The previously-deprecated method
PRNGKeyArray.unsafe_raw_array
has been removed. Use {func}jax.random.key_data
instead. bool(empty_array)
now raises an error rather than returningFalse
. This previously raised a deprecation warning, and follows a similar change in NumPy.- Support for the mhlo MLIR dialect has been deprecated. JAX no longer uses the mhlo dialect, in favor of stablehlo. APIs that refer to "mhlo" will be removed in the future. Use the "stablehlo" dialect instead.
- {mod}
jax.random
: passing batched keys directly to random number generation functions, such as {func}~jax.random.bits
, {func}~jax.random.gamma
, and others, is deprecated and will emit aFutureWarning
. Usejax.vmap
for explicit batching. - {func}
jax.lax.tie_in
is deprecated: it has been a no-op since JAX v0.2.0.
- A number of previously deprecated functions have been removed, following a
standard 3+ month deprecation cycle (see {ref}
-
Changes
- JAX now supports CUDA 12.3 and CUDA 11.8. Support for CUDA 12.2 has been dropped.
cost_analysis
now works with cross-compiledCompiled
objects (i.e. when using.lower().compile()
with a topology object, e.g., to compile for Cloud TPU from a non-TPU computer).- Added CUDA Array Interface import support (requires jax 0.4.25).
- Fixed a bug that caused verbose logging from the GPU compiler during compilation.
- Deprecations
- The
device_buffer
anddevice_buffers
properties of JAX arrays are deprecated. Explicit buffers have been replaced by the more flexible array sharding interface, but the previous outputs can be recovered this way:arr.device_buffer
becomesarr.addressable_data(0)
arr.device_buffers
becomes[x.data for x in arr.addressable_shards]
- The
-
New Features
- Added {obj}
jax.nn.squareplus
.
- Added {obj}
-
Changes
- The minimum jaxlib version is now 0.4.19.
- Released wheels are built now with clang instead of gcc.
- Enforce that the device backend has not been initialized prior to calling
jax.distributed.initialize()
. - Automate arguments to
jax.distributed.initialize()
in cloud TPU environments.
-
Deprecations
- The previously-deprecated
sym_pos
argument has been removed from {func}jax.scipy.linalg.solve
. Useassume_a='pos'
instead. - Passing
None
to {func}jax.array
or {func}jax.asarray
, either directly or within a list or tuple, is deprecated and now raises a {obj}FutureWarning
. It currently is converted to NaN, and in the future will raise a {obj}TypeError
. - Passing the
condition
,x
, andy
parameters tojax.numpy.where
by keyword arguments has been deprecated, to matchnumpy.where
. - Passing arguments to {func}
jax.numpy.array_equal
and {func}jax.numpy.array_equiv
that cannot be converted to a JAX array is deprecated and now raises a {obj}DeprecationWaning
. Currently the functions return False, in the future this will raise an exception. - The
device()
method of JAX arrays is deprecated. Depending on the context, it may be replaced with one of the following:- {meth}
jax.Array.devices
returns the set of all devices used by the array. - {attr}
jax.Array.sharding
gives the sharding configuration used by the array.
- {meth}
- The previously-deprecated
-
Changes
-
In preparation for adding distributed CPU support, JAX now treats CPU devices identically to GPU and TPU devices, that is:
jax.devices()
includes all devices present in a distributed job, even those not local to the current process.jax.local_devices()
still only includes devices local to the current process, so if the change tojax.devices()
breaks you, you most likely want to usejax.local_devices()
instead.- CPU devices now receive a globally unique ID number within a distributed job; previously CPU devices would receive a process-local ID number.
- The
process_index
of each CPU device will now match any GPU or TPU devices within the same process; previously theprocess_index
of a CPU device was always 0.
-
On NVIDIA GPU, JAX now prefers a Jacobi SVD solver for matrices up to 1024x1024. The Jacobi solver appears faster than the non-Jacobi version.
-
-
Bug fixes
- Fixed error/hang when an array with non-finite values is passed to a non-symmetric eigendecomposition (#18226). Arrays with non-finite values now produce arrays full of NaNs as outputs.
- Bug fixes
- Fixed some type confusion between E4M3 and E5M2 float8 types.
-
New Features
- Added {obj}
jax.typing.DTypeLike
, which can be used to annotate objects that are convertible to JAX dtypes. - Added
jax.numpy.fill_diagonal
.
- Added {obj}
-
Changes
- JAX now requires SciPy 1.9 or newer.
-
Bug fixes
- Only process 0 in a multicontroller distributed JAX program will write persistent compilation cache entries. This fixes write contention if the cache is placed on a network file system such as GCS.
- The version check for cusolver and cufft no longer considers the patch versions when determining if the installed version of these libraries is at least as new as the versions against which JAX was built.
- Changes
- jaxlib will now always prefer pip-installed NVIDIA CUDA libraries
(nvidia-... packages) over any other CUDA installation if they are
installed, including installations named in
LD_LIBRARY_PATH
. If this causes problems and the intent is to use a system-installed CUDA, the fix is to remove the pip installed CUDA library packages.
- jaxlib will now always prefer pip-installed NVIDIA CUDA libraries
(nvidia-... packages) over any other CUDA installation if they are
installed, including installations named in
-
Changes
- CUDA jaxlibs now depend on the user to install a compatible NCCL version.
If using the recommended
cuda12_pip
installation, NCCL should be installed automatically. Currently, NCCL 2.16 or newer is required. - We now provide Linux aarch64 wheels, both with and without NVIDIA GPU support.
- {meth}
jax.Array.item
now supports optional index arguments.
- CUDA jaxlibs now depend on the user to install a compatible NCCL version.
If using the recommended
-
Deprecations
- A number of internal utilities and inadvertent exports in {mod}
jax.lax
have been deprecated, and will be removed in a future release.jax.lax.dtypes
: usejax.dtypes
instead.jax.lax.itertools
: useitertools
instead.naryop
,naryop_dtype_rule
,standard_abstract_eval
,standard_naryop
,standard_primitive
,standard_unop
,unop
, andunop_dtype_rule
are internal utilities, now deprecated without replacement.
- A number of internal utilities and inadvertent exports in {mod}
-
Bug fixes
- Fixed Cloud TPU regression where compilation would OOM due to smem.
- New features
- Added new {func}
jax.numpy.bitwise_count
function, matching the API of the similar function recently added to NumPy.
- Added new {func}
- Deprecations
- Removed the deprecated module
jax.abstract_arrays
and all its contents. - Named key constructors in {mod}
jax.random
are deprecated. Pass theimpl
argument to {func}jax.random.PRNGKey
or {func}jax.random.key
instead:random.threefry2x32_key(seed)
becomesrandom.PRNGKey(seed, impl='threefry2x32')
random.rbg_key(seed)
becomesrandom.PRNGKey(seed, impl='rbg')
random.unsafe_rbg_key(seed)
becomesrandom.PRNGKey(seed, impl='unsafe_rbg')
- Removed the deprecated module
- Changes:
- CUDA: JAX now verifies that the CUDA libraries it finds are at least as new as the CUDA libraries that JAX was built against. If older libraries are found, JAX raises an exception since that is preferable to mysterious failures and crashes.
- Removed the "No GPU/TPU" found warning. Instead warn if, on Linux, an
NVIDIA GPU or a Google TPU are found but not used and
--jax_platforms
was not specified. - {func}
jax.scipy.stats.mode
now returns a 0 count if the mode is taken across a size-0 axis, matching the behavior ofscipy.stats.mode
in SciPy 1.11. - Most
jax.numpy
functions and attributes now have fully-defined type stubs. Previously many of these were treated asAny
by static type checkers likemypy
andpytype
.
-
Changes:
- Python 3.12 wheels were added in this release.
- The CUDA 12 wheels now require CUDA 12.2 or newer and cuDNN 8.9.4 or newer.
-
Bug fixes:
- Fixed log spam from ABSL when the JAX CPU backend was initialized.
-
Changes
- Added {class}
jax.numpy.ufunc
, as well as {func}jax.numpy.frompyfunc
, which can convert any scalar-valued function into a {func}numpy.ufunc
-like object, with methods such as {meth}~jax.numpy.ufunc.outer
, {meth}~jax.numpy.ufunc.reduce
, {meth}~jax.numpy.ufunc.accumulate
, {meth}~jax.numpy.ufunc.at
, and {meth}~jax.numpy.ufunc.reduceat
({jax-issue}#17054
). - Added {func}
jax.scipy.integrate.trapezoid
. - When not running under IPython: when an exception is raised, JAX now filters out the
entirety of its internal frames from tracebacks. (Without the "unfiltered stack trace"
that previously appeared.) This should produce much friendlier-looking tracebacks. See
here for an example.
This behavior can be changed by setting
JAX_TRACEBACK_FILTERING=remove_frames
(for two separate unfiltered/filtered tracebacks, which was the old behavior) orJAX_TRACEBACK_FILTERING=off
(for one unfiltered traceback). - jax2tf default serialization version is now 7, which introduces new shape safety assertions.
- Devices passed to
jax.sharding.Mesh
should be hashable. This specifically applies to mock devices or user created devices.jax.devices()
are already hashable.
- Added {class}
-
Breaking changes:
- jax2tf now uses native serialization by default. See the jax2tf documentation for details and for mechanisms to override the default.
- The option
--jax_coordination_service
has been removed. It is now alwaysTrue
. jax.jaxpr_util
has been removed from the public JAX namespace.JAX_USE_PJRT_C_API_ON_TPU
no longer has an effect (i.e. it always defaults to true).- The backwards compatibility flag
--jax_host_callback_ad_transforms
introduced in December 2021, has been removed.
-
Deprecations:
- Several
jax.numpy
APIs have been deprecated following NumPy NEP-52:jax.numpy.NINF
has been deprecated. Use-jax.numpy.inf
instead.jax.numpy.PZERO
has been deprecated. Use0.0
instead.jax.numpy.NZERO
has been deprecated. Use-0.0
instead.jax.numpy.issubsctype(x, t)
has been deprecated. Usejax.numpy.issubdtype(x.dtype, t)
.jax.numpy.row_stack
has been deprecated. Usejax.numpy.vstack
instead.jax.numpy.in1d
has been deprecated. Usejax.numpy.isin
instead.jax.numpy.trapz
has been deprecated. Usejax.scipy.integrate.trapezoid
instead.
jax.scipy.linalg.tril
andjax.scipy.linalg.triu
have been deprecated, following SciPy. Usejax.numpy.tril
andjax.numpy.triu
instead.jax.lax.prod
has been removed after being deprecated in JAX v0.4.11. Use the built-inmath.prod
instead.- A number of exports from
jax.interpreters.xla
related to defining HLO lowering rules for custom JAX primitives have been deprecated. Custom primitives should be defined using the StableHLO lowering utilities injax.interpreters.mlir
instead. - The following previously-deprecated functions have been removed after a
three-month deprecation period:
jax.abstract_arrays.ShapedArray
: usejax.core.ShapedArray
.jax.abstract_arrays.raise_to_shaped
: usejax.core.raise_to_shaped
.jax.numpy.alltrue
: usejax.numpy.all
.jax.numpy.sometrue
: usejax.numpy.any
.jax.numpy.product
: usejax.numpy.prod
.jax.numpy.cumproduct
: usejax.numpy.cumprod
.
- Several
-
Deprecations/removals:
- The internal submodule
jax.prng
is now deprecated. Its contents are available at {mod}jax.extend.random
. - The internal submodule path
jax.linear_util
has been deprecated. Use {mod}jax.extend.linear_util
instead (Part of {ref}jax-extend-jep
) jax.random.PRNGKeyArray
andjax.random.KeyArray
are deprecated. Use {class}jax.Array
for type annotations, andjax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key)
for runtime detection of typed prng keys.- The method
PRNGKeyArray.unsafe_raw_array
is deprecated. Use {func}jax.random.key_data
instead. jax.experimental.pjit.with_sharding_constraint
is deprecated. Usejax.lax.with_sharding_constraint
instead.- The internal utilities
jax.core.is_opaque_dtype
andjax.core.has_opaque_dtype
have been removed. Opaque dtypes have been renamed to Extended dtypes; usejnp.issubdtype(dtype, jax.dtypes.extended)
instead (available since jax v0.4.14). - The utility
jax.interpreters.xla.register_collective_primitive
has been removed. This utility did nothing useful in recent JAX releases and calls to it can be safely removed. - The internal submodule path
jax.linear_util
has been deprecated. Use {mod}jax.extend.linear_util
instead (Part of {ref}jax-extend-jep
)
- The internal submodule
-
Changes:
- Sparse CSR matrix multiplications via the experimental jax sparse APIs no longer uses a deterministic algorithm on NVIDIA GPUs. This change was made to improve compatibility with CUDA 12.2.1.
-
Bug fixes:
- Fixed a crash on Windows due to a fatal LLVM error related to out-of-order sections and IMAGE_REL_AMD64_ADDR32NB relocations (https://github.com/openxla/xla/commit/cb732a921f0c4184995cbed82394931011d12bd4).
-
Changes
jax.jit
takesdonate_argnames
as an argument. It's semantics are similar tostatic_argnames
. If neither donate_argnums nor donate_argnames is provided, no arguments are donated. If donate_argnums is not provided but donate_argnames is, or vice versa, JAX usesinspect.signature(fun)
to find any positional arguments that correspond to donate_argnames (or vice versa). If both donate_argnums and donate_argnames are provided, inspect.signature is not used, and only actual parameters listed in either donate_argnums or donate_argnames will be donated.- {func}
jax.random.gamma
has been re-factored to a more efficient algorithm with more robust endpoint behavior ({jax-issue}#16779
). This means that the sequence of values returned for a givenkey
will change between JAX v0.4.13 and v0.4.14 forgamma
and related samplers (including {func}jax.random.ball
, {func}jax.random.beta
, {func}jax.random.chisquare
, {func}jax.random.dirichlet
, {func}jax.random.generalized_normal
, {func}jax.random.loggamma
, {func}jax.random.t
).
-
Deletions
in_axis_resources
andout_axis_resources
have been deleted from pjit since it has been more than 3 months since their deprecation. Please usein_shardings
andout_shardings
as the replacement. This is a safe and trivial name replacement. It does not change any of the current pjit semantics and doesn't break any code. You can still pass inPartitionSpecs
to in_shardings and out_shardings.
-
Deprecations
- Python 3.8 support has been dropped as per https://jax.readthedocs.io/en/latest/deprecation.html
- JAX now requires NumPy 1.22 or newer as per https://jax.readthedocs.io/en/latest/deprecation.html
- Passing optional arguments to {func}
jax.numpy.ndarray.at
by position is no longer supported, after being deprecated in JAX version 0.4.7. For example, instead ofx.at[i].get(True)
, usex.at[i].get(indices_are_sorted=True)
- The following
jax.Array
methods have been removed, after being deprecated in JAX v0.4.5:jax.Array.broadcast
: use {func}jax.lax.broadcast
instead.jax.Array.broadcast_in_dim
: use {func}jax.lax.broadcast_in_dim
instead.jax.Array.split
: use {func}jax.numpy.split
instead.
- The following APIs have been removed after previous deprecation:
jax.ad
: use {mod}jax.interpreters.ad
.jax.curry
: usecurry = lambda f: partial(partial, f)
.jax.partial_eval
: use {mod}jax.interpreters.partial_eval
.jax.pxla
: use {mod}jax.interpreters.pxla
.jax.xla
: use {mod}jax.interpreters.xla
.jax.ShapedArray
: use {class}jax.core.ShapedArray
.jax.interpreters.pxla.device_put
: use {func}jax.device_put
.jax.interpreters.pxla.make_sharded_device_array
: use {func}jax.make_array_from_single_device_arrays
.jax.interpreters.pxla.ShardedDeviceArray
: use {class}jax.Array
.jax.numpy.DeviceArray
: use {class}jax.Array
.jax.stages.Compiled.compiler_ir
: use {func}jax.stages.Compiled.as_text
.
-
Breaking changes
- JAX now requires ml_dtypes version 0.2.0 or newer.
- To fix a corner case, calls to {func}
jax.lax.cond
with five arguments will always resolve to the "common operands"cond
behavior (as documented) if the second and third arguments are callable, even if other operands are callable as well. See #16413. - The deprecated config options
jax_array
andjax_jit_pjit_api_merge
, which did nothing, have been removed. These options have been true by default for many releases.
-
New features
- JAX now supports a configuration flag --jax_serialization_version
and a JAX_SERIALIZATION_VERSION environment variable to control the
serialization version ({jax-issue}
#16746
). - jax2tf in presence of shape polymorphism now generates code that checks certain shape constraints, if the serialization version is at least 7. See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism.
- JAX now supports a configuration flag --jax_serialization_version
and a JAX_SERIALIZATION_VERSION environment variable to control the
serialization version ({jax-issue}
- Deprecations
- Python 3.8 support has been dropped as per https://jax.readthedocs.io/en/latest/deprecation.html
-
Changes
jax.jit
now allowsNone
to be passed toin_shardings
andout_shardings
. The semantics are as follows:- For in_shardings, JAX will mark is as replicated but this behavior can change in the future.
- For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.
jax.experimental.pjit.pjit
also allowsNone
to be passed toin_shardings
andout_shardings
. The semantics are as follows:- If the mesh context manager is not provided, JAX has the freedom to
choose whatever sharding it wants.
- For in_shardings, JAX will mark is as replicated but this behavior can change in the future.
- For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.
- If the mesh context manager is provided, None will imply that the value will be replicated on all devices of the mesh.
- If the mesh context manager is not provided, JAX has the freedom to
choose whatever sharding it wants.
- Executable.cost_analysis() works on Cloud TPU
- Added a warning if a non-allowlisted
jaxlib
plugin is in use. - Added
jax.tree_util.tree_leaves_with_path
. None
is not a valid input tojax.experimental.multihost_utils.host_local_array_to_global_array
orjax.experimental.multihost_utils.global_array_to_host_local_array
. Please usejax.sharding.PartitionSpec()
if you wanted to replicate your input.
-
Bug fixes
- Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel
is named
cudnn89
instead ofcudnn88
.
- Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel
is named
-
Deprecations
- The
native_serialization_strict_checks
parameter to {func}jax.experimental.jax2tf.convert
is deprecated in favor of the newnative_serializaation_disabled_checks
({jax-issue}#16347
).
- The
-
Changes
- Added Windows CPU-only wheels to the
jaxlib
Pypi release.
- Added Windows CPU-only wheels to the
-
Bug fixes
__cuda_array_interface__
was broken in previous jaxlib versions and is now fixed ({jax-issue}16440
).- Concurrent CUDA kernel tracing is now enabled by default on NVIDIA GPUs.
-
Changes
- Added {class}
scipy.spatial.transform.Rotation
and {class}scipy.spatial.transform.Slerp
- Added {class}
-
Deprecations
jax.abstract_arrays
and its contents are now deprecated. See related functionality in :mod:jax.core
.jax.numpy.alltrue
: usejax.numpy.all
. This follows the deprecation ofnumpy.alltrue
in NumPy version 1.25.0.jax.numpy.sometrue
: usejax.numpy.any
. This follows the deprecation ofnumpy.sometrue
in NumPy version 1.25.0.jax.numpy.product
: usejax.numpy.prod
. This follows the deprecation ofnumpy.product
in NumPy version 1.25.0.jax.numpy.cumproduct
: usejax.numpy.cumprod
. This follows the deprecation ofnumpy.cumproduct
in NumPy version 1.25.0.jax.sharding.OpShardingSharding
has been removed since it has been 3 months since it was deprecated.
-
Changes
- Includes PTX/SASS for Hopper (SM version 9.0+) GPUs. Previous versions of jaxlib should work on Hopper but would have a long JIT-compilation delay the first time a JAX operation was executed.
-
Bug fixes
- Fixes incorrect source line information in JAX-generated Python tracebacks under Python 3.11.
- Fixes crash when printing local variables of frames in JAX-generated Python tracebacks (#16027).
- Deprecations
- The following APIs have been removed after a 3 month deprecation period, in
accordance with the {ref}
api-compatibility
policy:jax.experimental.PartitionSpec
: usejax.sharding.PartitionSpec
.jax.experimental.maps.Mesh
: usejax.sharding.Mesh
jax.experimental.pjit.NamedSharding
: usejax.sharding.NamedSharding
.jax.experimental.pjit.PartitionSpec
: usejax.sharding.PartitionSpec
.jax.experimental.pjit.FROM_GDA
. Instead pass shardedjax.Array
objects as input and remove the optionalin_shardings
argument topjit
.jax.interpreters.pxla.PartitionSpec
: usejax.sharding.PartitionSpec
.jax.interpreters.pxla.Mesh
: usejax.sharding.Mesh
jax.interpreters.xla.Buffer
: usejax.Array
.jax.interpreters.xla.Device
: usejax.Device
.jax.interpreters.xla.DeviceArray
: usejax.Array
.jax.interpreters.xla.device_put
: usejax.device_put
.jax.interpreters.xla.xla_call_p
: usejax.experimental.pjit.pjit_p
.axis_resources
argument ofwith_sharding_constraint
is removed. Please useshardings
instead.
- The following APIs have been removed after a 3 month deprecation period, in
accordance with the {ref}
- Changes
- Added
memory_stats()
method toDevice
s. If supported, this returns a dict of string stat names with int values, e.g."bytes_in_use"
, or None if the platform doesn't support memory statistics. The exact stats returned may vary across platforms. Currently only implemented on Cloud TPU. - Readded support for the Python buffer protocol (
memoryview
) on CPU devices.
- Added
- Changes
- Fixed
'apple-m1' is not a recognized processor for this target (ignoring processor)
issue that prevented previous release from running on Mac M1.
- Fixed
-
Changes
- The flags experimental_cpp_jit, experimental_cpp_pjit and experimental_cpp_pmap have been removed. They are now always on.
- Accuracy of singular value decomposition (SVD) on TPU has been improved (requires jaxlib 0.4.9).
-
Deprecations
jax.experimental.gda_serialization
is deprecated and has been renamed tojax.experimental.array_serialization
. Please change your imports to usejax.experimental.array_serialization
.- The
in_axis_resources
andout_axis_resources
arguments of pjit have been deprecated. Please usein_shardings
andout_shardings
respectively. - The function
jax.numpy.msort
has been removed. It has been deprecated since JAX v0.4.1. Usejnp.sort(a, axis=0)
instead. in_parts
andout_parts
arguments have been removed fromjax.xla_computation
since they were only used with sharded_jit and sharded_jit is long gone.instantiate_const_outputs
argument has been removed fromjax.xla_computation
since it has been unused for a very long time.
-
Breaking changes
-
A major component of the Cloud TPU runtime has been upgraded. This enables the following new features on Cloud TPU:
- {func}
jax.debug.print
, {func}jax.debug.callback
, and {func}jax.debug.breakpoint()
now work on Cloud TPU - Automatic TPU memory defragmentation
{func}
jax.experimental.host_callback
is no longer supported on Cloud TPU with the new runtime component. Please file an issue on the JAX issue tracker if the newjax.debug
APIs are insufficient for your use case.The old runtime component will be available for at least the next three months by setting the environment variable
JAX_USE_PJRT_C_API_ON_TPU=false
. If you find you need to disable the new runtime for any reason, please let us know on the JAX issue tracker. - {func}
-
-
Changes
- The minimum jaxlib version has been bumped from 0.4.6 to 0.4.7.
-
Deprecations
- CUDA 11.4 support has been dropped. JAX GPU wheels only support CUDA 11.8 and CUDA 12. Older CUDA versions may work if jaxlib is built from source.
global_arg_shapes
argument of pmap only worked with sharded_jit and has been removed from pmap. Please migrate to pjit and remove global_arg_shapes from pmap.
-
Changes
- As per https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration
jax.config.jax_array
cannot be disabled anymore. jax.config.jax_jit_pjit_api_merge
cannot be disabled anymore.- {func}
jax.experimental.jax2tf.convert
now supports thenative_serialization
parameter to use JAX's native lowering to StableHLO to obtain a StableHLO module for the entire JAX function instead of lowering each JAX primitive to a TensorFlow op. This simplifies the internals and increases the confidence that what you serialize matches the JAX native semantics. See documentation. As part of this change the config flag--jax2tf_default_experimental_native_lowering
has been renamed to--jax2tf_native_serialization
. - JAX now depends on
ml_dtypes
, which contains definitions of NumPy types like bfloat16. These definitions were previously internal to JAX, but have been split into a separate package to facilitate sharing them with other projects. - JAX now requires NumPy 1.21 or newer and SciPy 1.7 or newer.
- As per https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration
-
Deprecations
- The type
jax.numpy.DeviceArray
is deprecated. Usejax.Array
instead, for which it is an alias. - The type
jax.interpreters.pxla.ShardedDeviceArray
is deprecated. Usejax.Array
instead. - Passing additional arguments to {func}
jax.numpy.ndarray.at
by position is deprecated. For example, instead ofx.at[i].get(True)
, usex.at[i].get(indices_are_sorted=True)
jax.interpreters.xla.device_put
is deprecated. Please usejax.device_put
.jax.interpreters.pxla.device_put
is deprecated. Please usejax.device_put
.jax.experimental.pjit.FROM_GDA
is deprecated. Please pass in sharded jax.Arrays as input and remove thein_shardings
argument to pjit since it is optional.
- The type
Changes:
- jaxlib now depends on
ml_dtypes
, which contains definitions of NumPy types like bfloat16. These definitions were previously internal to JAX, but have been split into a separate package to facilitate sharing them with other projects.
-
Changes
-
jax.tree_util
now contain a set of APIs that allow user to define keys for their custom pytree node. This includes:tree_flatten_with_path
that flattens a tree and return not only each leaf but also their key paths.tree_map_with_path
that can map a function that takes the key path as an argument.register_pytree_with_keys
to register how the key path and leaves should looks like in a custom pytree node.keystr
that pretty-prints a key path.
-
{func}
jax2tf.call_tf
has a new parameteroutput_shape_dtype
(defaultNone
) that can be used to declare the output shape and type of the result. This enables {func}jax2tf.call_tf
to work in the presence of shape polymorphism. ({jax-issue}#14734
).
-
-
Deprecations
- The old key-path APIs in
jax.tree_util
are deprecated and will be removed 3 months from Mar 10 2023:register_keypaths
: use {func}jax.tree_util.register_pytree_with_keys
instead.AttributeKeyPathEntry
: useGetAttrKey
instead.GetitemKeyPathEntry
: useSequenceKey
orDictKey
instead.
- The old key-path APIs in
- Deprecations
jax.sharding.OpShardingSharding
has been renamed tojax.sharding.GSPMDSharding
.jax.sharding.OpShardingSharding
will be removed in 3 months from Feb 17, 2023.- The following
jax.Array
methods are deprecated and will be removed 3 months from Feb 23 2023:jax.Array.broadcast
: use {func}jax.lax.broadcast
instead.jax.Array.broadcast_in_dim
: use {func}jax.lax.broadcast_in_dim
instead.jax.Array.split
: use {func}jax.numpy.split
instead.
- Changes
- The implementation of
jit
andpjit
has been merged. Merging jit and pjit changes the internals of JAX without affecting the public API of JAX. Before,jit
was a final style primitive. Final style means that the creation of jaxpr was delayed as much as possible and transformations were stacked on top of each other. With thejit
-pjit
implementation merge,jit
becomes an initial style primitive which means that we trace to jaxpr as early as possible. For more information see this section in autodidax. Moving to initial style should simplify JAX's internals and make development of features like dynamic shapes, etc easier. You can disable it only via the environment variable i.e.os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'
. The merge must be disabled via an environment variable since it affects JAX at import time so it needs to be disabled before jax is imported. axis_resources
argument ofwith_sharding_constraint
is deprecated. Please useshardings
instead. There is no change needed if you were usingaxis_resources
as an arg. If you were using it as a kwarg, then please useshardings
instead.axis_resources
will be removed after 3 months from Feb 13, 2023.- added the {mod}
jax.typing
module, with tools for type annotations of JAX functions. - The following names have been deprecated:
jax.xla.Device
andjax.interpreters.xla.Device
: usejax.Device
.jax.experimental.maps.Mesh
. Usejax.sharding.Mesh
instead.jax.experimental.pjit.NamedSharding
: usejax.sharding.NamedSharding
.jax.experimental.pjit.PartitionSpec
: usejax.sharding.PartitionSpec
.jax.interpreters.pxla.Mesh
: usejax.sharding.Mesh
.jax.interpreters.pxla.PartitionSpec
: usejax.sharding.PartitionSpec
.
- The implementation of
- Breaking Changes
- the
initial
argument to reduction functions like :func:jax.numpy.sum
is now required to be a scalar, consistent with the corresponding NumPy API. The previous behavior of broadcasting the output against non-scalarinitial
values was an unintentional implementation detail ({jax-issue}#14446
).
- the
- Breaking changes
- Support for NVIDIA Kepler series GPUs has been removed from the default
jaxlib
builds. If Kepler support is needed, it is still possible to buildjaxlib
from source with Kepler support (via the--cuda_compute_capabilities=sm_35
option tobuild.py
), however note that CUDA 12 has completely dropped support for Kepler GPUs.
- Support for NVIDIA Kepler series GPUs has been removed from the default
-
Breaking changes
- Deleted {func}
jax.scipy.linalg.polar_unitary
, which was a deprecated JAX extension to the scipy API. Use {func}jax.scipy.linalg.polar
instead.
- Deleted {func}
-
Changes
- Added {func}
jax.scipy.stats.rankdata
.
- Added {func}
jax.Array
now has the non-blockingis_ready()
method, which returnsTrue
if the array is ready (see also {func}jax.block_until_ready
).
-
Breaking changes
- Deleted
jax.experimental.callback
- Operations with dimensions in presence of jax2tf shape polymorphism have
been generalized to work in more scenarios, by converting the symbolic
dimension to JAX arrays. Operations involving symbolic dimensions and
np.ndarray
now can raise errors when the result is used as a shape value ({jax-issue}#14106
). - jaxpr objects now raise an error on attribute setting in order to avoid
problematic mutations ({jax-issue}
14102
)
- Deleted
-
Changes
- {func}
jax2tf.call_tf
has a new parameterhas_side_effects
(defaultTrue
) that can be used to declare whether an instance can be removed or replicated by JAX optimizations such as dead-code elimination ({jax-issue}#13980
). - Added more support for floordiv and mod for jax2tf shape polymorphism. Previously,
certain division operations resulted in errors in presence of symbolic dimensions
({jax-issue}
#14108
).
- {func}
- Changes
- Set JAX_USE_PJRT_C_API_ON_TPU=1 to enable new Cloud TPU runtime, featuring automatic device memory defragmentation.
- Changes
- Support for Python 3.7 has been dropped, in accordance with JAX's
{ref}
version-support-policy
. - We introduce
jax.Array
which is a unified array type that subsumesDeviceArray
,ShardedDeviceArray
, andGlobalDeviceArray
types in JAX. Thejax.Array
type helps make parallelism a core feature of JAX, simplifies and unifies JAX internals, and allows us to unifyjit
andpjit
.jax.Array
has been enabled by default in JAX 0.4 and makes some breaking change to thepjit
API. The jax.Array migration guide can help you migrate your codebase tojax.Array
. You can also look at the Distributed arrays and automatic parallelization tutorial to understand the new concepts. PartitionSpec
andMesh
are now out of experimental. The new API endpoints arejax.sharding.PartitionSpec
andjax.sharding.Mesh
.jax.experimental.maps.Mesh
andjax.experimental.PartitionSpec
are deprecated and will be removed in 3 months.with_sharding_constraint
s new public endpoint isjax.lax.with_sharding_constraint
.- If using ABSL flags together with
jax.config
, the ABSL flag values are no longer read or written after the JAX configuration options are initially populated from the ABSL flags. This change improves performance of readingjax.config
options, which are used pervasively in JAX. - The jax2tf.call_tf function now uses for TF lowering the first TF device of the same platform as used by the embedding JAX computation. Before, it was using the 0th device for the JAX-default backend.
- A number of
jax.numpy
functions now have their arguments marked as positional-only, matching NumPy. jnp.msort
is now deprecated, following the deprecation ofnp.msort
in numpy 1.24. It will be removed in a future release, in accordance with the {ref}api-compatibility
policy. It can be replaced withjnp.sort(a, axis=0)
.
- Support for Python 3.7 has been dropped, in accordance with JAX's
{ref}
- Changes
- Support for Python 3.7 has been dropped, in accordance with JAX's
{ref}
version-support-policy
. - The behavior of
XLA_PYTHON_CLIENT_MEM_FRACTION=.XX
has been changed to allocate XX% of the total GPU memory instead of the previous behavior of using currently available GPU memory to calculate preallocation. Please refer to GPU memory allocation for more details. - The deprecated method
.block_host_until_ready()
has been removed. Use.block_until_ready()
instead.
- Support for Python 3.7 has been dropped, in accordance with JAX's
{ref}
- The release was yanked.
- The release was yanked.
- Changes
- {func}
jax.numpy.linalg.pinv
now supports thehermitian
option. - {func}
jax.scipy.linalg.hessenberg
is now supported on CPU only. Requires jaxlib > 0.3.24. - New functions {func}
jax.lax.linalg.hessenberg
, {func}jax.lax.linalg.tridiagonal
, and {func}jax.lax.linalg.householder_product
were added. Householder reduction is currently CPU-only and tridiagonal reductions are supported on CPU and GPU only. - The gradients of
svd
andjax.numpy.linalg.pinv
are now computed more economically for non-square matrices.
- {func}
- Breaking Changes
- Deleted the
jax_experimental_name_stack
config option. - Convert a string
axis_names
arguments to the {class}jax.experimental.maps.Mesh
constructor into a singleton tuple instead of unpacking the string into a sequence of character axis names.
- Deleted the
- Changes
- Added support for tridiagonal reductions on CPU and GPU.
- Added support for upper Hessenberg reductions on CPU.
- Bugs
- Fixed a bug that meant that frames in tracebacks captured by JAX were incorrectly mapped to source lines under Python 3.10+
- Changes
- JAX should be faster to import. We now import scipy lazily, which accounted for a significant fraction of JAX's import time.
- Setting the env var
JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$N
can be used to limit the number of cache entries written to the persistent cache. By default, computations that take 1 second or more to compile will be cached.- Added {func}
jax.scipy.stats.mode
.
- Added {func}
- The default device order used by
pmap
on TPU if no order is specified now matchesjax.devices()
for single-process jobs. Previously the two orderings differed, which could lead to unnecessary copies or out-of-memory errors. Requiring the orderings to agree simplifies matters.
- Breaking Changes
- {func}
jax.numpy.gradient
now behaves like most other functions in {mod}jax.numpy
, and forbids passing lists or tuples in place of arrays ({jax-issue}#12958
) - Functions in {mod}
jax.numpy.linalg
and {mod}jax.numpy.fft
now uniformly require inputs to be array-like: i.e. lists and tuples cannot be used in place of arrays. Part of {jax-issue}#7737
.
- {func}
- Deprecations
jax.sharding.MeshPspecSharding
has been renamed tojax.sharding.NamedSharding
.jax.sharding.MeshPspecSharding
name will be removed in 3 months.
- Changes
- Buffer donation now works on CPU. This may break code that marked buffers for donation on CPU but relied on donation not being implemented.
- Changes
- Update Colab TPU driver version for new jaxlib release.
- Changes
- Add
JAX_PLATFORMS=tpu,cpu
as default setting in TPU initialization, so JAX will raise an error if TPU cannot be initialized instead of falling back to CPU. SetJAX_PLATFORMS=''
to override this behavior and automatically choose an available backend (the original default), or setJAX_PLATFORMS=cpu
to always use CPU regardless of if the TPU is available.
- Add
- Deprecations
- Several test utilities deprecated in JAX v0.3.8 are now removed from
{mod}
jax.test_util
.
- Several test utilities deprecated in JAX v0.3.8 are now removed from
{mod}
- GitHub commits.
- Changes
- The persistent compilation cache will now warn instead of raising an
exception on error ({jax-issue}
#12582
), so program execution can continue if something goes wrong with the cache. SetJAX_RAISE_PERSISTENT_CACHE_ERRORS=true
to revert this behavior.
- The persistent compilation cache will now warn instead of raising an
exception on error ({jax-issue}
- Bug fixes:
- Adds missing
.pyi
files that were missing from the previous release ({jax-issue}#12536
). - Fixes an incompatibility between
jax
0.3.19 and the libtpu version it pinned ({jax-issue}#12550
). Requires jaxlib 0.3.20. - Fix incorrect
pip
url insetup.py
comment ({jax-issue}#12528
).
- Adds missing
- GitHub commits.
- Bug fixes
- Fixes support for limiting the visible CUDA devices via
jax_cuda_visible_devices
in distributed jobs. This functionality is needed for the JAX/SLURM integration on GPU ({jax-issue}#12533
).
- Fixes support for limiting the visible CUDA devices via
- GitHub commits.
- Fixes required jaxlib version.
- GitHub commits.
- Changes
- Ahead-of-time lowering and compilation functionality (tracked in
{jax-issue}
#7733
) is stable and public. See the overview and the API docs for {mod}jax.stages
. - Introduced {class}
jax.Array
, intended to be used for bothisinstance
checks and type annotations for array types in JAX. Notice that this included some subtle changes to howisinstance
works for {class}jax.numpy.ndarray
for jax-internal objects, as {class}jax.numpy.ndarray
is now a simple alias of {class}jax.Array
.
- Ahead-of-time lowering and compilation functionality (tracked in
{jax-issue}
- Breaking changes
jax._src
is no longer imported into the publicjax
namespace. This may break users that were using JAX internals.jax.soft_pmap
has been deleted. Please usepjit
orxmap
instead.jax.soft_pmap
is undocumented. If it were documented, a deprecation period would have been provided.
- GitHub commits.
- Bugs
- Fix corner case issue in gradient of
lax.pow
with an exponent of zero ({jax-issue}12041
)
- Fix corner case issue in gradient of
- Breaking changes
- {func}
jax.checkpoint
, also known as {func}jax.remat
, no longer supports theconcrete
option, following the previous version's deprecation; see JEP 11830.
- {func}
- Changes
- Added {func}
jax.pure_callback
that enables calling back to pure Python functions from compiled functions (e.g. functions decorated withjax.jit
orjax.pmap
).
- Added {func}
- Deprecations:
- The deprecated
DeviceArray.tile()
method has been removed. Use {func}jax.numpy.tile
({jax-issue}#11944
). DeviceArray.to_py()
has been deprecated. Usenp.asarray(x)
instead.
- The deprecated
- GitHub commits.
- Breaking changes
- Support for NumPy 1.19 has been dropped, per the deprecation policy. Please upgrade to NumPy 1.20 or newer.
- Changes
- Added {mod}
jax.debug
that includes utilities for runtime value debugging such at {func}jax.debug.print
and {func}jax.debug.breakpoint
. - Added new documentation for runtime value debugging
- Added {mod}
- Deprecations
- {func}
jax.mask
{func}jax.shapecheck
APIs have been removed. See {jax-issue}#11557
. - {mod}
jax.experimental.loops
has been removed. See {jax-issue}#10278
for an alternative API. - {func}
jax.tree_util.tree_multimap
has been removed. It has been deprecated since JAX release 0.3.5, and {func}jax.tree_util.tree_map
is a direct replacement. - Removed
jax.experimental.stax
; it has long been a deprecated alias of {mod}jax.example_libraries.stax
. - Removed
jax.experimental.optimizers
; it has long been a deprecated alias of {mod}jax.example_libraries.optimizers
. - {func}
jax.checkpoint
, also known as {func}jax.remat
, has a new implementation switched on by default, meaning the old implementation is deprecated; see JEP 11830.
- {func}
- GitHub commits.
- Changes
JaxTestCase
andJaxTestLoader
have been removed fromjax.test_util
. These classes have been deprecated since v0.3.1 ({jax-issue}#11248
).- Added {class}
jax.scipy.gaussian_kde
({jax-issue}#11237
). - Binary operations between JAX arrays and built-in collections (
dict
,list
,set
,tuple
) now raise aTypeError
in all cases. Previously some cases (particularly equality and inequality) would return boolean scalars inconsistent with similar operations in NumPy ({jax-issue}#11234
). - Several {mod}
jax.tree_util
routines accessed as top-level JAX package imports are now deprecated, and will be removed in a future JAX release in accordance with the {ref}api-compatibility
policy:- {func}
jax.treedef_is_leaf
is deprecated in favor of {func}jax.tree_util.treedef_is_leaf
- {func}
jax.tree_flatten
is deprecated in favor of {func}jax.tree_util.tree_flatten
- {func}
jax.tree_leaves
is deprecated in favor of {func}jax.tree_util.tree_leaves
- {func}
jax.tree_structure
is deprecated in favor of {func}jax.tree_util.tree_structure
- {func}
jax.tree_transpose
is deprecated in favor of {func}jax.tree_util.tree_transpose
- {func}
jax.tree_unflatten
is deprecated in favor of {func}jax.tree_util.tree_unflatten
- {func}
- The
sym_pos
argument of {func}jax.scipy.linalg.solve
is deprecated in favor ofassume_a='pos'
, following a similar deprecation in {func}scipy.linalg.solve
.
- GitHub commits.
- Breaking changes
- {func}
jax.experimental.compilation_cache.initialize_cache
does not supportmax_cache_size_ bytes
anymore and will not get that as an input. JAX_PLATFORMS
now raises an exception when platform initialization fails.
- {func}
- Changes
- Fixed compatibility problems with NumPy 1.23.
- {func}
jax.numpy.linalg.slogdet
now accepts an optionalmethod
argument that allows selection between an LU-decomposition based implementation and an implementation based on QR decomposition. - {func}
jax.numpy.linalg.qr
now supportsmode="raw"
. pickle
,copy.copy
, andcopy.deepcopy
now have more complete support when used on jax arrays ({jax-issue}#10659
). In particular:pickle
anddeepcopy
previously returnednp.ndarray
objects when used on aDeviceArray
; nowDeviceArray
objects are returned. Fordeepcopy
, the copied array is on the same device as the original. Forpickle
the deserialized array will be on the default device.- Within function transformations (i.e. traced code),
deepcopy
andcopy
previously were no-ops. Now they use the same mechanism asDeviceArray.copy()
. - Calling
pickle
on a traced array now results in an explicitConcretizationTypeError
.
- The implementation of singular value decomposition (SVD) and symmetric/Hermitian eigendecomposition should be significantly faster on TPU, especially for matrices above 1000x1000 or so. Both now use a spectral divide-and-conquer algorithm for eigendecomposition (QDWH-eig).
- {func}
jax.numpy.ldexp
no longer silently promotes all inputs to float64, instead it promotes to float32 for integer inputs of size int32 or smaller ({jax-issue}#10921
). - Add a
create_perfetto_link
option to {func}jax.profiler.start_trace
and {func}jax.profiler.start_trace
. When used, the profiler will generate a link to the Perfetto UI to view the trace. - Changed the semantics of {func}
jax.profiler.start_server(...)
to store the keepalive globally, rather than requiring the user to keep a reference to it. - Added {func}
jax.random.generalized_normal
. - Added {func}
jax.random.ball
. - Added {func}
jax.default_device
. - Added a
python -m jax.collect_profile
script to manually capture program traces as an alternative to the TensorBoard UI. - Added a
jax.named_scope
context manager that adds profiler metadata to Python programs (similar tojax.named_call
). - In scatter-update operations (i.e. :attr:
jax.numpy.ndarray.at
), unsafe implicit dtype casts are deprecated, and now result in aFutureWarning
. In a future release, this will become an error. An example of an unsafe implicit cast isjnp.zeros(4, dtype=int).at[0].set(1.5)
, in which1.5
previously was silently truncated to1
. - {func}
jax.experimental.compilation_cache.initialize_cache
now supports gcs bucket path as input. - Added {func}
jax.scipy.stats.gennorm
. - {func}
jax.numpy.roots
is now better behaved whenstrip_zeros=False
when coefficients have leading zeros ({jax-issue}#11215
).
- GitHub commits.
- x86-64 Mac wheels now require Mac OS 10.14 (Mojave) or newer. Mac OS 10.14 was released in 2018, so this should not be a very onerous requirement.
- The bundled version of NCCL was updated to 2.12.12, fixing some deadlocks.
- The Python flatbuffers package is no longer a dependency of jaxlib.
- GitHub commits.
- Changes
- Fixes #10717.
- GitHub commits.
- Changes
- {func}
jax.lax.eigh
now accepts an optionalsort_eigenvalues
argument that allows users to opt out of eigenvalue sorting on TPU.
- {func}
- Deprecations
- Non-array arguments to functions in {mod}
jax.lax.linalg
are now marked keyword-only. As a backward-compatibility step passing keyword-only arguments positionally yields a warning, but in a future JAX release passing keyword-only arguments positionally will fail. However, most users should prefer to use {mod}jax.numpy.linalg
instead. - {func}
jax.scipy.linalg.polar_unitary
, which was a JAX extension to the scipy API, is deprecated. Use {func}jax.scipy.linalg.polar
instead.
- Non-array arguments to functions in {mod}
- GitHub commits.
- Changes
- TF commit fixes an issue in the MHLO canonicalizer that caused constant folding to take a long time or crash for certain programs.
- GitHub commits.
- Changes
- Added support for fully asynchronous checkpointing for GlobalDeviceArray.
- GitHub commits.
- Changes
- {func}
jax.numpy.linalg.svd
on TPUs uses a qdwh-svd solver. - {func}
jax.numpy.linalg.cond
on TPUs now accepts complex input. - {func}
jax.numpy.linalg.pinv
on TPUs now accepts complex input. - {func}
jax.numpy.linalg.matrix_rank
on TPUs now accepts complex input. - {func}
jax.scipy.cluster.vq.vq
has been added. jax.experimental.maps.mesh
has been deleted. Please usejax.experimental.maps.Mesh
. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh for more information.- {func}
jax.scipy.linalg.qr
now returns a length-1 tuple rather than the raw array whenmode='r'
, in order to match the behavior ofscipy.linalg.qr
({jax-issue}#10452
) - {func}
jax.numpy.take_along_axis
now takes an optionalmode
parameter that specifies the behavior of out-of-bounds indexing. By default, invalid values (e.g., NaN) will be returned for out-of-bounds indices. In previous versions of JAX, invalid indices were clamped into range. The previous behavior can be restored by passingmode="clip"
. - {func}
jax.numpy.take
now defaults tomode="fill"
, which returns invalid values (e.g., NaN) for out-of-bounds indices. - Scatter operations, such as
x.at[...].set(...)
, now have"drop"
semantics. This has no effect on the scatter operation itself, but it means that when differentiated the gradient of a scatter will yield zero cotangents for out-of-bounds indices. Previously out-of-bounds indices were clamped into range for the gradient, which was not mathematically correct. - {func}
jax.numpy.take_along_axis
now raises aTypeError
if its indices are not of an integer type, matching the behavior of {func}numpy.take_along_axis
. Previously non-integer indices were silently cast to integers. - {func}
jax.numpy.ravel_multi_index
now raises aTypeError
if itsdims
argument is not of an integer type, matching the behavior of {func}numpy.ravel_multi_index
. Previously non-integerdims
was silently cast to integers. - {func}
jax.numpy.split
now raises aTypeError
if itsaxis
argument is not of an integer type, matching the behavior of {func}numpy.split
. Previously non-integeraxis
was silently cast to integers. - {func}
jax.numpy.indices
now raises aTypeError
if its dimensions are not of an integer type, matching the behavior of {func}numpy.indices
. Previously non-integer dimensions were silently cast to integers. - {func}
jax.numpy.diag
now raises aTypeError
if itsk
argument is not of an integer type, matching the behavior of {func}numpy.diag
. Previously non-integerk
was silently cast to integers. - Added {func}
jax.random.orthogonal
.
- {func}
- Deprecations
- Many functions and objects available in {mod}
jax.test_util
are now deprecated and will raise a warning on import. This includescases_from_list
,check_close
,check_eq
,device_under_test
,format_shape_dtype_string
,rand_uniform
,skip_on_devices
,with_config
,xla_bridge
, and_default_tolerance
({jax-issue}#10389
). These, along with previously-deprecatedJaxTestCase
,JaxTestLoader
, andBufferDonationTestCase
, will be removed in a future JAX release. Most of these utilities can be replaced by calls to standard python & numpy testing utilities found in e.g. {mod}unittest
, {mod}absl.testing
, {mod}numpy.testing
, etc. JAX-specific functionality such as device checking can be replaced through the use of public APIs such as {func}jax.devices
. Many of the deprecated utilities will still exist in {mod}jax._src.test_util
, but these are not public APIs and as such may be changed or removed without notice in future releases.
- Many functions and objects available in {mod}
- GitHub commits.
- Changes:
- Fixed a performance problem if the indices passed to
{func}
jax.numpy.take_along_axis
were broadcasted ({jax-issue}#10281
). - {func}
jax.scipy.special.expit
and {func}jax.scipy.special.logit
now require their arguments to be scalars or JAX arrays. They also now promote integer arguments to floating point. - The
DeviceArray.tile()
method is deprecated, because numpy arrays do not have atile()
method. As a replacement for this, use {func}jax.numpy.tile
({jax-issue}#10266
).
- Fixed a performance problem if the indices passed to
{func}
- Changes:
- Linux wheels are now built conforming to the
manylinux2014
standard, instead ofmanylinux2010
.
- Linux wheels are now built conforming to the
- GitHub commits.
- Changes:
- Upgraded libtpu wheel to a version that fixes a hang when initializing a TPU pod. Fixes #10218.
- Deprecations:
- {mod}
jax.experimental.loops
is being deprecated. See {jax-issue}#10278
for an alternative API.
- {mod}
- GitHub commits.
- Changes:
- added {func}
jax.random.loggamma
& improved behavior of {func}jax.random.beta
and {func}jax.random.dirichlet
for small parameter values ({jax-issue}#9906
). - the private
lax_numpy
submodule is no longer exposed in thejax.numpy
namespace ({jax-issue}#10029
). - added array creation routines {func}
jax.numpy.frombuffer
, {func}jax.numpy.fromfunction
, and {func}jax.numpy.fromstring
({jax-issue}#10049
). DeviceArray.copy()
now returns aDeviceArray
rather than anp.ndarray
({jax-issue}#10069
)- added {func}
jax.scipy.linalg.rsf2csf
jax.experimental.sharded_jit
has been deprecated and will be removed soon.
- added {func}
- Deprecations:
- {func}
jax.nn.normalize
is being deprecated. Use {func}jax.nn.standardize
instead ({jax-issue}#9899
). - {func}
jax.tree_util.tree_multimap
is deprecated. Use {func}jax.tree_util.tree_map
instead ({jax-issue}#5746
). jax.experimental.sharded_jit
is deprecated. Usepjit
instead.
- {func}
- Bug fixes
- Fixed a bug where double-precision complex-to-real IRFFTs would mutate their
input buffers on GPU ({jax-issue}
#9946
). - Fixed incorrect constant-folding of complex scatters ({jax-issue}
#10159
)
- Fixed a bug where double-precision complex-to-real IRFFTs would mutate their
input buffers on GPU ({jax-issue}
- GitHub commits.
- Changes:
- The functions
jax.ops.index_update
,jax.ops.index_add
, which were deprecated in 0.2.22, have been removed. Please use the.at
property on JAX arrays instead, e.g.,x.at[idx].set(y)
. - Moved
jax.experimental.ann.approx_*_k
intojax.lax
. These functions are optimized alternatives tojax.lax.top_k
. - {func}
jax.numpy.broadcast_arrays
and {func}jax.numpy.broadcast_to
now require scalar or array-like inputs, and will fail if they are passed lists (part of {jax-issue}#7737
). - The standard jax[tpu] install can now be used with Cloud TPU v4 VMs.
pjit
now works on CPU (in addition to previous TPU and GPU support).
- The functions
- Changes
XlaComputation.as_hlo_text()
now supports printing large constants by passing boolean flagprint_large_constants=True
.
- Deprecations:
- The
.block_host_until_ready()
method on JAX arrays has been deprecated. Use.block_until_ready()
instead.
- The
-
Changes:
jax.test_util.JaxTestCase
andjax.test_util.JaxTestLoader
are now deprecated. The suggested replacement is to useparametrized.TestCase
directly. For tests that rely on custom asserts such asJaxTestCase.assertAllClose()
, the suggested replacement is to use standard numpy testing utilities such as {func}numpy.testing.assert_allclose()
, which work directly with JAX arrays ({jax-issue}#9620
).jax.test_util.JaxTestCase
now setsjax_numpy_rank_promotion='raise'
by default ({jax-issue}#9562
). To recover the previous behavior, use the newjax.test_util.with_config
decorator:@jtu.with_config(jax_numpy_rank_promotion='allow') class MyTestCase(jtu.JaxTestCase): ...
- Added {func}
jax.scipy.linalg.schur
, {func}jax.scipy.linalg.sqrtm
, {func}jax.scipy.signal.csd
, {func}jax.scipy.signal.stft
, {func}jax.scipy.signal.welch
.
-
Changes
- jax version has been bumped to 0.3.0. Please see the design doc for the explanation.
- Changes
- Bazel 5.0.0 is now required to build jaxlib.
- jaxlib version has been bumped to 0.3.0. Please see the design doc for the explanation.
- GitHub
commits.
jax.jit(f).lower(...).compiler_ir()
now defaults to the MHLO dialect if nodialect=
is passed.- The
jax.jit(f).lower(...).compiler_ir(dialect='mhlo')
now returns an MLIRir.Module
object instead of its string representation.
- New features
- Includes precompiled SASS for NVidia compute capability 8.0 GPUS (e.g. A100). Removes precompiled SASS for compute capability 6.1 so as not to increase the number of compute capabilities: GPUs with compute capability 6.1 can use the 6.0 SASS.
- With jaxlib 0.1.76, JAX uses the MHLO MLIR dialect as its primary target compiler IR by default.
- Breaking changes
- Support for NumPy 1.18 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.
- Bug fixes
- Fixed a bug where apparently identical pytreedef objects constructed by different routes do not compare as equal (#9066).
- The JAX jit cache requires two static arguments to have identical types for a cache hit (#9311).
-
Breaking changes:
- Support for NumPy 1.18 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.
- The host_callback primitives have been simplified to drop the
special autodiff handling for hcb.id_tap and id_print.
From now on, only the primals are tapped. The old behavior can be
obtained (for a limited time) by setting the
JAX_HOST_CALLBACK_AD_TRANSFORMS
environment variable, or the--jax_host_callback_ad_transforms
flag. Additionally, added documentation for how to implement the old behavior using JAX custom AD APIs ({jax-issue}#8678
). - Sorting now matches the behavior of NumPy for
0.0
andNaN
regardless of the bit representation. In particular,0.0
and-0.0
are now treated as equivalent, where previously-0.0
was treated as less than0.0
. Additionally allNaN
representations are now treated as equivalent and sorted to the end of the array. Previously negativeNaN
values were sorted to the front of the array, andNaN
values with different internal bit representations were not treated as equivalent, and were sorted according to those bit patterns ({jax-issue}#9178
). - {func}
jax.numpy.unique
now treatsNaN
values in the same way asnp.unique
in NumPy versions 1.21 and newer: at most oneNaN
value will appear in the uniquified output ({jax-issue}9184
).
-
Bug fixes:
- host_callback now supports ad_checkpoint.checkpoint ({jax-issue}
#8907
).
- host_callback now supports ad_checkpoint.checkpoint ({jax-issue}
-
New features:
- add
jax.block_until_ready
({jax-issue}`#8941) - Added a new debugging flag/environment variable
JAX_DUMP_IR_TO=/path
. If set, JAX dumps the MHLO/HLO IR it generates for each computation to a file under the given path. - Added
jax.ensure_compile_time_eval
to the public api ({jax-issue}#7987
). - jax2tf now supports a flag jax2tf_associative_scan_reductions to change
the lowering for associative reductions, e.g., jnp.cumsum, to behave
like JAX on CPU and GPU (to use an associative scan). See the jax2tf README
for more details ({jax-issue}
#9189
).
- add
- New features:
- Support for python 3.10.
-
Bug fixes:
- Out-of-bounds indices to
jax.ops.segment_sum
will now be handled withFILL_OR_DROP
semantics, as documented. This primarily affects the reverse-mode derivative, where gradients corresponding to out-of-bounds indices will now be returned as 0. (#8634). - jax2tf will force the converted code to use XLA for the code fragments
under jax.jit, e.g., most jax.numpy functions ({jax-issue}
#7839
).
- Out-of-bounds indices to
- Enabled peer-to-peer copies between GPUs. Previously, GPU copies were bounced via the host, which is usually slower.
- Added experimental MLIR Python bindings for use by JAX.
-
New features:
- (Experimental)
jax.distributed.initialize
exposes multi-host GPU backend. jax.random.permutation
supports newindependent
keyword argument ({jax-issue}#8430
)
- (Experimental)
-
Breaking changes
- Moved
jax.experimental.stax
tojax.example_libraries.stax
- Moved
jax.experimental.optimizers
tojax.example_libraries.optimizers
- Moved
-
New features:
- Added
jax.lax.linalg.qdwh
.
- Added
-
New features:
jax.random.choice
andjax.random.permutation
now support multidimensional arrays and an optionalaxis
argument ({jax-issue}#8158
)
-
Breaking changes:
jax.numpy.take
andjax.numpy.take_along_axis
now require array-like inputs (see {jax-issue}#7737
)
-
Multiple cuDNN versions are now supported for jaxlib GPU
cuda11
wheels.- cuDNN 8.2 or newer. We recommend using the cuDNN 8.2 wheel if your cuDNN installation is new enough, since it supports additional functionality.
- cuDNN 8.0.5 or newer.
-
Breaking changes:
-
The install commands for GPU jaxlib are as follows:
pip install --upgrade pip # Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer. pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html # Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer. pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html # Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer. pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
-
- GitHub commits.
- Breaking Changes
-
Static arguments to
jax.pmap
must now be hashable.Unhashable static arguments have long been disallowed on
jax.jit
, but they were still permitted onjax.pmap
;jax.pmap
compared unhashable static arguments using object identity.This behavior is a footgun, since comparing arguments using object identity leads to recompilation each time the object identity changes. Instead, we now ban unhashable arguments: if a user of
jax.pmap
wants to compare static arguments by object identity, they can define__hash__
and__eq__
methods on their objects that do that, or wrap their objects in an object that has those operations with object identity semantics. Another option is to usefunctools.partial
to encapsulate the unhashable static arguments into the function object. -
jax.util.partial
was an accidental export that has now been removed. Usefunctools.partial
from the Python standard library instead.
-
- Deprecations
- The functions
jax.ops.index_update
,jax.ops.index_add
etc. are deprecated and will be removed in a future JAX release. Please use the.at
property on JAX arrays instead, e.g.,x.at[idx].set(y)
. For now, these functions produce aDeprecationWarning
.
- The functions
- New features:
- An optimized C++ code-path improving the dispatch time for
pmap
is now the default when using jaxlib 0.1.72 or newer. The feature can be disabled using the--experimental_cpp_pmap
flag (orJAX_CPP_PMAP
environment variable). jax.numpy.unique
now supports an optionalfill_value
argument ({jax-issue}#8121
)
- An optimized C++ code-path improving the dispatch time for
- Breaking changes:
- Support for CUDA 10.2 and CUDA 10.1 has been dropped. Jaxlib now supports CUDA 11.1+.
- Bug fixes:
- Fixes jax-ml#7461, which caused wrong outputs on all platforms due to incorrect buffer aliasing inside the XLA compiler.
- GitHub commits.
- Breaking Changes
jax.api
has been removed. Functions that were available asjax.api.*
were aliases for functions injax.*
; please use the functions injax.*
instead.jax.partial
, andjax.lax.partial
were accidental exports that have now been removed. Usefunctools.partial
from the Python standard library instead.- Boolean scalar indices now raise a
TypeError
; previously this silently returned wrong results ({jax-issue}#7925
). - Many more
jax.numpy
functions now require array-like inputs, and will error if passed a list ({jax-issue}#7747
{jax-issue}#7802
{jax-issue}#7907
). See {jax-issue}#7737
for a discussion of the rationale behind this change. - When inside a transformation such as
jax.jit
,jax.numpy.array
always stages the array it produces into the traced computation. Previouslyjax.numpy.array
would sometimes produce a on-device array, even under ajax.jit
decorator. This change may break code that used JAX arrays to perform shape or index computations that must be known statically; the workaround is to perform such computations using classic NumPy arrays instead. jnp.ndarray
is now a true base-class for JAX arrays. In particular, this means that for a standard numpy arrayx
,isinstance(x, jnp.ndarray)
will now returnFalse
({jax-issue}7927
).
- New features:
- Added {func}
jax.numpy.insert
implementation ({jax-issue}#7936
).
- Added {func}
- GitHub commits.
- Breaking Changes
jnp.poly*
functions now require array-like inputs ({jax-issue}#7732
)jnp.unique
and other set-like operations now require array-like inputs ({jax-issue}#7662
)
- Breaking changes:
- Support for CUDA 11.0 and CUDA 10.1 has been dropped. Jaxlib now supports CUDA 10.2 and CUDA 11.1+.
- GitHub commits.
- Breaking changes:
-
Support for NumPy 1.17 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.
-
The
jit
decorator has been added around the implementation of a number of operators on JAX arrays. This speeds up dispatch times for common operators such as+
.This change should largely be transparent to most users. However, there is one known behavioral change, which is that large integer constants may now produce an error when passed directly to a JAX operator (e.g.,
x + 2**40
). The workaround is to cast the constant to an explicit type (e.g.,np.float64(2**40)
).
-
- New features:
- Improved the support for shape polymorphism in jax2tf for operations that
need to use a dimension size in array computation, e.g.,
jnp.mean
. ({jax-issue}#7317
)
- Improved the support for shape polymorphism in jax2tf for operations that
need to use a dimension size in array computation, e.g.,
- Bug fixes:
- Some leaked trace errors from the previous release ({jax-issue}
#7613
)
- Some leaked trace errors from the previous release ({jax-issue}
- Breaking changes:
-
Support for Python 3.6 has been dropped, per the deprecation policy. Please upgrade to a supported Python version.
-
Support for NumPy 1.17 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.
-
The host_callback mechanism now uses one thread per local device for making the calls to the Python callbacks. Previously there was a single thread for all devices. This means that the callbacks may now be called interleaved. The callbacks corresponding to one device will still be called in sequence.
-
-
Breaking changes:
- Support for Python 3.6 has been dropped, per the deprecation policy. Please upgrade to a supported Python version.
- The minimum jaxlib version is now 0.1.69.
- The
backend
argument to {py:func}jax.dlpack.from_dlpack
has been removed.
-
New features:
- Added a polar decomposition ({py:func}
jax.scipy.linalg.polar
).
- Added a polar decomposition ({py:func}
-
Bug fixes:
- Tightened the checks for lax.argmin and lax.argmax to ensure they are
not used with an invalid
axis
value, or with an empty reduction dimension. ({jax-issue}#7196
)
- Tightened the checks for lax.argmin and lax.argmax to ensure they are
not used with an invalid
- Fix bugs in TFRT CPU backend that results in incorrect results.
- GitHub commits.
- Bug fixes:
- Default to the older "stream_executor" CPU runtime for jaxlib <= 0.1.68 to work around #7229, which caused wrong outputs on CPU due to a concurrency problem.
- New features:
- New SciPy function {py:func}
jax.scipy.special.sph_harm
. - Reverse-mode autodiff functions ({func}
jax.grad
, {func}jax.value_and_grad
, {func}jax.vjp
, and {func}jax.linear_transpose
) support a parameter that indicates which named axes should be summed over in the backward pass if they were broadcasted over in the forward pass. This enables use of these APIs in a non-per-example way inside maps (initially only {func}jax.experimental.maps.xmap
) ({jax-issue}#6950
).
- New SciPy function {py:func}
-
New features:
- #7042 Turned on TFRT CPU backend with significant dispatch performance improvements on CPU.
- The {func}
jax2tf.convert
supports inequalities and min/max for booleans ({jax-issue}#6956
). - New SciPy function {py:func}
jax.scipy.special.lpmn_values
.
-
Breaking changes:
- Support for NumPy 1.16 has been dropped, per the deprecation policy.
-
Bug fixes:
- Fixed bug that prevented round-tripping from JAX to TF and back:
jax2tf.call_tf(jax2tf.convert)
({jax-issue}#6947
).
- Fixed bug that prevented round-tripping from JAX to TF and back:
- Bug fixes:
- Fixed bug in TFRT CPU backend that gets nans when transfer TPU buffer to CPU.
-
New features:
- The {func}
jax2tf.convert
now has support forpjit
andsharded_jit
. - A new configuration option JAX_TRACEBACK_FILTERING controls how JAX filters tracebacks.
- A new traceback filtering mode using
__tracebackhide__
is now enabled by default in sufficiently recent versions of IPython. - The {func}
jax2tf.convert
supports shape polymorphism even when the unknown dimensions are used in arithmetic operations, e.g.,jnp.reshape(-1)
({jax-issue}#6827
). - The {func}
jax2tf.convert
generates custom attributes with location information in TF ops. The code that XLA generates after jax2tf has the same location information as JAX/XLA. - New SciPy function {py:func}
jax.scipy.special.lpmn
.
- The {func}
-
Bug fixes:
- The {func}
jax2tf.convert
now ensures that it uses the same typing rules for Python scalars and for choosing 32-bit vs. 64-bit computations as JAX ({jax-issue}#6883
). - The {func}
jax2tf.convert
now scopes theenable_xla
conversion parameter properly to apply only during the just-in-time conversion ({jax-issue}#6720
). - The {func}
jax2tf.convert
now convertslax.dot_general
using theXlaDot
TensorFlow op, for better fidelity w.r.t. JAX numerical precision ({jax-issue}#6717
). - The {func}
jax2tf.convert
now has support for inequality comparisons and min/max for complex numbers ({jax-issue}#6892
).
- The {func}
- New features:
-
CUDA 11.1 wheels are now supported on all CUDA 11 versions 11.1 or higher.
NVidia now promises compatibility between CUDA minor releases starting with CUDA 11.1. This means that JAX can release a single CUDA 11.1 wheel that is compatible with CUDA 11.2 and 11.3.
There is no longer a separate jaxlib release for CUDA 11.2 (or higher); use the CUDA 11.1 wheel for those versions (cuda111).
-
Jaxlib now bundles
libdevice.10.bc
in CUDA wheels. There should be no need to point JAX to a CUDA installation to find this file. -
Added automatic support for static keyword arguments to the {func}
jit
implementation. -
Added support for pretransformation exception traces.
-
Initial support for pruning unused arguments from {func}
jit
-transformed computations. Pruning is still a work in progress. -
Improved the string representation of {class}
PyTreeDef
objects. -
Added support for XLA's variadic ReduceWindow.
-
- Bug fixes:
- Fixed a bug in the remote cloud TPU support when large numbers of arguments are passed to a computation.
- Fix a bug that meant that JAX garbage collection was not triggered by
{func}
jit
transformed functions.
-
New features:
- When combined with jaxlib 0.1.66, {func}
jax.jit
now supports static keyword arguments. A newstatic_argnames
option has been added to specify keyword arguments as static. - {func}
jax.nonzero
has a new optionalsize
argument that allows it to be used withinjit
({jax-issue}#6501
) - {func}
jax.numpy.unique
now supports theaxis
argument ({jax-issue}#6532
). - {func}
jax.experimental.host_callback.call
now supportspjit.pjit
({jax-issue}#6569
). - Added {func}
jax.scipy.linalg.eigh_tridiagonal
that computes the eigenvalues of a tridiagonal matrix. Only eigenvalues are supported at present. - The order of the filtered and unfiltered stack traces in exceptions has been
changed. The traceback attached to an exception thrown from JAX-transformed
code is now filtered, with an
UnfilteredStackTrace
exception containing the original trace as the__cause__
of the filtered exception. Filtered stack traces now also work with Python 3.6. - If an exception is thrown by code that has been transformed by reverse-mode
automatic differentiation, JAX now attempts to attach as a
__cause__
of the exception aJaxStackTraceBeforeTransformation
object that contains the stack trace that created the original operation in the forward pass. Requires jaxlib 0.1.66.
- When combined with jaxlib 0.1.66, {func}
-
Breaking changes:
- The following function names have changed. There are still aliases, so this
should not break existing code, but the aliases will eventually be removed
so please change your code.
host_id
--> {func}~jax.process_index
host_count
--> {func}~jax.process_count
host_ids
-->range(jax.process_count())
- Similarly, the argument to {func}
~jax.local_devices
has been renamed fromhost_id
toprocess_index
. - Arguments to {func}
jax.jit
other than the function are now marked as keyword-only. This change is to prevent accidental breakage when arguments are added tojit
.
- The following function names have changed. There are still aliases, so this
should not break existing code, but the aliases will eventually be removed
so please change your code.
-
Bug fixes:
- The {func}
jax2tf.convert
now works in presence of gradients for functions with integer inputs ({jax-issue}#6360
). - Fixed assertion failure in {func}
jax2tf.call_tf
when used with capturedtf.Variable
({jax-issue}#6572
).
- The {func}
- GitHub commits.
- New features
- New profiling APIs: {func}
jax.profiler.start_trace
, {func}jax.profiler.stop_trace
, and {func}jax.profiler.trace
- {func}
jax.lax.reduce
is now differentiable.
- New profiling APIs: {func}
- Breaking changes:
- The minimum jaxlib version is now 0.1.64.
- Some profiler APIs names have been changed. There are still aliases, so this
should not break existing code, but the aliases will eventually be removed
so please change your code.
TraceContext
--> {func}~jax.profiler.TraceAnnotation
StepTraceContext
--> {func}~jax.profiler.StepTraceAnnotation
trace_function
--> {func}~jax.profiler.annotate_function
- Omnistaging can no longer be disabled. See omnistaging for more information.
- Python integers larger than the maximum
int64
value will now lead to an overflow in all cases, rather than being silently converted touint64
in some cases ({jax-issue}#6047
). - Outside X64 mode, Python integers outside the range representable by
int32
will now lead to anOverflowError
rather than having their value silently truncated.
- Bug fixes:
host_callback
now supports empty arrays in arguments and results ({jax-issue}#6262
).- {func}
jax.random.randint
clips rather than wraps of out-of-bounds limits, and can now generate integers in the full range of the specified dtype ({jax-issue}#5868
)
-
New features:
-
Bug fixes:
- #6136 generalized
jax.flatten_util.ravel_pytree
to handle integer dtypes. - #6129 fixed a bug with handling
some constants like
enum.IntEnums
- #6145 fixed batching issues with incomplete beta functions
- #6014 fixed H2D transfers during tracing
- #6165 avoids OverflowErrors when converting some large Python integers to floats
- #6136 generalized
-
Breaking changes:
- The minimum jaxlib version is now 0.1.62.
- GitHub commits.
- New features:
- {func}
jax.scipy.stats.chi2
is now available as a distribution with logpdf and pdf methods. - {func}
jax.scipy.stats.betabinom
is now available as a distribution with logpmf and pmf methods. - Added {func}
jax.experimental.jax2tf.call_tf
to call TensorFlow functions from JAX ({jax-issue}#5627
) and README). - Extended the batching rule for
lax.pad
to support batching of the padding values.
- {func}
- Bug fixes:
- {func}
jax.numpy.take
properly handles negative indices ({jax-issue}#5768
)
- {func}
- Breaking changes:
- JAX's promotion rules were adjusted to make promotion more consistent and
invariant to JIT. In particular, binary operations can now result in weakly-typed
values when appropriate. The main user-visible effect of the change is that
some operations result in outputs of different precision than before; for
example the expression
jnp.bfloat16(1) + 0.1 * jnp.arange(10)
previously returned afloat64
array, and now returns abfloat16
array. JAX's type promotion behavior is described at {ref}type-promotion
. - {func}
jax.numpy.linspace
now computes the floor of integer values, i.e., rounding towards -inf rather than 0. This change was made to match NumPy 1.20.0. - {func}
jax.numpy.i0
no longer accepts complex numbers. Previously the function computed the absolute value of complex arguments. This change was made to match the semantics of NumPy 1.20.0. - Several {mod}
jax.numpy
functions no longer accept tuples or lists in place of array arguments: {func}jax.numpy.pad
, :funcjax.numpy.ravel
, {func}jax.numpy.repeat
, {func}jax.numpy.reshape
. In general, {mod}jax.numpy
functions should be used with scalars or array arguments.
- JAX's promotion rules were adjusted to make promotion more consistent and
invariant to JIT. In particular, binary operations can now result in weakly-typed
values when appropriate. The main user-visible effect of the change is that
some operations result in outputs of different precision than before; for
example the expression
- New features:
- jaxlib wheels are now built to require AVX instructions on x86-64 machines
by default. If you want to use JAX on a machine that doesn't support AVX,
you can build a jaxlib from source using the
--target_cpu_features
flag tobuild.py
.--target_cpu_features
also replaces--enable_march_native
.
- jaxlib wheels are now built to require AVX instructions on x86-64 machines
by default. If you want to use JAX on a machine that doesn't support AVX,
you can build a jaxlib from source using the
- Bug fixes:
- Fixed a memory leak when converting CPU DeviceArrays to NumPy arrays. The memory leak was present in jaxlib releases 0.1.58 and 0.1.59.
bool
,int8
, anduint8
are now considered safe to cast tobfloat16
NumPy extension type.
- GitHub commits.
- New features:
- Extend the {mod}
jax.experimental.loops
module with support for pytrees. Improved error checking and error messages. - Add {func}
jax.experimental.enable_x64
and {func}jax.experimental.disable_x64
. These are context managers which allow X64 mode to be temporarily enabled/disabled within a session.
- Extend the {mod}
- Breaking changes:
- {func}
jax.ops.segment_sum
now drops segment IDs that are out of range rather than wrapping them into the segment ID space. This was done for performance reasons.
- {func}
- GitHub commits.
- New features:
- Add {func}
jax.closure_convert
for use with higher-order custom derivative functions. ({jax-issue}#5244
) - Add {func}
jax.experimental.host_callback.call
to call a custom Python function on the host and return a result to the device computation. ({jax-issue}#5243
)
- Add {func}
- Bug fixes:
jax.numpy.arccosh
now returns the same branch asnumpy.arccosh
for complex inputs ({jax-issue}#5156
)host_callback.id_tap
now works forjax.pmap
also. There is an optional parameter forid_tap
andid_print
to request that the device from which the value is tapped be passed as a keyword argument to the tap function ({jax-issue}#5182
).
- Breaking changes:
jax.numpy.pad
now takes keyword arguments. Positional argumentconstant_values
has been removed. In addition, passing unsupported keyword arguments raises an error.- Changes for {func}
jax.experimental.host_callback.id_tap
({jax-issue}#5243
):- Removed support for
kwargs
for {func}jax.experimental.host_callback.id_tap
. (This support has been deprecated for a few months.) - Changed the printing of tuples for {func}
jax.experimental.host_callback.id_print
to use '(' instead of '['. - Changed the {func}
jax.experimental.host_callback.id_print
in presence of JVP to print a pair of primal and tangent. Previously, there were two separate print operations for the primals and the tangent. host_callback.outfeed_receiver
has been removed (it is not necessary, and was deprecated a few months ago).
- Removed support for
- New features:
- New flag for debugging
inf
, analogous to that forNaN
({jax-issue}#5224
).
- New flag for debugging
- GitHub commits.
- New features:
- Add
jax.device_put_replicated
- Add multi-host support to
jax.experimental.sharded_jit
- Add support for differentiating eigenvalues computed by
jax.numpy.linalg.eig
- Add support for building on Windows platforms
- Add support for general in_axes and out_axes in
jax.pmap
- Add complex support for
jax.numpy.linalg.slogdet
- Add
- Bug fixes:
- Fix higher-than-second order derivatives of
jax.numpy.sinc
at zero - Fix some hard-to-hit bugs around symbolic zeros in transpose rules
- Fix higher-than-second order derivatives of
- Breaking changes:
jax.experimental.optix
has been deleted, in favor of the standaloneoptax
Python package.- indexing of JAX arrays with non-tuple sequences now raises a
TypeError
. This type of indexing has been deprecated in Numpy since v1.16, and in JAX since v0.2.4. See {jax-issue}#4564
.
-
New Features:
- Add support for shape-polymorphic tracing for the jax.experimental.jax2tf converter. See README.md.
-
Breaking change cleanup
-
Raise an error on non-hashable static arguments for jax.jit and xla_computation. See cb48f42.
-
Improve consistency of type promotion behavior ({jax-issue}
#4744
):- Adding a complex Python scalar to a JAX floating point number respects the precision of
the JAX float. For example,
jnp.float32(1) + 1j
now returnscomplex64
, where previously it returnedcomplex128
. - Results of type promotion with 3 or more terms involving uint64, a signed int, and a third type
are now independent of the order of arguments. For example:
jnp.result_type(jnp.uint64, jnp.int64, jnp.float16)
andjnp.result_type(jnp.float16, jnp.uint64, jnp.int64)
both returnfloat16
, where previously the first returnedfloat64
and the second returnedfloat16
.
- Adding a complex Python scalar to a JAX floating point number respects the precision of
the JAX float. For example,
-
The contents of the (undocumented)
jax.lax_linalg
linear algebra module are now exposed publicly asjax.lax.linalg
. -
jax.random.PRNGKey
now produces the same results in and out of JIT compilation ({jax-issue}#4877
). This required changing the result for a given seed in a few particular cases:- With
jax_enable_x64=False
, negative seeds passed as Python integers now return a different result outside JIT mode. For example,jax.random.PRNGKey(-1)
previously returned[4294967295, 4294967295]
, and now returns[0, 4294967295]
. This matches the behavior in JIT. - Seeds outside the range representable by
int64
outside JIT now result in anOverflowError
rather than aTypeError
. This matches the behavior in JIT.
To recover the keys returned previously for negative integers with
jax_enable_x64=False
outside JIT, you can use:key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF)
- With
-
DeviceArray now raises
RuntimeError
instead ofValueError
when trying to access its value while it has been deleted.
-
- Fixed a bug that meant JAX sometimes return platform-specific types (e.g.,
np.cint
) instead of standard types (e.g.,np.int32
). (#4903) - Fixed a crash when constant-folding certain int16 operations. (#4971)
- Added an
is_leaf
predicate to {func}pytree.flatten
.
- Fixed manylinux2010 compliance issues in GPU wheels.
- Switched the CPU FFT implementation from Eigen to PocketFFT.
- Fixed a bug where the hash of bfloat16 values was not correctly initialized and could change (#4651).
- Add support for retaining ownership when passing arrays to DLPack (#4636).
- Fixed a bug for batched triangular solves with sizes greater than 128 but not a multiple of 128.
- Fixed a bug when performing concurrent FFTs on multiple GPUs (#3518).
- Fixed a bug in profiler where tools are missing (#4427).
- Dropped support for CUDA 10.0.
- GitHub commits.
- Improvements:
- Ensure that
check_jaxpr
does not perform FLOPS. See {jax-issue}#4650
. - Expanded the set of JAX primitives converted by jax2tf. See primitives_with_limited_support.md.
- Ensure that
-
Improvements:
- Add support for
remat
to jax.experimental.host_callback. See {jax-issue}#4608
.
- Add support for
-
Deprecations
- Indexing with non-tuple sequences is now deprecated, following a similar deprecation in Numpy.
In a future release, this will result in a TypeError. See {jax-issue}
#4564
.
- Indexing with non-tuple sequences is now deprecated, following a similar deprecation in Numpy.
In a future release, this will result in a TypeError. See {jax-issue}
- GitHub commits.
- The reason for another release so soon is we need to temporarily roll back a new jit fastpath while we look into a performance degradation
- GitHub commits.
- Improvements:
- As a benefit of omnistaging, the host_callback functions are executed (in program
order) even if the result of the {py:func}
jax.experimental.host_callback.id_print
/ {py:func}jax.experimental.host_callback.id_tap
is not used in the computation.
- As a benefit of omnistaging, the host_callback functions are executed (in program
order) even if the result of the {py:func}
- GitHub commits.
- Improvements:
- Omnistaging on by default. See {jax-issue}
#3370
and omnistaging
- Omnistaging on by default. See {jax-issue}
- Breaking changes:
- New simplified interface for {py:func}
jax.experimental.host_callback.id_tap
(#4101)
- New simplified interface for {py:func}
- Update XLA:
- Fix bug in DLPackManagedTensorToBuffer (#4196)
- GitHub commits.
- Bug Fixes:
- make jnp.abs() work for unsigned inputs (#3914)
- Improvements:
- "Omnistaging" behavior added behind a flag, disabled by default (#3370)
- GitHub commits.
- New Features:
- BFGS (#3101)
- TPU support for half-precision arithmetic (#3878)
- Bug Fixes:
- Prevent some accidental dtype warnings (#3874)
- Fix a multi-threading bug in custom derivatives (#3845, #3869)
- Improvements:
- Faster searchsorted implementation (#3873)
- Better test coverage for jax.numpy sorting algorithms (#3836)
- Update XLA.
- GitHub commits.
- The minimum jaxlib version is now 0.1.51.
- New Features:
- jax.image.resize. (#3703)
- hfft and ihfft (#3664)
- jax.numpy.intersect1d (#3726)
- jax.numpy.lexsort (#3812)
lax.scan
and thescan
primitive support anunroll
parameter for loop unrolling when lowering to XLA ({jax-issue}#3738
).
- Bug Fixes:
- Fix reduction repeated axis error (#3618)
- Fix shape rule for lax.pad for input dimensions of size 0. (#3608)
- make psum transpose handle zero cotangents (#3653)
- Fix shape error when taking JVP of reduce-prod over size 0 axis. (#3729)
- Support differentiation through jax.lax.all_to_all (#3733)
- address nan issue in jax.scipy.special.zeta (#3777)
- Improvements:
- Many improvements to jax2tf
- Reimplement argmin/argmax using a single pass variadic reduction. (#3611)
- Enable XLA SPMD partitioning by default. (#3151)
- Add support for 0d transpose convolution (#3643)
- Make LU gradient work for low-rank matrices (#3610)
- support multiple_results and custom JVPs in jet (#3657)
- Generalize reduce-window padding to support (lo, hi) pairs. (#3728)
- Implement complex convolutions on CPU and GPU. (#3735)
- Make jnp.take work for empty slices of empty arrays. (#3751)
- Relax dimension ordering rules for dot_general. (#3778)
- Enable buffer donation for GPU. (#3800)
- Add support for base dilation and window dilation to reduce window op… (#3803)
- Update XLA.
- Add new runtime support for host_callback.
- GitHub commits.
- Bug fixes:
- Fix an odeint bug introduced in the previous release, see
{jax-issue}
#3587
.
- Fix an odeint bug introduced in the previous release, see
{jax-issue}
- GitHub commits.
- The minimum jaxlib version is now 0.1.48.
- Bug fixes:
- Allow
jax.experimental.ode.odeint
dynamics functions to close over values with respect to which we're differentiating {jax-issue}#3562
.
- Allow
- Add support for CUDA 11.0.
- Drop support for CUDA 9.2 (we only maintain support for the last four CUDA versions.)
- Update XLA.
- Bug fixes:
- Fix build issue that could result in slow compiles (https://github.com/tensorflow/tensorflow/commit/f805153a25b00d12072bd728e91bb1621bfcf1b1)
- New features:
- Adds support for fast traceback collection.
- Adds preliminary support for on-device heap profiling.
- Implements
np.nextafter
forbfloat16
types. - Complex128 support for FFTs on CPU and GPU.
- Bug fixes:
- Improved float64
tanh
accuracy on GPU. - float64 scatters on GPU are much faster.
- Complex matrix multiplication on CPU should be much faster.
- Stable sorts on CPU should actually be stable now.
- Concurrency bug fix in CPU backend.
- Improved float64
- GitHub commits.
- New features:
lax.switch
introduces indexed conditionals with multiple branches, together with a generalization of thecond
primitive {jax-issue}#3318
.
- GitHub commits.
- New features:
- {func}
lax.cond
supports a single-operand form, taken as the argument to both branches {jax-issue}#2993
.
- {func}
- Notable changes:
- The format of the
transforms
keyword for the {func}jax.experimental.host_callback.id_tap
primitive has changed {jax-issue}#3132
.
- The format of the
- GitHub commits.
- New features:
- Support for reduction over subsets of a pmapped axis using
axis_index_groups
{jax-issue}#2382
. - Experimental support for printing and calling host-side Python function from
compiled code. See id_print and id_tap
({jax-issue}
#3006
).
- Support for reduction over subsets of a pmapped axis using
- Notable changes:
- The visibility of names exported from {mod}
jax.numpy
has been tightened. This may break code that was making use of names that were previously exported accidentally.
- The visibility of names exported from {mod}
- Fixes crash for outfeed.
- GitHub commits.
- New features:
- Support for
in_axes=None
on {func}pmap
{jax-issue}#2896
.
- Support for
- Fixes crash for linear algebra functions on Mac OS X (#432).
- Fixes an illegal instruction crash caused by using AVX512 instructions when an operating system or hypervisor disabled them (#2906).
- GitHub commits.
- New features:
- Differentiation of determinants of singular matrices
{jax-issue}
#2809
.
- Differentiation of determinants of singular matrices
{jax-issue}
- Bug fixes:
- Fix {func}
odeint
differentiation with respect to time of ODEs with time-dependent dynamics {jax-issue}#2817
, also add ODE CI testing. - Fix {func}
lax_linalg.qr
differentiation {jax-issue}#2867
.
- Fix {func}
- Fixes segfault: {jax-issue}
#2755
- Plumb is_stable option on Sort HLO through to Python.
- GitHub commits.
- New features:
- Add syntactic sugar for functional indexed updates
{jax-issue}
#2684
. - Add {func}
jax.numpy.linalg.multi_dot
{jax-issue}#2726
. - Add {func}
jax.numpy.unique
{jax-issue}#2760
. - Add {func}
jax.numpy.rint
{jax-issue}#2724
. - Add {func}
jax.numpy.rint
{jax-issue}#2724
. - Add more primitive rules for {func}
jax.experimental.jet
.
- Add syntactic sugar for functional indexed updates
{jax-issue}
- Bug fixes:
- Fix {func}
logaddexp
and {func}logaddexp2
differentiation at zero {jax-issue}#2107
. - Improve memory usage in reverse-mode autodiff without {func}
jit
{jax-issue}#2719
.
- Fix {func}
- Better errors:
- Improves error message for reverse-mode differentiation of {func}
lax.while_loop
{jax-issue}#2129
.
- Improves error message for reverse-mode differentiation of {func}
- Fixes a bug where if multiple GPUs of different models were present, JAX would only compile programs suitable for the first GPU.
- Bugfix for
batch_group_count
convolutions. - Added precompiled SASS for more GPU versions to avoid startup PTX compilation hang.
- GitHub commits.
- Added
jax.custom_jvp
andjax.custom_vjp
from {jax-issue}#2026
, see the tutorial notebook. Deprecatedjax.custom_transforms
and removed it from the docs (though it still works). - Add
scipy.sparse.linalg.cg
{jax-issue}#2566
. - Changed how Tracers are printed to show more useful information for debugging {jax-issue}
#2591
. - Made
jax.numpy.isclose
handlenan
andinf
correctly {jax-issue}#2501
. - Added several new rules for
jax.experimental.jet
{jax-issue}#2537
. - Fixed
jax.experimental.stax.BatchNorm
whenscale
/center
isn't provided. - Fix some missing cases of broadcasting in
jax.numpy.einsum
{jax-issue}#2512
. - Implement
jax.numpy.cumsum
andjax.numpy.cumprod
in terms of a parallel prefix scan {jax-issue}#2596
and makereduce_prod
differentiable to arbitrary order {jax-issue}#2597
. - Add
batch_group_count
toconv_general_dilated
{jax-issue}#2635
. - Add docstring for
test_util.check_grads
{jax-issue}#2656
. - Add
callback_transform
{jax-issue}#2665
. - Implement
rollaxis
,convolve
/correlate
1d & 2d,copysign
,trunc
,roots
, andquantile
/percentile
interpolation options.
- Fixed a performance regression for Resnet-50 on GPU.
- GitHub commits.
- JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
- Removed the internal function
lax._safe_mul
, which implemented the convention0. * nan == 0.
. This change means some programs when differentiated will produce nans when they previously produced correct values, though it ensures nans rather than silently incorrect results are produced for other programs. See #2447 and #1052 for details. - Added an
all_gather
parallel convenience function. - More type annotations in core code.
- jaxlib 0.1.41 broke cloud TPU support due to an API incompatibility. This release fixes it again.
- JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
- GitHub commits.
- Fixes Python 3.5 support. This will be the last JAX or jaxlib release that supports Python 3.5.
- GitHub commits.
- New features:
- {py:func}
jax.pmap
hasstatic_broadcast_argnums
argument which allows the user to specify arguments that should be treated as compile-time constants and should be broadcasted to all devices. It works analogously tostatic_argnums
in {py:func}jax.jit
. - Improved error messages for when tracers are mistakenly saved in global state.
- Added {py:func}
jax.nn.one_hot
utility function. - Added {mod}
jax.experimental.jet
for exponentially faster higher-order automatic differentiation. - Added more correctness checking to arguments of {py:func}
jax.lax.broadcast_in_dim
.
- {py:func}
- The minimum jaxlib version is now 0.1.41.
- Adds experimental support in Jaxlib for TensorFlow profiler, which allows tracing of CPU and GPU computations from TensorBoard.
- Includes prototype support for multihost GPU computations that communicate via NCCL.
- Improves performance of NCCL collectives on GPU.
- Adds TopK, CustomCallWithoutLayout, CustomCallWithLayout, IGammaGradA and RandomGamma implementations.
- Supports device assignments known at XLA compilation time.
-
Breaking changes
- The minimum jaxlib version is now 0.1.38.
- Simplified {py:class}
Jaxpr
by removing theJaxpr.freevars
andJaxpr.bound_subjaxprs
. The call primitives (xla_call
,xla_pmap
,sharded_call
, andremat_call
) get a new parametercall_jaxpr
with a fully-closed (noconstvars
) jaxpr. Also, added a new fieldcall_primitive
to primitives.
-
New features:
- Reverse-mode automatic differentiation (e.g.
grad
) oflax.cond
, making it now differentiable in both modes ({jax-issue}#2091
) - JAX now supports DLPack, which allows sharing CPU and GPU arrays in a zero-copy way with other libraries, such as PyTorch.
- JAX GPU DeviceArrays now support
__cuda_array_interface__
, which is another zero-copy protocol for sharing GPU arrays with other libraries such as CuPy and Numba. - JAX CPU device buffers now implement the Python buffer protocol, which allows zero-copy buffer sharing between JAX and NumPy.
- Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known as slow.
- Reverse-mode automatic differentiation (e.g.
- Updates XLA.
- CUDA 9.0 is no longer supported.
- CUDA 10.2 wheels are now built by default.
-
Breaking changes
- JAX has dropped Python 2 support, because Python 2 reached its end of life on January 1, 2020. Please update to Python 3.5 or newer.
-
New features
- Forward-mode automatic differentiation (
jvp
) of while loop ({jax-issue}#1980
)
-
New NumPy and SciPy functions:
- {py:func}
jax.numpy.fft.fft2
- {py:func}
jax.numpy.fft.ifft2
- {py:func}
jax.numpy.fft.rfft
- {py:func}
jax.numpy.fft.irfft
- {py:func}
jax.numpy.fft.rfft2
- {py:func}
jax.numpy.fft.irfft2
- {py:func}
jax.numpy.fft.rfftn
- {py:func}
jax.numpy.fft.irfftn
- {py:func}
jax.numpy.fft.fftfreq
- {py:func}
jax.numpy.fft.rfftfreq
- {py:func}
jax.numpy.linalg.matrix_rank
- {py:func}
jax.numpy.linalg.matrix_power
- {py:func}
jax.scipy.special.betainc
- {py:func}
-
Batched Cholesky decomposition on GPU now uses a more efficient batched kernel.
- Forward-mode automatic differentiation (
- With the Python 3 upgrade, JAX no longer depends on
fastcache
, which should help with installation.