Skip to content

Commit

Permalink
Reduce unnecessary tensor allocation in Adam and AdamW (#5700)
Browse files Browse the repository at this point in the history
* Reduce unnecessary tensor allocation in adam and adamw

* fix lint
  • Loading branch information
baoleai authored and bhavya01 committed Apr 22, 2024
1 parent 7c3faf9 commit 624fac9
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 11 deletions.
12 changes: 9 additions & 3 deletions torch_xla/amp/syncfree/adam.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from torch import Tensor
import torch_xla.core.xla_model as xm
from . import _functional as F


Expand Down Expand Up @@ -86,9 +87,14 @@ def step(self, closure=None, found_inf: Tensor = None):
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(
p, memory_format=torch.preserve_format)
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(
p, memory_format=torch.preserve_format)

if group['amsgrad']:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(
p, memory_format=torch.preserve_format)
else:
state['max_exp_avg_sq'] = torch.empty(
0, dtype=torch.float, device=xm.xla_device())

exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
Expand Down
12 changes: 9 additions & 3 deletions torch_xla/amp/syncfree/adamw.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from torch import Tensor
import torch_xla.core.xla_model as xm
from . import _functional as F


Expand Down Expand Up @@ -84,9 +85,14 @@ def step(self, closure=None, found_inf: Tensor = None):
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(
p, memory_format=torch.preserve_format)
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(
p, memory_format=torch.preserve_format)

if group['amsgrad']:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(
p, memory_format=torch.preserve_format)
else:
state['max_exp_avg_sq'] = torch.empty(
0, dtype=torch.float, device=xm.xla_device())

exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/adam_optimizer_step.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ AdamOptimizerStep::AdamOptimizerStep(
{found_inf, step, param, grad, exp_avg, exp_avg_sq,
max_exp_avg_sq, beta1, beta2, lr, weight_decay, eps},
NodeOutputShape(step, param),
/*num_outputs=*/5,
/*num_outputs=*/(use_amsgrad ? 5 : 4),
torch::lazy::MHash(use_weight_decay, use_amsgrad, use_adamw)),
use_weight_decay_(use_weight_decay),
use_amsgrad_(use_amsgrad),
Expand Down
4 changes: 3 additions & 1 deletion torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,9 @@ void adam_optimizer_step_(const XLATensorPtr& found_inf, XLATensorPtr& step,
param->SetInPlaceIrValue(torch::lazy::Value(node, 1));
exp_avg->SetInPlaceIrValue(torch::lazy::Value(node, 2));
exp_avg_sq->SetInPlaceIrValue(torch::lazy::Value(node, 3));
max_exp_avg_sq->SetInPlaceIrValue(torch::lazy::Value(node, 4));
if (amsgrad) {
max_exp_avg_sq->SetInPlaceIrValue(torch::lazy::Value(node, 4));
}
}

std::vector<XLATensorPtr> user_computation(
Expand Down
9 changes: 6 additions & 3 deletions torch_xla/csrc/xla_lower_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1013,10 +1013,11 @@ std::vector<xla::XlaOp> BuildAdamOptimizerStep(
xla::XlaOp new_exp_avg_sq =
xla::Select(found_inf_cond, exp_avg_sq,
exp_avg_sq * beta2 + new_grad * new_grad * (one - beta2));
xla::XlaOp new_max_exp_avg_sq = xla::Select(
found_inf_cond, max_exp_avg_sq, xla::Max(max_exp_avg_sq, new_exp_avg_sq));
xla::XlaOp new_max_exp_avg_sq;
xla::XlaOp denom;
if (use_amsgrad) {
new_max_exp_avg_sq = xla::Select(found_inf_cond, max_exp_avg_sq,
xla::Max(max_exp_avg_sq, new_exp_avg_sq));
denom = xla::Sqrt(new_max_exp_avg_sq) / xla::Sqrt(bias_correction2) + eps;
} else {
denom = xla::Sqrt(new_exp_avg_sq) / xla::Sqrt(bias_correction2) + eps;
Expand All @@ -1031,7 +1032,9 @@ std::vector<xla::XlaOp> BuildAdamOptimizerStep(
results.push_back(new_param);
results.push_back(new_exp_avg);
results.push_back(new_exp_avg_sq);
results.push_back(new_max_exp_avg_sq);
if (use_amsgrad) {
results.push_back(new_max_exp_avg_sq);
}
return results;
}

Expand Down

0 comments on commit 624fac9

Please sign in to comment.