diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 7946fd85618..9dc1b32503b 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -636,8 +636,8 @@ PjRtComputationClient::ExecuteReplicated( argument_handles[d][i] = shard->buffer.get(); } }; - env::ScheduleIoClosure(util::MultiWait::Completer( - mwait, std::move(buffer_converter))); + env::ScheduleIoClosure( + util::MultiWait::Completer(mwait, std::move(buffer_converter))); } mwait->Wait(); } @@ -673,8 +673,9 @@ PjRtComputationClient::ExecuteReplicated( tsl::profiler::TraceMeLevel::kInfo); xla::HloModuleConfig hlo_config(computation.program_shape()); - std::unique_ptr hlo_modules = ConsumeValue( - xla::HloModule::CreateFromProto(computation.computation().proto(), hlo_config)); + std::unique_ptr hlo_modules = + ConsumeValue(xla::HloModule::CreateFromProto( + computation.computation().proto(), hlo_config)); const xla::Shape& result_shape = hlo_modules->result_shape(); TF_VLOG(3) << "Processing output with shape " << result_shape.ToString(); const std::vector& output_shapes = @@ -683,17 +684,20 @@ PjRtComputationClient::ExecuteReplicated( std::vector output_shardings; if (hlo_modules->has_spmd_output_sharding()) { - xla::OpSharding output_sharding = hlo_modules->spmd_output_sharding().ToProto(); + xla::OpSharding output_sharding = + hlo_modules->spmd_output_sharding().ToProto(); if (output_sharding.type() == xla::OpSharding::TUPLE) { auto tuple_shardings = output_sharding.tuple_shardings(); - output_shardings = std::vector({tuple_shardings.begin(), tuple_shardings.end()}); + output_shardings = std::vector( + {tuple_shardings.begin(), tuple_shardings.end()}); } else { output_shardings = std::vector({output_sharding}); } } else { // Without an explicit sharding annotation, the output is implicitly // replicated - output_shardings = std::vector(output_shapes.size(), xla::HloSharding::Replicate().ToProto()); + output_shardings = std::vector( + output_shapes.size(), xla::HloSharding::Replicate().ToProto()); } XLA_CHECK_EQ(output_shapes.size(), output_shardings.size()); @@ -707,14 +711,13 @@ PjRtComputationClient::ExecuteReplicated( } data_handles[i] = std::make_shared( - spmd_device_str, - output_shapes[i], - std::move(shards), output_shardings[i]); + spmd_device_str, output_shapes[i], std::move(shards), + output_shardings[i]); TF_VLOG(5) << "Created sharded data with shape " << data_handles[i]->shape().ToString(); }; - env::ScheduleIoClosure(util::MultiWait::Completer( - mwait, std::move(collect_shards))); + env::ScheduleIoClosure( + util::MultiWait::Completer(mwait, std::move(collect_shards))); } mwait->Wait(); }