diff --git a/WORKSPACE b/WORKSPACE index c73231dd4021..aebf8d6ba732 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -38,6 +38,9 @@ python_configure( # b) get the sha256 hash of the commit by running: # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update the sha256 with the result. + +xla_hash = 'dc2b3b3545b41aa3280291fe40face744d187ad7' + http_archive( name = "xla", patch_args = [ @@ -50,9 +53,9 @@ http_archive( "//openxla_patches:gpu_race_condition.diff", "//openxla_patches:f16_abi_clang.diff", ], - strip_prefix = "xla-80462ef5b22360df177fe24fc13c81b235d3f3a2", + strip_prefix = "xla-" + xla_hash, urls = [ - "https://github.com/openxla/xla/archive/80462ef5b22360df177fe24fc13c81b235d3f3a2.tar.gz", + "https://github.com/openxla/xla/archive/" + xla_hash + ".tar.gz", ], ) diff --git a/setup.py b/setup.py index df54224ed98b..6f38e5645a42 100644 --- a/setup.py +++ b/setup.py @@ -64,10 +64,10 @@ base_dir = os.path.dirname(os.path.abspath(__file__)) -_date = '20240502' +_date = '20240513' _libtpu_version = f'0.1.dev{_date}' _libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl' -_jax_version = f'0.4.27.dev{_date}' +_jax_version = f'0.4.29.dev{_date}' def _get_build_mode(): diff --git a/torch_xla/csrc/reduction.cpp b/torch_xla/csrc/reduction.cpp index f7b29d1cf3b1..56702e79279b 100644 --- a/torch_xla/csrc/reduction.cpp +++ b/torch_xla/csrc/reduction.cpp @@ -459,9 +459,9 @@ xla::XlaOp BuildArgMin(xla::XlaOp input, int64_t dim, bool keepdim) { shape = &ShapeHelper::ShapeOfXlaOp(operand); } } - xla::XlaOp result = xla::ArgMin( + xla::XlaOp result = xla::ArgMinMax( operand, GetXlaPrimitiveTypeForCurrentDevice(xla::PrimitiveType::S64), - dim); + dim, /* is_min */ true); if (keepdim) { auto dimensions = torch::lazy::ToVector(shape->dimensions()); if (dim_is_none) {