Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute and hash buffer_donor_indices for step marker #8467

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions test/test_input_output_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,25 @@ def try_grad_accum(model, device, train_x, train_label, accum_steps):
alias_count == 1.0
), f"Expect 1 input-output alias pair for gradient accumulation, got {alias_count}"

def test_separate_graphs(self):
"""
Test that paramater aliasing differences should produce different graphs.
"""
xla_device = xm.xla_device()
t0 = torch.tensor([1], device=xla_device)
t1 = torch.tensor([2], device=xla_device)
xm.mark_step()

t1.add_(t0)
xm.mark_step()

# This needs to be a separate graph, otherwise t1 can be corrupted
# or result in PJRT error.
t2 = t1 + t0
xm.mark_step()

self.assertEqual(t1.item(), 3)


if __name__ == '__main__':
test = unittest.main()
Expand Down
121 changes: 62 additions & 59 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -451,20 +451,6 @@ std::vector<size_t> GetBufferDonorIndexFromUserConfig(
return buffer_donor_indexs;
}

std::vector<size_t> XLAGraphExecutor::SetBufferDonorsFromUserConfig(
LoweringContext* lowering_ctx) {
const std::vector<torch::lazy::BackendDataPtr>& parameters_data =
lowering_ctx->GetParametersData();
std::vector<size_t> buffer_donor_indexs =
GetBufferDonorIndexFromUserConfig(parameters_data);
for (size_t i : buffer_donor_indexs) {
lowering_ctx->builder()->AddBufferDonor(/*param_number=*/i,
/*param_index=*/{});
}
TORCH_LAZY_VALUE_METRIC("InputOutputAliasCount", buffer_donor_indexs.size());
return buffer_donor_indexs;
}

void XLAGraphExecutor::WaitDeviceOps(absl::Span<const std::string> devices) {
std::set<torch::lazy::BackendDevice> wait_devices;
if (!devices.empty()) {
Expand Down Expand Up @@ -1262,9 +1248,9 @@ XLAGraphExecutor::TryRunCachedSync(
std::move(cached_computation), tensor_data_vec));
}

