Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement mark_sharding as a custom op to support dynamo spmd activation sharding #5712

Merged
merged 14 commits into from
Nov 11, 2023

Conversation

wonjoolee95
Copy link
Collaborator

@wonjoolee95 wonjoolee95 commented Oct 19, 2023

Implement mark_sharding as a custom op to support dynamo spmd activation sharding.

The PR is a bit messy as it includes fair bit amount refactoring, here is a quick summary:

  • In init_python_bindings.cpp
    • Move the existing _xla_mark_sharding logic out to a helper function, so it can be called by the new custom op xla_mark_sharding_dynamo_custom_op.
    • Move the torch custom op registration TORCH_LIBRARY from aten_autograd_ops.h to init_python_bindings.cpp since torch custom ops can registered only in one location.
  • Update the existing torch_xla/experimental/xla_sharding.py function to accept an additional boolean flag dynamo_custom_op. When set to true, it calls this new custom op xla_mark_sharding_dynamo_custom_op instead of the existing ``_xla_mark_sharding`.

Test plan:
python test/spmd/test_dynamo_spmd.py DynamoSpmdInferenceTest.test_mark_sharding_inside_compile

@wonjoolee95 wonjoolee95 force-pushed the wonjoo/dynamo-custom-op branch 3 times, most recently from 63f2673 to d6b5852 Compare October 24, 2023 23:07
@@ -1626,6 +1626,11 @@ void InitXlaModuleBindings(py::module m) {
// Register sharded tensor data.
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data());
});
m.def("_xla_mark_sharding_custom_op",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit. for future reference, can we make it explicit by calling _xla_mark_sharding_dynamo_custom_op?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sg, updatd. Also added a new API in xla_sharding.py named mark_sharding_dynamo_custom_op specifically for interacting with this new custom op.

y = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]],
dtype=torch.float,
device=xm.xla_device())
ys = xs.mark_sharding(y, self._get_mesh((1, self.n_devices)), (0, 1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this should test the activation sharding use-case. cc @wonjoolee95

xla::XlaOp BuildCustomMarkSharding(const torch::lazy::BackendDevice& device,
const xla::XlaOp& input,
const xla::XlaOp sharding) {
return xla::CustomCall(input.builder(), /*call_target_name=*/"MarkSharding",
Copy link
Contributor

@yeounoh yeounoh Oct 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we re-use the existing CustomSharding op? mark_sharding can be translated into the CustomSharding op for an IR, which is the activation node. Unless it's the Dynamo side change to register the custom ops, we can reuse the existing extern const OpKindWrapper xla_custom_sharding, or there is more details/differences here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original intention here was to have a separate custom op that includes all the logic for mark_sharding so Dynamo can properly capture mark_sharding as an op. All the logic currently at xla_mark_sharding at init_python_bindings.cpp would be moved to the lowering logic of this new custom op, so Dynamo can trace and recognize it properly.

@wonjoolee95 wonjoolee95 force-pushed the wonjoo/dynamo-custom-op branch 3 times, most recently from ad4cc4f to e88c7d3 Compare October 26, 2023 23:56
@wonjoolee95
Copy link
Collaborator Author

Did some tests locally, it seems like we need to move rest of mark_sharding logic that's currently at tensor_methods.cpp:: custom_mark_sharding to the actual lowering logic of the new CustomMarkSharding op for dynamo to properly recognize and trace this new op. Now working on this fix.

@wonjoolee95 wonjoolee95 force-pushed the wonjoo/dynamo-custom-op branch from 1bceb4a to 3c51d5e Compare October 31, 2023 23:38
@wonjoolee95 wonjoolee95 force-pushed the wonjoo/dynamo-custom-op branch from 3c51d5e to 08d6296 Compare November 2, 2023 18:32
@wonjoolee95 wonjoolee95 force-pushed the wonjoo/dynamo-custom-op branch from 4b453c7 to 9aaa533 Compare November 7, 2023 05:08
@wonjoolee95 wonjoolee95 changed the title [WIP] Implement mark_sharding as a custom op to support dynamo spmd activation sharding Implement mark_sharding as a custom op to support dynamo spmd activation sharding Nov 7, 2023
@wonjoolee95 wonjoolee95 marked this pull request as ready for review November 7, 2023 05:13
@wonjoolee95 wonjoolee95 force-pushed the wonjoo/dynamo-custom-op branch 2 times, most recently from 281d49a to fbe5c79 Compare November 7, 2023 06:49
@miladm
Copy link
Collaborator

miladm commented Nov 7, 2023

Thanks @wonjoolee95 for this PR.
As a follow up to this PR, once landed, can you please evaluate the perf gain impact of activation sharding on the llama2 model?

@wonjoolee95 wonjoolee95 force-pushed the wonjoo/dynamo-custom-op branch 3 times, most recently from f468ed6 to 3604c15 Compare November 8, 2023 17:41
@wonjoolee95 wonjoolee95 force-pushed the wonjoo/dynamo-custom-op branch from 3604c15 to a4318a6 Compare November 8, 2023 19:57
@wonjoolee95 wonjoolee95 requested a review from JackCaoG November 8, 2023 19:58

dynamo_linear = torch.compile(linear, backend="openxla")
dynamo_res = dynamo_linear(xla_x)
torch.allclose(xla_res.cpu(), dynamo_res.cpu())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a counter check? We want to make sure we are not recompiling across differerent runs. You can either add it here or add a separate test.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly lgtm, minor nits

@wonjoolee95 wonjoolee95 requested a review from JackCaoG November 9, 2023 05:34
int(sharding_type))

# If flatten_opsharding = True, return the flattened version of OpSharding
if flatten_opsharding:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is flattening here? Just returning as a tuple? If so, maybe call it as_tuple. I am still hesitant to override the return type here... you can try this using the accessor methods instead, since you only need the sharding_type value in your work. Let me create a follow-up PR and add you, let's land this for now. Thanks @wonjoolee95

@@ -471,6 +479,9 @@ def mark_sharding(
>> mesh_shape = (4, 2)
>> partition_spec = (0, None)

dynamo_custom_op (bool): if set to True, it calls the dynamo custom op variant of mark_sharding
to make itself recognizeable and traceable by dynamo.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit. "recognizeable and traceable" --> "traceable"

[](const at::Tensor& input, const py::list& tile_assignment,
const py::list& group_assignment, const py::list& replication_groups,
int sharding_type) {
c10::List<at::IntArrayRef> tile_assignment_list =
Copy link
Contributor

@yeounoh yeounoh Nov 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move the following data processing logic into xla_mark_sharding_dynamo_custom_op and just call

xla_mark_sharding_dynamo_custom_op(
              input, tile_assignment_list, group_assignment_list,
              replication_groups_list, sharding_type);
        });

similar to what you've done for mark_sharing.

Copy link
Contributor

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, have some comments that we can address in a follow-up PR. Thank you @wonjoolee95

@wonjoolee95 wonjoolee95 merged commit 367f47f into master Nov 11, 2023
17 checks passed
mbzomowski pushed a commit to mbzomowski-test-org/xla that referenced this pull request Nov 16, 2023
zpcore pushed a commit that referenced this pull request Nov 21, 2023
lsy323 pushed a commit to lsy323/xla that referenced this pull request Nov 28, 2023
ManfeiBai pushed a commit that referenced this pull request Nov 29, 2023
ManfeiBai pushed a commit that referenced this pull request Nov 29, 2023
chunnienc pushed a commit to chunnienc/xla that referenced this pull request Dec 14, 2023
bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants