diff --git a/torch_xla/torch_xla.py b/torch_xla/torch_xla.py index fab890d0a4d..f5f56ba9cbc 100644 --- a/torch_xla/torch_xla.py +++ b/torch_xla/torch_xla.py @@ -62,6 +62,9 @@ def step(): works with `xla.step` but does not follow best practices will become errors in future releases. See https://github.com/pytorch/xla/issues/6751 for context. """ + # Clear pending operations + xm.mark_step() + try: yield finally: