From 549ca7f123472b2d68cd63ae5d31aa98c8d219d8 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Wed, 31 Jan 2024 12:23:06 -0800 Subject: [PATCH] Force guards when performing boolean operations on XLASymNode Needed for https://github.com/pytorch/pytorch/pull/118579 Signed-off-by: Edward Z. Yang --- torch_xla/csrc/tensor.cpp | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 6aba73ee9cf..2f857f3b3b0 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -761,22 +761,18 @@ c10::SymNode XLASymNodeImpl::sym_max(const c10::SymNode& other) { << " has not been implemented."; } -// It is used to compute contiguity fields on tensors like "is non overlapping -// and dense" and it's never fetched. If they are never fetched it is fine for -// them to error only if poked. +// Force guards when performing these logical operations + c10::SymNode XLASymNodeImpl::sym_or(const c10::SymNode& other) { - auto error_node = torch::lazy::MakeNode(); - return c10::make_intrusive(error_node, PyType::BOOL); + return bool_() || other.bool_(); } c10::SymNode XLASymNodeImpl::sym_and(const c10::SymNode& other) { - XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__ - << " has not been implemented."; + return bool_() && other.bool_(); } c10::SymNode XLASymNodeImpl::sym_not() { - XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__ - << " has not been implemented."; + return !bool_(); } // NB: self is ignored here, only the arguments are used c10::SymNode XLASymNodeImpl::is_contiguous(at::ArrayRef sizes,