Skip to content

Commit

Permalink
add
Browse files Browse the repository at this point in the history
  • Loading branch information
stevhliu committed Feb 27, 2024
1 parent 5e1cda5 commit 9f77a71
Show file tree
Hide file tree
Showing 18 changed files with 958 additions and 17 deletions.
8 changes: 4 additions & 4 deletions bitsandbytes/optim/adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
initial_accumulator_value (`int`, defaults to 0):
The initial momemtum values.
eps (`float`, defaults to 1e-10):
The epsilon value for the optimizer.
The epsilon value prevents division by zero in the optimizer.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
Expand Down Expand Up @@ -104,7 +104,7 @@ def __init__(
initial_accumulator_value (`int`, defaults to 0):
The initial momemtum values.
eps (`float`, defaults to 1e-10):
The epsilon value for the optimizer.
The epsilon value prevents division by zero in the optimizer.
optim_bits (`int`, defaults to 8):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
Expand Down Expand Up @@ -174,7 +174,7 @@ def __init__(
initial_accumulator_value (`int`, defaults to 0):
The initial momemtum values.
eps (`float`, defaults to 1e-10):
The epsilon value for the optimizer.
The epsilon value prevents division by zero in the optimizer.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
Expand All @@ -185,7 +185,7 @@ def __init__(
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
"""
"""
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay:
Expand Down
174 changes: 174 additions & 0 deletions bitsandbytes/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,205 @@
class Adam(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
Base Adam optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)

class Adam8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
8-bit Adam optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)

class Adam32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
32-bit Adam optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)

class PagedAdam(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
Paged Adam optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)

class PagedAdam8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
8-bit paged Adam optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)

class PagedAdam32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
"""
Paged 32-bit Adam optimizer.
Arguments:
params (`torch.tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
betas (`tuple(float, float)`, defaults to (0.9, 0.999)):
The beta values are the decay rates of the first and second-order moment of the optimizer.
eps (`float`, defaults to 1e-8):
The epsilon value prevents division by zero in the optimizer.
weight_decay (`float`, defaults to 0.0):
The weight decay value for the optimizer.
amsgrad (`bool`, defaults to `False`):
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
args (`dict`, defaults to `None`):
A dictionary with additional arguments.
min_8bit_size (`int`, defaults to 4096):
The minimum number of elements of the parameter tensors for 8-bit optimization.
percentile_clipping (`int`, defaults to 100):
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer or not.
"""
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)

class AnalysisAdam(torch.optim.Optimizer):
Expand Down
Loading

0 comments on commit 9f77a71

Please sign in to comment.