diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 54f32a86932..76a9d7c19b6 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -390,11 +390,13 @@ ComputationClient::DataPtr PjRtComputationClient::CopyToDevice( std::shared_ptr PjRtComputationClient::ReplicateShardedData( const ComputationClient::DataPtr& handle) { - if (PjRtShardedData* sharded_data = - dynamic_cast(handle.get())) { + if (auto unsharded_data = std::dynamic_pointer_cast(handle)) { + return unsharded_data; + } else if (auto sharded_data = + std::dynamic_pointer_cast(handle)) { XLA_COUNTER("ReplicateShardedData", 1); - TF_VLOG(1) << "ReplicateShardedData (handle=" << handle->GetHandle() - << ", shape=" << handle->shape() << ")"; + TF_VLOG(1) << "ReplicateShardedData (handle=" << sharded_data->GetHandle() + << ", shape=" << sharded_data->shape() << ")"; if (sharded_data->GetSharding().type() == xla::OpSharding::REPLICATED) { // Data is replicated, return the first shard return sharded_data->shards[0]; @@ -432,8 +434,9 @@ PjRtComputationClient::ReplicateShardedData( torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions execute_options; - auto sharded_results = ExecuteReplicated( - *computations.front(), {handle}, GetLocalDevices(), execute_options); + auto sharded_results = + ExecuteReplicated(*computations.front(), {sharded_data}, + GetLocalDevices(), execute_options); XLA_CHECK(sharded_results.size() > 0) << "empty ExecuteReplicated results returned."; XLA_CHECK(sharded_results.size() == 1) @@ -442,9 +445,9 @@ PjRtComputationClient::ReplicateShardedData( return std::dynamic_pointer_cast(sharded_results[0]) ->shards[0]; } - auto pjrt_data = std::dynamic_pointer_cast(handle); - XLA_CHECK(pjrt_data) << "Data must be PjRtData or PjRtShardedData."; - return pjrt_data; + + XLA_ERROR() << "Data must be PjRtData or PjRtShardedData, got " + << handle->ToString(); } std::vector PjRtComputationClient::TransferFromServer(