Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoweringContext] Support an optimized parameter mapping for SPMD #8453

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ function run_xla_op_tests3 {
run_test "$CDIR/spmd/test_xla_auto_sharding.py"
run_test "$CDIR/spmd/test_spmd_parameter_wrapping.py"
run_test "$CDIR/spmd/test_mp_input_sharding.py"
run_test "$CDIR/spmd/test_spmd_lowering_context.py"
rpsilva-aws marked this conversation as resolved.
Show resolved Hide resolved
run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
run_test "$CDIR/test_input_output_aliases.py"
run_test "$CDIR/test_torch_distributed_xla_backend.py"
Expand Down
58 changes: 58 additions & 0 deletions test/spmd/test_spmd_lowering_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import sys

import unittest

import test_xla_sharding_base

import torch
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.spmd as xs
import torch_xla.core.xla_model as xm
import contextlib


class TestSPMDLoweringContext(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
super().setUpClass()

def test_device_parameter_id_tensor_mapping(self):
met.clear_all()

model_axis = min(8, self.n_devices)
data_axis = self.n_devices // model_axis
mesh_shape = (data_axis, model_axis)
spmd_mesh = self._get_mesh(mesh_shape, axis_names=('x', 'y'))

device = xm.xla_device()
a = torch.randn([32, 2048]).to(device)
xs.mark_sharding(a, spmd_mesh, ('x', 'y'))
b = torch.ones(2048).to(device)
xs.mark_sharding(b, spmd_mesh, ('x',))

def fn(a, b):
return a + b

result = fn(a, b)
ctx = torch_xla._XLAC.lowering.LoweringContext("MyCustomName")
ctx.build([result])
torch_xla.sync()

mapping = ctx.device_parameter_id_tensor_mapping()
num_params = len(mapping)
self.assertEqual(num_params, 2)
self.assertNotEqual(ctx.tensor_parameter_id(a), -1)
self.assertNotEqual(ctx.tensor_parameter_id(b), -1)
self.assertEqual(met.counter_value("VirtualDeviceUsage"), num_params)

# Ensure that the parameter mapping does not require transferring data
# from the device to the host when sharded.
self.assertFalse(met.metric_data("TransferFromDeviceTime"))
self.assertFalse(met.counter_value("ReplicateShardedData"))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
12 changes: 8 additions & 4 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2634,6 +2634,7 @@ def test_multi_init_xla_backend(self):
class TestLoweringContext(test_utils.XlaTestCase):

def test_api(self):
met.clear_all()
device = xm.xla_device()
a = torch.tensor([1.0, 2.0, 3.0], device=device)
b = torch.tensor([4.0, 5.0, 6.0], device=device)
Expand All @@ -2642,14 +2643,17 @@ def test_api(self):

ctx = torch_xla._XLAC.lowering.LoweringContext("MyCustomName")
ctx.build([result])
hlo = ctx.hlo()
rpsilva-aws marked this conversation as resolved.
Show resolved Hide resolved
_ = ctx.hlo()
hlo_text = ctx.hlo_text()
self.assertIn('MyCustomName', hlo_text)
self.assertIn('opcode: "parameter"', hlo_text)
self.assertIn('opcode: "parameter"', hlo_text)
self.assertTrue(hlo_text.count('opcode: "parameter"'), 2)
self.assertIn('opcode: "add"', hlo_text)
mapping = ctx.parameter_id_tensor_mapping()
self.assertEqual(len(mapping), 2)
num_params = len(mapping)
self.assertEqual(num_params, 2)
mapping = ctx.parameter_id_tensor_mapping()
self.assertEqual(len(mapping), num_params)
self.assertTrue(met.metric_data("TransferFromDeviceTime"))
rpsilva-aws marked this conversation as resolved.
Show resolved Hide resolved

rpsilva-aws marked this conversation as resolved.
Show resolved Hide resolved
def test_get_parameters_scalar(self):
"""Scalar tensors parameters may be shared in the HLO graph if their
Expand Down
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ python3 test/test_operations.py -v
python3 test/pjrt/test_runtime_tpu.py
python3 test/pjrt/test_collective_ops_tpu.py
python3 test/spmd/test_mp_input_sharding.py
python3 test/spmd/test_spmd_lowering_context.py
python3 test/spmd/test_xla_sharding.py
python3 test/spmd/test_xla_virtual_device.py
# TODO(JackCaoG): to reenable
Expand Down
36 changes: 33 additions & 3 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1070,7 +1070,7 @@ class PyLoweringContext {
// Get a mapping from the HLO input parameters to the backing Tensor values.
// This allows the caller to get all parameter information regardless of
// how the parameter was allocated (inline tensor, nn.Parameter, constant,
// etc.)
// etc.). This will copy the tensor data from the device to the host.
std::unordered_map<int64_t, at::Tensor> GetParameterIdTensorMapping() {
// Find parameters in the lowering
const std::vector<torch::lazy::BackendDataPtr>& device_data =
Expand All @@ -1096,10 +1096,38 @@ class PyLoweringContext {
return results;
}

rpsilva-aws marked this conversation as resolved.
Show resolved Hide resolved
// Returns a mapping from HLO parameter IDs to their corresponding
// device-backed Tensors. This version only returns parameters that were
// explicitly allocated on device data, accessible via GetTensorParameterId().
// Unlike GetParameterIdTensorMapping(), it avoids transferring data from
// device to host, making it more efficient especially for SPMD scenarios
// where transferring data involves costly collectives.
std::unordered_map<int64_t, at::Tensor> GetDeviceParameterIdTensorMapping() {
// Find parameters in the lowering
const std::vector<torch::lazy::BackendDataPtr>& device_data =
lowering_ctx.GetParametersData();

// Create a mapping from parameter id to the tensor data
std::unordered_map<int64_t, at::Tensor> param_to_tensor;
param_to_tensor.reserve(device_data.size());

for (const auto& data : device_data) {
std::optional<int64_t> param_id = lowering_ctx.GetParameterId(data);
XLA_CHECK(param_id.has_value())
<< "Parameter ID must exist for device data";

at::Tensor tensor =
bridge::AtenFromXlaTensor(torch_xla::XLATensor::Create(data));
param_to_tensor.emplace(param_id.value(), std::move(tensor));
}
return param_to_tensor;
}

// Get the parameter identifier of a given tensor. If the tensor is not a
// parameter this will always return -1. This is useful in conjunction with
// GetParameterIdTensorMapping to identify which values can be baked into
// the graph and which values must remain parameters.
// GetParameterIdTensorMapping or GetDeviceParameterIdTensorMapping, to
// identify which values can be baked into the graph and which values must
// remain parameters.
int64_t GetTensorParameterId(at::Tensor tensor) {
// Convert tensor into the backing lazy node
XLATensorPtr xtensor = bridge::GetXlaTensor(tensor);
Expand Down Expand Up @@ -1201,6 +1229,8 @@ void BuildLoweringContextSubmodule(py::module* m) {
.def("hlo_json", &PyLoweringContext::GetHloJsonText)
.def("parameter_id_tensor_mapping",
&PyLoweringContext::GetParameterIdTensorMapping)
.def("device_parameter_id_tensor_mapping",
&PyLoweringContext::GetDeviceParameterIdTensorMapping)
.def("tensor_parameter_id", &PyLoweringContext::GetTensorParameterId)
.def("set_name_string", &PyLoweringContext::SetNameString)
.def("get_name_string", &PyLoweringContext::GetNameString);
Expand Down
Loading