diff --git a/WORKSPACE b/WORKSPACE index 82e5db1fc335..c648d9173e72 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -49,11 +49,10 @@ http_archive( "//openxla_patches:cache_urls.diff", "//openxla_patches:gpu_race_condition.diff", "//openxla_patches:f16_abi_clang.diff", - "//openxla_patches:unbounded_dynamism.diff", ], - strip_prefix = "xla-25c8a6781af6be51d3bc43a0953b07803ab761ea", + strip_prefix = "xla-b9ece5154181608ff1acea9b02a1ee7a526af16b", urls = [ - "https://github.com/openxla/xla/archive/25c8a6781af6be51d3bc43a0953b07803ab761ea.tar.gz", + "https://github.com/openxla/xla/archive/b9ece5154181608ff1acea9b02a1ee7a526af16b.tar.gz", ], ) diff --git a/openxla_patches/unbounded_dynamism.diff b/openxla_patches/unbounded_dynamism.diff deleted file mode 100644 index 37a424eddaad..000000000000 --- a/openxla_patches/unbounded_dynamism.diff +++ /dev/null @@ -1,40 +0,0 @@ -diff --git a/xla/client/xla_builder.cc b/xla/client/xla_builder.cc -index 57f0529b5..5f8a1c582 100644 ---- a/xla/client/xla_builder.cc -+++ b/xla/client/xla_builder.cc -@@ -1182,12 +1182,16 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs, - this, rhs, lhs, *lhs_shape)); - } - } else { -- TF_ASSIGN_OR_RETURN(UnboundedBroadcastResult broadcast_result, -- BroadcastToOutputShapeWithUnbounded( -- this, lhs, *lhs_shape, rhs, *rhs_shape, shape, -- broadcast_dimensions)); -- updated_lhs = broadcast_result.lhs; -- updated_rhs = broadcast_result.rhs; -+ if (!ShapeUtil::SameDimensions(*lhs_shape, *rhs_shape)) { -+ Shape output_shape = shape; -+ output_shape.set_element_type(lhs_shape->element_type()); -+ TF_ASSIGN_OR_RETURN(UnboundedBroadcastResult broadcast_result, -+ BroadcastToOutputShapeWithUnbounded( -+ this, lhs, *lhs_shape, rhs, *rhs_shape, -+ output_shape, broadcast_dimensions)); -+ updated_lhs = broadcast_result.lhs; -+ updated_rhs = broadcast_result.rhs; -+ } - } - } - -diff --git a/xla/client/xla_builder_test.cc b/xla/client/xla_builder_test.cc -index ee9d100a3..231cd6baf 100644 ---- a/xla/client/xla_builder_test.cc -+++ b/xla/client/xla_builder_test.cc -@@ -2436,6 +2436,8 @@ INSTANTIATE_TEST_SUITE_P( - /*broadcast_dimensions=*/{}, "f32[?, ?, 2, 2, <=2, <=2, ?]", &Mul}, - {"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array, - "f32[?, 10]", &Mul}, -+ {"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array, -+ "pred[?, 10]", &Ne}, - {"f32[1, ?, 2, ?, <=2, ?, ?]", "f32[?, 1, ?, 2, ?, <=2, ?]", - /*broadcast_dimensions=*/{}, "f32[?, ?, 2, 2, <=2, <=2, ?]", &Pow}, - {"f32[?, 10]", "f32[1]", /*broadcast_dimensions=*/zero_array, diff --git a/setup.py b/setup.py index 16f14494f743..22b2e0b4956f 100644 --- a/setup.py +++ b/setup.py @@ -64,7 +64,7 @@ base_dir = os.path.dirname(os.path.abspath(__file__)) -_date = '20240320' +_date = '20240322' _libtpu_version = f'0.1.dev{_date}' _libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl' _jax_version = f'0.4.26.dev{_date}'