Skip to content

Commit

Permalink
Fix ADOPT on older PyTorch (tested back to 1.13)
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Nov 8, 2024
1 parent ec857fc commit 6db2710
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions timm/optim/adopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,13 @@ def _get_scalar_dtype(is_fused=None):
)


def _is_compiling():
return torch.compiler.is_compiling() if hasattr(torch, 'compiler') else False


def _get_value(x):
# item is significantly faster than a cpu tensor in eager mode
if not torch.jit.is_scripting() and torch.compiler.is_compiling():
if not torch.jit.is_scripting() and _is_compiling():
return x
else:
return x.item() if isinstance(x, torch.Tensor) else x
Expand Down Expand Up @@ -271,7 +275,7 @@ def _single_tensor_adopt(
step_t = state_steps[i]

# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
if capturable and not _is_compiling():
from torch.optim.optimizer import _get_capturable_supported_devices
capturable_supported_devices = _get_capturable_supported_devices()
assert (
Expand Down Expand Up @@ -340,7 +344,7 @@ def _multi_tensor_adopt(
)

# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
if capturable and not _is_compiling():
from torch.optim.optimizer import _get_capturable_supported_devices
capturable_supported_devices = _get_capturable_supported_devices(
supports_xla=False
Expand Down Expand Up @@ -384,7 +388,7 @@ def _multi_tensor_adopt(
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
if not _is_compiling() and device_state_steps[0].is_cpu:
torch._foreach_add_(
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
Expand Down Expand Up @@ -457,9 +461,7 @@ def adopt(

# this check is slow during compilation, so we skip it
# if it's strictly needed we can add this check back in dynamo
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
if not _is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
Expand Down

0 comments on commit 6db2710

Please sign in to comment.