Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support list of ShardingSpec in MpDeviceLoader #5789

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

jonb377
Copy link
Collaborator

@jonb377 jonb377 commented Nov 10, 2023

In #5768, there is a use case to shard different-rank input tensors. Currently, the MpDeviceLoader sharding only accepts a single spec, which is applied to compatible tensors. With this change, multiple specs can be provided in construction, and each input tensor will have the first compatible spec applied to it.

for sharding in input_sharding:
if sharding.can_apply(tensor):
shardings[i] = sharding.xla_spec(tensor)
break
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we also add some comments about the first match is applied?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks Yeounoh!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is first match something common practice? Or will it be better:

  1. either one sharding spec for everything;
  2. or one sharding spec for each input and in order.

That the behavior is well defined. Otherwise, this "compatible check" is totally a black box for the user.

Copy link
Collaborator Author

@jonb377 jonb377 Nov 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's definitely not ideal... Probably better would be to have the user provide sharding specs in a structure matching that of the inputs and avoid the black-box compatibility check, e.g.:

# Inputs of the form:
{"a": torch.randn(16), "b": torch.randn(16, 16), "c": torch.randn(16, 16, 16)}

# Then input_sharding would be:
{"a": ShardingSpec(mesh, (0,)), "b": ShardingSpec(mesh, (0, None)), "c": ShardingSpec(mesh, (0, None, None))}

This is a bit of a refactor of the TensorToXlaArena code though, so I took the easy way out. I'll go ahead and take a stab at the cleaner approach now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries. Take your time.

Copy link
Contributor

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jonb377 jonb377 added the DO_NOT_MERGE_YET For PRs which cannot be merged, despite tests passing label Nov 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backport_2.2 DO_NOT_MERGE_YET For PRs which cannot be merged, despite tests passing SPMD / Distributed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants