-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement mark_sharding as a custom op to support dynamo spmd activation sharding #5712
Merged
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
53e7ec0
Implement mark_sharding as a custom op to support dynamo spmd activat…
wonjoolee95 f0e8a94
Update to include OpSharding as an input
wonjoolee95 7891b42
Rebase with master and run linter
wonjoolee95 6aeeecf
Update unit tests
wonjoolee95 dc19b9b
Refine custom marking sharding op
wonjoolee95 a20d710
Re-run linter due to wrong version
wonjoolee95 bd169c2
Add new API for custom mark sharding op and update tests
wonjoolee95 ae05c9a
Add torch pin
wonjoolee95 08d6296
Clean up some code
wonjoolee95 9aaa533
Update code to transfer pylist to xla::OpSharding
wonjoolee95 a98bfb2
Update unit test and run linter
wonjoolee95 a4318a6
Address comments -- fix typos and variable names
wonjoolee95 e35ca64
Run linter
wonjoolee95 ae721ae
Add metric assertions to unit tests
wonjoolee95 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,7 @@ | |
#include "pybind11/pytypes.h" | ||
#include "pybind11/stl_bind.h" | ||
#include "torch_xla/csrc/XLANativeFunctions.h" | ||
#include "torch_xla/csrc/aten_autograd_ops.h" | ||
#include "torch_xla/csrc/aten_xla_bridge.h" | ||
#include "torch_xla/csrc/device.h" | ||
#include "torch_xla/csrc/helpers.h" | ||
|
@@ -1559,72 +1560,39 @@ void InitXlaModuleBindings(py::module m) { | |
tile_assignment, group_assignment, replication_groups, | ||
ShardingUtil::ShardingType(sharding_type)); | ||
})); | ||
m.def("_xla_mark_sharding", [](const at::Tensor& input, | ||
xla::OpSharding sharding) { | ||
TORCH_LAZY_COUNTER("XlaMarkSharding", 1); | ||
XLA_CHECK(UseVirtualDevice()) | ||
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; | ||
XLATensorPtr xtensor = bridge::GetXlaTensor(input); | ||
auto new_sharding_spec = std::make_shared<XLATensor::ShardingSpec>( | ||
sharding, MakeShapeWithDeviceLayout( | ||
xtensor->shape(), | ||
static_cast<XlaDeviceType>(xtensor->GetDevice().type()))); | ||
|
||
// For Non DeviceData IR values, we directly attach the sharding spec | ||
// to the xtensor. | ||
const DeviceData* device_data_node = nullptr; | ||
if (xtensor->CurrentIrValue()) { | ||
device_data_node = DeviceData::Cast(xtensor->CurrentIrValue().node.get()); | ||
if (!device_data_node) { | ||
tensor_methods::custom_sharding_(xtensor, new_sharding_spec); | ||
return; | ||
} | ||
} | ||
m.def("_xla_mark_sharding", | ||
[](const at::Tensor& input, xla::OpSharding sharding) { | ||
ShardingUtil::xla_mark_sharding(input, sharding); | ||
}); | ||
m.def("_xla_mark_sharding_dynamo_custom_op", | ||
[](const at::Tensor& input, const py::list& tile_assignment, | ||
const py::list& group_assignment, const py::list& replication_groups, | ||
int sharding_type) { | ||
c10::List<at::IntArrayRef> tile_assignment_list = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we move the following data processing logic into
similar to what you've done for mark_sharing. |
||
c10::List<at::IntArrayRef>(); | ||
for (auto t : tile_assignment) { | ||
tile_assignment_list.push_back( | ||
at::IntArrayRef(t.cast<std::vector<int64_t>>())); | ||
} | ||
|
||
// For data, we need to deal with the data transfers between | ||
// host and device. | ||
at::Tensor cpu_tensor; | ||
if (xtensor->CurrentTensorData().has_value()) { | ||
TORCH_LAZY_COUNTER("VirtualDeviceUsage", 1); | ||
// When virtual device is enabled for SPMD, we defer the initial | ||
// data transfer to the device and retain the original data on the | ||
// host, until the sharded data transfer. | ||
cpu_tensor = xtensor->CurrentTensorData().value(); | ||
} else { | ||
// A new input tensor is not expected to be sharded. But sometimes, | ||
// the same input is called for sharding annotation over multiple steps, | ||
// in which case we can skip if it's the same sharding; however, if it's | ||
// the same input with a different sharding then we block & ask the user | ||
// to clear the existing sharding first. | ||
auto current_sharding_spec = xtensor->sharding_spec(); | ||
if (current_sharding_spec && (current_sharding_spec->sharding.type() != | ||
xla::OpSharding::REPLICATED)) { | ||
XLA_CHECK(ShardingUtil::EqualShardingSpecs(*new_sharding_spec, | ||
*current_sharding_spec)) | ||
<< "Existing annotation must be cleared first."; | ||
return; | ||
} | ||
c10::List<at::IntArrayRef> group_assignment_list = | ||
c10::List<at::IntArrayRef>(); | ||
for (auto t : group_assignment) { | ||
group_assignment_list.push_back( | ||
at::IntArrayRef(t.cast<std::vector<int64_t>>())); | ||
} | ||
|
||
// If the at::Tensor data is not present, we need to re-download the | ||
// tensor from the physical device to CPU. In that case, the value | ||
// must be present on the backend device. | ||
XLA_CHECK((xtensor->CurrentDataHandle() && | ||
xtensor->CurrentDataHandle()->HasValue()) || | ||
device_data_node != nullptr) | ||
<< "Cannot shard tensor. Data does not present on any device."; | ||
std::vector<XLATensorPtr> xla_tensors{xtensor}; | ||
cpu_tensor = XLAGraphExecutor::Get()->GetTensors(&xla_tensors)[0]; | ||
} | ||
auto xla_data = CreateTensorsData( | ||
std::vector<at::Tensor>{cpu_tensor}, | ||
std::vector<XLATensor::ShardingSpecPtr>{new_sharding_spec}, | ||
std::vector<std::string>{GetVirtualDevice().toString()})[0]; | ||
xtensor->SetXlaData(xla_data); | ||
xtensor->SetShardingSpec(*new_sharding_spec); | ||
c10::List<at::IntArrayRef> replication_groups_list = | ||
c10::List<at::IntArrayRef>(); | ||
for (auto t : replication_groups) { | ||
replication_groups_list.push_back( | ||
at::IntArrayRef(t.cast<std::vector<int64_t>>())); | ||
} | ||
|
||
// Register sharded tensor data. | ||
XLAGraphExecutor::Get()->RegisterTensor(xtensor->data()); | ||
}); | ||
xla_mark_sharding_dynamo_custom_op( | ||
input, tile_assignment_list, group_assignment_list, | ||
replication_groups_list, sharding_type); | ||
}); | ||
m.def("_xla_clear_sharding", [](const at::Tensor& input) { | ||
XLATensorPtr xtensor = bridge::GetXlaTensor(input); | ||
xtensor->ClearShardingSpec(); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yeounoh, I updated the unit test for doing
mark_sharding
inside torch compile to be part of the existingSimpleLinear
. Here, I just do amark_sharding
call inside theforward
function. Please let me know if you think this will suffice.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SG