diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index ff991fc89403..086eb24e51c1 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -114,9 +114,10 @@ void CheckSubOperandTypes(at::ScalarType type1, at::ScalarType type2) { c10::optional PromoteIntegralType( at::ScalarType src_dtype, const c10::optional& opt_dtype) { - return opt_dtype.has_value() ? opt_dtype.value() - : at::isIntegralType(src_dtype, /*includeBool=*/true) ? at::kLong - : opt_dtype; + return opt_dtype.has_value() + ? opt_dtype.value() + : at::isIntegralType(src_dtype, /*includeBool=*/true) ? at::kLong + : opt_dtype; } bool IsTypeWithLargerRangeThanLong(torch::ScalarType dtype) { diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 30e7cc540713..c8c0d7d0cd61 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -969,28 +969,27 @@ void InitXlaModuleBindings(py::module m) { ShardingUtil::ShardingType(sharding_type)), global_shape, minibatch); })); - m.def( - "_xla_tensors_from_aten", - [](const std::vector& tensors, - const std::vector& devices, - const std::optional>& - shardings) { - std::vector result; - { - NoGilSection nogil; - std::vector xla_tensors = - GetXlaTensorsFromAten(tensors, devices, shardings); - result.reserve(xla_tensors.size()); - for (size_t i = 0; i < xla_tensors.size(); ++i) { - result.push_back(torch::autograd::make_variable( - xla_tensors[i], - /*requires_grad=*/tensors.at(i).requires_grad())); + m.def("_xla_tensors_from_aten", + [](const std::vector& tensors, + const std::vector& devices, + const std::optional>& + shardings) { + std::vector result; + { + NoGilSection nogil; + std::vector xla_tensors = + GetXlaTensorsFromAten(tensors, devices, shardings); + result.reserve(xla_tensors.size()); + for (size_t i = 0; i < xla_tensors.size(); ++i) { + result.push_back(torch::autograd::make_variable( + xla_tensors[i], + /*requires_grad=*/tensors.at(i).requires_grad())); + } } - } - return result; - }, - py::arg("tensors"), py::arg("devices"), - py::arg("shardings") = py::none()); + return result; + }, + py::arg("tensors"), py::arg("devices"), + py::arg("shardings") = py::none()); m.def("_xla_get_cpu_tensors", [](const std::vector& tensors) { std::vector result; { @@ -1290,51 +1289,45 @@ void InitXlaModuleBindings(py::module m) { } return list; }); - m.def( - "_xla_set_rng_seed", - [](uint64_t seed, const std::string& device) { - SetRngSeed(seed, device); - }, - py::arg("seed") = 101, py::arg("device") = ""); - m.def( - "_xla_get_rng_seed", - [](const std::string& device) { return GetRngSeed(device); }, - py::arg("device") = ""); - m.def( - "_xla_sync_multi", - [](const std::vector& tensors, - const std::vector& devices, bool wait, - bool sync_xla_data) { - NoGilSection nogil; - SyncTensors(tensors, devices, wait, sync_xla_data); - }, - py::arg("tensors"), py::arg("devices"), py::arg("wait") = true, - py::arg("sync_xla_data") = true); - m.def( - "_xla_warm_up_cache", - [](const std::vector& tensors, - const std::vector& devices) { - NoGilSection nogil; - SyncTensors(tensors, devices, /*wait=*/false, /*sync_xla_data=*/false, - /*warm_up_cache_only=*/true); - }, - py::arg("tensors"), py::arg("devices")); - m.def( - "_xla_sync_live_tensors", - [](const std::string& device, const std::vector& devices, - bool wait) { - NoGilSection nogil; - SyncLiveTensors(device, devices, wait); - }, - py::arg("device") = "", py::arg("devices"), py::arg("wait") = true); - m.def( - "_xla_step_marker", - [](const std::string& device, const std::vector& devices, - bool wait) { - NoGilSection nogil; - StepMarker(device, devices, wait); - }, - py::arg("device") = "", py::arg("devices"), py::arg("wait") = true); + m.def("_xla_set_rng_seed", + [](uint64_t seed, const std::string& device) { + SetRngSeed(seed, device); + }, + py::arg("seed") = 101, py::arg("device") = ""); + m.def("_xla_get_rng_seed", + [](const std::string& device) { return GetRngSeed(device); }, + py::arg("device") = ""); + m.def("_xla_sync_multi", + [](const std::vector& tensors, + const std::vector& devices, bool wait, + bool sync_xla_data) { + NoGilSection nogil; + SyncTensors(tensors, devices, wait, sync_xla_data); + }, + py::arg("tensors"), py::arg("devices"), py::arg("wait") = true, + py::arg("sync_xla_data") = true); + m.def("_xla_warm_up_cache", + [](const std::vector& tensors, + const std::vector& devices) { + NoGilSection nogil; + SyncTensors(tensors, devices, /*wait=*/false, /*sync_xla_data=*/false, + /*warm_up_cache_only=*/true); + }, + py::arg("tensors"), py::arg("devices")); + m.def("_xla_sync_live_tensors", + [](const std::string& device, const std::vector& devices, + bool wait) { + NoGilSection nogil; + SyncLiveTensors(device, devices, wait); + }, + py::arg("device") = "", py::arg("devices"), py::arg("wait") = true); + m.def("_xla_step_marker", + [](const std::string& device, const std::vector& devices, + bool wait) { + NoGilSection nogil; + StepMarker(device, devices, wait); + }, + py::arg("device") = "", py::arg("devices"), py::arg("wait") = true); m.def("_get_stablehlo", [](const std::vector& tensors, const std::string& device, const std::vector& devices, @@ -1371,19 +1364,18 @@ void InitXlaModuleBindings(py::module m) { } return retlist; }); - m.def( - "_xla_wait_device_ops", - [](const std::vector& devices) { - NoGilSection nogil; - XLAGraphExecutor::Get()->WaitDeviceOps(devices); - if (UseVirtualDevice()) { - std::vector spmd_device = {"SPMD:0"}; - runtime::GetComputationClient()->WaitDeviceOps(spmd_device); - } else { - runtime::GetComputationClient()->WaitDeviceOps(devices); - } - }, - py::arg("devices")); + m.def("_xla_wait_device_ops", + [](const std::vector& devices) { + NoGilSection nogil; + XLAGraphExecutor::Get()->WaitDeviceOps(devices); + if (UseVirtualDevice()) { + std::vector spmd_device = {"SPMD:0"}; + runtime::GetComputationClient()->WaitDeviceOps(spmd_device); + } else { + runtime::GetComputationClient()->WaitDeviceOps(devices); + } + }, + py::arg("devices")); m.def("_xla_counter_names", []() { auto counter_names = torch::lazy::GetCounterNames(); auto xla_counter_names = runtime::metrics::GetCounterNames(); @@ -1448,23 +1440,21 @@ void InitXlaModuleBindings(py::module m) { torch::lazy::MetricsArena::Get()->ResetMetrics(); runtime::metrics::ClearMetrics(); }); - m.def( - "_xla_tensors_report", - [](size_t nodes_threshold, const std::string& device) { - return GetLiveTensorsReport(nodes_threshold, device); - }, - py::arg("nodes_threshold") = 100, py::arg("device") = ""); + m.def("_xla_tensors_report", + [](size_t nodes_threshold, const std::string& device) { + return GetLiveTensorsReport(nodes_threshold, device); + }, + py::arg("nodes_threshold") = 100, py::arg("device") = ""); m.def("_xla_memory_info", [](const std::string& device) -> py::object { return GetMemoryInfo(device); }); - m.def( - "_xla_set_use_full_mat_mul_precision", - [](bool use_full_mat_mul_precision) { - XlaHelpers::set_mat_mul_precision(use_full_mat_mul_precision - ? xla::PrecisionConfig::HIGHEST - : xla::PrecisionConfig::DEFAULT); - }, - py::arg("use_full_mat_mul_precision") = true); + m.def("_xla_set_use_full_mat_mul_precision", + [](bool use_full_mat_mul_precision) { + XlaHelpers::set_mat_mul_precision( + use_full_mat_mul_precision ? xla::PrecisionConfig::HIGHEST + : xla::PrecisionConfig::DEFAULT); + }, + py::arg("use_full_mat_mul_precision") = true); py::class_(m, "XlaBuilder"); py::class_(m, "XlaOp");