diff --git a/WORKSPACE b/WORKSPACE index 2d3e69033fd..b2154725282 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -42,13 +42,12 @@ http_archive( "//openxla_patches:constexpr_return.diff", "//openxla_patches:gpu_race_condition.diff", "//openxla_patches:f16_abi_clang.diff", - "//openxla_patches:gpu_topk_rewriter.diff", "//openxla_patches:quant_dequant_converter.diff", "//openxla_patches:stablehlo_quant_seralization.diff", ], - strip_prefix = "xla-4f8381651977dff16b1d86bb4b198eb733c5f478", + strip_prefix = "xla-8744c9a94782cd7804f015e6d29df253437af3cb", urls = [ - "https://github.com/openxla/xla/archive/4f8381651977dff16b1d86bb4b198eb733c5f478.tar.gz", + "https://github.com/openxla/xla/archive/8744c9a94782cd7804f015e6d29df253437af3cb.tar.gz", ], ) diff --git a/openxla_patches/cache_urls.diff b/openxla_patches/cache_urls.diff index 10aeadbb2a4..e1d10bfe916 100644 --- a/openxla_patches/cache_urls.diff +++ b/openxla_patches/cache_urls.diff @@ -1,5 +1,5 @@ diff --git a/xla/mlir_hlo/WORKSPACE b/xla/mlir_hlo/WORKSPACE -index cc9eeb64f..b290eb455 100644 +index c3115e33d..d315ad745 100644 --- a/xla/mlir_hlo/WORKSPACE +++ b/xla/mlir_hlo/WORKSPACE @@ -35,7 +35,10 @@ http_archive( @@ -13,10 +13,9 @@ index cc9eeb64f..b290eb455 100644 + ], ) - load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure", "llvm_disable_optional_support_deps") - load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure", "llvm_disable_optional_support_deps") + load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure") diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index a4574d75d..f9ce37094 100644 +index d7f3a8093..a7af9c68a 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -13,7 +13,9 @@ def repo(name): @@ -28,4 +27,4 @@ index a4574d75d..f9ce37094 100644 + "https://storage.googleapis.com/tpu-pytorch/llvm-raw/{commit}.tar.gz".format(commit = LLVM_COMMIT), ], build_file = "//third_party/llvm:llvm.BUILD", - patch_file = [ \ No newline at end of file + patch_file = [ diff --git a/openxla_patches/constexpr_return.diff b/openxla_patches/constexpr_return.diff index 99825c02409..0872b5f6e78 100644 --- a/openxla_patches/constexpr_return.diff +++ b/openxla_patches/constexpr_return.diff @@ -1,12 +1,12 @@ diff --git a/xla/primitive_util.h b/xla/primitive_util.h -index 696147844..dfea15a4d 100644 +index 63fa4e193..ab352626c 100644 --- a/xla/primitive_util.h +++ b/xla/primitive_util.h -@@ -748,6 +748,7 @@ inline bool FitsInIntegralType(int64_t x, PrimitiveType ty) { +@@ -706,6 +706,7 @@ inline bool FitsInIntegralType(int64_t x, PrimitiveType ty) { std::numeric_limits::max() >= x; } LOG(FATAL) << "Invalid primitive type " << PrimitiveType_Name(ty); -+ return false; ++ return false; }, ty); } diff --git a/openxla_patches/gpu_race_condition.diff b/openxla_patches/gpu_race_condition.diff index dfdc3aa7460..683b156e7d2 100644 --- a/openxla_patches/gpu_race_condition.diff +++ b/openxla_patches/gpu_race_condition.diff @@ -1,8 +1,8 @@ diff --git a/xla/service/gpu/gpu_executable.cc b/xla/service/gpu/gpu_executable.cc -index 242961dd1..787275868 100644 +index 1f9903cb3..763b7fc23 100644 --- a/xla/service/gpu/gpu_executable.cc +++ b/xla/service/gpu/gpu_executable.cc -@@ -563,8 +563,7 @@ StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( +@@ -589,8 +589,7 @@ StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( } // Force synchronous execution if the allocator requires it. @@ -10,5 +10,5 @@ index 242961dd1..787275868 100644 - !memory_allocator->AllowsAsynchronousDeallocation(); + const bool block_host_until_done = true; - - // Lock the GPU with a shared lock so that we don't interfere with autotuning \ No newline at end of file + // Lock the GPU with a shared lock so that we don't interfere with autotuning + // that may be running during JIT compilation while allowing multiple XLA diff --git a/openxla_patches/gpu_topk_rewriter.diff b/openxla_patches/gpu_topk_rewriter.diff deleted file mode 100644 index 47ee3fa0f0a..00000000000 --- a/openxla_patches/gpu_topk_rewriter.diff +++ /dev/null @@ -1,184 +0,0 @@ -diff --git a/xla/service/topk_rewriter.cc b/xla/service/topk_rewriter.cc -index da872d962..1b7141055 100644 ---- a/xla/service/topk_rewriter.cc -+++ b/xla/service/topk_rewriter.cc -@@ -196,6 +196,8 @@ std::optional TopkRewriter::SortIsInTopK(HloInstruction* inst) { - return std::nullopt; - } - const int64_t sort_dim = sort->sort_dimension(); -+ const int64_t batch_dim = sort_dim == 1 ? 0 : 1; -+ const bool has_batch = data->shape().rank() == 2; - - bool supported = true; - std::optional k; -@@ -220,15 +222,10 @@ std::optional TopkRewriter::SortIsInTopK(HloInstruction* inst) { - supported = false; - break; - } -- for (int64_t i = 0; i < slice->slice_limits().size(); ++i) { -- if (i != sort_dim && -- slice->slice_limits(i) != slice->operand(0)->shape().dimensions(i)) { -- // Slicing along a non-sort dimension isn't supported. -- supported = false; -- break; -- } -- } -- if (!supported) { -+ if (has_batch && slice->slice_limits(batch_dim) != -+ slice->operand(0)->shape().dimensions(batch_dim)) { -+ // Slicing along the batch dimension isn't supported. -+ supported = false; - break; - } - if (k == std::nullopt) { -@@ -260,57 +257,29 @@ StatusOr TopkRewriter::TransformToCustomCall( - HloSortInstruction* sort = DynCast(inst); - HloInstruction* data = sort->mutable_operand(0); - const PrimitiveType element_type = data->shape().element_type(); -- const Shape data_shape = data->shape(); - -- if (element_type != F32 && element_type != BF16) { -+ if ((data->shape().rank() != 1 && data->shape().rank() != 2) || -+ (element_type != F32 && element_type != BF16)) { - continue; - } - -- // Sort dimension must be the first or last dimension. - const int64_t sort_dim = sort->sort_dimension(); -- if (sort_dim != 0 && sort_dim != data_shape.rank() - 1) { -- continue; -- } -+ const int64_t batch_dim = sort_dim == 1 ? 0 : 1; -+ const bool has_batch = data->shape().rank() == 2; - - // Profitability check. - if (!is_profitable_to_convert_(sort, *k)) { - continue; - } - -- HloInstruction* input = data; -- const bool has_batch = data_shape.rank() >= 2; -- const int64_t input_size = data_shape.dimensions(sort_dim); -- int64_t batch_size = 1; -- Shape topk_input_shape; -- -- if (has_batch) { -- // The TopK custom call expects either a 1d tensor or a 2d tensor with -- // the last dimension being the sort dimension. An input with rank > 2 -- // is reshaped into a 2d tensor by combining non-sort dimensions into a -- // single batch dimension. The original non-sort dimensions are -- // restored for the outputs with another reshape after the custom call. -- batch_size = -- ShapeUtil::ElementsIn(data_shape) / data_shape.dimensions(sort_dim); -- topk_input_shape = -- ShapeUtil::MakeShape(element_type, {batch_size, input_size}); -- -- if (data_shape.rank() > 2) { -- // Reshape to 2d. -- input = comp->AddInstruction(HloInstruction::CreateReshape( -- sort_dim == 0 -- ? ShapeUtil::MakeShape(element_type, {input_size, batch_size}) -- : ShapeUtil::MakeShape(element_type, -- {batch_size, input_size}), -- input)); -- } -- -- if (sort_dim == 0) { -- // Transpose for the custom call when sorting the first dimension. -- input = comp->AddInstruction( -- HloInstruction::CreateTranspose(topk_input_shape, input, {1, 0})); -- } -- } else { -- topk_input_shape = data_shape; -+ const int64_t batch_size = -+ has_batch ? sort->operand(0)->shape().dimensions(batch_dim) : 1; -+ const int64_t input_size = sort->operand(0)->shape().dimensions(sort_dim); -+ HloInstruction* input = sort->mutable_operand(0); -+ if (has_batch && sort_dim == 0) { -+ input = comp->AddInstruction(HloInstruction::CreateTranspose( -+ ShapeUtil::MakeShape(element_type, {batch_size, input_size}), input, -+ {1, 0})); - } - - Shape topk_shape = -@@ -331,26 +300,13 @@ StatusOr TopkRewriter::TransformToCustomCall( - comp->AddInstruction(HloInstruction::CreateGetTupleElement( - topk->shape().tuple_shapes(1), topk, 1)); - -- if (has_batch) { -- if (sort_dim == 0) { -- // Transpose back. -- value_gte = comp->AddInstruction(HloInstruction::CreateTranspose( -- ShapeUtil::MakeShape(element_type, {k.value(), batch_size}), -- value_gte, {1, 0})); -- index_gte = comp->AddInstruction(HloInstruction::CreateTranspose( -- ShapeUtil::MakeShape(S32, {k.value(), batch_size}), index_gte, -- {1, 0})); -- } -- if (data_shape.rank() > 2) { -- // Reshape back. -- std::vector shape_dim(data_shape.dimensions().begin(), -- data_shape.dimensions().end()); -- shape_dim[sort_dim] = k.value(); -- value_gte = comp->AddInstruction(HloInstruction::CreateReshape( -- ShapeUtil::MakeShape(element_type, shape_dim), value_gte)); -- index_gte = comp->AddInstruction(HloInstruction::CreateReshape( -- ShapeUtil::MakeShape(S32, shape_dim), index_gte)); -- } -+ if (has_batch && sort_dim == 0) { -+ value_gte = comp->AddInstruction(HloInstruction::CreateTranspose( -+ ShapeUtil::MakeShape(element_type, {k.value(), batch_size}), -+ value_gte, {1, 0})); -+ index_gte = comp->AddInstruction(HloInstruction::CreateTranspose( -+ ShapeUtil::MakeShape(S32, {k.value(), batch_size}), index_gte, -+ {1, 0})); - } - - for (HloInstruction* user : sort->users()) { -diff --git a/xla/service/topk_rewriter_test.cc b/xla/service/topk_rewriter_test.cc -index 36e723737..25ce150e0 100644 ---- a/xla/service/topk_rewriter_test.cc -+++ b/xla/service/topk_rewriter_test.cc -@@ -326,42 +326,6 @@ ENTRY cluster { - EXPECT_THAT(cc->custom_call_target(), "TopK"); - } - --TEST_F(TopkRewriterTest, RewriteReshape) { -- const std::string hlo_string = R"( --HloModule module --)" + getComparator() + R"( --ENTRY cluster { -- %arg_tuple.1 = f32[3,8,1234567] parameter(0) -- %iota.4 = s32[3,8,1234567] iota(), iota_dimension=2 -- %sort.27 = (f32[3,8,1234567], s32[3,8,1234567]) sort(%arg_tuple.1, %iota.4), -- dimensions={2}, is_stable=true, to_apply=%compare -- %get-tuple-element.28 = f32[3, 8,1234567] get-tuple-element(%sort.27), index=0 -- %slice.29 = f32[3,8,5] slice(%get-tuple-element.28), slice={[0:3], [0:8], [0:5]} -- %get-tuple-element.30 = s32[3,8,1234567] get-tuple-element(%sort.27), index=1 -- %slice.31 = s32[3,8,5] slice(%get-tuple-element.30), slice={[0:3], [0:8], [0:5]} -- ROOT %tuple.32 = (f32[3,8,5], s32[3,8,5]) tuple(%slice.29, %slice.31) --})"; -- TF_ASSERT_OK_AND_ASSIGN(auto module, -- ParseAndReturnVerifiedModule(hlo_string)); -- TopkRewriter rewriter( -- [](const HloSortInstruction*, int64_t) { return true; }); -- TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get())); -- TF_ASSERT_OK(HloDCE().Run(module.get()).status()); -- EXPECT_TRUE(changed); -- EXPECT_THAT(module->entry_computation()->root_instruction(), -- GmockMatch(m::Tuple( -- m::Reshape(m::GetTupleElement( -- m::CustomCall(m::Reshape(m::Parameter(0))), 0)), -- m::Reshape(m::GetTupleElement( -- m::CustomCall(m::Reshape(m::Parameter(0))), 1))))); -- const HloInstruction* cc = module->entry_computation() -- ->root_instruction() -- ->operand(0) -- ->operand(0) -- ->operand(0); -- EXPECT_THAT(cc->custom_call_target(), "TopK"); --} -- - TEST_F(TopkRewriterTest, RewriteNoIota) { - const std::string hlo_string = R"( - HloModule module diff --git a/openxla_patches/quant_dequant_converter.diff b/openxla_patches/quant_dequant_converter.diff index d35e36a5e22..0ad551e779f 100644 --- a/openxla_patches/quant_dequant_converter.diff +++ b/openxla_patches/quant_dequant_converter.diff @@ -2,7 +2,7 @@ // stablehlo.uniform_quantize/dequantize to be converted to stablehlo.uniform_quantize/dequantize. // The patch can be removed after quantize/dequantize, quantized dtype support is added to HLO. diff --git a/xla/translate/hlo_to_mhlo/BUILD b/xla/translate/hlo_to_mhlo/BUILD -index f74973ae1..8e3f0e06b 100644 +index 0f0c5e842..59a30c585 100644 --- a/xla/translate/hlo_to_mhlo/BUILD +++ b/xla/translate/hlo_to_mhlo/BUILD @@ -67,6 +67,7 @@ cc_library( @@ -14,7 +14,7 @@ index f74973ae1..8e3f0e06b 100644 "@llvm-project//mlir:SparseTensorDialect", "@tsl//tsl/platform:statusor", diff --git a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc -index 08d5f49c8..2f9ad1e0b 100644 +index cc7aa9e9e..c24f24b50 100644 --- a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -664,6 +664,70 @@ StatusOr HloFunctionImporter::ImportInstruction( @@ -115,7 +115,7 @@ index 08d5f49c8..2f9ad1e0b 100644 } else { attributes.push_back(builder_->getNamedAttr( diff --git a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc -index 9f05992c8..03cf4840d 100644 +index 1494efd9e..dcb3d9e89 100644 --- a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc +++ b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc @@ -19,6 +19,8 @@ limitations under the License. diff --git a/openxla_patches/stablehlo_quant_seralization.diff b/openxla_patches/stablehlo_quant_seralization.diff index fc4328dcfa7..2b6ece8d62f 100644 --- a/openxla_patches/stablehlo_quant_seralization.diff +++ b/openxla_patches/stablehlo_quant_seralization.diff @@ -33,7 +33,7 @@ index 000000000..24e23b67d + } // namespace + diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl -index 9f4494aac..64fa072bb 100644 +index 80ab0e479..caaa11080 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -15,5 +15,6 @@ def repo(): diff --git a/setup.py b/setup.py index a8a04c4c286..ae3d9824667 100644 --- a/setup.py +++ b/setup.py @@ -72,7 +72,7 @@ base_dir = os.path.dirname(os.path.abspath(__file__)) -_libtpu_version = '0.1.dev20231022' +_libtpu_version = '0.1.dev20231125' _libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl' diff --git a/test/test_autocast.py b/test/test_autocast.py index edbd834b61b..acbd0e03be3 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -229,24 +229,30 @@ def __init__(self): class TestAutocastBase(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.autocast_unsupported_lists = None + + @classmethod + def tearDownClass(cls): + del cls.autocast_lists + def setUp(self): super(TestAutocastBase, self).setUp() self.is_autocast_enabled = None - self.autocast_lists = None - self.autocast_unsupported_lists = None def tearDown(self): - del self.autocast_lists super(TestAutocastBase, self).tearDown() - def get_autocast_list(self, list_name): - if self.autocast_unsupported_lists: + @classmethod + def get_autocast_list(cls, list_name): + if cls.autocast_unsupported_lists: return [ - tp for tp in getattr(self.autocast_lists, list_name) - if tp[0] not in getattr(self.autocast_unsupported_lists, list_name) + tp for tp in getattr(cls.autocast_lists, list_name) + if tp[0] not in getattr(cls.autocast_unsupported_lists, list_name) ] else: - return [tp for tp in getattr(self.autocast_lists, list_name)] + return [tp for tp in getattr(cls.autocast_lists, list_name)] def args_maybe_kwargs(self, op_with_args): if len(op_with_args) == 2: @@ -345,46 +351,51 @@ def compare(first, second): f"CUDA autocast test.") class TestAutocastCuda(TestAutocastBase): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.autocast_lists = AutocastTestLists(torch.device(xm.xla_device())) + cls.autocast_lists_extra = AutocastCudaTestExtraLists( + torch.device(xm.xla_device())) + cls.autocast_unsupported_lists = AutocastCudaTestUnsupportedLists() + def setUp(self): super(TestAutocastCuda, self).setUp() self.is_autocast_enabled = torch.is_autocast_xla_enabled - self.autocast_lists = AutocastTestLists(torch.device(xm.xla_device())) - self.autocast_lists_extra = AutocastCudaTestExtraLists( - torch.device(xm.xla_device())) - self.autocast_unsupported_lists = AutocastCudaTestUnsupportedLists() def test_autocast_nn_fp16(self): with torch.backends.cudnn.flags(enabled=True, deterministic=True): - for op, args in self.get_autocast_list('nn_fp16'): + for op, args in TestAutocastCuda.get_autocast_list('nn_fp16'): self._run_autocast_outofplace( op, args, torch.float16, module=torch._C._nn) def test_autocast_linalg_fp16(self): with torch.backends.cudnn.flags(enabled=True, deterministic=True): - for op, args in self.get_autocast_list('linalg_fp16'): + for op, args in TestAutocastCuda.get_autocast_list('linalg_fp16'): self._run_autocast_outofplace( op, args, torch.float16, module=torch._C._linalg) def test_autocast_methods_fp16(self): with torch.backends.cudnn.flags(enabled=True, deterministic=True): - for op, args in self.get_autocast_list('methods_fp16'): + for op, args in TestAutocastCuda.get_autocast_list('methods_fp16'): self._run_autocast_outofplace(op, args, torch.float16, module=None) def test_autocast_banned(self): with torch.cuda.amp.autocast(): - for op, args, module in self.get_autocast_list('banned'): + for op, args, module in TestAutocastCuda.get_autocast_list('banned'): with self.assertRaises(RuntimeError): getattr(module, op)(*args) def test_autocast_torch_fp32(self): - for op_with_args in self.get_autocast_list('torch_fp32'): + for op_with_args in TestAutocastCuda.get_autocast_list('torch_fp32'): op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) self._run_autocast_outofplace( op, args, torch.float32, add_kwargs=maybe_kwargs) def test_autocast_torch_bf16(self): bf16_test_list = [ - tp for tp in getattr(self.autocast_lists_extra, 'torch_bf16') + tp + for tp in getattr(TestAutocastCuda.autocast_lists_extra, 'torch_bf16') ] for op_with_args in bf16_test_list: op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) @@ -396,26 +407,26 @@ def test_autocast_torch_bf16(self): autocast_dtype=torch.bfloat16) def test_autocast_torch_need_autocast_promote(self): - for op, args in self.get_autocast_list('torch_need_autocast_promote'): + for op, args in TestAutocastCuda.get_autocast_list( + 'torch_need_autocast_promote'): self._run_autocast_outofplace(op, args, torch.float32) def test_autocast_torch_expect_builtin_promote(self): - for op, args, out_type in self.get_autocast_list( + for op, args, out_type in TestAutocastCuda.get_autocast_list( 'torch_expect_builtin_promote'): self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type) def test_autocast_nn_fp32(self): - for op, args in self.get_autocast_list('nn_fp32'): + for op, args in TestAutocastCuda.get_autocast_list('nn_fp32'): self._run_autocast_outofplace( op, args, torch.float32, module=torch._C._nn) def test_autocast_methods_fp32(self): - for op, args in self.get_autocast_list('methods_fp32'): - print("autocast fp32", op) + for op, args in TestAutocastCuda.get_autocast_list('methods_fp32'): self._run_autocast_outofplace(op, args, torch.float32, module=None) def test_autocast_methods_expect_builtin_promote(self): - for op, args, out_type in self.get_autocast_list( + for op, args, out_type in TestAutocastCuda.get_autocast_list( 'methods_expect_builtin_promote'): self._run_autocast_outofplace( op, args, torch.float32, module=None, out_type=out_type) @@ -424,42 +435,46 @@ def test_autocast_methods_expect_builtin_promote(self): @unittest.skipIf(not xm.get_xla_supported_devices("TPU"), f"TPU autocast test.") class TestAutocastTPU(TestAutocastBase): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.autocast_lists = AutocastTPUTestLists(torch.device(xm.xla_device())) + def setUp(self): super(TestAutocastTPU, self).setUp() self.is_autocast_enabled = torch.is_autocast_xla_enabled - self.autocast_lists = AutocastTPUTestLists(torch.device(xm.xla_device())) def test_autocast_methods_bf16(self): - for op, args in self.get_autocast_list('methods_bf16'): + for op, args in TestAutocastTPU.get_autocast_list('methods_bf16'): self._run_autocast_outofplace(op, args, torch.bfloat16, module=None) def test_autocast_torch_fp32(self): - for op_with_args in self.get_autocast_list('torch_fp32'): + for op_with_args in TestAutocastTPU.get_autocast_list('torch_fp32'): op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args) self._run_autocast_outofplace( op, args, torch.float32, add_kwargs=maybe_kwargs) def test_autocast_torch_need_autocast_promote(self): - for op, args in self.get_autocast_list('torch_need_autocast_promote'): + for op, args in TestAutocastTPU.get_autocast_list( + 'torch_need_autocast_promote'): self._run_autocast_outofplace(op, args, torch.float32) def test_autocast_torch_expect_builtin_promote(self): - for op, args, out_type in self.get_autocast_list( + for op, args, out_type in TestAutocastTPU.get_autocast_list( 'torch_expect_builtin_promote'): self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type) def test_autocast_nn_fp32(self): - for op, args in self.get_autocast_list('nn_fp32'): + for op, args in TestAutocastTPU.get_autocast_list('nn_fp32'): self._run_autocast_outofplace( op, args, torch.float32, module=torch._C._nn) def test_autocast_methods_fp32(self): - for op, args in self.get_autocast_list('methods_fp32'): - print("autocast fp32", op) + for op, args in TestAutocastTPU.get_autocast_list('methods_fp32'): self._run_autocast_outofplace(op, args, torch.float32, module=None) def test_autocast_methods_expect_builtin_promote(self): - for op, args, out_type in self.get_autocast_list( + for op, args, out_type in TestAutocastTPU.get_autocast_list( 'methods_expect_builtin_promote'): self._run_autocast_outofplace( op, args, torch.float32, module=None, out_type=out_type) diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 28e09be6c68..9dc516d4316 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -321,7 +321,7 @@ class ComputationClient { virtual int GetNumProcesses() const = 0; using DeviceAttribute = - std::variant, float, bool>; + std::variant, float>; virtual const absl::flat_hash_map< std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>& diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 1aa017bc33d..8b74fb67b22 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -152,29 +152,28 @@ PjRtComputationClient::PjRtComputationClient() { xla::PjRtClient::KeyValuePutCallback kv_put = nullptr; if (distributed_client != nullptr) { std::string key_prefix = "gpu:"; - kv_get = [distributed_client, key_prefix](const std::string& k, - absl::Duration timeout) { + kv_get = [distributed_client, key_prefix]( + std::string_view k, + absl::Duration timeout) -> xla::StatusOr { return distributed_client->BlockingKeyValueGet( absl::StrCat(key_prefix, k), timeout); }; - kv_put = [distributed_client, key_prefix](const std::string& k, - const std::string& v) { + kv_put = [distributed_client, key_prefix]( + std::string_view k, std::string_view v) -> xla::Status { return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v); }; } TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id=" << global_process_rank << ", num_nodes=" << global_world_size; - client_ = std::move(xla::GetStreamExecutorGpuClient( - /*asynchronous=*/async, - /*allocator_config=*/GetGpuAllocatorConfig(), - /*node_id=*/global_process_rank, - /*num_nodes=*/global_world_size, - /*allowed_devices=*/allowed_devices, - /*platform_name=*/"gpu", - /*should_stage_host_to_device_transfers=*/true, - /*kv_get=*/kv_get, - /*kv_put=*/kv_put) - .value()); + xla::GpuClientOptions options; + options.allocator_config = GetGpuAllocatorConfig(); + options.node_id = global_process_rank; + options.num_nodes = global_world_size; + options.allowed_devices = allowed_devices; + options.platform_name = "gpu"; + options.kv_get = kv_get; + options.kv_put = kv_put; + client_ = std::move(xla::GetStreamExecutorGpuClient(options).value()); } else if (device_type == "XPU") { TF_VLOG(1) << "Initializing PjRt XPU client..."; XLA_CHECK_OK(