Skip to content

Commit

Permalink
Refactor and clean SPMD+Dynamo integration code
Browse files Browse the repository at this point in the history
  • Loading branch information
yeounoh committed Nov 15, 2023
1 parent d8c6a83 commit b4e7a5a
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 92 deletions.
72 changes: 71 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,75 @@
"./bazel-out/_coverage/_coverage_report.dat"
],
"python.formatting.provider": "yapf",
"editor.formatOnSave": true
"editor.formatOnSave": true,
"files.associations": {
".*/BUILD": "starlark",
".*/METADATA": "starlark",
".*/WORKSPACE": "starlark",
"*.gss": "css",
"__bit_reference": "cpp",
"__config": "cpp",
"__debug": "cpp",
"__errc": "cpp",
"__hash_table": "cpp",
"__locale": "cpp",
"__mutex_base": "cpp",
"__node_handle": "cpp",
"__split_buffer": "cpp",
"__threading_support": "cpp",
"__tree": "cpp",
"__verbose_abort": "cpp",
"array": "cpp",
"atomic": "cpp",
"bitset": "cpp",
"cctype": "cpp",
"charconv": "cpp",
"clocale": "cpp",
"cmath": "cpp",
"complex": "cpp",
"condition_variable": "cpp",
"cstdarg": "cpp",
"cstddef": "cpp",
"cstdint": "cpp",
"cstdio": "cpp",
"cstdlib": "cpp",
"cstring": "cpp",
"ctime": "cpp",
"cwchar": "cpp",
"cwctype": "cpp",
"deque": "cpp",
"exception": "cpp",
"fstream": "cpp",
"initializer_list": "cpp",
"iomanip": "cpp",
"ios": "cpp",
"iosfwd": "cpp",
"iostream": "cpp",
"istream": "cpp",
"limits": "cpp",
"list": "cpp",
"locale": "cpp",
"map": "cpp",
"mutex": "cpp",
"new": "cpp",
"optional": "cpp",
"ostream": "cpp",
"ratio": "cpp",
"set": "cpp",
"shared_mutex": "cpp",
"sstream": "cpp",
"stdexcept": "cpp",
"streambuf": "cpp",
"string": "cpp",
"string_view": "cpp",
"system_error": "cpp",
"thread": "cpp",
"tuple": "cpp",
"typeinfo": "cpp",
"unordered_map": "cpp",
"unordered_set": "cpp",
"variant": "cpp",
"vector": "cpp",
"algorithm": "cpp"
}
}
4 changes: 2 additions & 2 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1564,7 +1564,7 @@ void InitXlaModuleBindings(py::module m) {
}));
m.def("_xla_mark_sharding",
[](const at::Tensor& input, xla::OpSharding sharding) {
ShardingUtil::xla_mark_sharding(input, sharding);
ShardingUtil::XlaMarkSharding(input, sharding);
});
m.def("_xla_mark_sharding_dynamo_custom_op",
[](const at::Tensor& input, const py::list& tile_assignment,
Expand All @@ -1591,7 +1591,7 @@ void InitXlaModuleBindings(py::module m) {
at::IntArrayRef(t.cast<std::vector<int64_t>>()));
}

xla_mark_sharding_dynamo_custom_op(
ShardingUtil::XlaMarkShardingDynamoCustomOp(
input, tile_assignment_list, group_assignment_list,
replication_groups_list, sharding_type);
});
Expand Down
62 changes: 32 additions & 30 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,34 @@
#include "xla/xla.pb.h"

namespace torch_xla {

// Macro for defining a function that will be run at static initialization time
// to define a library of operators in the namespace. Used to define a new set
// of custom operators that do not already exist in PyTorch.
TORCH_LIBRARY(xla, m) {
m.def(
"max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], "
"int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
torch::dispatch(
c10::DispatchKey::XLA,
TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward)));

m.def(
"max_pool2d_backward(Tensor grad_output, Tensor self, int[2] "
"kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) "
"-> Tensor",
torch::dispatch(
c10::DispatchKey::XLA,
TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward)));
m.def(
"xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] "
"tile_assignment, int[][] group_assignment, int[][] replication_groups, "
"int sharding_type) -> ()",
torch::dispatch(
c10::DispatchKey::XLA,
TORCH_FN(torch_xla::ShardingUtil::XlaMarkShardingDynamoCustomOp)));
}

namespace {

using tsl::ERROR;
Expand Down Expand Up @@ -740,8 +768,8 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData(
source_tensors, GetVirtualDevice().toString(), global_shape, sharding);
}

