diff --git a/test/spmd/test_dynamo_spmd.py b/test/spmd/test_dynamo_spmd.py index 87f9d20b62ed..4928ad1de6cc 100644 --- a/test/spmd/test_dynamo_spmd.py +++ b/test/spmd/test_dynamo_spmd.py @@ -15,7 +15,7 @@ class SimpleLinear(nn.Module): - def __init__(self, mark_sharding_inside = False, op_sharding = None): + def __init__(self, mesh=None): super(SimpleLinear, self).__init__() self.fc1 = nn.Linear(128, 128) self.relu = nn.ReLU() @@ -23,11 +23,14 @@ def __init__(self, mark_sharding_inside = False, op_sharding = None): # Add an additional 1x1 layer at the end to ensure the final layer # is not sharded. self.fc3 = nn.Linear(1, 1) + # If mesh is not none, we'll do a mark sharding inside the forward function + # to ensure dynamo can recognize and trace it in a torch compile. + self.mesh = mesh def forward(self, x): - print(f'self.fc2.weight.device={self.fc2.weight.device}') - if self.mark_sharding_inside and self.op_sharding and 'xla' in self.fc2.weight.device: - xs.mark_sharding(self.fc2.weight, self.op_sharding) + if self.mesh and 'xla' in str(self.fc2.weight.device): + xs.mark_sharding( + self.fc2.weight, self.mesh, (1, 0), dynamo_custom_op=True) y = self.relu(self.fc1(x)) z = self.fc2(y) return self.fc3(z) @@ -191,23 +194,19 @@ def test_dynamo_input_sharding_threashold(self): def test_mark_sharding_inside_compile(self): device = xm.xla_device() + mesh = self._get_mesh((1, self.n_devices)) - def fn_simple(t): - xs.mark_sharding_dynamo_custom_op( - t, self._get_mesh((1, self.n_devices)), (0, 1)) - - x = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], - dtype=torch.float, - device=device) - - return t + x + # Passing this `mesh` as a parameter to `SimpleLinear` will call the dynamo custom op + # variant of mark_sharding inside the forward function. + linear = SimpleLinear(mesh=mesh).to(device) + linear.eval() - x_xla = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]]).to(device) - xla_res = fn_simple(x_xla) + xla_x = torch.randn(1, 128, device=device) + xla_res = linear(xla_x) xm.mark_step() - dynamo_fn_simple = torch.compile(fn_simple, backend="openxla") - dynamo_res = dynamo_fn_simple(x_xla) + dynamo_linear = torch.compile(linear, backend="openxla") + dynamo_res = dynamo_linear(xla_x) torch.allclose(xla_res.cpu(), dynamo_res.cpu()) diff --git a/torch_xla/csrc/aten_autograd_ops.cpp b/torch_xla/csrc/aten_autograd_ops.cpp index 9d8edbbb7316..81cfdfb4f428 100644 --- a/torch_xla/csrc/aten_autograd_ops.cpp +++ b/torch_xla/csrc/aten_autograd_ops.cpp @@ -253,17 +253,5 @@ torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self, return grad; } -// TORCH_LIBRARY(xla, m) { -// m.def( -// "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], " -// "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", -// torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_forward))); - -// m.def( -// "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] " -// "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) " -// "-> Tensor", -// torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(max_pool2d_backward))); -// } } // namespace aten_autograd_ops } // namespace torch_xla diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 8106550c7f3f..35385256b1ba 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -653,72 +653,75 @@ std::string GetPyTypeString(py::handle obj) { } void xla_mark_sharding(const at::Tensor& input, xla::OpSharding sharding) { - TORCH_LAZY_COUNTER("XlaMarkSharding", 1); - XLA_CHECK(UseVirtualDevice()) - << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; - XLATensorPtr xtensor = bridge::GetXlaTensor(input); - auto new_sharding_spec = std::make_shared( - sharding, MakeShapeWithDeviceLayout( - xtensor->shape(), - static_cast(xtensor->GetDevice().type()))); - - // For Non DeviceData IR values, we directly attach the sharding spec - // to the xtensor. - const DeviceData* device_data_node = nullptr; - if (xtensor->CurrentIrValue()) { - device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get()); - if (!device_data_node) { - tensor_methods::custom_sharding_(xtensor, new_sharding_spec); - return; - } + TORCH_LAZY_COUNTER("XlaMarkSharding", 1); + XLA_CHECK(UseVirtualDevice()) + << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + auto new_sharding_spec = std::make_shared( + sharding, MakeShapeWithDeviceLayout( + xtensor->shape(), + static_cast(xtensor->GetDevice().type()))); + + // For Non DeviceData IR values, we directly attach the sharding spec + // to the xtensor. + const DeviceData* device_data_node = nullptr; + if (xtensor->CurrentIrValue()) { + device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get()); + if (!device_data_node) { + tensor_methods::custom_sharding_(xtensor, new_sharding_spec); + return; } + } - // For data, we need to deal with the data transfers between - // host and device. - at::Tensor cpu_tensor; - if (xtensor->CurrentTensorData().has_value()) { - TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1); - // When virtual device is enabled for SPMD, we defer the initial - // data transfer to the device and retain the original data on the - // host, until the sharded data transfer. - cpu_tensor = xtensor->CurrentTensorData().value(); - } else { - // A new input tensor is not expected to be sharded. But sometimes, - // the same input is called for sharding annotation over multiple steps, - // in which case we can skip if it's the same sharding; however, if it's - // the same input with a different sharding then we block & ask the user - // to clear the existing sharding first. - auto current_sharding_spec = xtensor->sharding_spec(); - if (current_sharding_spec && (current_sharding_spec->sharding.type() != - xla::OpSharding::REPLICATED)) { - XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec, - *current_sharding_spec)) - << "Existing annotation must be cleared first."; - return; - } - - // If the at::Tensor data is not present, we need to re-download the - // tensor from the physical device to CPU. In that case, the value - // must be present on the backend device. - XLA_CHECK((xtensor->CurrentDataHandle() && - xtensor->CurrentDataHandle()->HasValue()) || - device_data_node != nullptr) - << "Cannot shard tensor. Data does not present on any device."; - std::vector xla_tensors{xtensor}; - cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0]; + // For data, we need to deal with the data transfers between + // host and device. + at::Tensor cpu_tensor; + if (xtensor->CurrentTensorData().has_value()) { + TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1); + // When virtual device is enabled for SPMD, we defer the initial + // data transfer to the device and retain the original data on the + // host, until the sharded data transfer. + cpu_tensor = xtensor->CurrentTensorData().value(); + } else { + // A new input tensor is not expected to be sharded. But sometimes, + // the same input is called for sharding annotation over multiple steps, + // in which case we can skip if it's the same sharding; however, if it's + // the same input with a different sharding then we block & ask the user + // to clear the existing sharding first. + auto current_sharding_spec = xtensor->sharding_spec(); + if (current_sharding_spec && (current_sharding_spec->sharding.type() != + xla::OpSharding::REPLICATED)) { + XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec, + *current_sharding_spec)) + << "Existing annotation must be cleared first."; + return; } - auto xla_data = CreateTensorsData( - std::vector{cpu_tensor}, - std::vector{new_sharding_spec}, - std::vector{GetVirtualDevice().toString()})[0]; - xtensor->SetXlaData(xla_data); - xtensor->SetShardingSpec(*new_sharding_spec); - // Register sharded tensor data. - XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); + // If the at::Tensor data is not present, we need to re-download the + // tensor from the physical device to CPU. In that case, the value + // must be present on the backend device. + XLA_CHECK((xtensor->CurrentDataHandle() && + xtensor->CurrentDataHandle()->HasValue()) || + device_data_node != nullptr) + << "Cannot shard tensor. Data does not present on any device."; + std::vector xla_tensors{xtensor}; + cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0]; + } + auto xla_data = CreateTensorsData( + std::vector{cpu_tensor}, + std::vector{new_sharding_spec}, + std::vector{GetVirtualDevice().toString()})[0]; + xtensor->SetXlaData(xla_data); + xtensor->SetShardingSpec(*new_sharding_spec); + + // Register sharded tensor data. + XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); } -void xla_mark_sharding_dynamo_custom_op(const at::Tensor& input, c10::List tile_assignment, c10::List group_assignment, c10::List replication_groups, int64_t sharding_type) { +void xla_mark_sharding_dynamo_custom_op( + const at::Tensor& input, c10::List tile_assignment, + c10::List group_assignment, + c10::List replication_groups, int64_t sharding_type) { py::list tile_assignment_py = py::list(); for (int i = 0; i < tile_assignment.size(); i++) { py::list pylist = py::list(); @@ -747,28 +750,36 @@ void xla_mark_sharding_dynamo_custom_op(const at::Tensor& input, c10::List Tensor", - torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward))); + torch::dispatch( + c10::DispatchKey::XLA, + TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward))); m.def( "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] " "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) " "-> Tensor", - torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward))); + torch::dispatch( + c10::DispatchKey::XLA, + TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward))); m.def( - "xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] tile_assignment, int[][] group_assignment, int[][] replication_groups, int sharding_type) -> ()", - torch::dispatch(c10::DispatchKey::XLA, TORCH_FN(xla_mark_sharding_dynamo_custom_op))); + "xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] " + "tile_assignment, int[][] group_assignment, int[][] replication_groups, " + "int sharding_type) -> ()", + torch::dispatch(c10::DispatchKey::XLA, + TORCH_FN(xla_mark_sharding_dynamo_custom_op))); } std::vector check_materialization_helper( @@ -1679,30 +1690,38 @@ void InitXlaModuleBindings(py::module m) { tile_assignment, group_assignment, replication_groups, ShardingUtil::ShardingType(sharding_type)); })); - m.def("_xla_mark_sharding", [](const at::Tensor& input, - xla::OpSharding sharding) { - xla_mark_sharding(input, sharding); - }); + m.def("_xla_mark_sharding", + [](const at::Tensor& input, xla::OpSharding sharding) { + xla_mark_sharding(input, sharding); + }); m.def("_xla_mark_sharding_dynamo_custom_op", [](const at::Tensor& input, const py::list& tile_assignment, - const py::list& group_assignment, - const py::list& replication_groups, int sharding_type) { - c10::List time_assignment_list = c10::List(); + const py::list& group_assignment, const py::list& replication_groups, + int sharding_type) { + c10::List time_assignment_list = + c10::List(); for (auto t : tile_assignment) { - time_assignment_list.push_back(at::IntArrayRef(t.cast>())); + time_assignment_list.push_back( + at::IntArrayRef(t.cast>())); } - c10::List group_assignment_list = c10::List(); + c10::List group_assignment_list = + c10::List(); for (auto t : group_assignment) { - group_assignment_list.push_back(at::IntArrayRef(t.cast>())); + group_assignment_list.push_back( + at::IntArrayRef(t.cast>())); } - c10::List replication_groups_list = c10::List(); + c10::List replication_groups_list = + c10::List(); for (auto t : replication_groups) { - replication_groups_list.push_back(at::IntArrayRef(t.cast>())); + replication_groups_list.push_back( + at::IntArrayRef(t.cast>())); } - xla_mark_sharding_dynamo_custom_op(input, time_assignment_list, group_assignment_list, replication_groups_list, sharding_type); + xla_mark_sharding_dynamo_custom_op( + input, time_assignment_list, group_assignment_list, + replication_groups_list, sharding_type); }); m.def("_xla_clear_sharding", [](const at::Tensor& input) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); diff --git a/torch_xla/experimental/xla_sharding.py b/torch_xla/experimental/xla_sharding.py index 811823c1415a..84872082b198 100644 --- a/torch_xla/experimental/xla_sharding.py +++ b/torch_xla/experimental/xla_sharding.py @@ -82,7 +82,8 @@ def get_axis_name_idx(self, name: str) -> int: @functools.lru_cache(maxsize=None) def get_op_sharding(self, - partition_spec: Tuple, flatten = False) -> torch_xla._XLAC.OpSharding: + partition_spec: Tuple, + flatten=False) -> torch_xla._XLAC.OpSharding: """ Return the OpSharding for the given partition spec. This is an expensive operation as the mesh grows, so the value is cached for reuse. @@ -104,14 +105,15 @@ def get_op_sharding(self, replicate_dims = {i for i, d in enumerate(partition_spec) if d is None} group_assignment, replication_groups = _get_group_assignment( sharding_type, tile_assignment, len(partition_spec), replicate_dims) - + # If flatten = True, return the flattened version of OpSharding if flatten: - return (tile_assignment.tolist(), group_assignment, replication_groups, int(sharding_type)) + return (tile_assignment.tolist(), group_assignment, replication_groups, + int(sharding_type)) else: return torch_xla._XLAC.OpSharding(tile_assignment.tolist(), - group_assignment, replication_groups, - int(sharding_type)) + group_assignment, replication_groups, + int(sharding_type)) # HybridDevice class has been inspired from jax's mesh_utils: https://github.com/google/jax/blob/fc5960f2b8b7a0ef74dbae4e27c5c08ff1564cff/jax/experimental/mesh_utils.py#L4 @@ -454,9 +456,10 @@ def _translate_named_partition_spec(mesh: Mesh, partition_spec: Tuple): @xr.requires_pjrt -def mark_sharding( - t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, - partition_spec: Tuple[Union[Tuple, int, str, None]], dynamo_custom_op: bool = False) -> XLAShardedTensor: +def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], + mesh: Mesh, + partition_spec: Tuple[Union[Tuple, int, str, None]], + dynamo_custom_op: bool = False) -> XLAShardedTensor: """ Annotates the tensor provided with XLA partition spec. Internally, it annotates the corresponding XLATensor as sharded for the XLA SpmdPartitioner pass. @@ -505,41 +508,28 @@ def mark_sharding( assert len(t.shape) == len(partition_spec), \ f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})." - op_sharding = mesh.get_op_sharding(partition_spec) - - if isinstance(t, XLAShardedTensor): - torch_xla._XLAC._xla_mark_sharding(t.global_tensor, op_sharding) - return t - torch_xla._XLAC._xla_mark_sharding(t, op_sharding) - return XLAShardedTensor(t) - - -@xr.requires_pjrt -def mark_sharding_dynamo_custom_op( - t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, - partition_spec: Tuple[Union[Tuple, int, str, None]]) -> XLAShardedTensor: - """ - Same functionality as `mark_sharding` above, except this variant uses the custom - mark_sharding op in torch_xla._XLAC to allow dynamo to recognize and trace it. - """ - num_devices = xr.global_runtime_device_count() - assert num_devices > 0, "This requires XLA supported device(s)." - assert mesh.size() == num_devices, \ - f"{mesh.mesh_shape} is not mappable over {num_devices} devices." - # We only allow fully specified `partition_spec` to be applicable, as opposed - # to filling in the unspecified replicated dims. Fully specified `partiion_spec` - # should be of the same rank as `t`. This is to support partial replication - # where the group assignment may vary with different input ranks. - assert len(t.shape) == len(partition_spec), \ - f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})." - - tile_assignment, group_assignment, replication_groups, sharding_type = mesh.get_op_sharding(partition_spec, flatten = True) - - if isinstance(t, XLAShardedTensor): - torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(t.global_tensor, tile_assignment, group_assignment, replication_groups, sharding_type) - return t - torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(t, tile_assignment, group_assignment, replication_groups, sharding_type) - return XLAShardedTensor(t) + if dynamo_custom_op: + tile_assignment, group_assignment, replication_groups, sharding_type = mesh.get_op_sharding( + partition_spec, flatten=True) + + if isinstance(t, XLAShardedTensor): + torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op( + t.global_tensor, tile_assignment, group_assignment, + replication_groups, sharding_type) + return t + torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(t, tile_assignment, + group_assignment, + replication_groups, + sharding_type) + return XLAShardedTensor(t) + else: + op_sharding = mesh.get_op_sharding(partition_spec) + + if isinstance(t, XLAShardedTensor): + torch_xla._XLAC._xla_mark_sharding(t.global_tensor, op_sharding) + return t + torch_xla._XLAC._xla_mark_sharding(t, op_sharding) + return XLAShardedTensor(t) def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor: