-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement mark_sharding as a custom op to support dynamo spmd activation sharding #5712
Conversation
63f2673
to
d6b5852
Compare
@@ -1626,6 +1626,11 @@ void InitXlaModuleBindings(py::module m) { | |||
// Register sharded tensor data. | |||
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); | |||
}); | |||
m.def("_xla_mark_sharding_custom_op", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit. for future reference, can we make it explicit by calling _xla_mark_sharding_dynamo_custom_op
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sg, updatd. Also added a new API in xla_sharding.py
named mark_sharding_dynamo_custom_op
specifically for interacting with this new custom op.
test/spmd/test_dynamo_spmd.py
Outdated
y = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], | ||
dtype=torch.float, | ||
device=xm.xla_device()) | ||
ys = xs.mark_sharding(y, self._get_mesh((1, self.n_devices)), (0, 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this should test the activation sharding use-case. cc @wonjoolee95
torch_xla/csrc/xla_lower_util.cpp
Outdated
xla::XlaOp BuildCustomMarkSharding(const torch::lazy::BackendDevice& device, | ||
const xla::XlaOp& input, | ||
const xla::XlaOp sharding) { | ||
return xla::CustomCall(input.builder(), /*call_target_name=*/"MarkSharding", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we re-use the existing CustomSharding
op? mark_sharding
can be translated into the CustomSharding
op for an IR, which is the activation node. Unless it's the Dynamo side change to register the custom ops, we can reuse the existing extern const OpKindWrapper xla_custom_sharding
, or there is more details/differences here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original intention here was to have a separate custom op that includes all the logic for mark_sharding
so Dynamo can properly capture mark_sharding
as an op. All the logic currently at xla_mark_sharding
at init_python_bindings.cpp
would be moved to the lowering logic of this new custom op, so Dynamo can trace and recognize it properly.
ad4cc4f
to
e88c7d3
Compare
Did some tests locally, it seems like we need to move rest of |
1bceb4a
to
3c51d5e
Compare
3c51d5e
to
08d6296
Compare
4b453c7
to
9aaa533
Compare
281d49a
to
fbe5c79
Compare
Thanks @wonjoolee95 for this PR. |
f468ed6
to
3604c15
Compare
3604c15
to
a4318a6
Compare
test/spmd/test_dynamo_spmd.py
Outdated
|
||
dynamo_linear = torch.compile(linear, backend="openxla") | ||
dynamo_res = dynamo_linear(xla_x) | ||
torch.allclose(xla_res.cpu(), dynamo_res.cpu()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add a counter check? We want to make sure we are not recompiling across differerent runs. You can either add it here or add a separate test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mostly lgtm, minor nits
int(sharding_type)) | ||
|
||
# If flatten_opsharding = True, return the flattened version of OpSharding | ||
if flatten_opsharding: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is flattening here? Just returning as a tuple? If so, maybe call it as_tuple
. I am still hesitant to override the return type here... you can try this using the accessor methods instead, since you only need the sharding_type
value in your work. Let me create a follow-up PR and add you, let's land this for now. Thanks @wonjoolee95
@@ -471,6 +479,9 @@ def mark_sharding( | |||
>> mesh_shape = (4, 2) | |||
>> partition_spec = (0, None) | |||
|
|||
dynamo_custom_op (bool): if set to True, it calls the dynamo custom op variant of mark_sharding | |||
to make itself recognizeable and traceable by dynamo. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit. "recognizeable and traceable" --> "traceable"
[](const at::Tensor& input, const py::list& tile_assignment, | ||
const py::list& group_assignment, const py::list& replication_groups, | ||
int sharding_type) { | ||
c10::List<at::IntArrayRef> tile_assignment_list = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move the following data processing logic into xla_mark_sharding_dynamo_custom_op
and just call
xla_mark_sharding_dynamo_custom_op(
input, tile_assignment_list, group_assignment_list,
replication_groups_list, sharding_type);
});
similar to what you've done for mark_sharing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, have some comments that we can address in a follow-up PR. Thank you @wonjoolee95
Implement mark_sharding as a custom op to support dynamo spmd activation sharding.
The PR is a bit messy as it includes fair bit amount refactoring, here is a quick summary:
_xla_mark_sharding
logic out to a helper function, so it can be called by the new custom opxla_mark_sharding_dynamo_custom_op
.TORCH_LIBRARY
fromaten_autograd_ops.h
toinit_python_bindings.cpp
since torch custom ops can registered only in one location.torch_xla/experimental/xla_sharding.py
function to accept an additional boolean flagdynamo_custom_op
. When set to true, it calls this new custom opxla_mark_sharding_dynamo_custom_op
instead of the existing ``_xla_mark_sharding`.Test plan:
python test/spmd/test_dynamo_spmd.py DynamoSpmdInferenceTest.test_mark_sharding_inside_compile