From acdd21b0d299d72f3fcbbc8467757a70d3d4fc1e Mon Sep 17 00:00:00 2001 From: Yeounoh Chung Date: Wed, 22 Nov 2023 10:54:34 -0800 Subject: [PATCH] debugging --- torch_xla/distributed/spmd/xla_sharding.py | 56 ++++++++++------------ 1 file changed, 26 insertions(+), 30 deletions(-) diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 546126777a9f..f1eeac611f42 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -80,6 +80,29 @@ def get_axis_name_idx(self, name: str) -> int: return None return self.axis_names.index(name) + def _get_op_sharding_args(self, partition_spec: Tuple): + partition_spec = _translate_named_partition_spec(self, partition_spec) + flat_specs = np.hstack([d for d in partition_spec]) + specs = [d for d in flat_specs if d is not None] + assert all(d >= 0 and d < len(self.mesh_shape) for d in specs), \ + f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape." + assert len(specs) == len(np.unique(specs)), \ + f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}." + + tile_assignment = _get_tile_assignment(self, partition_spec) + if len(tile_assignment.shape) > len(partition_spec): + # Use partial replication for sharding a tensor over a higher-rank mesh + sharding_type = ShardingType.PARTIAL + else: + sharding_type = _get_sharding_type(partition_spec, self.size()) + replicate_dims = {i for i, d in enumerate(partition_spec) if d is None} + group_assignment, replication_groups = _get_group_assignment( + sharding_type, tile_assignment, len(partition_spec), replicate_dims) + + tile_assignment = tile_assignment.tolist() + sharding_type = int(sharding_type) + return tile_assignment, group_assignment, replication_groups, sharding_type + @functools.lru_cache(maxsize=None) def get_op_sharding(self, partition_spec: Tuple) -> torch_xla._XLAC.OpSharding: @@ -87,8 +110,8 @@ def get_op_sharding(self, Return the OpSharding for the given partition spec. This is an expensive operation as the mesh grows, so the value is cached for reuse. """ - tile_assignment, group_assignment, replication_groups, sharding_type = _extract_op_sharding_specs( - self, partition_spec) + tile_assignment, group_assignment, replication_groups, sharding_type = self._get_op_sharding_args( + partition_spec) return torch_xla._XLAC.OpSharding(tile_assignment, group_assignment, replication_groups, sharding_type) @@ -410,30 +433,6 @@ def _get_group_assignment(sharding_type: ShardingType, return group_assignment, replication_groups -def _extract_op_sharding_specs(mesh: Mesh, partition_spec: Tuple): - partition_spec = _translate_named_partition_spec(mesh, partition_spec) - flat_specs = np.hstack([d for d in partition_spec]) - specs = [d for d in flat_specs if d is not None] - assert all(d >= 0 and d < len(mesh.mesh_shape) for d in specs), \ - f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape." - assert len(specs) == len(np.unique(specs)), \ - f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}." - - tile_assignment = _get_tile_assignment(mesh, partition_spec) - if len(tile_assignment.shape) > len(partition_spec): - # Use partial replication for sharding a tensor over a higher-rank mesh - sharding_type = ShardingType.PARTIAL - else: - sharding_type = _get_sharding_type(partition_spec, mesh.size()) - replicate_dims = {i for i, d in enumerate(partition_spec) if d is None} - group_assignment, replication_groups = _get_group_assignment( - sharding_type, tile_assignment, len(partition_spec), replicate_dims) - - tile_assignment = tile_assignment.tolist() - sharding_type = int(sharding_type) - return tile_assignment, group_assignment, replication_groups, sharding_type - - def _translate_named_partition_spec(mesh: Mesh, partition_spec: Tuple): _partition_spec = list() for p in partition_spec: @@ -507,14 +506,11 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], assert len(t.shape) == len(partition_spec), \ f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})." - tile_assignment, group_assignment, replication_groups, sharding_type = _extract_op_sharding_specs( - mesh, partition_spec) if use_dynamo_custom_op: # Allows Dynamo to capture mark_sharding op annotate_func = torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op annotate_func( - unwrap_sharded_tensor(t), tile_assignment, group_assignment, - replication_groups, sharding_type) + unwrap_sharded_tensor(t), *mesh._get_op_sharding_args(partition_spec)) else: op_sharding = mesh.get_op_sharding(partition_spec) annotate_func = torch_xla._XLAC._xla_mark_sharding