From 2c735b8b4d024d515e5dff5548d1967b933445b7 Mon Sep 17 00:00:00 2001 From: rpsilva-aws Date: Wed, 4 Dec 2024 08:21:15 +0000 Subject: [PATCH 01/15] Extend operation indices to int32 --- test/cpp/test_aten_xla_tensor_1.cpp | 754 +++++++++++++--------------- torch_xla/csrc/ops/index_ops.cpp | 18 +- torch_xla/csrc/tensor_ops.cpp | 12 +- 3 files changed, 382 insertions(+), 402 deletions(-) diff --git a/test/cpp/test_aten_xla_tensor_1.cpp b/test/cpp/test_aten_xla_tensor_1.cpp index d204b344808..2d8b392870e 100644 --- a/test/cpp/test_aten_xla_tensor_1.cpp +++ b/test/cpp/test_aten_xla_tensor_1.cpp @@ -3,6 +3,7 @@ #include +#include "absl/strings/str_cat.h" #include "test/cpp/cpp_test_util.h" #include "test/cpp/torch_xla_test.h" #include "torch_xla/csrc/aten_xla_bridge.h" @@ -2131,58 +2132,6 @@ TEST_F(AtenXlaTensorTest, TestScatterReduceMaxInPlace) { ExpectCounterChanged("xla::scatter_reduce", cpp_test::GetIgnoredCounters()); } -TEST_F(AtenXlaTensorTest, TestIndexSelect) { - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor a = - isFloatingType(scalar_type) - ? torch::rand({3, 4}, torch::TensorOptions(scalar_type)) - : torch::randint(100, {3, 4}, torch::TensorOptions(scalar_type)); - for (torch::ScalarType index_scalar_type : {torch::kInt, torch::kLong}) { - torch::Tensor b = - torch::empty({2}, torch::TensorOptions(index_scalar_type)); - b[0] = 0; - b[1] = 2; - torch::Tensor c0 = torch::index_select(a, 0, b); - torch::Tensor c1 = torch::index_select(a, 1, b); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_a = CopyToDevice(a, device); - torch::Tensor xla_b = CopyToDevice(b, device); - torch::Tensor xla_c0 = torch::index_select(xla_a, 0, xla_b); - torch::Tensor xla_c1 = torch::index_select(xla_a, 1, xla_b); - AllEqual(c0, xla_c0); - AllEqual(c1, xla_c1); - }); - } - } - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::index_select", cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestIndexSelectRank0) { - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor a = - isFloatingType(scalar_type) - ? torch::rand({3, 4}, torch::TensorOptions(scalar_type)) - : torch::randint(100, {3, 4}, torch::TensorOptions(scalar_type)); - torch::Tensor b = - torch::scalar_tensor(2, torch::TensorOptions(torch::kLong)); - torch::Tensor c0 = torch::index_select(a, 0, b); - torch::Tensor c1 = torch::index_select(a, 1, b); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_a = CopyToDevice(a, device); - torch::Tensor xla_b = CopyToDevice(b, device); - torch::Tensor xla_c0 = torch::index_select(xla_a, 0, xla_b); - torch::Tensor xla_c1 = torch::index_select(xla_a, 1, xla_b); - AllEqual(c0, xla_c0); - AllEqual(c1, xla_c1); - }); - } -} - TEST_F(AtenXlaTensorTest, TestInverse) { // TODO: Renable after the LAPACK dependency issue is resolved. GTEST_SKIP(); @@ -3010,204 +2959,271 @@ TEST_F(AtenXlaTensorTest, TestMaskIndexPut) { } } -TEST_F(AtenXlaTensorTest, TestIndexPutImpl) { +class IndexOpsAtenXlaTensorTest + : public AtenXlaTensorTest, + public ::testing::WithParamInterface< + std::tuple> { + protected: + torch::ScalarType GetIndexType() const { return std::get<0>(GetParam()); } + torch::ScalarType GetValueType() const { return std::get<1>(GetParam()); } +}; + +TEST_F(IndexOpsAtenXlaTensorTest, TestIndexSelect) { + torch::ScalarType scalar_type = GetValueType(); + torch::Tensor a = + isFloatingType(scalar_type) + ? torch::rand({3, 4}, torch::TensorOptions(scalar_type)) + : torch::randint(100, {3, 4}, torch::TensorOptions(scalar_type)); + torch::Tensor b = torch::empty({2}, torch::TensorOptions(GetIndexType())); + b[0] = 0; + b[1] = 2; + torch::Tensor c0 = torch::index_select(a, 0, b); + torch::Tensor c1 = torch::index_select(a, 1, b); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_a = CopyToDevice(a, device); + torch::Tensor xla_b = CopyToDevice(b, device); + torch::Tensor xla_c0 = torch::index_select(xla_a, 0, xla_b); + torch::Tensor xla_c1 = torch::index_select(xla_a, 1, xla_b); + AllEqual(c0, xla_c0); + AllEqual(c1, xla_c1); + }); + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::index_select", cpp_test::GetIgnoredCounters()); +} + +TEST_F(IndexOpsAtenXlaTensorTest, TestIndexSelectRank0) { + torch::ScalarType scalar_type = GetValueType(); + torch::Tensor a = + isFloatingType(scalar_type) + ? torch::rand({3, 4}, torch::TensorOptions(scalar_type)) + : torch::randint(100, {3, 4}, torch::TensorOptions(scalar_type)); + torch::Tensor b = + torch::scalar_tensor(2, torch::TensorOptions(GetIndexType())); + torch::Tensor c0 = torch::index_select(a, 0, b); + torch::Tensor c1 = torch::index_select(a, 1, b); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_a = CopyToDevice(a, device); + torch::Tensor xla_b = CopyToDevice(b, device); + torch::Tensor xla_c0 = torch::index_select(xla_a, 0, xla_b); + torch::Tensor xla_c1 = torch::index_select(xla_a, 1, xla_b); + AllEqual(c0, xla_c0); + AllEqual(c1, xla_c1); + }); +} + +TEST_F(IndexOpsAtenXlaTensorTest, TestIndexPutImpl) { torch::Tensor indices = - torch::randint(-3, 3, {2, 4, 3}, torch::TensorOptions(torch::kLong)); - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor values = - torch::ones({3, 5, 6, 7}, torch::TensorOptions(scalar_type)); - for (bool accumulate : {false, true}) { - ForEachDevice([&](const torch::Device& device) { - torch::Tensor params = - isFloatingType(scalar_type) - ? torch::rand({4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type)) - : torch::randint(100, {4, 3, 5, 6, 7}, - torch::TensorOptions(scalar_type)); - torch::Tensor xla_params = CopyToDevice(params.clone(), device); - torch::Tensor result = torch::_index_put_impl_( - params, {indices}, values, accumulate, /*unsafe=*/true); - torch::Tensor xla_indices = CopyToDevice(indices, device); - torch::Tensor xla_values = CopyToDevice(values, device); - torch::Tensor xla_result = torch::_index_put_impl_( - xla_params, {xla_indices}, xla_values, accumulate, /*unsafe=*/true); - AllEqual(result, xla_result); - AllEqual(params, xla_params); - }); + torch::randint(-3, 3, {2, 4, 3}, torch::TensorOptions(GetIndexType())); + torch::ScalarType scalar_type = GetValueType(); + torch::Tensor values = + torch::ones({3, 5, 6, 7}, torch::TensorOptions(scalar_type)); + for (bool accumulate : {false, true}) { + ForEachDevice([&](const torch::Device& device) { + torch::Tensor params = + isFloatingType(scalar_type) + ? torch::rand({4, 3, 5, 6, 7}, torch::TensorOptions(scalar_type)) + : torch::randint(100, {4, 3, 5, 6, 7}, + torch::TensorOptions(scalar_type)); + torch::Tensor xla_params = CopyToDevice(params.clone(), device); + torch::Tensor result = torch::_index_put_impl_( + params, {indices}, values, accumulate, /*unsafe=*/true); + torch::Tensor xla_indices = CopyToDevice(indices, device); + torch::Tensor xla_values = CopyToDevice(values, device); + torch::Tensor xla_result = torch::_index_put_impl_( + xla_params, {xla_indices}, xla_values, accumulate, /*unsafe=*/true); + AllEqual(result, xla_result); + AllEqual(params, xla_params); + }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::_index_put_impl_", - cpp_test::GetIgnoredCounters()); - } + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::_index_put_impl_", + cpp_test::GetIgnoredCounters()); } } -TEST_F(AtenXlaTensorTest, TestIndexFillWithScalar) { +TEST_F(IndexOpsAtenXlaTensorTest, TestIndexFillWithScalar) { torch::Tensor index = - torch::tensor({0, 2}, torch::TensorOptions(torch::kLong)); + torch::tensor({0, 2}, torch::TensorOptions(GetIndexType())); + torch::ScalarType scalar_type = GetValueType(); torch::Scalar value = 42; - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor base = - isFloatingType(scalar_type) - ? torch::rand({3, 4, 5}, torch::TensorOptions(scalar_type)) - : torch::randint(100, {3, 4, 5}, torch::TensorOptions(scalar_type)); - int rank = base.dim(); - for (int dim = -rank; dim < rank; ++dim) { - torch::Tensor result = torch::index_fill(base, dim, index, value); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_base = CopyToDevice(base, device); - torch::Tensor xla_index = CopyToDevice(index, device); - torch::Tensor xla_result = - torch::index_fill(xla_base, dim, xla_index, value); - AllEqual(result, xla_result); - }); + torch::Tensor base = + isFloatingType(scalar_type) + ? torch::rand({3, 4, 5}, torch::TensorOptions(scalar_type)) + : torch::randint(100, {3, 4, 5}, torch::TensorOptions(scalar_type)); + int rank = base.dim(); + for (int dim = -rank; dim < rank; ++dim) { + torch::Tensor result = torch::index_fill(base, dim, index, value); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_base = CopyToDevice(base, device); + torch::Tensor xla_index = CopyToDevice(index, device); + torch::Tensor xla_result = + torch::index_fill(xla_base, dim, xla_index, value); + AllEqual(result, xla_result); + }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::index_fill_", cpp_test::GetIgnoredCounters()); - } + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::index_fill_", cpp_test::GetIgnoredCounters()); } } -TEST_F(AtenXlaTensorTest, TestIndexFillWithScalarInPlace) { +TEST_F(IndexOpsAtenXlaTensorTest, TestIndexFillWithScalarInPlace) { torch::Tensor index = - torch::tensor({0, 2}, torch::TensorOptions(torch::kLong)); + torch::tensor({0, 2}, torch::TensorOptions(GetIndexType())); + torch::ScalarType scalar_type = GetValueType(); torch::Scalar value = 42; int rank = 3; - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - for (int dim = -rank; dim < rank; ++dim) { - ForEachDevice([&](const torch::Device& device) { - torch::Tensor base = - isFloatingType(scalar_type) - ? torch::rand({3, 4, 5}, torch::TensorOptions(scalar_type)) - : torch::randint(100, {3, 4, 5}, - torch::TensorOptions(scalar_type)); - torch::Tensor xla_base = CopyToDevice(base.clone(), device); - torch::Tensor result = base.index_fill_(dim, index, value); - torch::Tensor xla_index = CopyToDevice(index, device); - torch::Tensor xla_result = xla_base.index_fill_(dim, xla_index, value); - AllEqual(result, xla_result); - AllEqual(base, xla_base); - }); + for (int dim = -rank; dim < rank; ++dim) { + ForEachDevice([&](const torch::Device& device) { + torch::Tensor base = + isFloatingType(scalar_type) + ? torch::rand({3, 4, 5}, torch::TensorOptions(scalar_type)) + : torch::randint(100, {3, 4, 5}, + torch::TensorOptions(scalar_type)); + torch::Tensor xla_base = CopyToDevice(base.clone(), device); + torch::Tensor result = base.index_fill_(dim, index, value); + torch::Tensor xla_index = CopyToDevice(index, device); + torch::Tensor xla_result = xla_base.index_fill_(dim, xla_index, value); + AllEqual(result, xla_result); + AllEqual(base, xla_base); + }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::index_fill_", cpp_test::GetIgnoredCounters()); - } + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::index_fill_", cpp_test::GetIgnoredCounters()); } } -TEST_F(AtenXlaTensorTest, TestIndexFillWithTensor) { +TEST_F(IndexOpsAtenXlaTensorTest, TestIndexFillWithTensor) { torch::Tensor index = - torch::tensor({0, 2}, torch::TensorOptions(torch::kLong)); - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor base = - isFloatingType(scalar_type) - ? torch::rand({3, 4, 5}, torch::TensorOptions(scalar_type)) - : torch::randint(100, {3, 4, 5}, torch::TensorOptions(scalar_type)); - torch::Tensor value = - torch::scalar_tensor(42, torch::TensorOptions(scalar_type)); - int rank = base.dim(); - for (int dim = -rank; dim < rank; ++dim) { - torch::Tensor result = torch::index_fill(base, dim, index, value); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_base = CopyToDevice(base, device); - torch::Tensor xla_index = CopyToDevice(index, device); - torch::Tensor xla_value = CopyToDevice(value, device); - torch::Tensor xla_result = - torch::index_fill(xla_base, dim, xla_index, xla_value); - AllEqual(result, xla_result); - }); + torch::tensor({0, 2}, torch::TensorOptions(GetIndexType())); + torch::ScalarType scalar_type = GetValueType(); + torch::Tensor base = + isFloatingType(scalar_type) + ? torch::rand({3, 4, 5}, torch::TensorOptions(scalar_type)) + : torch::randint(100, {3, 4, 5}, torch::TensorOptions(scalar_type)); + torch::Tensor value = + torch::scalar_tensor(42, torch::TensorOptions(scalar_type)); + int rank = base.dim(); + for (int dim = -rank; dim < rank; ++dim) { + torch::Tensor result = torch::index_fill(base, dim, index, value); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_base = CopyToDevice(base, device); + torch::Tensor xla_index = CopyToDevice(index, device); + torch::Tensor xla_value = CopyToDevice(value, device); + torch::Tensor xla_result = + torch::index_fill(xla_base, dim, xla_index, xla_value); + AllEqual(result, xla_result); + }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::index_fill_", cpp_test::GetIgnoredCounters()); - } + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::index_fill_", cpp_test::GetIgnoredCounters()); } } -TEST_F(AtenXlaTensorTest, TestIndexFillWithTensorInPlace) { +TEST_F(IndexOpsAtenXlaTensorTest, TestIndexFillWithTensorInPlace) { torch::Tensor index = - torch::tensor({0, 2}, torch::TensorOptions(torch::kLong)); - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor value = - torch::scalar_tensor(42, torch::TensorOptions(scalar_type)); - int rank = 3; - for (int dim = -rank; dim < rank; ++dim) { - ForEachDevice([&](const torch::Device& device) { - torch::Tensor base = - isFloatingType(scalar_type) - ? torch::rand({3, 4, 5}, torch::TensorOptions(scalar_type)) - : torch::randint(100, {3, 4, 5}, - torch::TensorOptions(scalar_type)); - torch::Tensor xla_base = CopyToDevice(base.clone(), device); - torch::Tensor result = base.index_fill_(dim, index, value); - torch::Tensor xla_index = CopyToDevice(index, device); - torch::Tensor xla_value = CopyToDevice(value, device); - torch::Tensor xla_result = - xla_base.index_fill_(dim, xla_index, xla_value); - AllEqual(result, xla_result); - AllEqual(base, xla_base); - }); + torch::tensor({0, 2}, torch::TensorOptions(GetIndexType())); + torch::ScalarType scalar_type = GetValueType(); + torch::Tensor value = + torch::scalar_tensor(42, torch::TensorOptions(scalar_type)); + int rank = 3; + for (int dim = -rank; dim < rank; ++dim) { + ForEachDevice([&](const torch::Device& device) { + torch::Tensor base = + isFloatingType(scalar_type) + ? torch::rand({3, 4, 5}, torch::TensorOptions(scalar_type)) + : torch::randint(100, {3, 4, 5}, + torch::TensorOptions(scalar_type)); + torch::Tensor xla_base = CopyToDevice(base.clone(), device); + torch::Tensor result = base.index_fill_(dim, index, value); + torch::Tensor xla_index = CopyToDevice(index, device); + torch::Tensor xla_value = CopyToDevice(value, device); + torch::Tensor xla_result = + xla_base.index_fill_(dim, xla_index, xla_value); + AllEqual(result, xla_result); + AllEqual(base, xla_base); + }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::index_fill_", cpp_test::GetIgnoredCounters()); - } + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::index_fill_", cpp_test::GetIgnoredCounters()); } } -TEST_F(AtenXlaTensorTest, TestIndexFillRank0) { +TEST_F(IndexOpsAtenXlaTensorTest, TestIndexFillRank0) { torch::Tensor index = - torch::scalar_tensor(2, torch::TensorOptions(torch::kLong)); - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor base = - isFloatingType(scalar_type) - ? torch::rand({3, 4, 5}, torch::TensorOptions(scalar_type)) - : torch::randint(100, {3, 4, 5}, torch::TensorOptions(scalar_type)); - torch::Tensor value = - torch::scalar_tensor(42, torch::TensorOptions(scalar_type)); - int rank = base.dim(); - for (int dim = -rank; dim < rank; ++dim) { - torch::Tensor result = torch::index_fill(base, dim, index, value); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_base = CopyToDevice(base, device); - torch::Tensor xla_index = CopyToDevice(index, device); - torch::Tensor xla_value = CopyToDevice(value, device); - torch::Tensor xla_result = - torch::index_fill(xla_base, dim, xla_index, xla_value); - AllEqual(result, xla_result); - }); + torch::scalar_tensor(2, torch::TensorOptions(GetIndexType())); + torch::ScalarType scalar_type = GetValueType(); + torch::Tensor base = + isFloatingType(scalar_type) + ? torch::rand({3, 4, 5}, torch::TensorOptions(scalar_type)) + : torch::randint(100, {3, 4, 5}, torch::TensorOptions(scalar_type)); + torch::Tensor value = + torch::scalar_tensor(42, torch::TensorOptions(scalar_type)); + int rank = base.dim(); + for (int dim = -rank; dim < rank; ++dim) { + torch::Tensor result = torch::index_fill(base, dim, index, value); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_base = CopyToDevice(base, device); + torch::Tensor xla_index = CopyToDevice(index, device); + torch::Tensor xla_value = CopyToDevice(value, device); + torch::Tensor xla_result = + torch::index_fill(xla_base, dim, xla_index, xla_value); + AllEqual(result, xla_result); + }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::index_fill_", cpp_test::GetIgnoredCounters()); - } + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::index_fill_", cpp_test::GetIgnoredCounters()); } } -TEST_F(AtenXlaTensorTest, TestIndexAdd) { +TEST_F(IndexOpsAtenXlaTensorTest, TestIndexAdd) { int index_size = 10; - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor base = + torch::ScalarType scalar_type = GetValueType(); + torch::Tensor base = + isFloatingType(scalar_type) + ? torch::rand({5, 3, 7}, torch::TensorOptions(scalar_type)) + : torch::randint(100, {5, 3, 7}, torch::TensorOptions(scalar_type)); + int rank = base.dim(); + for (int dim = -rank; dim < rank; ++dim) { + torch::Tensor index = torch::randint(0, base.size(dim), {index_size}, + torch::TensorOptions(GetIndexType())); + std::vector value_sizes(base.sizes().begin(), base.sizes().end()); + int canonical_dim = dim < 0 ? dim + rank : dim; + value_sizes[canonical_dim] = index_size; + torch::Tensor value = isFloatingType(scalar_type) - ? torch::rand({5, 3, 7}, torch::TensorOptions(scalar_type)) - : torch::randint(100, {5, 3, 7}, torch::TensorOptions(scalar_type)); - int rank = base.dim(); - for (int dim = -rank; dim < rank; ++dim) { - for (torch::ScalarType index_scalar_type : {torch::kInt, torch::kLong}) { + ? torch::rand(value_sizes, torch::TensorOptions(scalar_type)) + : torch::randint(100, value_sizes, + torch::TensorOptions(scalar_type)); + torch::Tensor result = torch::index_add(base, dim, index, value); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_base = CopyToDevice(base, device); + torch::Tensor xla_index = CopyToDevice(index, device); + torch::Tensor xla_value = CopyToDevice(value, device); + torch::Tensor xla_result = + torch::index_add(xla_base, dim, xla_index, xla_value); + AllClose(result, xla_result); + }); + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::index_add", cpp_test::GetIgnoredCounters()); + } +} + +TEST_F(IndexOpsAtenXlaTensorTest, TestIndexAddInPlace) { + int index_size = 10; + int rank = 3; + std::vector alphas{0.0, 1.0, 2.0}; + torch::ScalarType scalar_type = GetValueType(); + for (int dim = -rank; dim < rank; ++dim) { + for (double alpha : alphas) { + ForEachDevice([&](const torch::Device& device) { + torch::Tensor base = + isFloatingType(scalar_type) + ? torch::rand({5, 3, 7}, torch::TensorOptions(scalar_type)) + : torch::randint(50, {5, 3, 7}, + torch::TensorOptions(scalar_type)); torch::Tensor index = torch::randint(0, base.size(dim), {index_size}, - torch::TensorOptions(index_scalar_type)); + torch::TensorOptions(GetIndexType())); std::vector value_sizes(base.sizes().begin(), base.sizes().end()); int canonical_dim = dim < 0 ? dim + rank : dim; @@ -3215,43 +3231,101 @@ TEST_F(AtenXlaTensorTest, TestIndexAdd) { torch::Tensor value = isFloatingType(scalar_type) ? torch::rand(value_sizes, torch::TensorOptions(scalar_type)) - : torch::randint(100, value_sizes, + : torch::randint(50, value_sizes, torch::TensorOptions(scalar_type)); - torch::Tensor result = torch::index_add(base, dim, index, value); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_base = CopyToDevice(base, device); - torch::Tensor xla_index = CopyToDevice(index, device); - torch::Tensor xla_value = CopyToDevice(value, device); - torch::Tensor xla_result = - torch::index_add(xla_base, dim, xla_index, xla_value); - AllClose(result, xla_result); - }); - } + torch::Tensor xla_base = CopyToDevice(base.clone(), device); + torch::Tensor xla_index = CopyToDevice(index, device); + torch::Tensor xla_value = CopyToDevice(value, device); + torch::Tensor xla_result = + xla_base.index_add_(dim, xla_index, xla_value, alpha); + torch::Tensor result = base.index_add_(dim, index, value, alpha); + AllClose(result, xla_result); + AllClose(base, xla_base); + }); } } ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); ExpectCounterChanged("xla::index_add", cpp_test::GetIgnoredCounters()); } -TEST_F(AtenXlaTensorTest, TestIndexAddInPlace) { +TEST_F(IndexOpsAtenXlaTensorTest, TestIndexAddRank0) { + torch::ScalarType scalar_type = GetValueType(); + torch::Tensor base = + isFloatingType(scalar_type) + ? torch::rand({5, 3, 7}, torch::TensorOptions(scalar_type)) + : torch::randint(100, {5, 3, 7}, torch::TensorOptions(scalar_type)); + int rank = base.dim(); + for (int dim = -rank; dim < rank; ++dim) { + torch::Tensor index = torch::randint(0, base.size(dim), at::IntArrayRef{}, + torch::TensorOptions(GetIndexType())); + std::vector value_sizes(base.sizes().begin(), base.sizes().end()); + int canonical_dim = dim < 0 ? dim + rank : dim; + value_sizes[canonical_dim] = 1; + torch::Tensor value = + isFloatingType(scalar_type) + ? torch::rand(value_sizes, torch::TensorOptions(scalar_type)) + : torch::randint(100, value_sizes, + torch::TensorOptions(scalar_type)); + torch::Tensor result = torch::index_add(base, dim, index, value); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_base = CopyToDevice(base, device); + torch::Tensor xla_index = CopyToDevice(index, device); + torch::Tensor xla_value = CopyToDevice(value, device); + torch::Tensor xla_result = + torch::index_add(xla_base, dim, xla_index, xla_value); + AllEqual(result, xla_result); + }); + + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::index_add", cpp_test::GetIgnoredCounters()); + } +} + +TEST_F(IndexOpsAtenXlaTensorTest, TestIndexCopy) { + torch::ScalarType scalar_type = GetValueType(); + torch::Tensor base = + isFloatingType(scalar_type) + ? torch::rand({5, 3, 7}, torch::TensorOptions(scalar_type)) + : torch::randint(100, {5, 3, 7}, torch::TensorOptions(scalar_type)); + int rank = base.dim(); + for (int dim = -rank; dim < rank; ++dim) { + torch::Tensor index = + torch::randperm(base.size(dim), torch::TensorOptions(GetIndexType())); + torch::Tensor value = + isFloatingType(scalar_type) + ? torch::rand(base.sizes(), torch::TensorOptions(scalar_type)) + : torch::randint(100, base.sizes(), + torch::TensorOptions(scalar_type)); + torch::Tensor result = torch::index_copy(base, dim, index, value); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_base = CopyToDevice(base, device); + torch::Tensor xla_index = CopyToDevice(index, device); + torch::Tensor xla_value = CopyToDevice(value, device); + torch::Tensor xla_result = + torch::index_copy(xla_base, dim, xla_index, xla_value); + AllEqual(result, xla_result); + }); + } + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::index_copy", cpp_test::GetIgnoredCounters()); +} + +TEST_F(IndexOpsAtenXlaTensorTest, TestIndexCopyInPlace) { int index_size = 10; int rank = 3; - std::vector alphas{0.0, 1.0, 2.0}; - - for (torch::ScalarType scalar_type : - {torch::kByte, torch::kFloat, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - for (int dim = -rank; dim < rank; ++dim) { - for (double alpha : alphas) { - ForEachDevice([&](const torch::Device& device) { + torch::ScalarType scalar_type = GetValueType(); + for (int dim = -rank; dim < rank; ++dim) { + ForEachDevice( + {XlaDeviceType::CPU, XlaDeviceType::TPU}, + [&](const torch::Device& device) { torch::Tensor base = isFloatingType(scalar_type) ? torch::rand({5, 3, 7}, torch::TensorOptions(scalar_type)) - : torch::randint(50, {5, 3, 7}, + : torch::randint(100, {5, 3, 7}, torch::TensorOptions(scalar_type)); torch::Tensor index = torch::randint(0, base.size(dim), {index_size}, - torch::TensorOptions(torch::kLong)); + torch::TensorOptions(GetIndexType())); std::vector value_sizes(base.sizes().begin(), base.sizes().end()); int canonical_dim = dim < 0 ? dim + rank : dim; @@ -3259,174 +3333,70 @@ TEST_F(AtenXlaTensorTest, TestIndexAddInPlace) { torch::Tensor value = isFloatingType(scalar_type) ? torch::rand(value_sizes, torch::TensorOptions(scalar_type)) - : torch::randint(50, value_sizes, + : torch::randint(100, value_sizes, torch::TensorOptions(scalar_type)); torch::Tensor xla_base = CopyToDevice(base.clone(), device); + torch::Tensor result = base.index_copy(dim, index, value); torch::Tensor xla_index = CopyToDevice(index, device); torch::Tensor xla_value = CopyToDevice(value, device); torch::Tensor xla_result = - xla_base.index_add_(dim, xla_index, xla_value, alpha); - torch::Tensor result = base.index_add_(dim, index, value, alpha); - AllClose(result, xla_result); - AllClose(base, xla_base); - }); - } - } - } - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::index_add", cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestIndexAddRank0) { - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor base = - isFloatingType(scalar_type) - ? torch::rand({5, 3, 7}, torch::TensorOptions(scalar_type)) - : torch::randint(100, {5, 3, 7}, torch::TensorOptions(scalar_type)); - int rank = base.dim(); - for (int dim = -rank; dim < rank; ++dim) { - torch::Tensor index = torch::randint(0, base.size(dim), at::IntArrayRef{}, - torch::TensorOptions(torch::kLong)); - std::vector value_sizes(base.sizes().begin(), - base.sizes().end()); - int canonical_dim = dim < 0 ? dim + rank : dim; - value_sizes[canonical_dim] = 1; - torch::Tensor value = - isFloatingType(scalar_type) - ? torch::rand(value_sizes, torch::TensorOptions(scalar_type)) - : torch::randint(100, value_sizes, - torch::TensorOptions(scalar_type)); - torch::Tensor result = torch::index_add(base, dim, index, value); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_base = CopyToDevice(base, device); - torch::Tensor xla_index = CopyToDevice(index, device); - torch::Tensor xla_value = CopyToDevice(value, device); - torch::Tensor xla_result = - torch::index_add(xla_base, dim, xla_index, xla_value); - AllEqual(result, xla_result); - }); + xla_base.index_copy(dim, xla_index, xla_value); + AllEqual(result, xla_result); + AllEqual(base, xla_base); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::index_add", cpp_test::GetIgnoredCounters()); - } - } -} - -TEST_F(AtenXlaTensorTest, TestIndexCopy) { - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor base = - isFloatingType(scalar_type) - ? torch::rand({5, 3, 7}, torch::TensorOptions(scalar_type)) - : torch::randint(100, {5, 3, 7}, torch::TensorOptions(scalar_type)); - int rank = base.dim(); - for (int dim = -rank; dim < rank; ++dim) { - torch::Tensor index = - torch::randperm(base.size(dim), torch::TensorOptions(torch::kLong)); - torch::Tensor value = - isFloatingType(scalar_type) - ? torch::rand(base.sizes(), torch::TensorOptions(scalar_type)) - : torch::randint(100, base.sizes(), - torch::TensorOptions(scalar_type)); - torch::Tensor result = torch::index_copy(base, dim, index, value); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_base = CopyToDevice(base, device); - torch::Tensor xla_index = CopyToDevice(index, device); - torch::Tensor xla_value = CopyToDevice(value, device); - torch::Tensor xla_result = - torch::index_copy(xla_base, dim, xla_index, xla_value); - AllEqual(result, xla_result); - }); - } - } - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::index_copy", cpp_test::GetIgnoredCounters()); -} - -TEST_F(AtenXlaTensorTest, TestIndexCopyInPlace) { - int index_size = 10; - int rank = 3; - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - for (int dim = -rank; dim < rank; ++dim) { - ForEachDevice( - {XlaDeviceType::CPU, XlaDeviceType::TPU}, - [&](const torch::Device& device) { - torch::Tensor base = - isFloatingType(scalar_type) - ? torch::rand({5, 3, 7}, torch::TensorOptions(scalar_type)) - : torch::randint(100, {5, 3, 7}, - torch::TensorOptions(scalar_type)); - torch::Tensor index = - torch::randint(0, base.size(dim), {index_size}, - torch::TensorOptions(torch::kLong)); - std::vector value_sizes(base.sizes().begin(), - base.sizes().end()); - int canonical_dim = dim < 0 ? dim + rank : dim; - value_sizes[canonical_dim] = index_size; - torch::Tensor value = - isFloatingType(scalar_type) - ? torch::rand(value_sizes, - torch::TensorOptions(scalar_type)) - : torch::randint(100, value_sizes, - torch::TensorOptions(scalar_type)); - torch::Tensor xla_base = CopyToDevice(base.clone(), device); - torch::Tensor result = base.index_copy(dim, index, value); - torch::Tensor xla_index = CopyToDevice(index, device); - torch::Tensor xla_value = CopyToDevice(value, device); - torch::Tensor xla_result = - xla_base.index_copy(dim, xla_index, xla_value); - AllEqual(result, xla_result); - AllEqual(base, xla_base); - - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::index_copy", - cpp_test::GetIgnoredCounters()); - }); - } + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::index_copy", + cpp_test::GetIgnoredCounters()); + }); } } -TEST_F(AtenXlaTensorTest, TestIndexCopyRank0) { - for (torch::ScalarType scalar_type : - {torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt, - torch::kLong}) { - torch::Tensor base = +TEST_F(IndexOpsAtenXlaTensorTest, TestIndexCopyRank0) { + torch::ScalarType scalar_type = GetValueType(); + torch::Tensor base = + isFloatingType(scalar_type) + ? torch::rand({5, 3, 7}, torch::TensorOptions(scalar_type)) + : torch::randint(100, {5, 3, 7}, torch::TensorOptions(scalar_type)); + int rank = base.dim(); + for (int dim = -rank; dim < rank; ++dim) { + torch::Tensor index = torch::randint(0, base.size(dim), at::IntArrayRef{}, + torch::TensorOptions(GetIndexType())); + std::vector value_sizes(base.sizes().begin(), base.sizes().end()); + int canonical_dim = dim < 0 ? dim + rank : dim; + value_sizes[canonical_dim] = 1; + torch::Tensor value = isFloatingType(scalar_type) - ? torch::rand({5, 3, 7}, torch::TensorOptions(scalar_type)) - : torch::randint(100, {5, 3, 7}, torch::TensorOptions(scalar_type)); - int rank = base.dim(); - for (int dim = -rank; dim < rank; ++dim) { - torch::Tensor index = torch::randint(0, base.size(dim), at::IntArrayRef{}, - torch::TensorOptions(torch::kLong)); - std::vector value_sizes(base.sizes().begin(), - base.sizes().end()); - int canonical_dim = dim < 0 ? dim + rank : dim; - value_sizes[canonical_dim] = 1; - torch::Tensor value = - isFloatingType(scalar_type) - ? torch::rand(value_sizes, torch::TensorOptions(scalar_type)) - : torch::randint(100, value_sizes, - torch::TensorOptions(scalar_type)); - torch::Tensor result = torch::index_copy(base, dim, index, value); - ForEachDevice([&](const torch::Device& device) { - torch::Tensor xla_base = CopyToDevice(base, device); - torch::Tensor xla_index = CopyToDevice(index, device); - torch::Tensor xla_value = CopyToDevice(value, device); - torch::Tensor xla_result = - torch::index_copy(xla_base, dim, xla_index, xla_value); - AllEqual(result, xla_result); - }); + ? torch::rand(value_sizes, torch::TensorOptions(scalar_type)) + : torch::randint(100, value_sizes, + torch::TensorOptions(scalar_type)); + torch::Tensor result = torch::index_copy(base, dim, index, value); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_base = CopyToDevice(base, device); + torch::Tensor xla_index = CopyToDevice(index, device); + torch::Tensor xla_value = CopyToDevice(value, device); + torch::Tensor xla_result = + torch::index_copy(xla_base, dim, xla_index, xla_value); + AllEqual(result, xla_result); + }); - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - ExpectCounterChanged("xla::index_copy", cpp_test::GetIgnoredCounters()); - } - } -} + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::index_copy", cpp_test::GetIgnoredCounters()); + } +} + +INSTANTIATE_TEST_SUITE_P( + IndexOpsAtenXlaTensor, IndexOpsAtenXlaTensorTest, + ::testing::Combine(::testing::Values(torch::kLong, torch::kInt), + ::testing::Values(torch::kFloat, torch::kByte, + torch::kChar, torch::kShort, + torch::kInt, torch::kLong)), + [](const testing::TestParamInfo& + info) { + const auto& params = info.param; + return absl::StrCat("IndexType", torch::toString(std::get<0>(params)), + "_", "ValueType", + torch::toString(std::get<1>(params))); + }); TEST_F(AtenXlaTensorTest, TestRelu) { torch::Tensor input = diff --git a/torch_xla/csrc/ops/index_ops.cpp b/torch_xla/csrc/ops/index_ops.cpp index 1472f05b78a..357435cdd6d 100644 --- a/torch_xla/csrc/ops/index_ops.cpp +++ b/torch_xla/csrc/ops/index_ops.cpp @@ -365,8 +365,10 @@ torch::lazy::Value IndexPutByTensors( torch::lazy::NodePtr IndexFill(const XLATensorPtr& base, int64_t dim, const XLATensorPtr& index, const at::Scalar& value) { - XLA_CHECK_EQ(index->dtype(), at::ScalarType::Long) - << "Fill index is expected to be of scalar type Long, but it is " + XLA_CHECK(index->dtype() == at::ScalarType::Long || + index->dtype() == at::ScalarType::Int) + << "Fill index is expected to be of scalar type Long or scalar type Int, " + "but it is " << index->dtype(); XLA_CHECK_LE(index->shape().get().rank(), 1) << "Fill index is supposed to be a vector"; @@ -379,8 +381,10 @@ torch::lazy::NodePtr IndexFill(const XLATensorPtr& base, int64_t dim, torch::lazy::NodePtr IndexFill(const XLATensorPtr& base, int64_t dim, const XLATensorPtr& index, const XLATensorPtr& value) { - XLA_CHECK_EQ(index->dtype(), at::ScalarType::Long) - << "Fill index is expected to be of scalar type Long, but it is " + XLA_CHECK(index->dtype() == at::ScalarType::Long || + index->dtype() == at::ScalarType::Int) + << "Fill index is expected to be of scalar type Long or scalar type Int, " + "but it is " << index->dtype(); XLA_CHECK_LE(index->shape().get().rank(), 1) << "Fill index is supposed to be a vector"; @@ -407,8 +411,10 @@ torch::lazy::Value IndexAdd(const XLATensorPtr& base, int64_t dim, torch::lazy::Value IndexCopy(const XLATensorPtr& base, int64_t dim, const XLATensorPtr& index, const XLATensorPtr& source) { - XLA_CHECK_EQ(index->dtype(), at::ScalarType::Long) - << "Copy index is expected to be of scalar type Long, but it is " + XLA_CHECK(index->dtype() == at::ScalarType::Long || + index->dtype() == at::ScalarType::Int) + << "Add index is expected to be of scalar type Long or scalar type Int, " + "but it is " << index->dtype(); XLA_CHECK_LE(index->shape().get().rank(), 1) << "Copy index is supposed to be a vector"; diff --git a/torch_xla/csrc/tensor_ops.cpp b/torch_xla/csrc/tensor_ops.cpp index 676ec730bbc..66c45b30fe6 100644 --- a/torch_xla/csrc/tensor_ops.cpp +++ b/torch_xla/csrc/tensor_ops.cpp @@ -196,8 +196,10 @@ XLATensorPtr EmbeddingDenseBackward(const XLATensorPtr& grad_output, const XLATensorPtr& indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) { - XLA_CHECK_EQ(indices->dtype(), at::ScalarType::Long) - << "Embedding indices are expected to be of scalar type Long"; + XLA_CHECK(indices->dtype() == at::ScalarType::Long || + indices->dtype() == at::ScalarType::Int) + << "Fill index is expected to be of scalar type Long or Int, but it is " + << indices->dtype(); auto indices_shape_ref = indices->shape(); // The weight must be of rank 2, which means the rank of grad_output is one // more than the indices. @@ -245,8 +247,10 @@ XLATensorPtr EmbeddingDenseBackward(const XLATensorPtr& grad_output, XLATensorPtr Embedding(const XLATensorPtr& weight, const XLATensorPtr& indices) { XLA_CHECK_EQ(weight->shape().get().rank(), 2); - XLA_CHECK(indices->dtype() == at::kLong || indices->dtype() == at::kInt); - + XLA_CHECK(indices->dtype() == at::ScalarType::Long || + indices->dtype() == at::ScalarType::Int) + << "Fill index is expected to be of scalar type Long or Int, but it is " + << indices->dtype(); if (indices->shape().get().rank() == 1) { return tensor_methods::index_select(weight, 0, indices); } From de2ff402d5efe76f1bb75b90810626b1076d81c2 Mon Sep 17 00:00:00 2001 From: rpsilva-aws Date: Wed, 4 Dec 2024 08:21:35 +0000 Subject: [PATCH 02/15] Neuron S64/U64 downcast --- torch_xla/csrc/dtype.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch_xla/csrc/dtype.cpp b/torch_xla/csrc/dtype.cpp index 923f1152c9d..82acc9ef87d 100644 --- a/torch_xla/csrc/dtype.cpp +++ b/torch_xla/csrc/dtype.cpp @@ -143,9 +143,11 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType( return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::S32 : xla::PrimitiveType::S16; case xla::PrimitiveType::S64: - return xla::PrimitiveType::S64; + return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::S32 + : xla::PrimitiveType::S64; case xla::PrimitiveType::U64: - return xla::PrimitiveType::U64; + return CheckNeuronDevice(hw_type) ? xla::PrimitiveType::U64 + : xla::PrimitiveType::U32; case xla::PrimitiveType::C128: return xla::PrimitiveType::C128; default: From f3dadf322c7d61dd6f18f3fc94bec053ad762f7a Mon Sep 17 00:00:00 2001 From: qihqi Date: Wed, 27 Nov 2024 15:10:08 -0800 Subject: [PATCH 03/15] Use regular `torch.Tensor` for CPU tensors (#8416) --- .../torch_xla2/examples/basic_training.py | 31 ++-- .../torch_xla2/test/llama/test_llama.py | 172 +++++++++--------- experimental/torch_xla2/test/test_context.py | 7 + .../torch_xla2/test/test_core_aten_ops.py | 5 + .../torch_xla2/test/test_functions.py | 1 + .../torch_xla2/test/test_libraries.py | 9 +- experimental/torch_xla2/test/test_ops.py | 5 + .../torch_xla2/test/test_tf_integration.py | 2 +- .../test/test_unbounded_dynamism.py | 10 +- .../torch_xla2/torch_xla2/__init__.py | 10 + experimental/torch_xla2/torch_xla2/config.py | 2 +- experimental/torch_xla2/torch_xla2/export.py | 4 + .../torch_xla2/torch_xla2/ops/jtorch.py | 11 +- experimental/torch_xla2/torch_xla2/tensor.py | 13 +- 14 files changed, 165 insertions(+), 117 deletions(-) diff --git a/experimental/torch_xla2/examples/basic_training.py b/experimental/torch_xla2/examples/basic_training.py index a723f647ca8..fb814fcf978 100644 --- a/experimental/torch_xla2/examples/basic_training.py +++ b/experimental/torch_xla2/examples/basic_training.py @@ -51,7 +51,8 @@ def matplotlib_imshow(img, one_channel=False): plt.imshow(npimg, cmap="Greys") else: plt.imshow(np.transpose(npimg, (1, 2, 0))) - +#torch_xla2.env.config.debug_print_each_op = True +#torch_xla2.env.config.debug_mixed_tensor = True dataiter = iter(training_loader) images, labels = next(dataiter) @@ -80,15 +81,15 @@ def forward(self, x): return x -model = GarmentClassifier() +model = GarmentClassifier().to('jax') loss_fn = torch.nn.CrossEntropyLoss() # NB: Loss functions expect data in batches, so we're creating batches of 4 # Represents the model's confidence in each of the 10 classes for a given input -dummy_outputs = torch.rand(4, 10) +dummy_outputs = torch.rand(4, 10, device='jax') # Represents the correct class among the 10 being tested -dummy_labels = torch.tensor([1, 5, 3, 7]) +dummy_labels = torch.tensor([1, 5, 3, 7], device='jax') print(dummy_outputs) print(dummy_labels) @@ -110,6 +111,8 @@ def train_one_epoch(epoch_index, tb_writer=None): # Every data instance is an input + label pair # NEW: Move model to XLA device inputs, labels = data + inputs = inputs.to('jax') + labels = labels.to('jax') # Zero your gradients for every batch! optimizer.zero_grad() @@ -162,7 +165,9 @@ def train_one_epoch(epoch_index, tb_writer=None): # Disable gradient computation and reduce memory consumption. with torch.no_grad(): for i, vdata in enumerate(validation_loader): - # NOTE: move to XLA device + vinputs, vlabels = vdata + vinputs = vinputs.to('jax') + vlabels = vlabels.to('jax') voutputs = model(vinputs) # call model's forward vloss = loss_fn(voutputs, vlabels) running_vloss += vloss @@ -172,15 +177,11 @@ def train_one_epoch(epoch_index, tb_writer=None): # Log the running loss averaged per batch # for both training and validation - writer.add_scalars('Training vs. Validation Loss', - { 'Training' : avg_loss, 'Validation' : avg_vloss }, - epoch_number + 1) - writer.flush() - - # Track best performance, and save the model's state - if avg_vloss < best_vloss: - best_vloss = avg_vloss - model_path = 'model_{}_{}'.format(timestamp, epoch_number) - torch.save(model.state_dict(), model_path) + + # # Track best performance, and save the model's state + # if avg_vloss < best_vloss: + # best_vloss = avg_vloss + # model_path = 'model_{}_{}'.format(timestamp, epoch_number) + # torch.save(model.state_dict(), model_path) epoch_number += 1 diff --git a/experimental/torch_xla2/test/llama/test_llama.py b/experimental/torch_xla2/test/llama/test_llama.py index 03dfef27d0e..a47e8572186 100644 --- a/experimental/torch_xla2/test/llama/test_llama.py +++ b/experimental/torch_xla2/test/llama/test_llama.py @@ -12,101 +12,101 @@ class LlamaTest(test_base.TestCase): def test_can_run(self): - sample_args = ( - torch.randint(0, 32000, (1, 2048)), - torch.arange(0, 2048), - ) - sample_args = pytree.tree_map(tensor.t2j, sample_args) + with torch_xla2.default_env(): + sample_args = ( + torch.randint(0, 32000, (1, 2048), device='jax:0'), + torch.arange(0, 2048, device='jax:0'), + ) - model_args = llama_model.ModelArgs( - block_size=2048, - vocab_size=32000, - n_layer=2, - n_head=4, - dim=256, - ) - m = llama_model.Transformer(model_args) - m.to(torch.bfloat16) - m.setup_caches(1, 2048) + model_args = llama_model.ModelArgs( + block_size=2048, + vocab_size=32000, + n_layer=2, + n_head=4, + dim=256, + ) + m = llama_model.Transformer(model_args) + m.to(torch.bfloat16) + m.setup_caches(1, 2048) + m = m.to('jax') + + print(m(*sample_args)) - # NOTE: this API does NOT use torch export - weights, jax_func = torch_xla2.extract_jax(m) - print(jax_func(weights, sample_args)) def test_can_run_exportable(self): - model_args = model_exportable.ModelArgs( - vocab_size=32000, - n_layers=2, - n_heads=4, - dim=256, - ) - m = model_exportable.Transformer(model_args) - context_length = 2048 - input_shape_prefill = (1, context_length) - input_shape_decode = (1, 1) + model_args = model_exportable.ModelArgs( + vocab_size=32000, + n_layers=2, + n_heads=4, + dim=256, + ) + m = model_exportable.Transformer(model_args) + context_length = 2048 + input_shape_prefill = (1, context_length) + input_shape_decode = (1, 1) - def make_cache(args, batch_size): - n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads - n_local_heads = args.n_heads - n_local_kv_heads = n_kv_heads - n_rep = n_local_heads // n_local_kv_heads - head_dim = args.dim // args.n_heads - res = [] - for i in range(args.n_layers): - if batch_size is None: - size = ( - args.max_seq_len, - n_local_kv_heads, - head_dim, - ) - else: - size = ( - batch_size, - args.max_seq_len, - n_local_kv_heads, - head_dim, - ) - res.append( - (torch.zeros( - size, - dtype=torch.bfloat16 if args.bf16_enable else torch.float), - torch.zeros( - size, - dtype=torch.bfloat16 if args.bf16_enable else torch.float))) - return res + def make_cache(args, batch_size): + n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads + n_local_heads = args.n_heads + n_local_kv_heads = n_kv_heads + n_rep = n_local_heads // n_local_kv_heads + head_dim = args.dim // args.n_heads + res = [] + for i in range(args.n_layers): + if batch_size is None: + size = ( + args.max_seq_len, + n_local_kv_heads, + head_dim, + ) + else: + size = ( + batch_size, + args.max_seq_len, + n_local_kv_heads, + head_dim, + ) + res.append( + (torch.zeros( + size, + dtype=torch.bfloat16 if args.bf16_enable else torch.float), + torch.zeros( + size, + dtype=torch.bfloat16 if args.bf16_enable else torch.float))) + return res - prefill_caches = make_cache(model_args, 1) + prefill_caches = make_cache(model_args, 1) - sample_input_prefill = ( - torch.randint(0, 1000, input_shape_prefill, - dtype=torch.int32), # len seq length - torch.arange(0, context_length, dtype=torch.int32), # input indexes - torch.arange(0, context_length, dtype=torch.int32), # context indexes - prefill_caches, - True, # prefil - ) - with torch.no_grad(): - m_prefill = torch.export.export(m, sample_input_prefill) + sample_input_prefill = ( + torch.randint(0, 1000, input_shape_prefill, + dtype=torch.int32), # len seq length + torch.arange(0, context_length, dtype=torch.int32), # input indexes + torch.arange(0, context_length, dtype=torch.int32), # context indexes + prefill_caches, + True, # prefil + ) + with torch.no_grad(): + m_prefill = torch.export.export(m, sample_input_prefill) - weights, mj_prefill = torch_xla2.export.exported_program_to_jax(m_prefill) - sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j, - sample_input_prefill) - print('Prefill', mj_prefill(weights, sample_inputs)) + weights, mj_prefill = torch_xla2.export.exported_program_to_jax(m_prefill) + sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j, + sample_input_prefill) + print('Prefill', mj_prefill(weights, sample_inputs)) - sample_input_decode = ( - torch.randint(0, 1000, input_shape_decode, - dtype=torch.int32), # len = 1 - torch.tensor([0], dtype=torch.int32), - torch.roll(torch.arange(context_length, dtype=torch.int32), 1, 0), - prefill_caches, - False # prefill - ) - with torch.no_grad(): - m_decode = torch.export.export(m, sample_input_decode) - weights, mj_decode = torch_xla2.export.exported_program_to_jax(m_decode) - sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j, - sample_input_decode) - print('Decode', mj_decode(weights, sample_inputs)) + sample_input_decode = ( + torch.randint(0, 1000, input_shape_decode, + dtype=torch.int32), # len = 1 + torch.tensor([0], dtype=torch.int32), + torch.roll(torch.arange(context_length, dtype=torch.int32), 1, 0), + prefill_caches, + False # prefill + ) + with torch.no_grad(): + m_decode = torch.export.export(m, sample_input_decode) + weights, mj_decode = torch_xla2.export.exported_program_to_jax(m_decode) + sample_inputs = pytree.tree_map_only(torch.Tensor, tensor.t2j, + sample_input_decode) + print('Decode', mj_decode(weights, sample_inputs)) if __name__ == "__main__": diff --git a/experimental/torch_xla2/test/test_context.py b/experimental/torch_xla2/test/test_context.py index 16bcedf7931..5255f415ee1 100644 --- a/experimental/torch_xla2/test/test_context.py +++ b/experimental/torch_xla2/test/test_context.py @@ -10,6 +10,13 @@ class TestContext(unittest.TestCase): + def setUp(self): + self.old_var = xla_env.config.use_torch_native_for_cpu_tensor + xla_env.config.use_torch_native_for_cpu_tensor = False + + def tearDown(self): + xla_env.config.use_torch_native_for_cpu_tensor = self.old_var + def test_mode_context_manager(self): with xla_env: x = torch.full((3, 3), -1) diff --git a/experimental/torch_xla2/test/test_core_aten_ops.py b/experimental/torch_xla2/test/test_core_aten_ops.py index d207bc22a82..e60086db087 100644 --- a/experimental/torch_xla2/test/test_core_aten_ops.py +++ b/experimental/torch_xla2/test/test_core_aten_ops.py @@ -66,6 +66,11 @@ def setUp(self): super().setUp() torch.manual_seed(0) self.env = tensor.Environment() + self.old_var = self.env.config.use_torch_native_for_cpu_tensor + self.env.config.use_torch_native_for_cpu_tensor = False + + def tearDown(self): + self.env.config.use_torch_native_for_cpu_tensor = self.old_var def test_aten_abs_0(self): args = (torch.randn((10, 10)).to(torch.float32),) diff --git a/experimental/torch_xla2/test/test_functions.py b/experimental/torch_xla2/test/test_functions.py index 9e291dc802a..aab34bd1472 100644 --- a/experimental/torch_xla2/test/test_functions.py +++ b/experimental/torch_xla2/test/test_functions.py @@ -10,6 +10,7 @@ class TestTorchFunctions(parameterized.TestCase): def setUp(self): self.env = torch_xla2.tensor.Environment() + self.env.config.use_torch_native_for_cpu_tensor = False torch_xla2.enable_accuracy_mode() @parameterized.named_parameters( diff --git a/experimental/torch_xla2/test/test_libraries.py b/experimental/torch_xla2/test/test_libraries.py index 019c967db56..492d15467d5 100644 --- a/experimental/torch_xla2/test/test_libraries.py +++ b/experimental/torch_xla2/test/test_libraries.py @@ -1,11 +1,9 @@ import unittest -import jax import torch -import torch.nn as nn import torch.nn.functional as F from torch.library import Library, impl, impl_abstract import torch_xla2 -from torch_xla2 import tensor +import torch_xla2.export from torch_xla2.ops import jaten from torch_xla2.ops import jlibrary @@ -56,6 +54,7 @@ class LibraryTest(unittest.TestCase): def setUp(self): torch.manual_seed(0) + torch_xla2.default_env().config.use_torch_native_for_cpu_tensor = False def test_basic_sdpa_library(self): @@ -78,3 +77,7 @@ def forward(self, q,k,v): ## stablehlo.composite ops. self.assertIn("call @mylib.scaled_dot_product_attention", module_str) self.assertIn("call @mylib.softmax", module_str) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/experimental/torch_xla2/test/test_ops.py b/experimental/torch_xla2/test/test_ops.py index 28d0f29f0c1..d79b35e533a 100644 --- a/experimental/torch_xla2/test/test_ops.py +++ b/experimental/torch_xla2/test/test_ops.py @@ -192,6 +192,11 @@ def setUp(self): torch_xla2.enable_accuracy_mode() #self.env.config.debug_accuracy_for_each_op = True torch.manual_seed(0) + self.old_var = self.env.config.use_torch_native_for_cpu_tensor + self.env.config.use_torch_native_for_cpu_tensor = False + + def tearDown(self): + self.env.config.use_torch_native_for_cpu_tensor = self.old_var # Replaces all values in the input torch_tensor that are less than the given threshold # with the threshold value itself. diff --git a/experimental/torch_xla2/test/test_tf_integration.py b/experimental/torch_xla2/test/test_tf_integration.py index ff9da220c57..4562ba8cb0c 100644 --- a/experimental/torch_xla2/test/test_tf_integration.py +++ b/experimental/torch_xla2/test/test_tf_integration.py @@ -1,6 +1,6 @@ -import jax import os import tempfile +import numpy as np import tensorflow as tf import torch import torch.nn.functional as F diff --git a/experimental/torch_xla2/test/test_unbounded_dynamism.py b/experimental/torch_xla2/test/test_unbounded_dynamism.py index 0cd800cb1a7..06d7b19b149 100644 --- a/experimental/torch_xla2/test/test_unbounded_dynamism.py +++ b/experimental/torch_xla2/test/test_unbounded_dynamism.py @@ -2,10 +2,10 @@ import sys import unittest -import numpy as np import torch from torch.export import Dim, export from torch_xla2.export import exported_program_to_stablehlo as exp2shlo +import torch_xla2 ## This file is copied from `xla/test/stablehlo/test_unbounded_dynamism.py` ## To test that torch_xla2 has identical behavior. @@ -44,6 +44,14 @@ def forward(self, *args): class UnboundedDynamismExportTest(unittest.TestCase): + def setUp(self): + self.env = torch_xla2.default_env() + self.env.config.use_torch_native_for_cpu_tensor = False + torch_xla2.enable_accuracy_mode() + + def tearDown(self): + self.env.config.use_torch_native_for_cpu_tensor = True + def test_add(self): args = (torch.rand((10, 197, 768)), torch.rand((10, 197, 768))) dynamic_shapes = (({0: Dim("dim")}, {0: Dim("dim")}),) diff --git a/experimental/torch_xla2/torch_xla2/__init__.py b/experimental/torch_xla2/torch_xla2/__init__.py index ef6cd058429..f36a0737c00 100644 --- a/experimental/torch_xla2/torch_xla2/__init__.py +++ b/experimental/torch_xla2/torch_xla2/__init__.py @@ -1,3 +1,4 @@ +import contextlib from typing import List, Dict, Any, Optional import dataclasses import jax @@ -73,6 +74,15 @@ def disable_globally(): global env default_env().__exit__(None, None, None) +@contextlib.contextmanager +def disable_temporarily(): + prev = default_env().enabled + if prev: + disable_globally() + yield() + if prev: + enable_globally() + torch.utils.rename_privateuse1_backend('jax') unsupported_dtype = [torch.quint8] diff --git a/experimental/torch_xla2/torch_xla2/config.py b/experimental/torch_xla2/torch_xla2/config.py index 8a0870996a2..351d137df57 100644 --- a/experimental/torch_xla2/torch_xla2/config.py +++ b/experimental/torch_xla2/torch_xla2/config.py @@ -14,5 +14,5 @@ class Configuration: # device treat_cuda_as_jax_device: bool = True - use_torch_native_for_cpu_tensor: bool = False + use_torch_native_for_cpu_tensor: bool = True internal_respect_torch_return_dtypes: bool = False diff --git a/experimental/torch_xla2/torch_xla2/export.py b/experimental/torch_xla2/torch_xla2/export.py index 2744d931de4..3fdbedc8474 100644 --- a/experimental/torch_xla2/torch_xla2/export.py +++ b/experimental/torch_xla2/torch_xla2/export.py @@ -31,6 +31,10 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: print('Running ', target.name(), '--------') op = ops_registry.all_aten_ops.get(target) + if op is None: + op = ops_registry.all_aten_ops.get(target.overloadpacket) + assert op is not None, target + assert op.is_jax_function, op if op is None: op = ops_registry.all_aten_ops.get(target.overloadpacket) if op is None: diff --git a/experimental/torch_xla2/torch_xla2/ops/jtorch.py b/experimental/torch_xla2/torch_xla2/ops/jtorch.py index f21c5b8f671..4d541cd04d1 100644 --- a/experimental/torch_xla2/torch_xla2/ops/jtorch.py +++ b/experimental/torch_xla2/torch_xla2/ops/jtorch.py @@ -103,13 +103,12 @@ def _sdpa_reference(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor: L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale - attn_bias = torch.zeros(L, S, dtype=query.dtype) + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) if is_causal: assert attn_mask is None - temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) - if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) @@ -249,14 +248,14 @@ def _aten_isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False): def _ones(*size: int, dtype=None, **kwargs): if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): size = size[0] - return torch.ops.aten.ones(size, dtype=dtype) + return jaten._ones(size, dtype=dtype) -@register_function(torch.zeros, is_jax_function=False) +@register_function(torch.zeros, is_jax_function=True) def _zeros(*size: int, dtype=None, **kwargs): if len(size) == 1 and isinstance(size[0], collections.abc.Iterable): size = size[0] - return torch.ops.aten.zeros(size, dtype=dtype) + return jaten._zeros(size, dtype=dtype) @register_function(torch.eye) diff --git a/experimental/torch_xla2/torch_xla2/tensor.py b/experimental/torch_xla2/torch_xla2/tensor.py index d14eb9a68e1..35d69eb7326 100644 --- a/experimental/torch_xla2/torch_xla2/tensor.py +++ b/experimental/torch_xla2/torch_xla2/tensor.py @@ -247,6 +247,8 @@ def _name_of_func(func): torch.randn, torch.rand, torch.randint, + torch.full, + torch.as_tensor, } @@ -285,7 +287,8 @@ def get_as_jax_device(self, device: Any): if isinstance(device, torch.device): device = str(device) - if self.config.use_torch_native_for_cpu_tensor and device.startswith('cpu'): + if (self.config.use_torch_native_for_cpu_tensor and + not device.startswith('jax') and not device.startswith('cuda')): return None if not self.config.treat_cuda_as_jax_device and device.startswith('cuda'): @@ -338,7 +341,7 @@ def _to_copy(self, the_tensor, new_dtype, new_device): arr = jax.device_put(arr, jax_device) else: with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): - return torch_tensor.to(new_device) + return the_tensor.to(new_device) return XLATensor2(arr, self) @@ -358,7 +361,6 @@ def _handle_tensor_constructor(self, func, args, kwargs): # let torch handle it with mode_utils.no_dispatch(), torch._C.DisableTorchFunction(): return func(*args, **kwargs) - with jax.default_device(jax_device): op = self._ops.get(func) res = op.func(*args, **kwargs) @@ -396,7 +398,8 @@ def dispatch(self, func, types, args, kwargs): # If the func doesn't act on XLATensor2, and is not a tensor constructor, # We should skip and let torch handle it. - tensor_args = [t for t in args if isinstance(t, torch.Tensor)] + + tensor_args = [t for t in torch_pytree.tree_flatten(args)[0] if isinstance(t, torch.Tensor)] if tensor_args and all(not isinstance(t, XLATensor2) for t in tensor_args): return func(*args, **kwargs) @@ -444,11 +447,13 @@ def dispatch(self, func, types, args, kwargs): def __enter__(self): self._dispatch_mode.__enter__() self._function_mode.__enter__() + self.enabled = True return self def __exit__(self, *exc): self._function_mode.__exit__(*exc) self._dispatch_mode.__exit__(*exc) + self.enabled = False def _move_one_value(self, val): if isinstance(val, torch.nn.Module): From c00bf53118812b3e0fe4272da831aaaec8f4b138 Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Thu, 28 Nov 2024 01:06:16 -0800 Subject: [PATCH 04/15] Support SegmentID when doing data prallel SPMD (#8425) --- test/test_pallas_spmd.py | 130 +++++++++++++++++++++++- torch_xla/experimental/custom_kernel.py | 23 +++-- 2 files changed, 146 insertions(+), 7 deletions(-) diff --git a/test/test_pallas_spmd.py b/test/test_pallas_spmd.py index e88b8b2caff..713def2b8b1 100644 --- a/test/test_pallas_spmd.py +++ b/test/test_pallas_spmd.py @@ -3,6 +3,7 @@ import unittest import torch +import numpy as np from torch import nn as nn import torch_xla @@ -22,8 +23,24 @@ class PallasTest(unittest.TestCase): - def _attention(self, q, k, v): + # This is to create a diagonal mask where only elements within the same segment + # can attend to each other. Since the mask is to mask out the unrelevant parts, + # therefore we use != instead of ==. + def _make_attention_mask_from_segment_ids(self, q_segment_ids, + kv_segment_ids): + return q_segment_ids.view(q_segment_ids.shape[0], 1, + q_segment_ids.shape[1], 1) != kv_segment_ids.view( + kv_segment_ids.shape[0], 1, 1, + kv_segment_ids.shape[1]) + + def _attention(self, q, k, v, *, attn_mask=None, ab=None): attn_weight = q @ k.transpose(-2, -1) + if attn_mask is not None: + # Masked out the unrelevant parts. + attn_weight = attn_weight.masked_fill(attn_mask, + torch.finfo(attn_weight.dtype).min) + if ab is not None: + attn_weight = attn_weight + ab attn_weight = nn.functional.softmax(attn_weight, dim=-1) attn_output = attn_weight @ v return attn_output @@ -98,6 +115,117 @@ def test_flash_attention_backward_spmd_data_parallel(self): self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05)) jax.config.update('jax_default_matmul_precision', "default") + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, + "This test only works on TPUv3+.") + def test_flash_attention_wrapper_segment_ids_spmd(self): + from torch_xla.experimental.custom_kernel import flash_attention + from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention as jax_flash_attention, SegmentIds + xs.set_global_mesh(xs.get_1d_mesh("data")) + + q = torch.randn(3, 2, 128, 4) + k = torch.randn(3, 2, 128, 4) + v = torch.randn(3, 2, 128, 4) + zeros = torch.zeros(3, 32) + segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) + segment_ids_xla = segment_ids.to("xla") + # only shard data dimension + o = flash_attention( + q.to("xla"), + k.to("xla"), + v.to("xla"), + False, + segment_ids_xla, + segment_ids.to("xla"), + partition_spec=("data", None, None, None)) + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(o), + f"{{devices=[{xr.global_runtime_device_count()},1,1,1]0,1,2,3}}") + + jax_q = jnp.array(q.numpy(), dtype=jnp.float32) + jax_k = jnp.array(k.numpy(), dtype=jnp.float32) + jax_v = jnp.array(v.numpy(), dtype=jnp.float32) + jax_segment_ids = jnp.array(segment_ids.numpy(), dtype=jnp.float32) + expected_o = torch.from_numpy( + np.array( + jax_flash_attention( + jax_q, + jax_k, + jax_v, + segment_ids=SegmentIds(jax_segment_ids, jax_segment_ids), + ))) + + self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)) + jax.config.update('jax_default_matmul_precision', "default") + + @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3, + "This test only works on TPUv3+.") + def test_flash_attention_backward_segment_ids_spmd(self): + jax.config.update("jax_default_matmul_precision", "highest") + from torch_xla.experimental.custom_kernel import flash_attention + n_devices = xr.global_runtime_device_count() + xs.set_global_mesh(xs.get_1d_mesh("data")) + + torch.manual_seed(42) + q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + zeros = torch.zeros(4, 32).to("xla") + segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) + q.retain_grad() + k.retain_grad() + v.retain_grad() + + o = flash_attention( + q, + k, + v, + False, + segment_ids, + segment_ids, + partition_spec=("data", None, None, None)) + loss = o.sum() + loss.backward() + q_grad = q.grad + k_grad = k.grad + v_grad = v.grad + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(o), + f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(q_grad), + f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(k_grad), + f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") + self.assertEqual( + torch_xla._XLAC._get_xla_sharding_spec(v_grad), + f"{{devices=[{n_devices},1,1,1]0,1,2,3}}") + torch_xla.sync() + + torch.manual_seed(42) + q = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + k = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + v = torch.randn(4, 2, 128, 8, requires_grad=True).to("xla") + zeros = torch.zeros(4, 32).to("xla") + segment_ids = torch.cat([zeros, zeros + 1, zeros + 2, zeros + 3], dim=1) + q.retain_grad() + k.retain_grad() + v.retain_grad() + + o = self._attention( + q, + k, + v, + attn_mask=self._make_attention_mask_from_segment_ids( + segment_ids, segment_ids)) + loss = o.sum() + loss.backward() + xm.mark_step() + + for i in [(q, q_grad), (k, k_grad), (v, v_grad)]: + self.assertTrue(torch.allclose(i[0].grad.cpu(), i[1].cpu(), atol=1e-05)) + jax.config.update("jax_default_matmul_precision", "default") + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index fdc5992c3b0..5e30ffba26a 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -266,7 +266,15 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, dtypes.append(torch.float32) with torch.no_grad(): - segment_ids, q_segment_ids, kv_segment_ids = FlashAttention.prepare_segment_ids( + if partition_spec is not None and q_segment_ids is not None and kv_segment_ids is not None: + # partition_spec is for q,k,v with shape [batch, num_head, seq_len, head_dim], segment id + # is of shape [batch, seq_len], hence we need to tweak it a bit + segment_id_partition_spec = (partition_spec[0], partition_spec[2]) + q_segment_ids = xs.enable_manual_sharding( + q_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor + kv_segment_ids = xs.enable_manual_sharding( + kv_segment_ids, segment_id_partition_spec, mesh=mesh).global_tensor + segment_ids, q_segment_ids_fa, kv_segment_ids_fa = FlashAttention.prepare_segment_ids( q_segment_ids, kv_segment_ids) ctx.segment_ids = segment_ids @@ -297,7 +305,7 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, if ab is not None: args += [ab] if segment_ids is not None: - args += [q_segment_ids, kv_segment_ids] + args += [q_segment_ids_fa, kv_segment_ids_fa] o = torch_xla._XLAC._xla_tpu_custom_call(args, payload, shapes, dtypes) if not save_residuals: @@ -319,20 +327,23 @@ def forward(ctx, q, k, v, causal, q_segment_ids, kv_segment_ids, sm_scale, ab, m = xs.disable_manual_sharding( m, partition_spec[0:3], ctx.full_shape[0:3], mesh=mesh).global_tensor - ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids, - kv_segment_ids, full_ab) + # q_segment_ids and kv_segment_ids are sharded here if partition_spec is provided + # but it should be OK as the backward will use the same partition_spec + ctx.save_for_backward(full_q, full_k, full_v, o, l, m, q_segment_ids_fa, + kv_segment_ids_fa, full_ab) return o @staticmethod def backward(ctx, grad_output): from jax.experimental.pallas.ops.tpu.flash_attention import _flash_attention_bwd_dq, _flash_attention_bwd_dkv - q, k, v, o, l, m, q_segment_ids, kv_segment_ids, ab = ctx.saved_tensors + q, k, v, o, l, m, q_segment_ids_fa, kv_segment_ids_fa, ab = ctx.saved_tensors causal = ctx.causal sm_scale = ctx.sm_scale partition_spec = ctx.partition_spec mesh = ctx.mesh full_shape = ctx.full_shape + # this segment_ids only reflects the local shape of segment_ids segment_ids = ctx.segment_ids grad_q = grad_k = grad_v = grad_ab = None @@ -398,7 +409,7 @@ def backward(ctx, grad_output): if ab is not None: args += [ab] if segment_ids is not None: - args += [q_segment_ids, kv_segment_ids] + args += [q_segment_ids_fa, kv_segment_ids_fa] args += [expanded_l, expanded_m, grad_output, expanded_grad_i] outputs = [q] From dde47a723b1f6e1e1b2e0425c4c7935f9dcbb062 Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Mon, 2 Dec 2024 01:07:44 -0800 Subject: [PATCH 05/15] Reenable the distributed checkpointing test (#8424) --- test/tpu/run_tests.sh | 3 +-- torch_xla/csrc/tensor.cpp | 12 +++++++++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 8d5e74bde03..6ad06b07740 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -8,8 +8,7 @@ python3 test/pjrt/test_collective_ops_tpu.py python3 test/spmd/test_mp_input_sharding.py python3 test/spmd/test_xla_sharding.py python3 test/spmd/test_xla_virtual_device.py -# TODO(JackCaoG): to reenable -# python3 test/spmd/test_xla_distributed_checkpoint.py +python3 test/spmd/test_xla_distributed_checkpoint.py python3 test/spmd/test_train_spmd_linear_model.py python3 test/spmd/test_xla_spmd_python_api_interaction.py python3 test/spmd/test_xla_auto_sharding.py diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 095a6ce4163..01306c53d38 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -562,7 +562,17 @@ void XLATensor::UpdateFromTensor(at::Tensor tensor, bool sync) { at::Tensor coyped_tensor = torch::lazy::CopyTensor(tensor, dtype()); SetTensorData(coyped_tensor); data()->handle = nullptr; - data()->sharding = nullptr; + // if shape is different, + if (data()->sharding) { + auto coyped_tensor_dims = XlaHelpers::I64List(coyped_tensor.sizes()); + auto sharding_dims = data()->sharding->shape.dimensions(); + if (coyped_tensor_dims != sharding_dims) { + // sharding shape from origional tensor is different from the new cpu + // tensor, we need to clear the sharding here. + ClearShardingSpec(); + } + } + // ClearShardingSpec(); AssignIrValue(torch::lazy::Value()); if (data()->view != nullptr) { torch::lazy::Value ir_value = GetIrValueForTensor(coyped_tensor, device); From eb40b02c2dfbdcd70e9a7926b9c8494c47e5679e Mon Sep 17 00:00:00 2001 From: Emmanuel Ferdman Date: Tue, 3 Dec 2024 23:53:15 +0200 Subject: [PATCH 06/15] Update `train_resnet_benchmark.py` reference (#8428) Signed-off-by: Emmanuel Ferdman --- docs/source/learn/troubleshoot.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/learn/troubleshoot.md b/docs/source/learn/troubleshoot.md index 3014ee1d33c..fdc97f8a0b8 100644 --- a/docs/source/learn/troubleshoot.md +++ b/docs/source/learn/troubleshoot.md @@ -254,7 +254,7 @@ the following resources: Take a look at: -[examples/train_resnet_benchmark.py](https://github.com/pytorch/xla/blob/master/examples/train_resnet_benchmark.py) +[examples/debug/train_resnet_benchmark.py](https://github.com/pytorch/xla/blob/master/examples/debug/train_resnet_benchmark.py) for how to benchmark a PyTorch/XLA model. ## Known Performance Caveats From af0773a5590000463330bdb6d14134fea3999fff Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Tue, 3 Dec 2024 16:05:06 -0800 Subject: [PATCH 07/15] Switch to libtpu package from the libtpu-wheels registry (#8409) --- .github/workflows/_tpu_ci.yml | 2 +- CONTRIBUTING.md | 4 +++- README.md | 8 +++++--- docs/source/contribute/configure-environment.md | 4 +++- infra/ansible/roles/build_plugin/tasks/main.yaml | 2 +- scripts/build_developer.sh | 4 +++- scripts/build_torch_wheels.sh | 4 +++- setup.py | 14 ++++++++++---- 8 files changed, 29 insertions(+), 13 deletions(-) diff --git a/.github/workflows/_tpu_ci.yml b/.github/workflows/_tpu_ci.yml index facb3aa50b2..f04c2c3b099 100644 --- a/.github/workflows/_tpu_ci.yml +++ b/.github/workflows/_tpu_ci.yml @@ -25,7 +25,7 @@ jobs: pip install rich # Jax nightly is needed for pallas tests. pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html - pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html + pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html pip install --upgrade protobuf - name: Run Tests env: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 33c59a57cc9..06aca135e37 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -63,7 +63,9 @@ We recommend you to use our prebuilt Docker image to start your development work cd pytorch/xla python setup.py develop # Optional: if you're using TPU, install libtpu - pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html + pip install torch_xla[tpu] \ + -f https://storage.googleapis.com/libtpu-wheels/index.html \ + -f https://storage.googleapis.com/libtpu-releases/index.html ``` * Test your build diff --git a/README.md b/README.md index 96ffe9e6deb..6c4994637c5 100644 --- a/README.md +++ b/README.md @@ -26,14 +26,14 @@ started: To install PyTorch/XLA stable build in a new TPU VM: ``` -pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html +pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html ``` To install PyTorch/XLA nightly build in a new TPU VM: ``` pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html +pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html ``` ### GPU Plugin @@ -147,7 +147,9 @@ can now install the main build with `pip install torch_xla`. To also install the Cloud TPU plugin corresponding to your installed `torch_xla`, install the optional `tpu` dependencies after installing the main build with ``` -pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html +pip install torch_xla[tpu] \ + -f https://storage.googleapis.com/libtpu-wheels/index.html \ + -f https://storage.googleapis.com/libtpu-releases/index.html ``` GPU and nightly builds are available in our public GCS bucket. diff --git a/docs/source/contribute/configure-environment.md b/docs/source/contribute/configure-environment.md index 972765108b0..4a5bf35d02d 100644 --- a/docs/source/contribute/configure-environment.md +++ b/docs/source/contribute/configure-environment.md @@ -87,7 +87,9 @@ via the Command Palette (`Python: Create Environment`). Install the latest PyTorch and PyTorch/XLA releases: ``` bash -pip install numpy torch torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html +pip install numpy torch torch_xla[tpu] \ + -f https://storage.googleapis.com/libtpu-wheels/index.html \ + -f https://storage.googleapis.com/libtpu-releases/index.html ``` Create a file `test.py`: diff --git a/infra/ansible/roles/build_plugin/tasks/main.yaml b/infra/ansible/roles/build_plugin/tasks/main.yaml index 2e2590b150a..142d29c3718 100644 --- a/infra/ansible/roles/build_plugin/tasks/main.yaml +++ b/infra/ansible/roles/build_plugin/tasks/main.yaml @@ -28,5 +28,5 @@ - name: Install libtpu ansible.builtin.pip: name: torch_xla[tpu] - extra_args: -f https://storage.googleapis.com/libtpu-releases/index.html + extra_args: -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html when: accelerator == "tpuvm" diff --git a/scripts/build_developer.sh b/scripts/build_developer.sh index 2590860424e..680c4a3e8f7 100755 --- a/scripts/build_developer.sh +++ b/scripts/build_developer.sh @@ -23,7 +23,9 @@ python3 setup.py develop # libtpu is needed to talk to the TPUs. If TPUs are not present, # installing this wouldn't hurt either. -pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html +pip install torch_xla[tpu] \ + -f https://storage.googleapis.com/libtpu-wheels/index.html \ + -f https://storage.googleapis.com/libtpu-releases/index.html # Test that the library is installed correctly. python3 -c 'import torch_xla as xla; print(xla.device())' diff --git a/scripts/build_torch_wheels.sh b/scripts/build_torch_wheels.sh index 6258f4b9124..5e3ada94cd2 100755 --- a/scripts/build_torch_wheels.sh +++ b/scripts/build_torch_wheels.sh @@ -280,7 +280,9 @@ function build_and_install_torch_xla() { python setup.py bdist_wheel pip install dist/*.whl if [ "$TPUVM_MODE" == "1" ]; then - pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html + pip install torch_xla[tpu] \ + -f https://storage.googleapis.com/libtpu-wheels/index.html \ + -f https://storage.googleapis.com/libtpu-releases/index.html sudo apt-get install -y google-perftools fi diff --git a/setup.py b/setup.py index fab85548aec..e4db95ef14b 100644 --- a/setup.py +++ b/setup.py @@ -65,8 +65,8 @@ base_dir = os.path.dirname(os.path.abspath(__file__)) _date = '20241122' -_libtpu_version = f'0.1.dev{_date}' -_libtpu_storage_path = f'https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}+nightly-py3-none-any.whl' +_libtpu_version = f'0.0.5.dev{_date}' +_libtpu_storage_path = f'https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu/libtpu-{_libtpu_version}+nightly-py3-none-linux_x86_64.whl' _jax_version = f'0.4.36.dev{_date}' @@ -312,8 +312,14 @@ def run(self): }, extras_require={ # On Cloud TPU VM install with: - # pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html - 'tpu': [f'libtpu-nightly=={_libtpu_version}', 'tpu-info'], + # pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html + 'tpu': [ + f'libtpu=={_libtpu_version}', + 'tpu-info', + # This special version removes `libtpu.so` from any `libtpu-nightly` installations, + # since we have migrated to using the `libtpu.so` from the `libtpu` package. + "libtpu-nightly==0.1.dev20241010+nightly.cleanup" + ], # pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html 'pallas': [f'jaxlib=={_jax_version}', f'jax=={_jax_version}'], }, From ddcb2b9306e56df173e573c02aa0a7b2323c43f9 Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Tue, 3 Dec 2024 18:37:19 -0800 Subject: [PATCH 08/15] [scan] Add a test under SPMD (#8419) --- test/run_tests.sh | 1 + test/scan/test_scan_spmd.py | 51 +++++++++++++++++++++++++++++++++++++ test/tpu/run_tests.sh | 1 + 3 files changed, 53 insertions(+) create mode 100644 test/scan/test_scan_spmd.py diff --git a/test/run_tests.sh b/test/run_tests.sh index 543bc5f8403..a3a8c74cedd 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -209,6 +209,7 @@ function run_xla_op_tests2 { run_test "$CDIR/pjrt/test_dtypes.py" run_test "$CDIR/test_while_loop.py" run_test "$CDIR/scan/test_scan.py" + run_test "$CDIR/scan/test_scan_spmd.py" run_test "$CDIR/scan/test_scan_layers.py" run_test "$CDIR/test_autocast.py" run_test "$CDIR/eager/test_eager.py" diff --git a/test/scan/test_scan_spmd.py b/test/scan/test_scan_spmd.py new file mode 100644 index 00000000000..cde7fb7bb65 --- /dev/null +++ b/test/scan/test_scan_spmd.py @@ -0,0 +1,51 @@ +import sys +import unittest + +import torch +import torch_xla +from torch_xla.experimental.scan import scan +from torch_xla.distributed.spmd import mark_sharding, set_global_mesh, get_1d_mesh +import torch_xla.runtime as xr + + +class ScanSpmdTest(unittest.TestCase): + + def setUp(self): + # Activate SPMD + xr.use_spmd() + + # Set up a simple SPMD mesh for these tests. + self.spmd_mesh = get_1d_mesh(axis_name="model") + set_global_mesh(self.spmd_mesh) + self.device = torch_xla.device() + + @unittest.skipUnless(xr.global_runtime_device_count() >= 4, + "Multiple devices required") + def test_scan_cumsum(self): + """This test uses `scan` to implement `torch.cumsum`.""" + + def fn(carry, x): + new_carry = carry + x + y = new_carry + return new_carry, y + + init = torch.zeros(1024, requires_grad=True, device=self.device) + mark_sharding(init, self.spmd_mesh, ('model',)) + xs = torch.randn([8, 1024], requires_grad=True, device=self.device) + mark_sharding(xs, self.spmd_mesh, (None, 'model')) + final_carry, ys = scan(fn, init, xs) + torch_xla.sync() + + # Check the input and output sharding. Note that we do this after + # `torch_xla.sync()` to ensure the output tensors are materialized and + # have taken on sharding annotations propagated by the compiler. + for tensor in [init, xs, final_carry, ys]: + self.assertIn('ShardingSpec: {devices=[', + torch_xla._XLAC._get_xla_tensor_debug_info(tensor)) + self.assertIn('OpSharding: {devices=[', + torch_xla._XLAC._get_xla_tensor_debug_info(tensor)) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 6ad06b07740..fb5cdd51c8e 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -25,6 +25,7 @@ python3 test/pjrt/test_dtypes.py python3 test/pjrt/test_dynamic_plugin_tpu.py python3 test/test_while_loop.py python3 test/scan/test_scan.py +python3 test/scan/test_scan_spmd.py python3 test/scan/test_scan_layers.py python3 test/test_pallas.py -v python3 test/test_pallas_spmd.py From b14d01f5c7aaa3a5f650e82babd0c2f6e6935168 Mon Sep 17 00:00:00 2001 From: Ivan Kobzarev Date: Wed, 4 Dec 2024 13:20:11 +0100 Subject: [PATCH 09/15] Align rrelu_with_noise schema to reflect noise mutation [branch] (#8363) --- torch_xla/csrc/aten_xla_type.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 6e98726063f..e308263e4a5 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -3019,7 +3019,7 @@ at::Tensor XLANativeFunctions::roll(const at::Tensor& self, } at::Tensor XLANativeFunctions::rrelu_with_noise( - const at::Tensor& self, const at::Tensor& noise, const at::Scalar& lower, + const at::Tensor& self, at::Tensor& noise, const at::Scalar& lower, const at::Scalar& upper, bool training, std::optional generator) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); From 317d382ebd4e05a3a46ee2dae3f42737c10d540d Mon Sep 17 00:00:00 2001 From: Michael Green <59619482+mikegre-google@users.noreply.github.com> Date: Wed, 4 Dec 2024 22:07:12 +0000 Subject: [PATCH 10/15] [Documentation] Added a section pointing readers to the AI-Hypercomputer/tpu-recipies repo for reference model implementations. (#8412) --- CONTRIBUTING.md | 102 ++++++++++++++++++------------ README.md | 5 ++ docs/source/learn/xla-overview.md | 5 ++ 3 files changed, 71 insertions(+), 41 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 06aca135e37..7d5ba68e077 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,66 +1,85 @@ # Contribute To PyTorch/XLA -We appreciate all contributions. If you are planning to contribute a bug fix for an open issue, please comment on the thread and we're happy to provide any guidance. -You are very welcome to pick issues from [good first issue](https://github.com/pytorch/xla/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) and [help wanted](https://github.com/pytorch/xla/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22) labels. +We appreciate all contributions. If you are planning to contribute a bug fix for +an open issue, please comment on the thread and we're happy to provide guidance. +You are welcome to pick issues with [good first issue](https://github.com/pytorch/xla/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) +and [help wanted](https://github.com/pytorch/xla/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22) +labels to get started. -If you plan to contribute new features, utility functions or extensions to the core, please first open an issue and discuss the feature with us. -Sending a PR without discussion might end up resulting in a rejected PR, because we might be taking the core in a different direction than you might be aware of. +If you plan to contribute new features or extensions to this repository, first +open an issue and discuss the feature with us. Sending a PR without discussion +might result in a rejected PR, because we might be taking the repository in a +different direction. ## Building from source -We recommend you to use our prebuilt Docker image to start your development work using one of the two following methods. +We recommend you use our prebuilt Docker image to start your development work +using either VS Code or a local container: ### Visual Studio Code Dev Container -* Create an empty directory (optionally on a remote host via SSH) and open it in VSCode. Then, clone - PyTorch, TorchVision, and PyTorch/XLA: +* Create an empty directory for your workspace on your development host. These + instructions assume you are using a remote host and are connecting to it over + SSH. + +* Clone PyTorch, TorchVision, and PyTorch/XLA into your workspace directory: - ```bash +```bash git clone --recursive --depth=1 https://github.com/pytorch/pytorch.git - # Optional: install TorchVision if you need to run tests that involve vision modules + + # Install TorchVision if you need to run tests that involve vision modules git clone --recursive --depth=1 https://github.com/pytorch/vision.git + + # Clone with HTTPS if you use a GitHub a personal access token git clone https://github.com/pytorch/xla.git pytorch/xla - # Optional: use git@github.com:pytorch/xla.git instead if you prefer to use SSH with key forwarding - ``` -* Link (or copy) VSCode configuration to your workspace directory: + # Or clone with SSH if you prefer: + git clone git@github.com:pytorch/xla.git pytorch/xla +``` + +* Create links to VS Code configuration files in your workspace directory: - ```bash +```bash ln -s pytorch/xla/.devcontainer/ .devcontainer ln -s pytorch/xla/contrib/vscode/ .vscode ln -s pytorch/xla/.style.yapf .style.yapf ln -s pytorch/xla/.clang-format .clang-format - ``` - -* From VSCode's command menu, run `Reopen in Container` from the command palette - (F1 key) to open your workspace in one of our pre-built Docker containers. - Select the correct container config based on your local accelerator (default to - `tpu-contributor` if you are not sure). - - * If you cannot find `Reopen in Container`, make sure the `Dev Containers` - VSCode extension is installed, then open the `pytorch/xla` folder in VSCode. - -* Since you are running as root in this container, teach `git` to recognize the - repositories you just cloned (outside of docker) as safe: +``` - ```bash +* Start VS Code and ensure you have the [`Remote Development` Extension Pack](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.vscode-remote-extensionpack) + installed. It includes the [`Remote - SSH`](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-ssh) and + [`Dev Containers`](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers) + extensions. + +* From VS Code, connect to your remote host and open your workspace directory. + You will be prompted to reopen your workspace in container. Choose the + appropriate container. Use `tpu-contributor` if you are unsure of which to use. + If you are not prompted to reopen in a container, in the VS Code command + pallete, type `Dev Containers: Reopen in Container` to open your workspace in + one of our pre-built Docker containers. Select the correct container based on + your local accelerator. If you are unsure, use `tpu-contributor`. + +* Open a new terminal window in VS Code. Since you are running as root in this + container, mark the repository directories as safe. The commands below assume + your workspace directory is `torch`, update the commands to use your workspace + directory. + +```bash git config --global --add safe.directory /workspaces/torch/pytorch git config --global --add safe.directory /workspaces/torch/pytorch/xla git config --global --add safe.directory /workspaces/torch/vision - ``` - -* Build PyTorch, TorchVision, and PyTorch/XLA: +``` +* In the terminal window, run the following commands to build PyTorch, + TorchVision, and PyTorch/XLA: - ```bash +```bash cd pytorch # pytorch/xla requires pytorch wheel to be presented under pytorch/dist python setup.py bdist_wheel python setup.py install - cd .. - cd vision + cd ../vision python setup.py develop - cd .. - cd pytorch/xla + cd ../pytorch/xla python setup.py develop # Optional: if you're using TPU, install libtpu pip install torch_xla[tpu] \ @@ -68,17 +87,18 @@ We recommend you to use our prebuilt Docker image to start your development work -f https://storage.googleapis.com/libtpu-releases/index.html ``` -* Test your build +* If you are running on a TPU VM, ensure `torch` and `torch_xla` were built and + installed correctly: - ```bash +```bash python -c 'import torch_xla as xla; print(xla.device())' # Output: xla:0 - ``` +``` -**Subsequent builds**: after setting up the source checkouts and building them -for the first time, you may find the need to build everything again after e.g. -`git pull`. You can run `scripts/build_developer.sh` which will build PyTorch, -TorchVision, and PyTorch/XLA according to the above. +**Subsequent builds**: after building the packages from source code for the +first time, you may need to build everything again, for example, after a +`git pull`. You can run `scripts/build_developer.sh` which will rebuild PyTorch, +TorchVision, and PyTorch/XLA. ### Manually build in Docker container diff --git a/README.md b/README.md index 6c4994637c5..eac298af575 100644 --- a/README.md +++ b/README.md @@ -138,6 +138,11 @@ Our comprehensive user guides are available at: VM](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm) * [GPU guide](docs/gpu.md) +## Reference implementations + +The [AI-Hypercomputer/tpu-recipies](https://github.com/AI-Hypercomputer/tpu-recipes) +repo. contains examples for training and serving many LLM and diffusion models. + ## Available docker images and wheels ### Python packages diff --git a/docs/source/learn/xla-overview.md b/docs/source/learn/xla-overview.md index 7fdb6b05237..4eaf5e473ac 100644 --- a/docs/source/learn/xla-overview.md +++ b/docs/source/learn/xla-overview.md @@ -175,6 +175,11 @@ sudo apt-get install libopenblas-dev -y sudo apt-get update && sudo apt-get install libgl1 -y # diffusion specific ``` +## Reference implementations + +The [AI-Hypercomputer/tpu-recipies](https://github.com/AI-Hypercomputer/tpu-recipes) +repo. contains examples for training and serving many LLM and diffusion models. + ## Converting code to PyTorch XLA General guidelines to modify your code: From 61a70f51dd98e640ca0de24da2219a1b99b9900b Mon Sep 17 00:00:00 2001 From: Michael Green <59619482+mikegre-google@users.noreply.github.com> Date: Wed, 4 Dec 2024 22:08:14 +0000 Subject: [PATCH 11/15] Updates Contributing.md (#8343) From 5d575a8035d291e28ba7cebb43d35c0582d77160 Mon Sep 17 00:00:00 2001 From: qihqi Date: Wed, 4 Dec 2024 14:29:30 -0800 Subject: [PATCH 12/15] Update torch version to 2.5.1 (#8452) --- experimental/torch_xla2/dev-requirements.txt | 4 ++-- experimental/torch_xla2/torch_xla2/export.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/experimental/torch_xla2/dev-requirements.txt b/experimental/torch_xla2/dev-requirements.txt index 70dc73047d2..6195bcd3a1b 100644 --- a/experimental/torch_xla2/dev-requirements.txt +++ b/experimental/torch_xla2/dev-requirements.txt @@ -1,4 +1,4 @@ -f https://download.pytorch.org/whl/torch -torch==2.4.0; sys_platform == 'darwin' # macOS -torch==2.4.0+cpu; sys_platform != 'darwin' # Non-macOS (CPU-only), like on TPU +torch==2.5.1; sys_platform == 'darwin' # macOS +torch==2.5.1+cpu; sys_platform != 'darwin' # Non-macOS (CPU-only), like on TPU ruff~=0.3.5 diff --git a/experimental/torch_xla2/torch_xla2/export.py b/experimental/torch_xla2/torch_xla2/export.py index 3fdbedc8474..2d9c1684697 100644 --- a/experimental/torch_xla2/torch_xla2/export.py +++ b/experimental/torch_xla2/torch_xla2/export.py @@ -169,9 +169,10 @@ def _build_symbolic_constraints(symbol_name, torch_constraint): symbol = sympy.Symbol(symbol_name) if torch_constraint.lower != 2: constraints.append(symbol >= torch_constraint.lower) - if not torch_constraint.upper.is_infinite: + from sympy.core.singleton import S + if not torch_constraint.upper.is_infinite and torch_constraint.upper is not S.IntInfinity: constraints.append(symbol <= torch_constraint.upper) - + return tuple(sympy.pretty(c, use_unicode=False) for c in constraints) def _build_symbolic_shape(sym, constraint, free_symbols): From 582801f2f31eac72fc50ca74abf281c9e24ea299 Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Wed, 4 Dec 2024 19:40:34 -0800 Subject: [PATCH 13/15] Fix an edge case issue when storing output in the paged attention. (#8431) --- test/test_tpu_paged_attention_kernel.py | 80 ++++++++++++++++++- .../multi_queries_paged_attention_kernel.py | 6 +- 2 files changed, 81 insertions(+), 5 deletions(-) diff --git a/test/test_tpu_paged_attention_kernel.py b/test/test_tpu_paged_attention_kernel.py index d2d0f4f19a9..746439ba4d0 100644 --- a/test/test_tpu_paged_attention_kernel.py +++ b/test/test_tpu_paged_attention_kernel.py @@ -11,7 +11,7 @@ # Set up paged_attention inputs. def _generate_qkv( - kv_seq_lens, + batch_size, page_size, max_kv_len, query_len, @@ -23,7 +23,6 @@ def _generate_qkv( ): assert max_kv_len % page_size == 0 pages_per_sequence = max_kv_len // page_size - batch_size = len(kv_seq_lens) total_pages = batch_size * pages_per_sequence k1, k2, k3, k4 = jax.random.split(prng_key, 4) k_pages = jax.random.normal( @@ -138,7 +137,7 @@ def test_paged_attention_without_query_padding( assert max_kv_len <= total_num_pages * page_size q, k_pages, v_pages, page_indices = _generate_qkv( - kv_seq_lens, + batch_size, page_size, max_kv_len, query_len, @@ -235,7 +234,7 @@ def test_paged_attention_with_query_padding( total_num_pages = batch_size * pages_per_sequence assert max_kv_len <= total_num_pages * page_size q, k_pages, v_pages, page_indices = _generate_qkv( - kv_seq_lens, + batch_size, page_size, max_kv_len, query_len, @@ -292,6 +291,79 @@ def test_paged_attention_with_query_padding( atol=atol, rtol=rtol)) + def test_paged_attention_store_to_output_correctly(self,): + # Make sure the internal FA store_to_output correctly. + dtype = jnp.float32 + page_size = 16 + num_kv_heads = 8 + q_kv_head_ratio = 4 + head_dim = 256 + num_queries_per_compute_block = 32 + block_kv_size = 256 + + max_kv_len = 512 + query_len = max_kv_len + batch_size = 3 + # Set various edge case testing the internal flash attention can store_to_output correct + kv_seq_lens = jnp.array( + [block_kv_size - 1, block_kv_size + 1, 2 * block_kv_size]) + assert len(kv_seq_lens) == batch_size + effective_q_lens = jax.random.randint( + jax.random.key(0), (batch_size,), 0, kv_seq_lens) + for cur_effec_q_len, cur_kv_seq_len in zip(effective_q_lens, kv_seq_lens): + assert cur_effec_q_len <= cur_kv_seq_len, f'The effective query len {cur_effec_q_len} should be less than or equal to the kv_len {cur_kv_seq_len} in the current sequence.' + + pages_per_sequence = max_kv_len // page_size + total_num_pages = batch_size * pages_per_sequence + assert max_kv_len <= total_num_pages * page_size + q, k_pages, v_pages, page_indices = _generate_qkv( + batch_size, + page_size, + max_kv_len, + query_len, + num_kv_heads, + num_kv_heads * q_kv_head_ratio, + head_dim, + jax.random.key(0), + dtype, + ) + + num_kv_pages_per_compute_block = block_kv_size // page_size + actual_output = paged_attention( + q, + k_pages, + v_pages, + kv_seq_lens, + page_indices, + effective_q_lens, + num_kv_pages_per_compute_block=block_kv_size // page_size, + num_queries_per_compute_block=num_queries_per_compute_block, + ) + actual_output = jax.block_until_ready(actual_output) + + # Run the ref impl. + expected_output = _ref_jax_extended_paged_attention( + q, + k_pages, + v_pages, + kv_seq_lens, + page_indices, + effective_q_lens, + ) + + self.assertEqual(actual_output.shape, expected_output.shape) + + atol = 2e-2 + rtol = 1e-2 + for b in range(batch_size): + effective_q_len = effective_q_lens[b] + self.assertTrue( + jnp.allclose( + expected_output[b, :effective_q_len], + actual_output[b, :effective_q_len], + atol=atol, + rtol=rtol)) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py index 0bb572c49e1..557f8ad5ec3 100644 --- a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py @@ -198,7 +198,10 @@ def start_new_sequence(): o_curr = jax.lax.dot(p.astype(v.dtype), v, preferred_element_type=jnp.float32) acc_scratch_ref[q_head_idx_per_kv] += o_curr * l_broadcast(l_next_inv_safe) - @pl.when(kv_blk_idx == kv_len // kv_seq_len_per_kv_compute_blk) + # The condition comes from the one "@pl.when(kv_blk_idx * compute_blk_size_kv < kv_len)" controlling if we should run the function get_kv_and_run_flash_attention. + # If kv_len=512, kv_seq_len_per_kv_compute_blk=256, then last kv_blk_idx that we need to store_to_output is 1. + # If kv_len=513, kv_seq_len_per_kv_compute_blk=256, then last kv_blk_idx that we need to store_to_output is 2. + @pl.when(kv_blk_idx == pl.cdiv(kv_len, kv_seq_len_per_kv_compute_blk) - 1) def store_to_output(): o_ref[0, q_head_idx_per_kv] = acc_scratch_ref[q_head_idx_per_kv].astype( o_ref.dtype) @@ -385,6 +388,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable MIN_BLOCK_SIZE = 128 +@jax.profiler.annotate_function @functools.partial( jax.jit, static_argnames=[ From 4d4d9f24d795baefac938d4d944cd4fbe523236b Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Thu, 5 Dec 2024 17:02:33 -0300 Subject: [PATCH 14/15] DLPack: fix test using PyTorch API. (#8348) --- test/test_operations.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index cc3a73c4580..892a02ddb0b 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2993,8 +2993,7 @@ def test_dlpack_xla_to_pytorch_cuda(self): @onlyIfPJRTDeviceIsCUDA def test_dlpack_xla_to_pytorch_cuda_protocol_conversion(self): xla_t1 = torch.arange(5).to(xm.xla_device()) - caps_t1 = torch.utils.dlpack.to_dlpack(xla_t1) - cuda_t1 = torch.utils.dlpack.from_dlpack(caps_t1) + cuda_t1 = torch.utils.dlpack.from_dlpack(xla_t1) self.assertEqual(cuda_t1.device.type, 'cuda') self.assertEqual(cuda_t1.device.index, xla_t1.device.index) cuda_t1[0] = cuda_t1[0] + 20 From 8c25d80ee4104ffa7c7191b79b3985fcad955337 Mon Sep 17 00:00:00 2001 From: rpsilva-aws Date: Fri, 6 Dec 2024 00:15:14 +0000 Subject: [PATCH 15/15] Fix parameterized tests --- test/cpp/test_aten_xla_tensor_1.cpp | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/test/cpp/test_aten_xla_tensor_1.cpp b/test/cpp/test_aten_xla_tensor_1.cpp index 2d8b392870e..33b3be50f2d 100644 --- a/test/cpp/test_aten_xla_tensor_1.cpp +++ b/test/cpp/test_aten_xla_tensor_1.cpp @@ -2968,7 +2968,7 @@ class IndexOpsAtenXlaTensorTest torch::ScalarType GetValueType() const { return std::get<1>(GetParam()); } }; -TEST_F(IndexOpsAtenXlaTensorTest, TestIndexSelect) { +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexSelect) { torch::ScalarType scalar_type = GetValueType(); torch::Tensor a = isFloatingType(scalar_type) @@ -2991,7 +2991,7 @@ TEST_F(IndexOpsAtenXlaTensorTest, TestIndexSelect) { ExpectCounterChanged("xla::index_select", cpp_test::GetIgnoredCounters()); } -TEST_F(IndexOpsAtenXlaTensorTest, TestIndexSelectRank0) { +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexSelectRank0) { torch::ScalarType scalar_type = GetValueType(); torch::Tensor a = isFloatingType(scalar_type) @@ -3011,7 +3011,7 @@ TEST_F(IndexOpsAtenXlaTensorTest, TestIndexSelectRank0) { }); } -TEST_F(IndexOpsAtenXlaTensorTest, TestIndexPutImpl) { +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexPutImpl) { torch::Tensor indices = torch::randint(-3, 3, {2, 4, 3}, torch::TensorOptions(GetIndexType())); torch::ScalarType scalar_type = GetValueType(); @@ -3041,7 +3041,7 @@ TEST_F(IndexOpsAtenXlaTensorTest, TestIndexPutImpl) { } } -TEST_F(IndexOpsAtenXlaTensorTest, TestIndexFillWithScalar) { +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexFillWithScalar) { torch::Tensor index = torch::tensor({0, 2}, torch::TensorOptions(GetIndexType())); torch::ScalarType scalar_type = GetValueType(); @@ -3066,7 +3066,7 @@ TEST_F(IndexOpsAtenXlaTensorTest, TestIndexFillWithScalar) { } } -TEST_F(IndexOpsAtenXlaTensorTest, TestIndexFillWithScalarInPlace) { +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexFillWithScalarInPlace) { torch::Tensor index = torch::tensor({0, 2}, torch::TensorOptions(GetIndexType())); torch::ScalarType scalar_type = GetValueType(); @@ -3092,7 +3092,7 @@ TEST_F(IndexOpsAtenXlaTensorTest, TestIndexFillWithScalarInPlace) { } } -TEST_F(IndexOpsAtenXlaTensorTest, TestIndexFillWithTensor) { +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexFillWithTensor) { torch::Tensor index = torch::tensor({0, 2}, torch::TensorOptions(GetIndexType())); torch::ScalarType scalar_type = GetValueType(); @@ -3119,7 +3119,7 @@ TEST_F(IndexOpsAtenXlaTensorTest, TestIndexFillWithTensor) { } } -TEST_F(IndexOpsAtenXlaTensorTest, TestIndexFillWithTensorInPlace) { +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexFillWithTensorInPlace) { torch::Tensor index = torch::tensor({0, 2}, torch::TensorOptions(GetIndexType())); torch::ScalarType scalar_type = GetValueType(); @@ -3148,7 +3148,7 @@ TEST_F(IndexOpsAtenXlaTensorTest, TestIndexFillWithTensorInPlace) { } } -TEST_F(IndexOpsAtenXlaTensorTest, TestIndexFillRank0) { +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexFillRank0) { torch::Tensor index = torch::scalar_tensor(2, torch::TensorOptions(GetIndexType())); torch::ScalarType scalar_type = GetValueType(); @@ -3175,7 +3175,7 @@ TEST_F(IndexOpsAtenXlaTensorTest, TestIndexFillRank0) { } } -TEST_F(IndexOpsAtenXlaTensorTest, TestIndexAdd) { +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexAdd) { int index_size = 10; torch::ScalarType scalar_type = GetValueType(); torch::Tensor base = @@ -3208,7 +3208,7 @@ TEST_F(IndexOpsAtenXlaTensorTest, TestIndexAdd) { } } -TEST_F(IndexOpsAtenXlaTensorTest, TestIndexAddInPlace) { +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexAddInPlace) { int index_size = 10; int rank = 3; std::vector alphas{0.0, 1.0, 2.0}; @@ -3248,7 +3248,7 @@ TEST_F(IndexOpsAtenXlaTensorTest, TestIndexAddInPlace) { ExpectCounterChanged("xla::index_add", cpp_test::GetIgnoredCounters()); } -TEST_F(IndexOpsAtenXlaTensorTest, TestIndexAddRank0) { +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexAddRank0) { torch::ScalarType scalar_type = GetValueType(); torch::Tensor base = isFloatingType(scalar_type) @@ -3281,7 +3281,7 @@ TEST_F(IndexOpsAtenXlaTensorTest, TestIndexAddRank0) { } } -TEST_F(IndexOpsAtenXlaTensorTest, TestIndexCopy) { +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexCopy) { torch::ScalarType scalar_type = GetValueType(); torch::Tensor base = isFloatingType(scalar_type) @@ -3310,7 +3310,7 @@ TEST_F(IndexOpsAtenXlaTensorTest, TestIndexCopy) { ExpectCounterChanged("xla::index_copy", cpp_test::GetIgnoredCounters()); } -TEST_F(IndexOpsAtenXlaTensorTest, TestIndexCopyInPlace) { +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexCopyInPlace) { int index_size = 10; int rank = 3; torch::ScalarType scalar_type = GetValueType(); @@ -3351,7 +3351,7 @@ TEST_F(IndexOpsAtenXlaTensorTest, TestIndexCopyInPlace) { } } -TEST_F(IndexOpsAtenXlaTensorTest, TestIndexCopyRank0) { +TEST_P(IndexOpsAtenXlaTensorTest, TestIndexCopyRank0) { torch::ScalarType scalar_type = GetValueType(); torch::Tensor base = isFloatingType(scalar_type)