Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Oct 31, 2023
1 parent da90a7e commit dd9b848
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -636,8 +636,8 @@ PjRtComputationClient::ExecuteReplicated(
argument_handles[d][i] = shard->buffer.get();
}
};
env::ScheduleIoClosure(util::MultiWait::Completer(
mwait, std::move(buffer_converter)));
env::ScheduleIoClosure(
util::MultiWait::Completer(mwait, std::move(buffer_converter)));
}
mwait->Wait();
}
Expand Down Expand Up @@ -673,8 +673,9 @@ PjRtComputationClient::ExecuteReplicated(
tsl::profiler::TraceMeLevel::kInfo);

xla::HloModuleConfig hlo_config(computation.program_shape());
std::unique_ptr<xla::HloModule> hlo_modules = ConsumeValue(
xla::HloModule::CreateFromProto(computation.computation().proto(), hlo_config));
std::unique_ptr<xla::HloModule> hlo_modules =
ConsumeValue(xla::HloModule::CreateFromProto(
computation.computation().proto(), hlo_config));
const xla::Shape& result_shape = hlo_modules->result_shape();
TF_VLOG(3) << "Processing output with shape " << result_shape.ToString();
const std::vector<xla::Shape>& output_shapes =
Expand All @@ -683,17 +684,20 @@ PjRtComputationClient::ExecuteReplicated(

std::vector<xla::OpSharding> output_shardings;
if (hlo_modules->has_spmd_output_sharding()) {
xla::OpSharding output_sharding = hlo_modules->spmd_output_sharding().ToProto();
xla::OpSharding output_sharding =
hlo_modules->spmd_output_sharding().ToProto();
if (output_sharding.type() == xla::OpSharding::TUPLE) {
auto tuple_shardings = output_sharding.tuple_shardings();
output_shardings = std::vector<xla::OpSharding>({tuple_shardings.begin(), tuple_shardings.end()});
output_shardings = std::vector<xla::OpSharding>(
{tuple_shardings.begin(), tuple_shardings.end()});
} else {
output_shardings = std::vector<xla::OpSharding>({output_sharding});
}
} else {
// Without an explicit sharding annotation, the output is implicitly
// replicated
output_shardings = std::vector<xla::OpSharding>(output_shapes.size(), xla::HloSharding::Replicate().ToProto());
output_shardings = std::vector<xla::OpSharding>(
output_shapes.size(), xla::HloSharding::Replicate().ToProto());
}
XLA_CHECK_EQ(output_shapes.size(), output_shardings.size());

Expand All @@ -707,14 +711,13 @@ PjRtComputationClient::ExecuteReplicated(
}

data_handles[i] = std::make_shared<PjRtShardedData>(
spmd_device_str,
output_shapes[i],
std::move(shards), output_shardings[i]);
spmd_device_str, output_shapes[i], std::move(shards),
output_shardings[i]);
TF_VLOG(5) << "Created sharded data with shape "
<< data_handles[i]->shape().ToString();
};
env::ScheduleIoClosure(util::MultiWait::Completer(
mwait, std::move(collect_shards)));
env::ScheduleIoClosure(
util::MultiWait::Completer(mwait, std::move(collect_shards)));
}
mwait->Wait();
}
Expand Down

0 comments on commit dd9b848

Please sign in to comment.