Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AdamE optimizer with decoupled L1 and L2 regularization #1314

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
11 changes: 11 additions & 0 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,6 +1534,7 @@ def optimizer_update_32bit(
lr: float,
state2: Optional[torch.Tensor] = None,
beta2: float = 0.0,
lasso: float = 0.0,
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
unorm_vec: Optional[torch.Tensor] = None,
Expand All @@ -1559,6 +1560,8 @@ def optimizer_update_32bit(
Optimizer beta1.
eps : float
Optimizer epsilon.
lasso : float
Lasso regularization coefficient.
weight_decay : float
Weight decay.
step : int
Expand Down Expand Up @@ -1608,6 +1611,7 @@ def optimizer_update_32bit(
ct.c_float(beta1),
ct.c_float(beta2),
ct.c_float(eps),
ct.c_float(lasso),
ct.c_float(weight_decay),
ct.c_int32(step),
ct.c_float(lr),
Expand Down Expand Up @@ -1635,6 +1639,7 @@ def optimizer_update_8bit(
max2: Optional[torch.Tensor],
new_max1: Tensor,
new_max2: Optional[torch.Tensor],
lasso: float = 0.0,
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
unorm_vec: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -1664,6 +1669,8 @@ def optimizer_update_8bit(
Adam beta2.
eps : float
Adam epsilon.
lasso : float
Lasso regularization coefficient.
weight_decay : float
Weight decay.
step : int
Expand Down Expand Up @@ -1716,6 +1723,7 @@ def optimizer_update_8bit(
get_ptr(max2),
get_ptr(new_max1),
get_ptr(new_max2),
ct.c_float(lasso),
ct.c_float(weight_decay),
ct.c_float(gnorm_scale),
ct.c_int32(g.numel()),
Expand All @@ -1740,6 +1748,7 @@ def optimizer_update_8bit(
get_ptr(max2),
get_ptr(new_max1),
get_ptr(new_max2),
ct.c_float(lasso),
ct.c_float(weight_decay),
ct.c_float(gnorm_scale),
ct.c_int32(g.numel()),
Expand All @@ -1766,6 +1775,7 @@ def optimizer_update_8bit_blockwise(
qmap2: Optional[torch.Tensor],
absmax1: Tensor,
absmax2: Optional[torch.Tensor],
lasso: float = 0.0,
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
skip_zeros=False,
Expand Down Expand Up @@ -1806,6 +1816,7 @@ def optimizer_update_8bit_blockwise(
get_ptr(qmap2),
get_ptr(absmax1),
get_ptr(absmax2),
ct.c_float(lasso),
ct.c_float(weight_decay),
ct.c_float(gnorm_scale),
ct.c_bool(skip_zeros),
Expand Down
8 changes: 8 additions & 0 deletions bitsandbytes/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@

from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
from .adam import Adam, Adam8bit, Adam32bit, PagedAdam, PagedAdam8bit, PagedAdam32bit
from .adame import (
AdamE,
AdamE8bit,
AdamE32bit,
PagedAdamE,
PagedAdamE8bit,
PagedAdamE32bit,
)
from .adamw import (
AdamW,
AdamW8bit,
Expand Down
Loading
Loading