Skip to content

Commit

Permalink
Workaround that einsum with FakeTensor input (#5597)
Browse files Browse the repository at this point in the history
* Workaround that einsum with FakeTensor input

* typo
  • Loading branch information
JackCaoG authored Sep 18, 2023
1 parent 2ae3fa7 commit 786722c
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 2 deletions.
16 changes: 16 additions & 0 deletions test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,22 @@ def forward(self, index, copy_tensor, input_tensor, op_name):
xla_index, xla_copy_tensor, xla_input_tensor, op_name=in_place_op)
self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu()))

def test_einsum(self):
# einsum currently does not have meta function to compute the shape hence
# will fallback to XLA with FakeTensor as input to infer the output shape.
def einsum_mm(a, b):
return torch.einsum('ijkl,ijlm->ijkm', a, b)

device = xm.xla_device()
a = torch.randn(4, 4, 4, 4).to(xm.xla_device())
b = torch.randn(4, 4, 4, 4).to(xm.xla_device())
xm.mark_step()

dynamo_einsum_mm = torch.compile(einsum_mm, backend="openxla")
res_xla_dynamo = dynamo_einsum_mm(a, b)
res_cpu = einsum_mm(a.cpu(), b.cpu())
self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu()))

def test_simple_model_with_different_input_shape(self):
met.clear_counters()
device = xm.xla_device()
Expand Down
9 changes: 9 additions & 0 deletions torch_xla/csrc/aten_xla_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ XLATensorPtr TryGetXlaTensor(const at::Tensor& tensor) {
return impl->tensor();
}

std::vector<XLATensorPtr> TryGetXlaTensors(const at::ITensorListRef& tensors) {
std::vector<XLATensorPtr> xla_tensors;
xla_tensors.reserve(tensors.size());
for (const auto& tensor : tensors) {
xla_tensors.push_back(bridge::TryGetXlaTensor(tensor));
}
return xla_tensors;
}

bool IsXlaTensor(const at::Tensor& tensor) {
return GetXlaTensorImpl(tensor) != nullptr;
}
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/aten_xla_bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ namespace bridge {

XLATensorPtr TryGetXlaTensor(const at::Tensor& tensor);

// Same as above, applied to a list of tensors.
std::vector<XLATensorPtr> TryGetXlaTensors(const at::ITensorListRef& tensors);

bool IsXlaTensor(const at::Tensor& tensor);

// Extracts the XLATensorPtr out of our version of at::Tensor. Throws an
Expand Down
11 changes: 9 additions & 2 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1150,12 +1150,19 @@ at::Tensor XLANativeFunctions::einsum(c10::string_view equation,
[](unsigned char x) { return std::isspace(x); }),
cleansed_equation.end());

std::vector<XLATensorPtr> xla_tensors = bridge::GetXlaTensors(tensors);
std::vector<XLATensorPtr> xla_tensors = bridge::TryGetXlaTensors(tensors);
bool all_xla_tensors_are_valid = true;
for (const XLATensorPtr xla_tensor : xla_tensors) {
if (!xla_tensor) {
all_xla_tensors_are_valid = false;
break;
}
}

TORCH_LAZY_FN_COUNTER("xla::");
// Einsum operations with more than 2 operands, like bilinear operations, are
// not currently supported in XLA
if (tensors.size() < 1 || tensors.size() > 2 ||
if (tensors.size() < 1 || tensors.size() > 2 || !all_xla_tensors_are_valid ||
!EinsumUtilities::EquationIsValid(cleansed_equation) ||
TensorsAreOfType(xla_tensors, at::ScalarType::Long)) {
TORCH_LAZY_COUNTER("EinsumFallback", 1);
Expand Down

0 comments on commit 786722c

Please sign in to comment.