diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index 7cf83c84eff..76c32328d5e 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -156,6 +156,22 @@ def forward(self, index, copy_tensor, input_tensor, op_name): xla_index, xla_copy_tensor, xla_input_tensor, op_name=in_place_op) self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu())) + def test_einsum(self): + # einsum currently does not have meta function to compute the shape hence + # will fallback to XLA with FakeTensor as input to infer the output shape. + def einsum_mm(a, b): + return torch.einsum('ijkl,ijlm->ijkm', a, b) + + device = xm.xla_device() + a = torch.randn(4, 4, 4, 4).to(xm.xla_device()) + b = torch.randn(4, 4, 4, 4).to(xm.xla_device()) + xm.mark_step() + + dynamo_einsum_mm = torch.compile(einsum_mm, backend="openxla") + res_xla_dynamo = dynamo_einsum_mm(a, b) + res_cpu = einsum_mm(a.cpu(), b.cpu()) + self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu())) + def test_simple_model_with_different_input_shape(self): met.clear_counters() device = xm.xla_device() diff --git a/torch_xla/csrc/aten_xla_bridge.cpp b/torch_xla/csrc/aten_xla_bridge.cpp index d49302ab667..c768312e23e 100644 --- a/torch_xla/csrc/aten_xla_bridge.cpp +++ b/torch_xla/csrc/aten_xla_bridge.cpp @@ -75,6 +75,15 @@ XLATensorPtr TryGetXlaTensor(const at::Tensor& tensor) { return impl->tensor(); } +std::vector TryGetXlaTensors(const at::ITensorListRef& tensors) { + std::vector xla_tensors; + xla_tensors.reserve(tensors.size()); + for (const auto& tensor : tensors) { + xla_tensors.push_back(bridge::TryGetXlaTensor(tensor)); + } + return xla_tensors; +} + bool IsXlaTensor(const at::Tensor& tensor) { return GetXlaTensorImpl(tensor) != nullptr; } diff --git a/torch_xla/csrc/aten_xla_bridge.h b/torch_xla/csrc/aten_xla_bridge.h index 474a2919aa8..7982509494f 100644 --- a/torch_xla/csrc/aten_xla_bridge.h +++ b/torch_xla/csrc/aten_xla_bridge.h @@ -16,6 +16,9 @@ namespace bridge { XLATensorPtr TryGetXlaTensor(const at::Tensor& tensor); +// Same as above, applied to a list of tensors. +std::vector TryGetXlaTensors(const at::ITensorListRef& tensors); + bool IsXlaTensor(const at::Tensor& tensor); // Extracts the XLATensorPtr out of our version of at::Tensor. Throws an diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index c9495f31577..b62a5fb5cef 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1150,12 +1150,19 @@ at::Tensor XLANativeFunctions::einsum(c10::string_view equation, [](unsigned char x) { return std::isspace(x); }), cleansed_equation.end()); - std::vector xla_tensors = bridge::GetXlaTensors(tensors); + std::vector xla_tensors = bridge::TryGetXlaTensors(tensors); + bool all_xla_tensors_are_valid = true; + for (const XLATensorPtr xla_tensor : xla_tensors) { + if (!xla_tensor) { + all_xla_tensors_are_valid = false; + break; + } + } TORCH_LAZY_FN_COUNTER("xla::"); // Einsum operations with more than 2 operands, like bilinear operations, are // not currently supported in XLA - if (tensors.size() < 1 || tensors.size() > 2 || + if (tensors.size() < 1 || tensors.size() > 2 || !all_xla_tensors_are_valid || !EinsumUtilities::EquationIsValid(cleansed_equation) || TensorsAreOfType(xla_tensors, at::ScalarType::Long)) { TORCH_LAZY_COUNTER("EinsumFallback", 1);