Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Nov 3, 2023
1 parent 1097ce6 commit a5f5ef6
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
10 changes: 5 additions & 5 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,8 @@ PjRtComputationClient::ExecuteReplicated(
// TODO: tune and document cost estimate
pool_.ParallelFor(arguments.size(), 30000, [&](int64_t start, int64_t end) {
tsl::profiler::TraceMe activity(
"PjRtComputationClient::ExecuteReplicated_argument_handle_shard",
tsl::profiler::TraceMeLevel::kInfo);
"PjRtComputationClient::ExecuteReplicated_argument_handle_shard",
tsl::profiler::TraceMeLevel::kInfo);
for (int32_t i = start; i < end; ++i) {
auto pjrt_data =
std::dynamic_pointer_cast<PjRtShardedData>(arguments[i]);
Expand Down Expand Up @@ -687,8 +687,8 @@ PjRtComputationClient::ExecuteReplicated(
// TODO: tune and document cost estimate
pool_.ParallelFor(num_outputs, 30000, [&](int64_t start, int64_t end) {
tsl::profiler::TraceMe activity(
"PjRtComputationClient::ExecuteReplicated_result_handle_shard",
tsl::profiler::TraceMeLevel::kInfo);
"PjRtComputationClient::ExecuteReplicated_result_handle_shard",
tsl::profiler::TraceMeLevel::kInfo);
for (int32_t i = start; i < end; ++i) {
std::vector<std::shared_ptr<PjRtData>> shards(devices.size());
for (int32_t d = 0; d < devices.size(); d++) {
Expand All @@ -709,7 +709,7 @@ PjRtComputationClient::ExecuteReplicated(
}

pool_.Schedule([&, this, returned_futures = std::move(*returned_futures),
timed]() mutable {
timed]() mutable {
// Grab the shared lock and block the `WaitDeviceOps` until buffer is
// ready. Since this is the SPMD code path. There is no points to grab
// devices lock for every individual device.
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ class PjRtComputationClient : public ComputationClient {
std::shared_ptr<std::vector<std::string>> replication_devices_;
std::unordered_map<std::string, std::unique_ptr<std::shared_mutex>>
device_locks_;
tsl::thread::ThreadPool pool_ = tsl::thread::ThreadPool(tsl::Env::Default(), "pjrt", std::thread::hardware_concurrency());
tsl::thread::ThreadPool pool_ = tsl::thread::ThreadPool(
tsl::Env::Default(), "pjrt", std::thread::hardware_concurrency());

xla::PjRtDevice* StringToPjRtDevice(const std::string& device);
std::shared_lock<std::shared_mutex> lock_device_shared(
Expand Down Expand Up @@ -232,7 +233,6 @@ class PjRtComputationClient : public ComputationClient {

// Use XLA replication to re-assemble the sharded data.
std::shared_ptr<PjRtData> ReplicateShardedData(const DataPtr& handle);

};

} // namespace runtime
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/csrc/runtime/thread_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
#include <exception>
#include <mutex>

#include "tsl/platform/threadpool.h"
#include "torch_xla/csrc/runtime/metrics.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "tsl/platform/threadpool.h"

namespace torch_xla {
namespace runtime {
Expand All @@ -17,7 +17,8 @@ namespace {
tsl::thread::ThreadPool* GetThreadPool() {
static size_t num_threads = sys_util::GetEnvInt(
"XLA_THREAD_POOL_SIZE", std::thread::hardware_concurrency());
static tsl::thread::ThreadPool pool(tsl::Env::Default(), "pytorchxla", num_threads);
static tsl::thread::ThreadPool pool(tsl::Env::Default(), "pytorchxla",
num_threads);
return &pool;
}

Expand Down

0 comments on commit a5f5ef6

Please sign in to comment.