Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Nov 22, 2023
1 parent 70c4526 commit 80926b3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -652,8 +652,8 @@ PjRtComputationClient::ExecuteReplicated(
// TODO: tune and document cost estimate
pool_.ParallelFor(arguments.size(), 30000, [&](int64_t start, int64_t end) {
tsl::profiler::TraceMe activity(
"PjRtComputationClient::ExecuteReplicated_argument_handle_shard",
tsl::profiler::TraceMeLevel::kInfo);
"PjRtComputationClient::ExecuteReplicated_argument_handle_shard",
tsl::profiler::TraceMeLevel::kInfo);
for (int32_t i = start; i < end; ++i) {
auto pjrt_data =
std::dynamic_pointer_cast<PjRtShardedData>(arguments[i]);
Expand Down Expand Up @@ -733,8 +733,8 @@ PjRtComputationClient::ExecuteReplicated(
// TODO: tune and document cost estimate
pool_.ParallelFor(num_outputs, 30000, [&](int64_t start, int64_t end) {
tsl::profiler::TraceMe activity(
"PjRtComputationClient::ExecuteReplicated_result_handle_shard",
tsl::profiler::TraceMeLevel::kInfo);
"PjRtComputationClient::ExecuteReplicated_result_handle_shard",
tsl::profiler::TraceMeLevel::kInfo);
for (int32_t i = start; i < end; ++i) {
std::vector<std::shared_ptr<PjRtData>> shards(devices.size());
for (int32_t d = 0; d < devices.size(); d++) {
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ class PjRtComputationClient : public ComputationClient {
std::unordered_map<std::string, xla::PjRtDevice* const> string_to_device_;
std::shared_ptr<std::vector<std::string>> replication_devices_;
OperationManager operation_manager_;
tsl::thread::ThreadPool pool_ = tsl::thread::ThreadPool(tsl::Env::Default(), "pjrt", std::thread::hardware_concurrency());
tsl::thread::ThreadPool pool_ = tsl::thread::ThreadPool(
tsl::Env::Default(), "pjrt", std::thread::hardware_concurrency());

xla::PjRtDevice* StringToPjRtDevice(const std::string& device);

Expand Down Expand Up @@ -237,7 +238,6 @@ class PjRtComputationClient : public ComputationClient {

// Use XLA replication to re-assemble the sharded data.
std::shared_ptr<PjRtData> ReplicateShardedData(const DataPtr& handle);

};

} // namespace runtime
Expand Down

0 comments on commit 80926b3

Please sign in to comment.