Skip to content

Commit

Permalink
clarify ReplicateShardedData
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Nov 30, 2023
1 parent c2b8812 commit 9bc594e
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,11 +390,13 @@ ComputationClient::DataPtr PjRtComputationClient::CopyToDevice(
std::shared_ptr<PjRtComputationClient::PjRtData>
PjRtComputationClient::ReplicateShardedData(
const ComputationClient::DataPtr& handle) {
if (PjRtShardedData* sharded_data =
dynamic_cast<PjRtShardedData*>(handle.get())) {
if (auto unsharded_data = std::dynamic_pointer_cast<PjRtData>(handle)) {
return unsharded_data;
} else if (auto sharded_data =
std::dynamic_pointer_cast<PjRtShardedData>(handle)) {
XLA_COUNTER("ReplicateShardedData", 1);
TF_VLOG(1) << "ReplicateShardedData (handle=" << handle->GetHandle()
<< ", shape=" << handle->shape() << ")";
TF_VLOG(1) << "ReplicateShardedData (handle=" << sharded_data->GetHandle()
<< ", shape=" << sharded_data->shape() << ")";
if (sharded_data->GetSharding().type() == xla::OpSharding::REPLICATED) {
// Data is replicated, return the first shard
return sharded_data->shards[0];
Expand Down Expand Up @@ -432,8 +434,9 @@ PjRtComputationClient::ReplicateShardedData(

torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions
execute_options;
auto sharded_results = ExecuteReplicated(
*computations.front(), {handle}, GetLocalDevices(), execute_options);
auto sharded_results =
ExecuteReplicated(*computations.front(), {sharded_data},
GetLocalDevices(), execute_options);
XLA_CHECK(sharded_results.size() > 0)
<< "empty ExecuteReplicated results returned.";
XLA_CHECK(sharded_results.size() == 1)
Expand All @@ -442,9 +445,9 @@ PjRtComputationClient::ReplicateShardedData(
return std::dynamic_pointer_cast<PjRtShardedData>(sharded_results[0])
->shards[0];
}
auto pjrt_data = std::dynamic_pointer_cast<PjRtData>(handle);
XLA_CHECK(pjrt_data) << "Data must be PjRtData or PjRtShardedData.";
return pjrt_data;

XLA_ERROR() << "Data must be PjRtData or PjRtShardedData, got "
<< handle->ToString();
}

std::vector<xla::Literal> PjRtComputationClient::TransferFromServer(
Expand Down

0 comments on commit 9bc594e

Please sign in to comment.