void ShardingUtil::xla_mark_sharding(const at::Tensor& input,
xla::OpSharding sharding) {
void ShardingUtil::XlaMarkSharding(const at::Tensor& input,
xla::OpSharding sharding) {
TORCH_LAZY_COUNTER("XlaMarkSharding", 1);
XLA_CHECK(UseVirtualDevice())
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";
Expand Down Expand Up @@ -807,7 +835,7 @@ void ShardingUtil::xla_mark_sharding(const at::Tensor& input,
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
}

void xla_mark_sharding_dynamo_custom_op(
void ShardingUtil::XlaMarkShardingDynamoCustomOp(
const at::Tensor& input, c10::List<at::IntArrayRef> tile_assignment,
c10::List<at::IntArrayRef> group_assignment,
c10::List<at::IntArrayRef> replication_groups, int64_t sharding_type) {
Expand Down Expand Up @@ -842,33 +870,7 @@ void xla_mark_sharding_dynamo_custom_op(
tile_assignment_py, group_assignment_py, replication_groups_py,
ShardingUtil::ShardingType(sharding_type));

ShardingUtil::xla_mark_sharding(input, op_sharding);
}

// Macro for defining a function that will be run at static initialization time
// to define a library of operators in the namespace. Used to define a new set
// of custom operators that do not already exist in PyTorch.
TORCH_LIBRARY(xla, m) {
m.def(
"max_pool2d_forward(Tensor self, int[2] kernel_size, int[2] stride=[], "
"int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor",
torch::dispatch(
c10::DispatchKey::XLA,
TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_forward)));

m.def(
"max_pool2d_backward(Tensor grad_output, Tensor self, int[2] "
"kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False) "
"-> Tensor",
torch::dispatch(
c10::DispatchKey::XLA,
TORCH_FN(torch_xla::aten_autograd_ops::max_pool2d_backward)));
m.def(
"xla_mark_sharding_dynamo_custom_op(Tensor input, int[][] "
"tile_assignment, int[][] group_assignment, int[][] replication_groups, "
"int sharding_type) -> ()",
torch::dispatch(c10::DispatchKey::XLA,
TORCH_FN(torch_xla::xla_mark_sharding_dynamo_custom_op)));
ShardingUtil::XlaMarkSharding(input, op_sharding);
}

} // namespace torch_xla
15 changes: 8 additions & 7 deletions torch_xla/csrc/xla_sharding_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,16 @@ class ShardingUtil {
const std::vector<std::string>& devices,
const XLATensor::ShardingSpecPtr& sharding_spec);

static void xla_mark_sharding(const at::Tensor& input,
xla::OpSharding sharding);
//////////////////////////// Dynamo Integration ////////////////////////////

static void XlaMarkSharding(const at::Tensor& input,
xla::OpSharding sharding);
static void XlaMarkShardingDynamoCustomOp(
const at::Tensor& input, c10::List<at::IntArrayRef> tile_assignment,
c10::List<at::IntArrayRef> group_assignment,
c10::List<at::IntArrayRef> replication_groups, int64_t sharding_type);
};

void xla_mark_sharding_dynamo_custom_op(
const at::Tensor& input, c10::List<at::IntArrayRef> tile_assignment,
c10::List<at::IntArrayRef> group_assignment,
c10::List<at::IntArrayRef> replication_groups, int64_t sharding_type);

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_XLA_SHARDING_UTIL_H_
105 changes: 53 additions & 52 deletions torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,43 +82,16 @@ def get_axis_name_idx(self, name: str) -> int:

@functools.lru_cache(maxsize=None)
def get_op_sharding(self,
partition_spec: Tuple,
flatten_opsharding=False) -> torch_xla._XLAC.OpSharding:
partition_spec: Tuple) -> torch_xla._XLAC.OpSharding:
"""
Return the OpSharding for the given partition spec. This is an expensive
operation as the mesh grows, so the value is cached for reuse.
"""
partition_spec = _translate_named_partition_spec(self, partition_spec)
flat_specs = np.hstack([d for d in partition_spec])
specs = [d for d in flat_specs if d is not None]
assert all(d >= 0 and d < len(self.mesh_shape) for d in specs), \
f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape."
assert len(specs) == len(np.unique(specs)), \
f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}."

tile_assignment = _get_tile_assignment(self, partition_spec)
if len(tile_assignment.shape) > len(partition_spec):
# Use partial replication for sharding a tensor over a higher-rank mesh
sharding_type = ShardingType.PARTIAL
else:
sharding_type = _get_sharding_type(partition_spec, self.size())
replicate_dims = {i for i, d in enumerate(partition_spec) if d is None}
group_assignment, replication_groups = _get_group_assignment(
sharding_type, tile_assignment, len(partition_spec), replicate_dims)

