diff --git a/.vscode/settings.json b/.vscode/settings.json index 59b86e622e7..76326bd5d48 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -18,5 +18,75 @@ "./bazel-out/_coverage/_coverage_report.dat" ], "python.formatting.provider": "yapf", - "editor.formatOnSave": true + "editor.formatOnSave": true, + "files.associations": { + ".*/BUILD": "starlark", + ".*/METADATA": "starlark", + ".*/WORKSPACE": "starlark", + "*.gss": "css", + "__bit_reference": "cpp", + "__config": "cpp", + "__debug": "cpp", + "__errc": "cpp", + "__hash_table": "cpp", + "__locale": "cpp", + "__mutex_base": "cpp", + "__node_handle": "cpp", + "__split_buffer": "cpp", + "__threading_support": "cpp", + "__tree": "cpp", + "__verbose_abort": "cpp", + "array": "cpp", + "atomic": "cpp", + "bitset": "cpp", + "cctype": "cpp", + "charconv": "cpp", + "clocale": "cpp", + "cmath": "cpp", + "complex": "cpp", + "condition_variable": "cpp", + "cstdarg": "cpp", + "cstddef": "cpp", + "cstdint": "cpp", + "cstdio": "cpp", + "cstdlib": "cpp", + "cstring": "cpp", + "ctime": "cpp", + "cwchar": "cpp", + "cwctype": "cpp", + "deque": "cpp", + "exception": "cpp", + "fstream": "cpp", + "initializer_list": "cpp", + "iomanip": "cpp", + "ios": "cpp", + "iosfwd": "cpp", + "iostream": "cpp", + "istream": "cpp", + "limits": "cpp", + "list": "cpp", + "locale": "cpp", + "map": "cpp", + "mutex": "cpp", + "new": "cpp", + "optional": "cpp", + "ostream": "cpp", + "ratio": "cpp", + "set": "cpp", + "shared_mutex": "cpp", + "sstream": "cpp", + "stdexcept": "cpp", + "streambuf": "cpp", + "string": "cpp", + "string_view": "cpp", + "system_error": "cpp", + "thread": "cpp", + "tuple": "cpp", + "typeinfo": "cpp", + "unordered_map": "cpp", + "unordered_set": "cpp", + "variant": "cpp", + "vector": "cpp", + "algorithm": "cpp" + } } diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 1391b73a16c..115875f51f6 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1564,7 +1564,7 @@ void InitXlaModuleBindings(py::module m) { })); m.def("_xla_mark_sharding", [](const at::Tensor& input, xla::OpSharding sharding) { - ShardingUtil::xla_mark_sharding(input, sharding); + ShardingUtil::XlaMarkSharding(input, sharding); }); m.def("_xla_mark_sharding_dynamo_custom_op", [](const at::Tensor& input, const py::list& tile_assignment, @@ -1591,7 +1591,7 @@ void InitXlaModuleBindings(py::module m) { at::IntArrayRef(t.cast>())); } - xla_mark_sharding_dynamo_custom_op( + ShardingUtil::XlaMarkShardingDynamoCustomOp( input, tile_assignment_list, group_assignment_list, replication_groups_list, sharding_type); }); diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index ae586316073..7281ad17219 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -32,6 +32,34 @@ #include "xla/xla.pb.h" namespace torch_xla { + +// Macro for defining a function that will be run at static initialization time +// to define a library of operators in the namespace. Used to define a new set +// of custom operators that do not already exist in PyTorch. +TORCH_LIBRARY(xla, m) { + m.def( + "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], " + "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", + torch::dispatch( + c10::DispatchKey::XLA, + TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward))); + + m.def( + "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] " + "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) " + "-> Tensor", + torch::dispatch( + c10::DispatchKey::XLA, + TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward))); + m.def( + "xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] " + "tile_assignment, int[][] group_assignment, int[][] replication_groups, " + "int sharding_type) -> ()", + torch::dispatch( + c10::DispatchKey::XLA, + TORCH_FN(torch_xla::ShardingUtil::XlaMarkShardingDynamoCustomOp))); +} + namespace { using tsl::ERROR; @@ -740,8 +768,8 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData( source_tensors, GetVirtualDevice().toString(), global_shape, sharding); } -void ShardingUtil::xla_mark_sharding(const at::Tensor& input, - xla::OpSharding sharding) { +void ShardingUtil::XlaMarkSharding(const at::Tensor& input, + xla::OpSharding sharding) { TORCH_LAZY_COUNTER("XlaMarkSharding", 1); XLA_CHECK(UseVirtualDevice()) << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; @@ -807,7 +835,7 @@ void ShardingUtil::xla_mark_sharding(const at::Tensor& input, XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); } -void xla_mark_sharding_dynamo_custom_op( +void ShardingUtil::XlaMarkShardingDynamoCustomOp( const at::Tensor& input, c10::List tile_assignment, c10::List group_assignment, c10::List replication_groups, int64_t sharding_type) { @@ -842,33 +870,7 @@ void xla_mark_sharding_dynamo_custom_op( tile_assignment_py, group_assignment_py, replication_groups_py, ShardingUtil::ShardingType(sharding_type)); - ShardingUtil::xla_mark_sharding(input, op_sharding); -} - -// Macro for defining a function that will be run at static initialization time -// to define a library of operators in the namespace. Used to define a new set -// of custom operators that do not already exist in PyTorch. -TORCH_LIBRARY(xla, m) { - m.def( - "max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], " - "int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", - torch::dispatch( - c10::DispatchKey::XLA, - TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward))); - - m.def( - "max_pool2d_backward(Tensor grad_output, Tensor self, int[2] " - "kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) " - "-> Tensor", - torch::dispatch( - c10::DispatchKey::XLA, - TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward))); - m.def( - "xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] " - "tile_assignment, int[][] group_assignment, int[][] replication_groups, " - "int sharding_type) -> ()", - torch::dispatch(c10::DispatchKey::XLA, - TORCH_FN(torch_xla::xla_mark_sharding_dynamo_custom_op))); + ShardingUtil::XlaMarkSharding(input, op_sharding); } } // namespace torch_xla diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index 697f320f575..7d9a4918b51 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -151,15 +151,16 @@ class ShardingUtil { const std::vector& devices, const XLATensor::ShardingSpecPtr& sharding_spec); - static void xla_mark_sharding(const at::Tensor& input, - xla::OpSharding sharding); + //////////////////////////// Dynamo Integration //////////////////////////// + + static void XlaMarkSharding(const at::Tensor& input, + xla::OpSharding sharding); + static void XlaMarkShardingDynamoCustomOp( + const at::Tensor& input, c10::List tile_assignment, + c10::List group_assignment, + c10::List replication_groups, int64_t sharding_type); }; -void xla_mark_sharding_dynamo_custom_op( - const at::Tensor& input, c10::List tile_assignment, - c10::List group_assignment, - c10::List replication_groups, int64_t sharding_type); - } // namespace torch_xla #endif // XLA_TORCH_XLA_CSRC_XLA_SHARDING_UTIL_H_ diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index d96531a5616..7e8d2848c5c 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -82,43 +82,16 @@ def get_axis_name_idx(self, name: str) -> int: @functools.lru_cache(maxsize=None) def get_op_sharding(self, - partition_spec: Tuple, - flatten_opsharding=False) -> torch_xla._XLAC.OpSharding: + partition_spec: Tuple) -> torch_xla._XLAC.OpSharding: """ Return the OpSharding for the given partition spec. This is an expensive operation as the mesh grows, so the value is cached for reuse. """ - partition_spec = _translate_named_partition_spec(self, partition_spec) - flat_specs = np.hstack([d for d in partition_spec]) - specs = [d for d in flat_specs if d is not None] - assert all(d >= 0 and d < len(self.mesh_shape) for d in specs), \ - f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape." - assert len(specs) == len(np.unique(specs)), \ - f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}." - - tile_assignment = _get_tile_assignment(self, partition_spec) - if len(tile_assignment.shape) > len(partition_spec): - # Use partial replication for sharding a tensor over a higher-rank mesh - sharding_type = ShardingType.PARTIAL - else: - sharding_type = _get_sharding_type(partition_spec, self.size()) - replicate_dims = {i for i, d in enumerate(partition_spec) if d is None} - group_assignment, replication_groups = _get_group_assignment( - sharding_type, tile_assignment, len(partition_spec), replicate_dims) - - # If flatten_opsharding = True, return the flattened version of OpSharding - if flatten_opsharding: - return (tile_assignment.tolist(), group_assignment, replication_groups, - int(sharding_type)) - else: - return torch_xla._XLAC.OpSharding(tile_assignment.tolist(), - group_assignment, replication_groups, - int(sharding_type)) - - -# HybridDevice class has been inspired from jax's mesh_utils: https://github.com/google/jax/blob/fc5960f2b8b7a0ef74dbae4e27c5c08ff1564cff/jax/experimental/mesh_utils.py#L4 + return torch_xla._XLAC.OpSharding( + _extract_op_sharding_specs(self, partition_spec)) +# HybridDevice class has been inspired from jax's mesh_utils: https://github.com/google/jax/blob/fc5960f2b8b7a0ef74dbae4e27c5c08ff1564cff/jax/experimental/mesh_utils.py#L4ƒ class HybridMesh(Mesh): """Creates a hybrid device mesh of devices connected with ICI and DCN networks. The shape of logical mesh should be ordered by increasing network-intensity @@ -435,6 +408,30 @@ def _get_group_assignment(sharding_type: ShardingType, return group_assignment, replication_groups +def _extract_op_sharding_specs(mesh: Mesh, partition_spec: Tuple): + partition_spec = _translate_named_partition_spec(mesh, partition_spec) + flat_specs = np.hstack([d for d in partition_spec]) + specs = [d for d in flat_specs if d is not None] + assert all(d >= 0 and d < len(mesh.mesh_shape) for d in specs), \ + f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape." + assert len(specs) == len(np.unique(specs)), \ + f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}." + + tile_assignment = _get_tile_assignment(mesh, partition_spec) + if len(tile_assignment.shape) > len(partition_spec): + # Use partial replication for sharding a tensor over a higher-rank mesh + sharding_type = ShardingType.PARTIAL + else: + sharding_type = _get_sharding_type(partition_spec, mesh.size()) + replicate_dims = {i for i, d in enumerate(partition_spec) if d is None} + group_assignment, replication_groups = _get_group_assignment( + sharding_type, tile_assignment, len(partition_spec), replicate_dims) + + tile_assignment = tile_assignment.tolist() + sharding_type = int(sharding_type) + return tile_assignment, group_assignment, replication_groups, sharding_type + + def _translate_named_partition_spec(mesh: Mesh, partition_spec: Tuple): _partition_spec = list() for p in partition_spec: @@ -508,29 +505,19 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], assert len(t.shape) == len(partition_spec), \ f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})." + op_sharding = mesh.get_op_sharding(partition_spec) + tile_assignment, group_assignment, replication_groups, sharding_type = _extract_op_sharding_specs( + mesh, partition_spec) if use_dynamo_custom_op: - tile_assignment, group_assignment, replication_groups, sharding_type = mesh.get_op_sharding( - partition_spec, flatten_opsharding=True) - - if isinstance(t, XLAShardedTensor): - torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op( - t.global_tensor, tile_assignment, group_assignment, - replication_groups, sharding_type) - return t - else: - torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op( - t, tile_assignment, group_assignment, replication_groups, - sharding_type) - return XLAShardedTensor(t) + # Allows Dynamo to capture mark_sharding op + annotate_func = torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op + annotate_func( + unwrap_sharded_tensor(t), tile_assignment, group_assignment, + replication_groups, sharding_type) else: - op_sharding = mesh.get_op_sharding(partition_spec) - - if isinstance(t, XLAShardedTensor): - torch_xla._XLAC._xla_mark_sharding(t.global_tensor, op_sharding) - return t - else: - torch_xla._XLAC._xla_mark_sharding(t, op_sharding) - return XLAShardedTensor(t) + annotate_func = torch_xla._XLAC._xla_mark_sharding + annotate_func(unwrap_sharded_tensor(t), op_sharding) + return wrap_as_sharded_tensor(t) def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor: @@ -541,6 +528,20 @@ def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor: return t +def wrap_as_sharded_tensor( + t: Union[torch.Tensor, XLAShardedTensor]) -> XLAShardedTensor: + if not isinstance(t, XLAShardedTensor): + return XLAShardedTensor(t) + return t + + +def unwrap_sharded_tensor( + t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor: + if isinstance(t, XLAShardedTensor): + return t.global_tensor + return t + + def wrap_if_sharded(x: Any) -> Any: """ If the input is a sharded tensor, return an XLAShardedTensor wrapping it.