Skip to content

Commit

Permalink
[LoweringContext] Support an optimized parameter mapping for SPMD
Browse files Browse the repository at this point in the history
  • Loading branch information
rpsilva-aws committed Dec 5, 2024
1 parent 39e67b5 commit 8fd7ac7
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 7 deletions.
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ function run_xla_op_tests3 {
run_test "$CDIR/spmd/test_xla_auto_sharding.py"
run_test "$CDIR/spmd/test_spmd_parameter_wrapping.py"
run_test "$CDIR/spmd/test_mp_input_sharding.py"
run_test "$CDIR/spmd/test_spmd_lowering_context.py"
run_test "$CDIR/test_operations_hlo.py" "$@" --verbosity=$VERBOSITY
run_test "$CDIR/test_input_output_aliases.py"
run_test "$CDIR/test_torch_distributed_xla_backend.py"
Expand Down
58 changes: 58 additions & 0 deletions test/spmd/test_spmd_lowering_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import sys

import unittest

import test_xla_sharding_base

import torch
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.spmd as xs
import torch_xla.core.xla_model as xm
import contextlib


class TestSPMDLoweringContext(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
super().setUpClass()

def test_device_parameter_id_tensor_mapping(self):
met.clear_all()

model_axis = min(8, self.n_devices)
data_axis = self.n_devices // model_axis
mesh_shape = (data_axis, model_axis)
spmd_mesh = self._get_mesh(mesh_shape, axis_names=('x', 'y'))

device = xm.xla_device()
a = torch.randn([32, 2048]).to(device)
xs.mark_sharding(a, spmd_mesh, ('x', 'y'))
b = torch.ones(2048).to(device)
xs.mark_sharding(b, spmd_mesh, ('x',))

def fn(a, b):
return a + b

result = fn(a, b)
ctx = torch_xla._XLAC.lowering.LoweringContext("MyCustomName")
ctx.build([result])
torch_xla.sync()

mapping = ctx.device_parameter_id_tensor_mapping()
num_params = len(mapping)
self.assertEqual(num_params, 2)
self.assertNotEqual(ctx.tensor_parameter_id(a), -1)
self.assertNotEqual(ctx.tensor_parameter_id(b), -1)
self.assertEqual(met.counter_value("VirtualDeviceUsage"), num_params)

# Ensure that the parameter mapping does not require transferring data
# from the device to the host when sharded.
self.assertFalse(met.metric_data("TransferFromDeviceTime"))
self.assertFalse(met.counter_value("ReplicateShardedData"))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
14 changes: 10 additions & 4 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2634,6 +2634,7 @@ def test_multi_init_xla_backend(self):
class TestLoweringContext(test_utils.XlaTestCase):

def test_api(self):
met.clear_all()
device = xm.xla_device()
a = torch.tensor([1.0, 2.0, 3.0], device=device)
b = torch.tensor([4.0, 5.0, 6.0], device=device)
Expand All @@ -2642,14 +2643,19 @@ def test_api(self):

ctx = torch_xla._XLAC.lowering.LoweringContext("MyCustomName")
ctx.build([result])
hlo = ctx.hlo()
_ = ctx.hlo()
hlo_text = ctx.hlo_text()
self.assertIn('MyCustomName', hlo_text)
self.assertIn('opcode: "parameter"', hlo_text)
self.assertIn('opcode: "parameter"', hlo_text)
elf.assertTrue(hlo_text.count('opcode: "parameter"'), 2)
self.assertIn('opcode: "add"', hlo_text)
num_expected_params = 2
mapping = ctx.parameter_id_tensor_mapping()
self.assertEqual(len(mapping), 2)
self.assertEqual(len(mapping), num_expected_params)
self.assertTrue(met.metric_data("TransferFromDeviceTime"))
met.clear_all()
device_mapping = ctx.device_parameter_id_tensor_mapping()
self.assertEqual(len(device_mapping), num_expected_params)
self.assertFalse(met.metric_data("TransferFromDeviceTime"))

def test_get_parameters_scalar(self):
"""Scalar tensors parameters may be shared in the HLO graph if their
Expand Down
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ python3 test/test_operations.py -v
python3 test/pjrt/test_runtime_tpu.py
python3 test/pjrt/test_collective_ops_tpu.py
python3 test/spmd/test_mp_input_sharding.py
python3 test/spmd/test_spmd_lowering_context.py
python3 test/spmd/test_xla_sharding.py
python3 test/spmd/test_xla_virtual_device.py
# TODO(JackCaoG): to reenable
Expand Down
36 changes: 33 additions & 3 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1070,7 +1070,7 @@ class PyLoweringContext {
// Get a mapping from the HLO input parameters to the backing Tensor values.
// This allows the caller to get all parameter information regardless of
// how the parameter was allocated (inline tensor, nn.Parameter, constant,
// etc.)
// etc.). This will copy the tensor data from the device to the host.
std::unordered_map<int64_t, at::Tensor> GetParameterIdTensorMapping() {
// Find parameters in the lowering
const std::vector<torch::lazy::BackendDataPtr>& device_data =
Expand All @@ -1096,10 +1096,38 @@ class PyLoweringContext {
return results;
}

// Returns a mapping from HLO parameter IDs to their corresponding
// device-backed Tensors. This version only returns parameters that were
// explicitly allocated on device data, accessible via GetTensorParameterId().
// Unlike GetParameterIdTensorMapping(), it avoids transferring data from
// device to host, making it more efficient especially for SPMD scenarios
// where transferring data involves costly collectives.
std::unordered_map<int64_t, at::Tensor> GetDeviceParameterIdTensorMapping() {
// Find parameters in the lowering
const std::vector<torch::lazy::BackendDataPtr>& device_data =
lowering_ctx.GetParametersData();

// Create a mapping from parameter id to the tensor data
std::unordered_map<int64_t, at::Tensor> param_to_tensor;
param_to_tensor.reserve(device_data.size());

for (const auto& data : device_data) {
std::optional<int64_t> param_id = lowering_ctx.GetParameterId(data);
XLA_CHECK(param_id.has_value())
<< "Parameter ID must exist for device data";

at::Tensor tensor =
bridge::AtenFromXlaTensor(torch_xla::XLATensor::Create(data));
param_to_tensor.emplace(param_id.value(), std::move(tensor));
}
return param_to_tensor;
}

// Get the parameter identifier of a given tensor. If the tensor is not a
// parameter this will always return -1. This is useful in conjunction with
// GetParameterIdTensorMapping to identify which values can be baked into
// the graph and which values must remain parameters.
// GetParameterIdTensorMapping or GetDeviceParameterIdTensorMapping, to
// identify which values can be baked into the graph and which values must
// remain parameters.
int64_t GetTensorParameterId(at::Tensor tensor) {
// Convert tensor into the backing lazy node
XLATensorPtr xtensor = bridge::GetXlaTensor(tensor);
Expand Down Expand Up @@ -1201,6 +1229,8 @@ void BuildLoweringContextSubmodule(py::module* m) {
.def("hlo_json", &PyLoweringContext::GetHloJsonText)
.def("parameter_id_tensor_mapping",
&PyLoweringContext::GetParameterIdTensorMapping)
.def("device_parameter_id_tensor_mapping",
&PyLoweringContext::GetDeviceParameterIdTensorMapping)
.def("tensor_parameter_id", &PyLoweringContext::GetTensorParameterId)
.def("set_name_string", &PyLoweringContext::SetNameString)
.def("get_name_string", &PyLoweringContext::GetNameString);
Expand Down

0 comments on commit 8fd7ac7

Please sign in to comment.