Skip to content

Commit

Permalink
Compute and hash buffer_donor_indices for step marker (#8467)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcuiaws authored Dec 17, 2024
1 parent b2b890e commit 0154850
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 65 deletions.
19 changes: 19 additions & 0 deletions test/test_input_output_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,25 @@ def try_grad_accum(model, device, train_x, train_label, accum_steps):
alias_count == 1.0
), f"Expect 1 input-output alias pair for gradient accumulation, got {alias_count}"

def test_separate_graphs(self):
"""
Test that paramater aliasing differences should produce different graphs.
"""
xla_device = xm.xla_device()
t0 = torch.tensor([1], device=xla_device)
t1 = torch.tensor([2], device=xla_device)
xm.mark_step()

t1.add_(t0)
xm.mark_step()

# This needs to be a separate graph, otherwise t1 can be corrupted
# or result in PJRT error.
t2 = t1 + t0
xm.mark_step()

self.assertEqual(t1.item(), 3)


if __name__ == '__main__':
test = unittest.main()
Expand Down
121 changes: 62 additions & 59 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -451,20 +451,6 @@ std::vector<size_t> GetBufferDonorIndexFromUserConfig(
return buffer_donor_indexs;
}

std::vector<size_t> XLAGraphExecutor::SetBufferDonorsFromUserConfig(
LoweringContext* lowering_ctx) {
const std::vector<torch::lazy::BackendDataPtr>& parameters_data =
lowering_ctx->GetParametersData();
std::vector<size_t> buffer_donor_indexs =
GetBufferDonorIndexFromUserConfig(parameters_data);
for (size_t i : buffer_donor_indexs) {
lowering_ctx->builder()->AddBufferDonor(/*param_number=*/i,
/*param_index=*/{});
}
TORCH_LAZY_VALUE_METRIC("InputOutputAliasCount", buffer_donor_indexs.size());
return buffer_donor_indexs;
}

void XLAGraphExecutor::WaitDeviceOps(absl::Span<const std::string> devices) {
std::set<torch::lazy::BackendDevice> wait_devices;
if (!devices.empty()) {
Expand Down Expand Up @@ -1262,9 +1248,9 @@ XLAGraphExecutor::TryRunCachedSync(
std::move(cached_computation), tensor_data_vec));
}

