-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support list of ShardingSpec in MpDeviceLoader #5789
base: master
Are you sure you want to change the base?
Conversation
8cd4e2e
to
ab617f8
Compare
for sharding in input_sharding: | ||
if sharding.can_apply(tensor): | ||
shardings[i] = sharding.xla_spec(tensor) | ||
break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we also add some comments about the first match is applied?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks Yeounoh!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is first match something common practice? Or will it be better:
- either one sharding spec for everything;
- or one sharding spec for each input and in order.
That the behavior is well defined. Otherwise, this "compatible check" is totally a black box for the user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's definitely not ideal... Probably better would be to have the user provide sharding specs in a structure matching that of the inputs and avoid the black-box compatibility check, e.g.:
# Inputs of the form:
{"a": torch.randn(16), "b": torch.randn(16, 16), "c": torch.randn(16, 16, 16)}
# Then input_sharding would be:
{"a": ShardingSpec(mesh, (0,)), "b": ShardingSpec(mesh, (0, None)), "c": ShardingSpec(mesh, (0, None, None))}
This is a bit of a refactor of the TensorToXlaArena code though, so I took the easy way out. I'll go ahead and take a stab at the cleaner approach now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worries. Take your time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
In #5768, there is a use case to shard different-rank input tensors. Currently, the MpDeviceLoader sharding only accepts a single spec, which is applied to compatible tensors. With this change, multiple specs can be provided in construction, and each input tensor will have the first compatible spec applied to it.