std::vector<size_t> XLAGraphExecutor::SetBufferDonors(
std::vector<size_t> GetBufferDonorIndexForStepMarker(
const std::vector<XLATensorPtr>& tensors, absl::Span<const size_t> indices,
LoweringContext* lowering_ctx) {
const std::vector<torch::lazy::BackendDataPtr>& parameters_data) {
std::unordered_map<int64_t, size_t> output_tensor_id_map;
std::vector<size_t> buffer_donor_indexs;
// tensors[indices] represent all tensors that needs to be updated after
Expand All @@ -1275,7 +1261,6 @@ std::vector<size_t> XLAGraphExecutor::SetBufferDonors(
int64_t tensor_id = tensors[tensor_index]->data()->alias_id;
output_tensor_id_map[tensor_id] = i;
}
const auto& parameters_data = lowering_ctx->GetParametersData();
std::vector<ssize_t> alias_map(indices.size(), -1);
for (size_t i = 0; i < parameters_data.size(); ++i) {
auto* data_info =
Expand All @@ -1287,45 +1272,19 @@ std::vector<size_t> XLAGraphExecutor::SetBufferDonors(
// this buffer is not needed after execution since XLATensor will get a
// new buffer.
if (it != output_tensor_id_map.end()) {
lowering_ctx->builder()->AddBufferDonor(/*param_number=*/i,
/*param_index=*/{});
buffer_donor_indexs.push_back(i);
}
}
}
TORCH_LAZY_VALUE_METRIC("InputOutputAliasCount", buffer_donor_indexs.size());
return buffer_donor_indexs;
}

XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
std::vector<XLATensorPtr>& tensors, absl::Span<const std::string> devices,
const SyncTensorCollection& coll, PostOrderData* po_data,
const std::vector<torch::lazy::Value>& ir_values) {
tsl::profiler::TraceMe activity(
[&] {
return tsl::profiler::TraceMeEncode(
"XLAGraphExecutor::Compile",
{{"graph_hash", torch::lazy::HashToString(coll.hash)}});
},
tsl::profiler::TraceMeLevel::kInfo);
std::vector<size_t> XLAGraphExecutor::GetBufferDonors(
const std::vector<XLATensorPtr>& tensors, const SyncTensorCollection& coll,
const std::vector<torch::lazy::BackendDataPtr>& parameters_data) {
static const bool enable_aliasing =
runtime::sys_util::GetEnvBool("XLA_ENABLE_PARAM_ALIASING", true);
static const size_t parameter_wrapping_threadshold =
runtime::sys_util::GetEnvInt("XLA_PARAMETER_WRAPPING_THREADSHOLD", 3200);
static const bool use_autosharding = ShardingUtil::GetAutoSharding();
std::string graph_name =
(CurrentGraphName() != "") ? CurrentGraphName() : "SyncTensorsGraph";
LoweringContext lowering_ctx(graph_name, coll.device, po_data->post_order,
std::move(po_data->emission_map));
for (auto ir_value : ir_values) {
xla::XlaOp root = lowering_ctx.GetOutputOp(
torch::lazy::Output(ir_value.node.get(), ir_value.index));
lowering_ctx.AddResult(root);
}
// Always execute sharded when running in SPMD mode
bool is_sharded = (coll.device == GetVirtualDevice()) || UseVirtualDevice();
// Annotate HLO sharding selectively in the compuation.
ShardingUtil::SetHloSharding(&lowering_ctx);

std::vector<size_t> buffer_donor_indices;
// TODO(yeounoh) enable aliasing is disabled for partitioned computation,
Expand Down Expand Up @@ -1357,14 +1316,59 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
// will later fetch the new value of A, which is incorrect.
// But, when we issue a step barrier (force_ltc_data == true) we have to
// turn everything into DEVICE_DATA, so we can activate aliasing.
buffer_donor_indices =
SetBufferDonors(tensors, coll.indices, &lowering_ctx);
buffer_donor_indices = GetBufferDonorIndexForStepMarker(
tensors, coll.indices, parameters_data);
} else if (GetAliasWithBufferDonorConfig()) {
// only alias based on buffer donor if LTC can't auto infer the input
// output aliasing.
buffer_donor_indices = SetBufferDonorsFromUserConfig(&lowering_ctx);
buffer_donor_indices = GetBufferDonorIndexFromUserConfig(parameters_data);
}
}
return buffer_donor_indices;
}

void XLAGraphExecutor::SetBufferDonors(
LoweringContext* lowering_ctx,
const std::vector<size_t>& buffer_donor_indexs) {
const std::vector<torch::lazy::BackendDataPtr>& parameters_data =
lowering_ctx->GetParametersData();
for (size_t i : buffer_donor_indexs) {
lowering_ctx->builder()->AddBufferDonor(/*param_number=*/i,
/*param_index=*/{});
}
TORCH_LAZY_VALUE_METRIC("InputOutputAliasCount", buffer_donor_indexs.size());
}

XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
std::vector<XLATensorPtr>& tensors, absl::Span<const std::string> devices,
const SyncTensorCollection& coll, PostOrderData* po_data,
const std::vector<torch::lazy::Value>& ir_values,
const std::vector<size_t>& buffer_donor_indices) {
tsl::profiler::TraceMe activity(
[&] {
return tsl::profiler::TraceMeEncode(
"XLAGraphExecutor::Compile",
{{"graph_hash", torch::lazy::HashToString(coll.hash)}});
},
tsl::profiler::TraceMeLevel::kInfo);
static const size_t parameter_wrapping_threadshold =
runtime::sys_util::GetEnvInt("XLA_PARAMETER_WRAPPING_THREADSHOLD", 3200);
static const bool use_autosharding = ShardingUtil::GetAutoSharding();
std::string graph_name =
(CurrentGraphName() != "") ? CurrentGraphName() : "SyncTensorsGraph";
LoweringContext lowering_ctx(graph_name, coll.device, po_data->post_order,
std::move(po_data->emission_map));
for (auto ir_value : ir_values) {
xla::XlaOp root = lowering_ctx.GetOutputOp(
torch::lazy::Output(ir_value.node.get(), ir_value.index));
lowering_ctx.AddResult(root);
}
// Always execute sharded when running in SPMD mode
bool is_sharded = (coll.device == GetVirtualDevice()) || UseVirtualDevice();
// Annotate HLO sharding selectively in the compuation.
ShardingUtil::SetHloSharding(&lowering_ctx);

SetBufferDonors(&lowering_ctx, buffer_donor_indices);

xla::XlaComputation computation = ConsumeValue(lowering_ctx.BuildXla());
xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape());
Expand Down Expand Up @@ -1499,14 +1503,13 @@ XLAGraphExecutor::SyncTensorsGraphInternal(
PostOrderData po_data = RunPostOrder(ir_values, &coll);
coll.hash = torch::lazy::HashCombine(
coll.hash, torch::lazy::Hash(po_data.parameter_sequence));
if (GetAliasWithBufferDonorConfig()) {
std::vector<size_t> buffer_donor_index =
GetBufferDonorIndexFromUserConfig(po_data.parameters_data);
if (buffer_donor_index.size() > 0) {
// Do not include hash on a empty vector.
coll.hash = torch::lazy::HashCombine(
coll.hash, torch::lazy::Hash(buffer_donor_index));
}

std::vector<size_t> buffer_donor_indices =
GetBufferDonors(*tensors, coll, po_data.parameters_data);
if (buffer_donor_indices.size() > 0) {
// Do not include hash on a empty vector.
coll.hash = torch::lazy::HashCombine(
coll.hash, torch::lazy::Hash(buffer_donor_indices));
}
{
// Auto-sharding configs
Expand All @@ -1529,8 +1532,8 @@ XLAGraphExecutor::SyncTensorsGraphInternal(
// we have a cache hit, execution has been scheduled by TryRunCachedSync.
return cache_res.second;
}
CompilationResult compile_result =
Compile(*tensors, devices, coll, &po_data, ir_values);
CompilationResult compile_result = Compile(*tensors, devices, coll, &po_data,
ir_values, buffer_donor_indices);

TORCH_LAZY_VALUE_METRIC("TensorsGraphSize", compile_result.emitted_nodes);
TF_VLOG(5) << "TensorsGraphSize=" << compile_result.emitted_nodes;
Expand Down
14 changes: 8 additions & 6 deletions torch_xla/csrc/xla_graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,20 +356,22 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
const std::vector<torch::lazy::BackendDataPtr>& tensor_data_vec,
bool warm_up_cache_only);

std::vector<size_t> SetBufferDonors(const std::vector<XLATensorPtr>& tensors,
absl::Span<const size_t> indices,
LoweringContext* lowering_ctx);
std::vector<size_t> GetBufferDonors(
const std::vector<XLATensorPtr>& tensors,
const SyncTensorCollection& coll,
const std::vector<torch::lazy::BackendDataPtr>& parameters_data);

std::vector<size_t> SetBufferDonorsFromUserConfig(
LoweringContext* lowering_ctx);
void SetBufferDonors(LoweringContext* lowering_ctx,
const std::vector<size_t>& buffer_donor_indices);

// TODO(yeounoh) auto-sharding can change tensors shardings, which needs to be
// accounted for in Dynamo integration.
CompilationResult Compile(std::vector<XLATensorPtr>& tensors,
absl::Span<const std::string> devices,
const SyncTensorCollection& coll,
PostOrderData* po_data,
const std::vector<torch::lazy::Value>& ir_values);
const std::vector<torch::lazy::Value>& ir_values,
const std::vector<size_t>& buffer_donor_indices);

// We don't use the upstream SyncTensorsGraphInternal since
// our CachedComputation is different from upstream.
Expand Down
Loading