diff --git a/torch_xla/amp/syncfree/adam.py b/torch_xla/amp/syncfree/adam.py index 1edee07238c8..4201933ca590 100644 --- a/torch_xla/amp/syncfree/adam.py +++ b/torch_xla/amp/syncfree/adam.py @@ -1,5 +1,6 @@ import torch from torch import Tensor +import torch_xla.core.xla_model as xm from . import _functional as F @@ -86,9 +87,14 @@ def step(self, closure=None, found_inf: Tensor = None): # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like( p, memory_format=torch.preserve_format) - # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + + if group['amsgrad']: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + else: + state['max_exp_avg_sq'] = torch.empty( + 0, dtype=torch.float, device=xm.xla_device()) exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) diff --git a/torch_xla/amp/syncfree/adamw.py b/torch_xla/amp/syncfree/adamw.py index 60f3745d23a3..83e11d46fad9 100644 --- a/torch_xla/amp/syncfree/adamw.py +++ b/torch_xla/amp/syncfree/adamw.py @@ -1,5 +1,6 @@ import torch from torch import Tensor +import torch_xla.core.xla_model as xm from . import _functional as F @@ -84,9 +85,14 @@ def step(self, closure=None, found_inf: Tensor = None): # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like( p, memory_format=torch.preserve_format) - # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) + + if group['amsgrad']: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + else: + state['max_exp_avg_sq'] = torch.empty( + 0, dtype=torch.float, device=xm.xla_device()) exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) diff --git a/torch_xla/csrc/ops/adam_optimizer_step.cpp b/torch_xla/csrc/ops/adam_optimizer_step.cpp index d5b6d03b9104..dd12df326a0f 100644 --- a/torch_xla/csrc/ops/adam_optimizer_step.cpp +++ b/torch_xla/csrc/ops/adam_optimizer_step.cpp @@ -29,7 +29,7 @@ AdamOptimizerStep::AdamOptimizerStep( {found_inf, step, param, grad, exp_avg, exp_avg_sq, max_exp_avg_sq, beta1, beta2, lr, weight_decay, eps}, NodeOutputShape(step, param), - /*num_outputs=*/5, + /*num_outputs=*/(use_amsgrad ? 5 : 4), torch::lazy::MHash(use_weight_decay, use_amsgrad, use_adamw)), use_weight_decay_(use_weight_decay), use_amsgrad_(use_amsgrad), diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 6f2cfbf6b8e1..e036c0e70778 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -525,7 +525,9 @@ void adam_optimizer_step_(const XLATensorPtr& found_inf, XLATensorPtr& step, param->SetInPlaceIrValue(torch::lazy::Value(node, 1)); exp_avg->SetInPlaceIrValue(torch::lazy::Value(node, 2)); exp_avg_sq->SetInPlaceIrValue(torch::lazy::Value(node, 3)); - max_exp_avg_sq->SetInPlaceIrValue(torch::lazy::Value(node, 4)); + if (amsgrad) { + max_exp_avg_sq->SetInPlaceIrValue(torch::lazy::Value(node, 4)); + } } std::vector user_computation( diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index 8d87ac02e1a1..374e7569ca04 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -1013,10 +1013,11 @@ std::vector BuildAdamOptimizerStep( xla::XlaOp new_exp_avg_sq = xla::Select(found_inf_cond, exp_avg_sq, exp_avg_sq * beta2 + new_grad * new_grad * (one - beta2)); - xla::XlaOp new_max_exp_avg_sq = xla::Select( - found_inf_cond, max_exp_avg_sq, xla::Max(max_exp_avg_sq, new_exp_avg_sq)); + xla::XlaOp new_max_exp_avg_sq; xla::XlaOp denom; if (use_amsgrad) { + new_max_exp_avg_sq = xla::Select(found_inf_cond, max_exp_avg_sq, + xla::Max(max_exp_avg_sq, new_exp_avg_sq)); denom = xla::Sqrt(new_max_exp_avg_sq) / xla::Sqrt(bias_correction2) + eps; } else { denom = xla::Sqrt(new_exp_avg_sq) / xla::Sqrt(bias_correction2) + eps; @@ -1031,7 +1032,9 @@ std::vector BuildAdamOptimizerStep( results.push_back(new_param); results.push_back(new_exp_avg); results.push_back(new_exp_avg_sq); - results.push_back(new_max_exp_avg_sq); + if (use_amsgrad) { + results.push_back(new_max_exp_avg_sq); + } return results; }