diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 6e1936c258a..48af793d114 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -1281,6 +1281,9 @@ def convert_fn(tensors): tensors, devices=[], wait=True, sync_xla_data=True) if not convert: return tensors + for t in tensors: + if torch._is_functional_tensor(t): + torch._functionalize_sync(t) return torch_xla._XLAC._xla_get_cpu_tensors(tensors) def select_fn(v):