Skip to content

Commit

Permalink
debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
yeounoh committed Nov 22, 2023
1 parent a6dac44 commit acdd21b
Showing 1 changed file with 26 additions and 30 deletions.
56 changes: 26 additions & 30 deletions torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,38 @@ def get_axis_name_idx(self, name: str) -> int:
return None
return self.axis_names.index(name)

def _get_op_sharding_args(self, partition_spec: Tuple):
partition_spec = _translate_named_partition_spec(self, partition_spec)
flat_specs = np.hstack([d for d in partition_spec])
specs = [d for d in flat_specs if d is not None]
assert all(d >= 0 and d < len(self.mesh_shape) for d in specs), \
f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape."
assert len(specs) == len(np.unique(specs)), \
f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}."

tile_assignment = _get_tile_assignment(self, partition_spec)
if len(tile_assignment.shape) > len(partition_spec):
# Use partial replication for sharding a tensor over a higher-rank mesh
sharding_type = ShardingType.PARTIAL
else:
sharding_type = _get_sharding_type(partition_spec, self.size())
replicate_dims = {i for i, d in enumerate(partition_spec) if d is None}
group_assignment, replication_groups = _get_group_assignment(
sharding_type, tile_assignment, len(partition_spec), replicate_dims)

tile_assignment = tile_assignment.tolist()
sharding_type = int(sharding_type)
return tile_assignment, group_assignment, replication_groups, sharding_type

@functools.lru_cache(maxsize=None)
def get_op_sharding(self,
partition_spec: Tuple) -> torch_xla._XLAC.OpSharding:
"""
Return the OpSharding for the given partition spec. This is an expensive
operation as the mesh grows, so the value is cached for reuse.
"""
tile_assignment, group_assignment, replication_groups, sharding_type = _extract_op_sharding_specs(
self, partition_spec)
tile_assignment, group_assignment, replication_groups, sharding_type = self._get_op_sharding_args(
partition_spec)
return torch_xla._XLAC.OpSharding(tile_assignment, group_assignment,
replication_groups, sharding_type)

Expand Down Expand Up @@ -410,30 +433,6 @@ def _get_group_assignment(sharding_type: ShardingType,
return group_assignment, replication_groups


def _extract_op_sharding_specs(mesh: Mesh, partition_spec: Tuple):
partition_spec = _translate_named_partition_spec(mesh, partition_spec)
flat_specs = np.hstack([d for d in partition_spec])
specs = [d for d in flat_specs if d is not None]
assert all(d >= 0 and d < len(mesh.mesh_shape) for d in specs), \
f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape."
assert len(specs) == len(np.unique(specs)), \
f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}."

tile_assignment = _get_tile_assignment(mesh, partition_spec)
if len(tile_assignment.shape) > len(partition_spec):
# Use partial replication for sharding a tensor over a higher-rank mesh
sharding_type = ShardingType.PARTIAL
else:
sharding_type = _get_sharding_type(partition_spec, mesh.size())
replicate_dims = {i for i, d in enumerate(partition_spec) if d is None}
group_assignment, replication_groups = _get_group_assignment(
sharding_type, tile_assignment, len(partition_spec), replicate_dims)

tile_assignment = tile_assignment.tolist()
sharding_type = int(sharding_type)
return tile_assignment, group_assignment, replication_groups, sharding_type


def _translate_named_partition_spec(mesh: Mesh, partition_spec: Tuple):
_partition_spec = list()
for p in partition_spec:
Expand Down Expand Up @@ -507,14 +506,11 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor],
assert len(t.shape) == len(partition_spec), \
f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})."

tile_assignment, group_assignment, replication_groups, sharding_type = _extract_op_sharding_specs(
mesh, partition_spec)
if use_dynamo_custom_op:
# Allows Dynamo to capture mark_sharding op
annotate_func = torch_xla._XLAC._xla_mark_sharding_dynamo_custom_op
annotate_func(
unwrap_sharded_tensor(t), tile_assignment, group_assignment,
replication_groups, sharding_type)
unwrap_sharded_tensor(t), *mesh._get_op_sharding_args(partition_spec))
else:
op_sharding = mesh.get_op_sharding(partition_spec)
annotate_func = torch_xla._XLAC._xla_mark_sharding
Expand Down

0 comments on commit acdd21b

Please sign in to comment.