From fe604e97a99f3291d8fb8e3b8a5030e38a995c56 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Fri, 2 Feb 2024 15:01:05 -0500 Subject: [PATCH] Force guards when performing boolean operations on XLASymNode (#6433) --- torch_xla/csrc/tensor.cpp | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 6aba73ee9cf..99a3698b66a 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -761,23 +761,28 @@ c10::SymNode XLASymNodeImpl::sym_max(const c10::SymNode& other) { << " has not been implemented."; } -// It is used to compute contiguity fields on tensors like "is non overlapping -// and dense" and it's never fetched. If they are never fetched it is fine for -// them to error only if poked. +// Force guards when performing these logical operations + c10::SymNode XLASymNodeImpl::sym_or(const c10::SymNode& other) { - auto error_node = torch::lazy::MakeNode(); - return c10::make_intrusive(error_node, PyType::BOOL); + auto a = + guard_bool(__FILE__, __LINE__) || other->guard_bool(__FILE__, __LINE__); + auto cnst = torch::lazy::MakeNode(a); + return c10::make_intrusive(cnst, PyType::BOOL); } c10::SymNode XLASymNodeImpl::sym_and(const c10::SymNode& other) { - XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__ - << " has not been implemented."; + auto a = + guard_bool(__FILE__, __LINE__) && other->guard_bool(__FILE__, __LINE__); + auto cnst = torch::lazy::MakeNode(a); + return c10::make_intrusive(cnst, PyType::BOOL); } c10::SymNode XLASymNodeImpl::sym_not() { - XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__ - << " has not been implemented."; + auto a = !guard_bool(__FILE__, __LINE__); + auto cnst = torch::lazy::MakeNode(a); + return c10::make_intrusive(cnst, PyType::BOOL); } + // NB: self is ignored here, only the arguments are used c10::SymNode XLASymNodeImpl::is_contiguous(at::ArrayRef sizes, at::ArrayRef strides) {