Skip to content

Commit

Permalink
Open XLA pin update (pytorch#5675)
Browse files Browse the repository at this point in the history
Open XLA pin update - updated to 20231010
  • Loading branch information
qihqi authored and chunnienc committed Dec 14, 2023
1 parent ee8e559 commit 10022f7
Show file tree
Hide file tree
Showing 41 changed files with 283 additions and 171 deletions.
20 changes: 10 additions & 10 deletions .circleci/common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -148,16 +148,16 @@ function run_torch_xla_python_tests() {
else
./test/run_tests.sh

# GPU tests
# CUDA tests
if [ -x "$(command -v nvidia-smi)" ]; then
# These tests fail on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840)
PJRT_DEVICE=GPU python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
PJRT_DEVICE=GPU python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1
XLA_DISABLE_FUNCTIONALIZATION=1 PJRT_DEVICE=GPU python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
# These tests fail on CUDA with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840)
PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1
XLA_DISABLE_FUNCTIONALIZATION=1 PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
# Syncfree SGD optimizer tests
if [ -d ./torch_xla/amp/syncfree ]; then
echo "Running Syncfree Optimizer Test"
PJRT_DEVICE=GPU python test/test_syncfree_optimizers.py
PJRT_DEVICE=CUDA python test/test_syncfree_optimizers.py

# Following test scripts are mainly useful for
# performance evaluation & comparison among different
Expand Down Expand Up @@ -192,9 +192,9 @@ function run_torch_xla_cpp_tests() {
if [ "$USE_COVERAGE" != "0" ]; then
# TODO(yeounoh) shard the coverage testing
if [ -x "$(command -v nvidia-smi)" ]; then
PJRT_DEVICE=GPU test/cpp/run_tests.sh $EXTRA_ARGS -L""
PJRT_DEVICE=CUDA test/cpp/run_tests.sh $EXTRA_ARGS -L""
cp $XLA_DIR/bazel-out/_coverage/_coverage_report.dat /tmp/cov1.dat
PJRT_DEVICE=GPU test/cpp/run_tests.sh -X early_sync -F AtenXlaTensorTest.TestEarlySyncLiveTensors -L"" $EXTRA_ARGS
PJRT_DEVICE=CUDA test/cpp/run_tests.sh -X early_sync -F AtenXlaTensorTest.TestEarlySyncLiveTensors -L"" $EXTRA_ARGS
cp $XLA_DIR/bazel-out/_coverage/_coverage_report.dat /tmp/cov2.dat
lcov --add-tracefile /tmp/cov1.dat -a /tmp/cov2.dat -o /tmp/merged.dat
else
Expand All @@ -206,8 +206,8 @@ function run_torch_xla_cpp_tests() {
else
# Shard GPU testing
if [ -x "$(command -v nvidia-smi)" ]; then
PJRT_DEVICE=GPU test/cpp/run_tests.sh $EXTRA_ARGS -L""
PJRT_DEVICE=GPU test/cpp/run_tests.sh -X early_sync -F AtenXlaTensorTest.TestEarlySyncLiveTensors -L"" $EXTRA_ARGS
PJRT_DEVICE=CUDA test/cpp/run_tests.sh $EXTRA_ARGS -L""
PJRT_DEVICE=CUDA test/cpp/run_tests.sh -X early_sync -F AtenXlaTensorTest.TestEarlySyncLiveTensors -L"" $EXTRA_ARGS
else
PJRT_DEVICE=CPU test/cpp/run_tests.sh $EXTRA_ARGS -L""
fi
Expand Down
11 changes: 5 additions & 6 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,14 @@ http_archive(
patch_tool = "patch",
patches = [
"//openxla_patches:cache_urls.diff",
"//openxla_patches:f16_abi_clang.diff",
"//openxla_patches:gpu_race_condition.diff",
"//openxla_patches:constexpr_return.diff",
"//openxla_patches:pjrt_api_tsl_logging.diff",
"//openxla_patches:pjrt_c_api_dynamic_dimensions.diff",
"//openxla_patches:gpu_race_condition.diff",
"//openxla_patches:f16_abi_clang.diff",
"//openxla_patches:gpu_topk_rewriter.diff",
],
strip_prefix = "xla-97a5f819faf9ff793b7ba68ff1f31f74f9459c18",
strip_prefix = "xla-51b59cfb1999c6f1b3ec59851675044b2c502aae",
urls = [
"https://github.com/openxla/xla/archive/97a5f819faf9ff793b7ba68ff1f31f74f9459c18.tar.gz",
"https://github.com/openxla/xla/archive/51b59cfb1999c6f1b3ec59851675044b2c502aae.tar.gz",
],
)

Expand Down
184 changes: 184 additions & 0 deletions openxla_patches/gpu_topk_rewriter.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
diff --git a/xla/service/topk_rewriter.cc b/xla/service/topk_rewriter.cc
index da872d962..1b7141055 100644
--- a/xla/service/topk_rewriter.cc
+++ b/xla/service/topk_rewriter.cc
@@ -196,6 +196,8 @@ std::optional<int64_t> TopkRewriter::SortIsInTopK(HloInstruction* inst) {
return std::nullopt;
}
const int64_t sort_dim = sort->sort_dimension();
+ const int64_t batch_dim = sort_dim == 1 ? 0 : 1;
+ const bool has_batch = data->shape().rank() == 2;

bool supported = true;
std::optional<int64_t> k;
@@ -220,15 +222,10 @@ std::optional<int64_t> TopkRewriter::SortIsInTopK(HloInstruction* inst) {
supported = false;
break;
}
- for (int64_t i = 0; i < slice->slice_limits().size(); ++i) {
- if (i != sort_dim &&
- slice->slice_limits(i) != slice->operand(0)->shape().dimensions(i)) {
- // Slicing along a non-sort dimension isn't supported.
- supported = false;
- break;
- }
- }
- if (!supported) {
+ if (has_batch && slice->slice_limits(batch_dim) !=
+ slice->operand(0)->shape().dimensions(batch_dim)) {
+ // Slicing along the batch dimension isn't supported.
+ supported = false;
break;
}
if (k == std::nullopt) {
@@ -260,57 +257,29 @@ StatusOr<bool> TopkRewriter::TransformToCustomCall(
HloSortInstruction* sort = DynCast<HloSortInstruction>(inst);
HloInstruction* data = sort->mutable_operand(0);
const PrimitiveType element_type = data->shape().element_type();
- const Shape data_shape = data->shape();

- if (element_type != F32 && element_type != BF16) {
+ if ((data->shape().rank() != 1 && data->shape().rank() != 2) ||
+ (element_type != F32 && element_type != BF16)) {
continue;
}

- // Sort dimension must be the first or last dimension.
const int64_t sort_dim = sort->sort_dimension();
- if (sort_dim != 0 && sort_dim != data_shape.rank() - 1) {
- continue;
- }
+ const int64_t batch_dim = sort_dim == 1 ? 0 : 1;
+ const bool has_batch = data->shape().rank() == 2;

// Profitability check.
if (!is_profitable_to_convert_(sort, *k)) {
continue;
}

- HloInstruction* input = data;
- const bool has_batch = data_shape.rank() >= 2;
- const int64_t input_size = data_shape.dimensions(sort_dim);
- int64_t batch_size = 1;
- Shape topk_input_shape;
-
- if (has_batch) {
- // The TopK custom call expects either a 1d tensor or a 2d tensor with
- // the last dimension being the sort dimension. An input with rank > 2
- // is reshaped into a 2d tensor by combining non-sort dimensions into a
- // single batch dimension. The original non-sort dimensions are
- // restored for the outputs with another reshape after the custom call.
- batch_size =
- ShapeUtil::ElementsIn(data_shape) / data_shape.dimensions(sort_dim);
- topk_input_shape =
- ShapeUtil::MakeShape(element_type, {batch_size, input_size});
-
- if (data_shape.rank() > 2) {
- // Reshape to 2d.
- input = comp->AddInstruction(HloInstruction::CreateReshape(
- sort_dim == 0
- ? ShapeUtil::MakeShape(element_type, {input_size, batch_size})
- : ShapeUtil::MakeShape(element_type,
- {batch_size, input_size}),
- input));
- }
-
- if (sort_dim == 0) {
- // Transpose for the custom call when sorting the first dimension.
- input = comp->AddInstruction(
- HloInstruction::CreateTranspose(topk_input_shape, input, {1, 0}));
- }
- } else {
- topk_input_shape = data_shape;
+ const int64_t batch_size =
+ has_batch ? sort->operand(0)->shape().dimensions(batch_dim) : 1;
+ const int64_t input_size = sort->operand(0)->shape().dimensions(sort_dim);
+ HloInstruction* input = sort->mutable_operand(0);
+ if (has_batch && sort_dim == 0) {
+ input = comp->AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(element_type, {batch_size, input_size}), input,
+ {1, 0}));
}

Shape topk_shape =
@@ -331,26 +300,13 @@ StatusOr<bool> TopkRewriter::TransformToCustomCall(
comp->AddInstruction(HloInstruction::CreateGetTupleElement(
topk->shape().tuple_shapes(1), topk, 1));

- if (has_batch) {
- if (sort_dim == 0) {
- // Transpose back.
- value_gte = comp->AddInstruction(HloInstruction::CreateTranspose(
- ShapeUtil::MakeShape(element_type, {k.value(), batch_size}),
- value_gte, {1, 0}));
- index_gte = comp->AddInstruction(HloInstruction::CreateTranspose(
- ShapeUtil::MakeShape(S32, {k.value(), batch_size}), index_gte,
- {1, 0}));
- }
- if (data_shape.rank() > 2) {
- // Reshape back.
- std::vector<int64_t> shape_dim(data_shape.dimensions().begin(),
- data_shape.dimensions().end());
- shape_dim[sort_dim] = k.value();
- value_gte = comp->AddInstruction(HloInstruction::CreateReshape(
- ShapeUtil::MakeShape(element_type, shape_dim), value_gte));
- index_gte = comp->AddInstruction(HloInstruction::CreateReshape(
- ShapeUtil::MakeShape(S32, shape_dim), index_gte));
- }
+ if (has_batch && sort_dim == 0) {
+ value_gte = comp->AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(element_type, {k.value(), batch_size}),
+ value_gte, {1, 0}));
+ index_gte = comp->AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(S32, {k.value(), batch_size}), index_gte,
+ {1, 0}));
}

for (HloInstruction* user : sort->users()) {
diff --git a/xla/service/topk_rewriter_test.cc b/xla/service/topk_rewriter_test.cc
index 36e723737..25ce150e0 100644
--- a/xla/service/topk_rewriter_test.cc
+++ b/xla/service/topk_rewriter_test.cc
@@ -326,42 +326,6 @@ ENTRY cluster {
EXPECT_THAT(cc->custom_call_target(), "TopK");
}

-TEST_F(TopkRewriterTest, RewriteReshape) {
- const std::string hlo_string = R"(
-HloModule module
-)" + getComparator() + R"(
-ENTRY cluster {
- %arg_tuple.1 = f32[3,8,1234567] parameter(0)
- %iota.4 = s32[3,8,1234567] iota(), iota_dimension=2
- %sort.27 = (f32[3,8,1234567], s32[3,8,1234567]) sort(%arg_tuple.1, %iota.4),
- dimensions={2}, is_stable=true, to_apply=%compare
- %get-tuple-element.28 = f32[3, 8,1234567] get-tuple-element(%sort.27), index=0
- %slice.29 = f32[3,8,5] slice(%get-tuple-element.28), slice={[0:3], [0:8], [0:5]}
- %get-tuple-element.30 = s32[3,8,1234567] get-tuple-element(%sort.27), index=1
- %slice.31 = s32[3,8,5] slice(%get-tuple-element.30), slice={[0:3], [0:8], [0:5]}
- ROOT %tuple.32 = (f32[3,8,5], s32[3,8,5]) tuple(%slice.29, %slice.31)
-})";
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(hlo_string));
- TopkRewriter rewriter(
- [](const HloSortInstruction*, int64_t) { return true; });
- TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get()));
- TF_ASSERT_OK(HloDCE().Run(module.get()).status());
- EXPECT_TRUE(changed);
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(
- m::Reshape(m::GetTupleElement(
- m::CustomCall(m::Reshape(m::Parameter(0))), 0)),
- m::Reshape(m::GetTupleElement(
- m::CustomCall(m::Reshape(m::Parameter(0))), 1)))));
- const HloInstruction* cc = module->entry_computation()
- ->root_instruction()
- ->operand(0)
- ->operand(0)
- ->operand(0);
- EXPECT_THAT(cc->custom_call_target(), "TopK");
-}
-
TEST_F(TopkRewriterTest, RewriteNoIota) {
const std::string hlo_string = R"(
HloModule module
21 changes: 0 additions & 21 deletions openxla_patches/pjrt_api_tsl_logging.diff

This file was deleted.

76 changes: 0 additions & 76 deletions openxla_patches/pjrt_c_api_dynamic_dimensions.diff

This file was deleted.

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@

base_dir = os.path.dirname(os.path.abspath(__file__))

_libtpu_version = '0.1.dev20230825'
_libtpu_version = '0.1.dev20231010'
_libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl'


Expand Down
2 changes: 1 addition & 1 deletion test/cpp/test_aten_xla_tensor_2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1512,7 +1512,7 @@ TEST_F(AtenXlaTensorTest, TestGroupNormBackward) {
/*cudnn_enabled=*/false);
};
torch::Tensor undef;
ForEachDevice({XlaDeviceType::GPU, XlaDeviceType::TPU},
ForEachDevice({XlaDeviceType::CUDA, XlaDeviceType::TPU},
[&](const torch::Device& device) {
TestBackward({input, undef_weight ? undef : weight,
undef_weight ? undef : bias},
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/test_aten_xla_tensor_6.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ TEST_F(AtenXlaTensorTest, TestEmbeddingBackward) {
TEST_F(AtenXlaTensorTest, TestAmpUpdateScale) {
XlaDeviceType hw_type =
static_cast<XlaDeviceType>(bridge::GetDefaultDevice()->type());
if (hw_type != XlaDeviceType::GPU && hw_type != XlaDeviceType::CPU) {
if (hw_type != XlaDeviceType::CUDA && hw_type != XlaDeviceType::CPU) {
return;
}
torch::Tensor growth_tracker =
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/test_replication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class ReplicationTest : public AtenXlaTensorTestBase {};

TEST_F(ReplicationTest, TestNSingleReplication) {
WithAllDevices(
{XlaDeviceType::TPU, XlaDeviceType::GPU},
{XlaDeviceType::TPU, XlaDeviceType::CUDA},
[&](const std::vector<torch::lazy::BackendDevice>& devices,
const std::vector<torch::lazy::BackendDevice>& all_devices) {
TestSingleReplication(devices, all_devices);
Expand Down
2 changes: 1 addition & 1 deletion test/pjrt/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _ddp_init(index: int = ...):
def test_ddp_init(self):
pjrt.run_multiprocess(self._ddp_init)

@absltest.skipIf(xr.device_type() == 'GPU',
@absltest.skipIf(xr.device_type() in ('GPU', 'CUDA', 'ROCM'),
"GPU device is not supported by pjrt.spawn_threads")
def test_ddp_init_threaded(self):
pjrt.spawn_threads(self._ddp_init)
Expand Down
Loading

0 comments on commit 10022f7

Please sign in to comment.