From a5f5ef62cd9a07874c47931bd4dcd1640b15c3ef Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Fri, 3 Nov 2023 20:03:46 +0000 Subject: [PATCH] formatting --- torch_xla/csrc/runtime/pjrt_computation_client.cc | 10 +++++----- torch_xla/csrc/runtime/pjrt_computation_client.h | 4 ++-- torch_xla/csrc/runtime/thread_pool.cc | 5 +++-- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index a0e1ce8f9b0..0333c04b1e3 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -623,8 +623,8 @@ PjRtComputationClient::ExecuteReplicated( // TODO: tune and document cost estimate pool_.ParallelFor(arguments.size(), 30000, [&](int64_t start, int64_t end) { tsl::profiler::TraceMe activity( - "PjRtComputationClient::ExecuteReplicated_argument_handle_shard", - tsl::profiler::TraceMeLevel::kInfo); + "PjRtComputationClient::ExecuteReplicated_argument_handle_shard", + tsl::profiler::TraceMeLevel::kInfo); for (int32_t i = start; i < end; ++i) { auto pjrt_data = std::dynamic_pointer_cast(arguments[i]); @@ -687,8 +687,8 @@ PjRtComputationClient::ExecuteReplicated( // TODO: tune and document cost estimate pool_.ParallelFor(num_outputs, 30000, [&](int64_t start, int64_t end) { tsl::profiler::TraceMe activity( - "PjRtComputationClient::ExecuteReplicated_result_handle_shard", - tsl::profiler::TraceMeLevel::kInfo); + "PjRtComputationClient::ExecuteReplicated_result_handle_shard", + tsl::profiler::TraceMeLevel::kInfo); for (int32_t i = start; i < end; ++i) { std::vector> shards(devices.size()); for (int32_t d = 0; d < devices.size(); d++) { @@ -709,7 +709,7 @@ PjRtComputationClient::ExecuteReplicated( } pool_.Schedule([&, this, returned_futures = std::move(*returned_futures), - timed]() mutable { + timed]() mutable { // Grab the shared lock and block the `WaitDeviceOps` until buffer is // ready. Since this is the SPMD code path. There is no points to grab // devices lock for every individual device. diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index f73be0e8c09..2dd418514b6 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -104,7 +104,8 @@ class PjRtComputationClient : public ComputationClient { std::shared_ptr> replication_devices_; std::unordered_map> device_locks_; - tsl::thread::ThreadPool pool_ = tsl::thread::ThreadPool(tsl::Env::Default(), "pjrt", std::thread::hardware_concurrency()); + tsl::thread::ThreadPool pool_ = tsl::thread::ThreadPool( + tsl::Env::Default(), "pjrt", std::thread::hardware_concurrency()); xla::PjRtDevice* StringToPjRtDevice(const std::string& device); std::shared_lock lock_device_shared( @@ -232,7 +233,6 @@ class PjRtComputationClient : public ComputationClient { // Use XLA replication to re-assemble the sharded data. std::shared_ptr ReplicateShardedData(const DataPtr& handle); - }; } // namespace runtime diff --git a/torch_xla/csrc/runtime/thread_pool.cc b/torch_xla/csrc/runtime/thread_pool.cc index 78fb57c7bb7..97ffebd5e20 100644 --- a/torch_xla/csrc/runtime/thread_pool.cc +++ b/torch_xla/csrc/runtime/thread_pool.cc @@ -5,9 +5,9 @@ #include #include -#include "tsl/platform/threadpool.h" #include "torch_xla/csrc/runtime/metrics.h" #include "torch_xla/csrc/runtime/tf_logging.h" +#include "tsl/platform/threadpool.h" namespace torch_xla { namespace runtime { @@ -17,7 +17,8 @@ namespace { tsl::thread::ThreadPool* GetThreadPool() { static size_t num_threads = sys_util::GetEnvInt( "XLA_THREAD_POOL_SIZE", std::thread::hardware_concurrency()); - static tsl::thread::ThreadPool pool(tsl::Env::Default(), "pytorchxla", num_threads); + static tsl::thread::ThreadPool pool(tsl::Env::Default(), "pytorchxla", + num_threads); return &pool; }