# If flatten_opsharding = True, return the flattened version of OpSharding
if flatten_opsharding:
return (tile_assignment.tolist(), group_assignment, replication_groups,
int(sharding_type))
else:
return torch_xla._XLAC.OpSharding(tile_assignment.tolist(),
group_assignment, replication_groups,
int(sharding_type))


# HybridDevice class has been inspired from jax's mesh_utils: https://github.com/google/jax/blob/fc5960f2b8b7a0ef74dbae4e27c5c08ff1564cff/jax/experimental/mesh_utils.py#L4
return torch_xla._XLAC.OpSharding(
_extract_op_sharding_specs(self, partition_spec))


# HybridDevice class has been inspired from jax's mesh_utils: https://github.com/google/jax/blob/fc5960f2b8b7a0ef74dbae4e27c5c08ff1564cff/jax/experimental/mesh_utils.py#L4ƒ
class HybridMesh(Mesh):
"""Creates a hybrid device mesh of devices connected with ICI and DCN networks.
The shape of logical mesh should be ordered by increasing network-intensity
Expand Down Expand Up @@ -435,6 +408,30 @@ def _get_group_assignment(sharding_type: ShardingType,
return group_assignment, replication_groups


def _extract_op_sharding_specs(mesh: Mesh, partition_spec: Tuple):
partition_spec = _translate_named_partition_spec(mesh, partition_spec)
flat_specs = np.hstack([d for d in partition_spec])
specs = [d for d in flat_specs if d is not None]
assert all(d >= 0 and d < len(mesh.mesh_shape) for d in specs), \
f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape."
assert len(specs) == len(np.unique(specs)), \
f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}."

tile_assignment = _get_tile_assignment(mesh, partition_spec)
if len(tile_assignment.shape) > len(partition_spec):
# Use partial replication for sharding a tensor over a higher-rank mesh
sharding_type = ShardingType.PARTIAL
else:
sharding_type = _get_sharding_type(partition_spec, mesh.size())
replicate_dims = {i for i, d in enumerate(partition_spec) if d is None}
group_assignment, replication_groups = _get_group_assignment(
sharding_type, tile_assignment, len(partition_spec), replicate_dims)

tile_assignment = tile_assignment.tolist()
sharding_type = int(sharding_type)
return tile_assignment, group_assignment, replication_groups, sharding_type


def _translate_named_partition_spec(mesh: Mesh, partition_spec: Tuple):
_partition_spec = list()
for p in partition_spec:
Expand Down Expand Up @@ -508,29 +505,19 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor],
assert len(t.shape) == len(partition_spec), \
f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})."

op_sharding = mesh.get_op_sharding(partition_spec)
tile_assignment, group_assignment, replication_groups, sharding_type = _extract_op_sharding_specs(
mesh, partition_spec)
if use_dynamo_custom_op:
tile_assignment, group_assignment, replication_groups, sharding_type = mesh.get_op_sharding(
partition_spec, flatten_opsharding=True)

if isinstance(t, XLAShardedTensor):
torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(
t.global_tensor, tile_assignment, group_assignment,
replication_groups, sharding_type)
return t
else:
torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op(
t, tile_assignment, group_assignment, replication_groups,
sharding_type)
return XLAShardedTensor(t)
# Allows Dynamo to capture mark_sharding op
annotate_func = torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op
annotate_func(
unwrap_sharded_tensor(t), tile_assignment, group_assignment,
replication_groups, sharding_type)
else:
op_sharding = mesh.get_op_sharding(partition_spec)

if isinstance(t, XLAShardedTensor):
torch_xla._XLAC._xla_mark_sharding(t.global_tensor, op_sharding)
return t
else:
torch_xla._XLAC._xla_mark_sharding(t, op_sharding)
return XLAShardedTensor(t)
annotate_func = torch_xla._XLAC._xla_mark_sharding
annotate_func(unwrap_sharded_tensor(t), op_sharding)
return wrap_as_sharded_tensor(t)


def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor:
Expand All @@ -541,6 +528,20 @@ def clear_sharding(t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor:
return t


def wrap_as_sharded_tensor(
t: Union[torch.Tensor, XLAShardedTensor]) -> XLAShardedTensor:
if not isinstance(t, XLAShardedTensor):
return XLAShardedTensor(t)
return t


def unwrap_sharded_tensor(
t: Union[torch.Tensor, XLAShardedTensor]) -> torch.Tensor:
if isinstance(t, XLAShardedTensor):
return t.global_tensor
return t


def wrap_if_sharded(x: Any) -> Any:
"""
If the input is a sharded tensor, return an XLAShardedTensor wrapping it.
Expand Down

0 comments on commit b4e7a5a

Please sign in to comment.