std::vector<size_t> XLAGraphExecutor::SetBufferDonors(
std::vector<size_t> GetBufferDonorIndexForStepMarker(
const std::vector<XLATensorPtr>& tensors, absl::Span<const size_t> indices,
LoweringContext* lowering_ctx) {
const std::vector<torch::lazy::BackendDataPtr>& parameters_data) {
std::unordered_map<int64_t, size_t> output_tensor_id_map;
std::vector<size_t> buffer_donor_indexs;
// tensors[indices] represent all tensors that needs to be updated after
Expand All @@ -1275,7 +1261,6 @@ std::vector<size_t> XLAGraphExecutor::SetBufferDonors(
int64_t tensor_id = tensors[tensor_index]->data()->alias_id;
output_tensor_id_map[tensor_id] = i;
}
const auto& parameters_data = lowering_ctx->GetParametersData();
std::vector<ssize_t> alias_map(indices.size(), -1);
for (size_t i = 0; i < parameters_data.size(); ++i) {
auto* data_info =
Expand All @@ -1287,45 +1272,19 @@ std::vector<size_t> XLAGraphExecutor::SetBufferDonors(
// this buffer is not needed after execution since XLATensor will get a
// new buffer.
if (it != output_tensor_id_map.end()) {
lowering_ctx->builder()->AddBufferDonor(/*param_number=*/i,
/*param_index=*/{});
buffer_donor_indexs.push_back(i);
}
}
}
TORCH_LAZY_VALUE_METRIC("InputOutputAliasCount", buffer_donor_indexs.size());
return buffer_donor_indexs;
}

XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
std::vector<XLATensorPtr>& tensors, absl::Span<const std::string> devices,
const SyncTensorCollection& coll, PostOrderData* po_data,
const std::vector<torch::lazy::Value>& ir_values) {
tsl::profiler::TraceMe activity(
[&] {
return tsl::profiler::TraceMeEncode(
"XLAGraphExecutor::Compile",
{{"graph_hash", torch::lazy::HashToString(coll.hash)}});
},
tsl::profiler::TraceMeLevel::kInfo);
std::vector<size_t> XLAGraphExecutor::GetBufferDonors(
const std::vector<XLATensorPtr>& tensors, const SyncTensorCollection& coll,
const std::vector<torch::lazy::BackendDataPtr>& parameters_data) {
static const bool enable_aliasing =
runtime::sys_util::GetEnvBool("XLA_ENABLE_PARAM_ALIASING", true);
static const size_t parameter_wrapping_threadshold =
runtime::sys_util::GetEnvInt("XLA_PARAMETER_WRAPPING_THREADSHOLD", 3200);
static const bool use_autosharding = ShardingUtil::GetAutoSharding();
std::string graph_name =
(CurrentGraphName() != "") ? CurrentGraphName() : "SyncTensorsGraph";
LoweringContext lowering_ctx(graph_name, coll.device, po_data->post_order,
std::move(po_data->emission_map));
for (auto ir_value : ir_values) {
xla::XlaOp root = lowering_ctx.GetOutputOp(
torch::lazy::Output(ir_value.node.get(), ir_value.index));
lowering_ctx.AddResult(root);
}
// Always execute sharded when running in SPMD mode
bool is_sharded = (coll.device == GetVirtualDevice()) || UseVirtualDevice();
// Annotate HLO sharding selectively in the compuation.
ShardingUtil::SetHloSharding(&lowering_ctx);

std::vector<size_t> buffer_donor_indices;
// TODO(yeounoh) enable aliasing is disabled for partitioned computation,
Expand Down Expand Up @@ -1357,14 +1316,59 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
// will later fetch the new value of A, which is incorrect.
// But, when we issue a step barrier (force_ltc_data == true) we have to
// turn everything into DEVICE_DATA, so we can activate aliasing.
buffer_donor_indices =
SetBufferDonors(tensors, coll.indices, &lowering_ctx);
buffer_donor_indices = GetBufferDonorIndexForStepMarker(
tensors, coll.indices, parameters_data);
} else if (GetAliasWithBufferDonorConfig()) {
// only alias based on buffer donor if LTC can't auto infer the input
// output aliasing.
buffer_donor_indices = SetBufferDonorsFromUserConfig(&lowering_ctx);
buffer_donor_indices = GetBufferDonorIndexFromUserConfig(parameters_data);
}
}
return buffer_donor_indices;
}

void XLAGraphExecutor::SetBufferDonors(
LoweringContext* lowering_ctx,
const std::vector<size_t>& buffer_donor_indexs) {
const std::vector<torch::lazy::BackendDataPtr>& parameters_data =
lowering_ctx->GetParametersData();
for (size_t i : buffer_donor_indexs) {
lowering_ctx->builder()->AddBufferDonor(/*param_number=*/i,
/*param_index=*/{});
}
TORCH_LAZY_VALUE_METRIC("InputOutputAliasCount", buffer_donor_indexs.size());
}

XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
std::vector<XLATensorPtr>& tensors, absl::Span<const std::string> devices,
const SyncTensorCollection& coll, PostOrderData* po_data,
const std::vector<torch::lazy::Value>& ir_values,
const std::vector<size_t>& buffer_donor_indices) {
tsl::profiler::TraceMe activity(
[&] {
return tsl::profiler::TraceMeEncode(
"XLAGraphExecutor::Compile",
{{"graph_hash", torch::lazy::HashToString(coll.hash)}});
},
tsl::profiler::TraceMeLevel::kInfo);
static const size_t parameter_wrapping_threadshold =
runtime::sys_util::GetEnvInt("XLA_PARAMETER_WRAPPING_THREADSHOLD", 3200);
static const bool use_autosharding = ShardingUtil::GetAutoSharding();
std::string graph_name =
(CurrentGraphName() != "") ? CurrentGraphName() : "SyncTensorsGraph";
LoweringContext lowering_ctx(graph_name, coll.device, po_data->post_order,
std::move(po_data->emission_map));
for (auto ir_value : ir_values) {
xla::XlaOp root = lowering_ctx.GetOutputOp(
torch::lazy::Output(ir_value.node.get(), ir_value.index));
lowering_ctx.AddResult(root);
}
// Always execute sharded when running in SPMD mode
bool is_sharded = (coll.device == GetVirtualDevice()) || UseVirtualDevice();
// Annotate HLO sharding selectively in the compuation.
ShardingUtil::SetHloSharding(&lowering_ctx);

SetBufferDonors(&lowering_ctx, buffer_donor_indices);

xla::XlaComputation computation = ConsumeValue(lowering_ctx.BuildXla());
xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape());
Expand Down Expand Up @@ -1499,14 +1503,13 @@ XLAGraphExecutor::SyncTensorsGraphInternal(
PostOrderData po_data = RunPostOrder(ir_values, &coll);
coll.hash = torch::lazy::HashCombine(
coll.hash, torch::lazy::Hash(po_data.parameter_sequence));
if (GetAliasWithBufferDonorConfig()) {
std::vector<size_t> buffer_donor_index =
GetBufferDonorIndexFromUserConfig(po_data.parameters_data);
if (buffer_donor_index.size() > 0) {
// Do not include hash on a empty vector.
coll.hash = torch::lazy::HashCombine(
coll.hash, torch::lazy::Hash(buffer_donor_index));
}

std::vector<size_t> buffer_donor_indices =
GetBufferDonors(*tensors, coll, po_data.parameters_data);
if (buffer_donor_indices.size() > 0) {
// Do not include hash on a empty vector.
coll.hash = torch::lazy::HashCombine(
coll.hash, torch::lazy::Hash(buffer_donor_indices));
}
{
// Auto-sharding configs
Expand All @@ -1529,8 +1532,8 @@ XLAGraphExecutor::SyncTensorsGraphInternal(
// we have a cache hit, execution has been scheduled by TryRunCachedSync.
return cache_res.second;
}
CompilationResult compile_result =
Compile(*tensors, devices, coll, &po_data, ir_values);
CompilationResult compile_result = Compile(*tensors, devices, coll, &po_data,
ir_values, buffer_donor_indices);

TORCH_LAZY_VALUE_METRIC("TensorsGraphSize", compile_result.emitted_nodes);
TF_VLOG(5) << "TensorsGraphSize=" << compile_result.emitted_nodes;
Expand Down
14 changes: 8 additions & 6 deletions torch_xla/csrc/xla_graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,20 +356,22 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
const std::vector<torch::lazy::BackendDataPtr>& tensor_data_vec,
bool warm_up_cache_only);

std::vector<size_t> SetBufferDonors(const std::vector<XLATensorPtr>& tensors,
absl::Span<const size_t> indices,
LoweringContext* lowering_ctx);
std::vector<size_t> GetBufferDonors(
const std::vector<XLATensorPtr>& tensors,
const SyncTensorCollection& coll,
const std::vector<torch::lazy::BackendDataPtr>& parameters_data);

std::vector<size_t> SetBufferDonorsFromUserConfig(
LoweringContext* lowering_ctx);
void SetBufferDonors(LoweringContext* lowering_ctx,
const std::vector<size_t>& buffer_donor_indices);

// TODO(yeounoh) auto-sharding can change tensors shardings, which needs to be
// accounted for in Dynamo integration.
CompilationResult Compile(std::vector<XLATensorPtr>& tensors,
absl::Span<const std::string> devices,
const SyncTensorCollection& coll,
PostOrderData* po_data,
const std::vector<torch::lazy::Value>& ir_values);
const std::vector<torch::lazy::Value>& ir_values,
const std::vector<size_t>& buffer_donor_indices);

// We don't use the upstream SyncTensorsGraphInternal since
// our CachedComputation is different from upstream.
Expand Down

0 comments on commit 0154850

Please sign in to comment.