From 5d853180e3566a5bef0be4e6cb5a2dcf37e83fd6 Mon Sep 17 00:00:00 2001 From: Vincenzo Scotti Date: Sat, 10 Aug 2024 07:26:48 +0200 Subject: [PATCH 1/8] Add AdamE optimizer classes --- bitsandbytes/optim/adame.py | 397 ++++++++++++++++++++++++++++++++++++ 1 file changed, 397 insertions(+) create mode 100644 bitsandbytes/optim/adame.py diff --git a/bitsandbytes/optim/adame.py b/bitsandbytes/optim/adame.py new file mode 100644 index 000000000..5ba7717c0 --- /dev/null +++ b/bitsandbytes/optim/adame.py @@ -0,0 +1,397 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from bitsandbytes.optim.optimizer import Optimizer2State + +class AdamE(Optimizer2State): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + lasso=1e-2, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + """ + Base AdamE optimizer. + This is a variant of the Adam optimizer implementing decoupled Lasso and Weight Decay regularisation, + resulting in an Elastic Net regularization. + The two regularizations are independent one of the other and can be switched off at will setting their coefficient to 0. + + (Developer's note: the name "AdamE" is pronounced as the French word "madame" without the initial 'm'; + I have never been good in French classes, but the name sounds nice) + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + lasso (`float`, defaults to 1e-2): + The lasso regularization coefficient value for the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + lasso, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + + +class AdamE8bit(Optimizer2State): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + lasso=1e-2, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + """ + 8-bit AdamE optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + lasso (`float`, defaults to 1e-2): + The lasso regularization coefficient value for the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + lasso, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + + +class AdamE32bit(Optimizer2State): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + lasso=1e-2, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + """ + 32-bit AdamE optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + lasso (`float`, defaults to 1e-2): + The lasso regularization coefficient value for the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + lasso, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + + +class PagedAdamE(Optimizer2State): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + lasso=1e-2, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + """ + Paged AdamE optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + lasso (`float`, defaults to 1e-2): + The lasso regularization coefficient value for the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + lasso, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) + + +class PagedAdamE8bit(Optimizer2State): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + lasso=1e-2, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + """ + Paged 8-bit AdamE optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + lasso (`float`, defaults to 1e-2): + The lasso regularization coefficient value for the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + lasso, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) + + +class PagedAdamE32bit(Optimizer2State): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + lasso=1e-2, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + """ + Paged 32-bit AdamE optimizer. + + Arguments: + params (`torch.tensor`): + The input parameters to optimize. + lr (`float`, defaults to 1e-3): + The learning rate. + betas (`tuple(float, float)`, defaults to (0.9, 0.999)): + The beta values are the decay rates of the first and second-order moment of the optimizer. + eps (`float`, defaults to 1e-8): + The epsilon value prevents division by zero in the optimizer. + lasso (`float`, defaults to 1e-2): + The lasso regularization coefficient value for the optimizer. + weight_decay (`float`, defaults to 1e-2): + The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + args (`object`, defaults to `None`): + An object with additional arguments. + min_8bit_size (`int`, defaults to 4096): + The minimum number of elements of the parameter tensors for 8-bit optimization. + percentile_clipping (`int`, defaults to 100): + Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability. + block_wise (`bool`, defaults to `True`): + Whether to independently quantize each block of tensors to reduce outlier effects and improve stability. + is_paged (`bool`, defaults to `False`): + Whether the optimizer is a paged optimizer or not. + """ + super().__init__( + "adam", + params, + lr, + betas, + eps, + lasso, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=True, + ) From 454a26463bfd161474adfcf54ff28cdf04ecde33 Mon Sep 17 00:00:00 2001 From: Vincenzo Scotti Date: Sat, 10 Aug 2024 07:37:36 +0200 Subject: [PATCH 2/8] Add support for L1 regularization in optimizer interfaces --- bitsandbytes/optim/optimizer.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index e9c857d49..1c5a2da05 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -297,6 +297,7 @@ def get_config(self, gindex, pindex, group): config = {} config["betas"] = group["betas"] config["eps"] = group["eps"] + config["lasso"] = group["lasso"] config["weight_decay"] = group["weight_decay"] config["lr"] = group["lr"] config["optim_bits"] = self.args.optim_bits @@ -345,6 +346,7 @@ def __init__( lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + lasso=0.0, weight_decay=0.0, optim_bits=32, args=None, @@ -369,6 +371,8 @@ def __init__( The beta values for the optimizer. eps (`float`, defaults to 1e-8): The epsilon value for the optimizer. + lasso (`float`, defaults to 0.0): + The lasso regularization coefficient value for the optimizer. weight_decay (`float`, defaults to 0.0): The weight decay value for the optimizer. optim_bits (`int`, defaults to 32): @@ -399,9 +403,11 @@ def __init__( for i in range(len(betas)): if not 0.0 <= betas[i] < 1.0: raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") + if not 0.0 <= lasso: + raise ValueError(f"Invalid lasso regularization coefficient value: {lasso}") if not 0.0 <= weight_decay: raise ValueError(f"Invalid weight_decay value: {weight_decay}") - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + defaults = dict(lr=lr, betas=betas, eps=eps, lasso=lasso, weight_decay=weight_decay) super().__init__(params, defaults, optim_bits, is_paged) if args is None: @@ -508,6 +514,7 @@ def update_step(self, group, p, gindex, pindex): config["lr"], state["state2"], config["betas"][1], + config["lasso"], config["weight_decay"], gnorm_scale, state["unorm_vec"] if config["max_unorm"] > 0.0 else None, @@ -533,6 +540,7 @@ def update_step(self, group, p, gindex, pindex): state["max2"], state["new_max1"], state["new_max2"], + config["lasso"], config["weight_decay"], gnorm_scale=gnorm_scale, unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None, @@ -558,6 +566,7 @@ def update_step(self, group, p, gindex, pindex): state["qmap2"], state["absmax1"], state["absmax2"], + config["lasso"], config["weight_decay"], gnorm_scale=gnorm_scale, skip_zeros=config["skip_zeros"], @@ -572,6 +581,7 @@ def __init__( lr=1e-3, betas=(0.9, 0.0), eps=1e-8, + lasso=0.0, weight_decay=0.0, optim_bits=32, args=None, @@ -596,6 +606,8 @@ def __init__( The beta values for the optimizer. eps (`float`, defaults to 1e-8): The epsilon value for the optimizer. + lasso (`float`, defaults to 0.0): + The lasso regularization coefficient value for the optimizer. weight_decay (`float`, defaults to 0.0): The weight decay value for the optimizer. optim_bits (`int`, defaults to 32): @@ -622,9 +634,11 @@ def __init__( for i in range(len(betas)): if not 0.0 <= betas[i] < 1.0: raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") + if not 0.0 <= lasso: + raise ValueError(f"Invalid lasso regularization coefficient value: {lasso}") if not 0.0 <= weight_decay: raise ValueError(f"Invalid weight_decay value: {weight_decay}") - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + defaults = dict(lr=lr, betas=betas, eps=eps, lasso=lasso, weight_decay=weight_decay) super().__init__(params, defaults, optim_bits, is_paged) if args is None: @@ -723,6 +737,7 @@ def update_step(self, group, p, gindex, pindex): config["lr"], None, config["betas"][1], + config["lasso"], config["weight_decay"], gnorm_scale, state["unorm_vec"] if config["max_unorm"] > 0.0 else None, @@ -748,6 +763,7 @@ def update_step(self, group, p, gindex, pindex): None, state["new_max1"], None, + config["lasso"], config["weight_decay"], gnorm_scale, state["unorm_vec"] if config["max_unorm"] > 0.0 else None, @@ -771,6 +787,7 @@ def update_step(self, group, p, gindex, pindex): None, state["absmax1"], None, + config["lasso"], config["weight_decay"], gnorm_scale=gnorm_scale, skip_zeros=config["skip_zeros"], From cccd9299383d6b186b4bb91e7e39aab0ba77cc8f Mon Sep 17 00:00:00 2001 From: Vincenzo Scotti Date: Sat, 10 Aug 2024 07:38:07 +0200 Subject: [PATCH 3/8] Register AdamE as part of the optim module --- bitsandbytes/optim/__init__.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py index b4c95793a..1bb72f6ae 100644 --- a/bitsandbytes/optim/__init__.py +++ b/bitsandbytes/optim/__init__.py @@ -5,6 +5,14 @@ from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit from .adam import Adam, Adam8bit, Adam32bit, PagedAdam, PagedAdam8bit, PagedAdam32bit +from .adame import ( + AdamE, + AdamE8bit, + AdamE32bit, + PagedAdamE, + PagedAdamE8bit, + PagedAdamE32bit, +) from .adamw import ( AdamW, AdamW8bit, From 82acad82f02c1d98429f51e6f9898aa626ae606c Mon Sep 17 00:00:00 2001 From: Vincenzo Scotti Date: Sat, 10 Aug 2024 07:48:36 +0200 Subject: [PATCH 4/8] Update optimizer step functions to include lasso regularization coefficient --- bitsandbytes/functional.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index cea3179a1..3917ff76a 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1534,6 +1534,7 @@ def optimizer_update_32bit( lr: float, state2: Optional[torch.Tensor] = None, beta2: float = 0.0, + lasso: float = 0.0, weight_decay: float = 0.0, gnorm_scale: float = 1.0, unorm_vec: Optional[torch.Tensor] = None, @@ -1559,6 +1560,8 @@ def optimizer_update_32bit( Optimizer beta1. eps : float Optimizer epsilon. + lasso : float + Lasso regularization coefficient. weight_decay : float Weight decay. step : int @@ -1608,6 +1611,7 @@ def optimizer_update_32bit( ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), + ct.c_float(lasso), ct.c_float(weight_decay), ct.c_int32(step), ct.c_float(lr), @@ -1635,6 +1639,7 @@ def optimizer_update_8bit( max2: Optional[torch.Tensor], new_max1: Tensor, new_max2: Optional[torch.Tensor], + lasso: float = 0.0, weight_decay: float = 0.0, gnorm_scale: float = 1.0, unorm_vec: Optional[torch.Tensor] = None, @@ -1664,6 +1669,8 @@ def optimizer_update_8bit( Adam beta2. eps : float Adam epsilon. + lasso : float + Lasso regularization coefficient. weight_decay : float Weight decay. step : int @@ -1716,6 +1723,7 @@ def optimizer_update_8bit( get_ptr(max2), get_ptr(new_max1), get_ptr(new_max2), + ct.c_float(lasso), ct.c_float(weight_decay), ct.c_float(gnorm_scale), ct.c_int32(g.numel()), @@ -1740,6 +1748,7 @@ def optimizer_update_8bit( get_ptr(max2), get_ptr(new_max1), get_ptr(new_max2), + ct.c_float(lasso), ct.c_float(weight_decay), ct.c_float(gnorm_scale), ct.c_int32(g.numel()), @@ -1766,6 +1775,7 @@ def optimizer_update_8bit_blockwise( qmap2: Optional[torch.Tensor], absmax1: Tensor, absmax2: Optional[torch.Tensor], + lasso: float = 0.0, weight_decay: float = 0.0, gnorm_scale: float = 1.0, skip_zeros=False, @@ -1806,6 +1816,7 @@ def optimizer_update_8bit_blockwise( get_ptr(qmap2), get_ptr(absmax1), get_ptr(absmax2), + ct.c_float(lasso), ct.c_float(weight_decay), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), From 051bdfc0932bd7c1a1b44891af69ab1e50aa9399 Mon Sep 17 00:00:00 2001 From: Vincenzo Scotti Date: Sat, 10 Aug 2024 09:50:01 +0200 Subject: [PATCH 5/8] Update C++ code to include lasso regularization parameter --- csrc/kernels.cu | 130 ++++++++++++++++++++++++++++----------- csrc/kernels.cuh | 18 +++--- csrc/ops.cu | 38 ++++++------ csrc/ops.cuh | 6 +- csrc/pythonInterface.cpp | 24 ++++---- 5 files changed, 136 insertions(+), 80 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index e4d459961..5fa6ca190 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -918,7 +918,7 @@ template __launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) __global__ void kPreconditionOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, - const float beta1, const float beta2, const float eps, const float weight_decay, + const float beta1, const float beta2, const float eps, const float lasso, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n) { @@ -997,7 +997,7 @@ template __launch_bounds__(TH, 1) __global__ void kOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float eps, const float weight_decay, + const float beta1, const float beta2, const float eps, const float lasso, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) { @@ -1065,7 +1065,12 @@ __global__ void kOptimizer32bit2State(T* g, T* p, s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2)))); - if(weight_decay > 0.0f) + + if(lasso > 0.0f && weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)) - ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*((float)(lr*lasso)); + else if(lasso > 0.0f) + p_vals[j] = ((float)p_vals[j]) - ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*((float)(lr*lasso)); + else if(weight_decay > 0.0f) p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); } break; @@ -1085,7 +1090,7 @@ template __launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, float* state1, float *unorm, - const float beta1, const float beta2, const float eps, const float weight_decay, + const float beta1, const float beta2, const float eps, const float lasso, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n) { @@ -1166,7 +1171,7 @@ template __launch_bounds__(TH, 1) __global__ void kOptimizer32bit1State(T *g, T *p, float *state1, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float eps, const float weight_decay, + const float beta1, const float beta2, const float eps, const float lasso, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) { @@ -1216,8 +1221,12 @@ __global__ void kOptimizer32bit1State(T *g, T *p, for(unsigned int j = 0; j < NUM_PER_THREAD; j++) { g_vals[j] = gnorm_scale*((float)g_vals[j]); - if(weight_decay > 0.0f) - g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay); + if(lasso > 0.0f && weight_decay > 0.0f) + p_vals[j] = (float)g_vals[j] + (((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*lasso) + (((float)p_vals[j])*weight_decay); + else if(lasso > 0.0f) + g_vals[j] = (float)g_vals[j] + (((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*lasso); + else if(weight_decay > 0.0f) + g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay); } # pragma unroll 4 @@ -1387,7 +1396,7 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, - float weight_decay, + float lasso, float weight_decay, const float gnorm_scale, const int n) { @@ -1488,7 +1497,11 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) { p_vals[j] = (T)(((float)p_vals[j]) + ((update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(correction2*eps)))))); - if(weight_decay > 0.0f) + if(lasso > 0.0f && weight_decay > 0.0f) + p_vals[j] = update_scale*(((float)p_vals[j])*(1.0f-(lr*weight_decay)) - ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*((float)(lr*lasso))); + else if(lasso > 0.0f) + p_vals[j] = update_scale*(((float)p_vals[j]) - ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*((float)(lr*lasso))); + else if(weight_decay > 0.0f) p_vals[j] = update_scale*((float)p_vals[j])*(1.0f-(lr*weight_decay)); } @@ -1511,7 +1524,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c const float eps, const int step, float* __restrict__ const quantiles1, float* max1, float* new_max1, - const float weight_decay, + const float lasso, const float weight_decay, const float gnorm_scale, const int n) { const int n_full = gridDim.x * NUM_PER_BLOCK; @@ -1601,7 +1614,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* max1, float* new_max1, - float weight_decay, + float lasso, float weight_decay, const float gnorm_scale, const int n) { @@ -1661,7 +1674,27 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, g_val = float(g_vals[j]); g_val *= gnorm_scale; - if(weight_decay > 0.0f) { + if(lasso > 0.0f && weight_decay > 0.0f) + switch(OPTIMIZER) { + case MOMENTUM: + case RMSPROP: + g_val += ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*lasso + ((float)p_vals[j])*weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay) - ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*((float)(lr*lasso)); + break; + } + else if(lasso > 0.0f) + switch(OPTIMIZER) { + case MOMENTUM: + case RMSPROP: + g_val += ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*lasso; + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*((float)(lr*lasso)); + break; + } + else if(weight_decay > 0.0f) { switch(OPTIMIZER) { case MOMENTUM: case RMSPROP: @@ -1770,7 +1803,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2, - float weight_decay, + float lasso, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n) { @@ -1913,8 +1946,12 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) { p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); - if(weight_decay > 0.0f) - p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); + if(lasso > 0.0f && weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)) - ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*((float)(lr*lasso)); + else if(lasso > 0.0f) + p_vals[j] = ((float)p_vals[j]) - ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*((float)(lr*lasso)); + else if(weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); } } @@ -1958,7 +1995,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* absmax1, - float weight_decay, + float lasso, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n) { @@ -2029,18 +2066,37 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char g_val *= gnorm_scale; if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) { - if(weight_decay > 0.0f) { + if(lasso > 0.0f && weight_decay > 0.0f) switch(OPTIMIZER) { - case MOMENTUM: - case ADAGRAD: - case RMSPROP: - g_val += ((float)p_vals[j])*weight_decay; - break; - case LION: - p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); - break; - } + case MOMENTUM: + case RMSPROP: + g_val += ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*lasso + ((float)p_vals[j])*weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay) - ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*((float)(lr*lasso)); + break; + } + else if(lasso > 0.0f) + switch(OPTIMIZER) { + case MOMENTUM: + case RMSPROP: + g_val += ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*lasso; + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*((float)(lr*lasso)); + break; } + else if(weight_decay > 0.0f) { + switch(OPTIMIZER) { + case MOMENTUM: + case RMSPROP: + g_val += ((float)p_vals[j])*weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); + break; + } + } s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; @@ -3849,7 +3905,7 @@ template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *c #define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ template __global__ void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, \ float* state1, float *unorm, \ - const float beta1, const float beta2, const float eps, const float weight_decay, \ + const float beta1, const float beta2, const float eps, const float lasso, const float weight_decay, \ const int step, const float lr, const float gnorm_scale, const int n); \ MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half) @@ -3864,7 +3920,7 @@ MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) #define MAKE_Optimizer32bit1State(oname, gtype) \ template __global__ void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ - const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \ + const float beta1, const float beta2, const float eps, const float lasso, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \ MAKE_Optimizer32bit1State(MOMENTUM, half) MAKE_Optimizer32bit1State(MOMENTUM, float) @@ -3879,7 +3935,7 @@ MAKE_Optimizer32bit1State(ADAGRAD, float) #define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \ template __global__ void kPreconditionOptimizer32bit2State(gtype* g, gtype* p, \ float* state1, float* state2, float *unorm, \ - const float beta1, const float beta2, const float eps, const float weight_decay, \ + const float beta1, const float beta2, const float eps, const float lasso, const float weight_decay, \ const int step, const float lr, const float gnorm_scale, const int n); \ MAKE_PreconditionOptimizer32bit2State(ADAM, float) @@ -3887,11 +3943,11 @@ MAKE_PreconditionOptimizer32bit2State(ADAM, half) MAKE_PreconditionOptimizer32bit2State(ADAM, __nv_bfloat16) template __global__ void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + const float beta1, const float beta2, const float eps, const float lasso, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + const float beta1, const float beta2, const float eps, const float lasso, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + const float beta1, const float beta2, const float eps, const float lasso, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); #define MAKE_PreconditionStatic8bit1State(oname, gtype) \ template __global__ void kPreconditionOptimizerStatic8bit1State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ @@ -3901,7 +3957,7 @@ template __global__ void kPreconditionOptimizerStatic8bit1State(gt const float eps, const int step, \ float* __restrict__ const quantiles1, \ float* max1, float* new_max1, \ - const float weight_decay, \ + const float lasso, const float weight_decay, \ const float gnorm_scale, \ const int n); \ @@ -3920,7 +3976,7 @@ template __global__ void kOptimizerStatic8bit1State(gtype* p, gtyp const float eps, const int step, const float lr, \ float* __restrict__ const quantiles1, \ float* max1, float* new_max1, \ - float weight_decay, \ + float lasso, float weight_decay, \ const float gnorm_scale, \ const int n); \ @@ -3951,7 +4007,7 @@ template __global__ void kOptimizerStatic8bit2State(gtype* p, gtyp const float eps, const int step, const float lr, \ float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ float* max1, float* max2, float* new_max1, float* new_max2, \ - float weight_decay, \ + float lasso, float weight_decay, \ const float gnorm_scale, \ const int n); \ @@ -4048,7 +4104,7 @@ template __global__ void kOptimizerStatic8bit2StateBlockwise template __global__ void kPreconditionOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, - const float beta1, const float beta2, const float eps, const float weight_decay, + const float beta1, const float beta2, const float eps, const float lasso, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n); template __global__ void kOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float eps, const float weight_decay, + const float beta1, const float beta2, const float eps, const float lasso, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, float* state1, float *unorm, - const float beta1, const float beta2, const float eps, const float weight_decay, + const float beta1, const float beta2, const float eps, const float lasso, const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n); template __global__ void kOptimizer32bit1State(T* g, T* p, float* state1, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float eps, const float weight_decay, + const float beta1, const float beta2, const float eps, const float lasso, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); template @@ -51,7 +51,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c const float eps, const int step, float* __restrict__ const quantiles1, float* max1, float* new_max1, - const float weight_decay, + const float lasso, const float weight_decay, const float gnorm_scale, const int n); @@ -63,7 +63,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* max1, float* new_max1, - float weight_decay, const float gnorm_scale, const int n); + float lasso, float weight_decay, const float gnorm_scale, const int n); @@ -86,13 +86,13 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, - float weight_decay, const float gnorm_scale, const int n); + float lasso, float weight_decay, const float gnorm_scale, const int n); template __global__ void kOptimizerStatic8bit2StateBlockwise( T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, const float beta2, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, - float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n); + float* absmax1, float* absmax2, float lasso, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void kOptimizerStatic8bit1StateBlockwise( T* p, T* __restrict__ const g, unsigned char* state1, @@ -100,7 +100,7 @@ template __global__ voi const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* absmax1, - float weight_decay, + float lasso, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n); diff --git a/csrc/ops.cu b/csrc/ops.cu index 3a6ffdda8..ed5a0c2fa 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -101,7 +101,7 @@ template void dequantizeBlockwise(float *code, unsign template void optimizer32bit(T* g, T* p, float* state1, float* state2, float *unorm, float max_unorm, float param_norm, - const float beta1, const float beta2, const float eps, const float weight_decay, + const float beta1, const float beta2, const float eps, const float lasso, const float weight_decay, const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) { int num_blocks = n/4096; @@ -112,10 +112,10 @@ template void optimizer32bit(T* g, T* p, if(max_unorm > 0.0f) { CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); - kPreconditionOptimizer32bit2State<<>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + kPreconditionOptimizer32bit2State<<>>(g, p, state1, state2, unorm, beta1, beta2, eps, lasso, weight_decay, step, lr, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } - kOptimizer32bit2State<<>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + kOptimizer32bit2State<<>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, lasso, weight_decay, step, lr, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; case MOMENTUM: @@ -124,22 +124,22 @@ template void optimizer32bit(T* g, T* p, if(max_unorm > 0.0f) { CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); - kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, beta2, eps, lasso, weight_decay, step, lr, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } - kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, lasso, weight_decay, step, lr, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; case LION: // in lion, the momentum update after the parameter update - kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, lasso, weight_decay, step, lr, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); if(max_unorm > 0.0f) { CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); - kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, beta2, eps, lasso, weight_decay, step, lr, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } break; @@ -153,7 +153,7 @@ template void optimizerStatic8bit(T* p, T* g, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, - float weight_decay, + float lasso, float weight_decay, const float gnorm_scale, int n) { int num_blocks = n/4096; @@ -169,27 +169,27 @@ template void optimizerStatic8bit(T* p, T* g, kPreconditionOptimizerStatic8bit2State<<>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); kOptimizerStatic8bit2State<<>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, - quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); + quantiles1, quantiles2, max1, max2, new_max1, new_max2, lasso, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; case MOMENTUM: case RMSPROP: case ADAGRAD: CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); - kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, lasso, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, - quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + quantiles1, max1, new_max1, lasso, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; case LION: // in lion, the momentum update happens after the parameter update kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, - quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + quantiles1, max1, new_max1, lasso, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); - kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, lasso, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; default: @@ -204,7 +204,7 @@ template void optimizerStatic8bit(T* p, T* g, template void optimizerStatic8bitBlockwise(T* p, T* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, - float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float lasso, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) { int num_blocks = 0; @@ -214,7 +214,7 @@ template void optimizerStatic8bitBlockwise(T* p, T* g num_blocks = n/BLOCKSIZE_2STATE; num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1; kOptimizerStatic8bit2StateBlockwise<<>>(p, g, state1, state2, beta1, beta2, eps, step, lr, - quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); + quantiles1, quantiles2, absmax1, absmax2, lasso, weight_decay, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; case MOMENTUM: @@ -224,7 +224,7 @@ template void optimizerStatic8bitBlockwise(T* p, T* g num_blocks = n/BLOCKSIZE_1STATE; num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; kOptimizerStatic8bit1StateBlockwise<<>>(p, g, state1, beta1, beta2, eps, step, lr, - quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); + quantiles1, absmax1, lasso, weight_decay, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; } @@ -808,7 +808,7 @@ template void dequantizeBlockwise<__nv_bfloat16, NF4>(float *code, unsigned char #define MAKE_optimizer32bit(name, gtype) \ template void optimizer32bit(gtype* g, gtype* p, \ float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \ - const float beta1, const float beta2, const float eps, const float weight_decay, \ + const float beta1, const float beta2, const float eps, const float lasso, const float weight_decay, \ const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); MAKE_optimizer32bit(ADAM, half) @@ -831,7 +831,7 @@ template void optimizerStatic8bit(gtype* p, gtype* g, unsigned char float eps, int step, float lr, \ float* quantiles1, float* quantiles2, \ float* max1, float* max2, float* new_max1, float* new_max2, \ - float weight_decay, \ + float lasso, float weight_decay, \ const float gnorm_scale, int n); \ MAKE_optimizerStatic8bit(ADAM, half) @@ -846,7 +846,7 @@ MAKE_optimizerStatic8bit(LION, float) #define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ template void optimizerStatic8bitBlockwise(gtype* p, gtype* g, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ - float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \ + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float lasso, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \ MAKE_optimizerStatic8bitBlockwise(half, ADAM); MAKE_optimizerStatic8bitBlockwise(float, ADAM); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 8b9a4f449..314732345 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -148,7 +148,7 @@ template void dequantizeBlockwise(float *code, unsign template void optimizer32bit(T* g, T* p, float* state1, float* state2, float *unorm, float max_unorm, float param_norm, - float beta1, float beta2, float eps, float weight_decay, + float beta1, float beta2, float eps, float lasso, float weight_decay, int step, float lr, const float gnorm_scale, bool skip_zeros, int n); template void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2, @@ -157,12 +157,12 @@ template void optimizerStatic8bit(T* p, T* g, unsigne float eps, int step, float lr, float* quantiles1, float* quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, - float weight_decay, + float lasso, float weight_decay, const float gnorm_scale, int n); template void optimizerStatic8bitBlockwise(T* p, T* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, - float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float lasso, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); template void percentileClipping(T * g, float *gnorm_vec, int step, const int n); diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index ea2283504..703499ad4 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -52,9 +52,9 @@ MAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) #define MAKE_FUNC32(fname, oname, gtype, gbits) \ void fname##32bit_grad_##gbits(gtype *g, gtype *p, \ float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ - const float beta1, const float beta2, const float eps, const float weight_decay, \ + const float beta1, const float beta2, const float eps, const float lasso, const float weight_decay, \ const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n) \ -{ optimizer32bit(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ +{ optimizer32bit(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, lasso, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ MAKE_FUNC32(momentum, MOMENTUM, float, 32) MAKE_FUNC32(momentum, MOMENTUM, half, 16) @@ -76,10 +76,10 @@ void fname##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, float eps, int step, float lr, \ float* quantiles1, float* quantiles2, \ float* max1, float* max2, float* new_max1, float* new_max2, \ - float weight_decay, float gnorm_scale, int n) \ + float lasso, float weight_decay, float gnorm_scale, int n) \ { \ optimizerStatic8bit(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ - quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ + quantiles1, quantiles2, max1, max2, new_max1, new_max2, lasso, weight_decay, gnorm_scale, n); \ } \ MAKE_FUNC8(adam, ADAM, float, 32) @@ -94,8 +94,8 @@ MAKE_FUNC8(lion, LION, half, 16) #define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \ void fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ - float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\ -{ optimizerStatic8bitBlockwise(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\ + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float lasso, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\ +{ optimizerStatic8bitBlockwise(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, lasso, weight_decay, gnorm_scale, skip_zeros, n); }\ MAKE_BLOCKWISE8(adam, ADAM, half, fp16) MAKE_BLOCKWISE8(adam, ADAM, float, fp32) @@ -224,9 +224,9 @@ extern "C" #define MAKE_CFUNC32(name, gtype, gbits) \ void c##name##32bit_grad_##gbits(gtype *g, gtype *p, \ float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ - const float beta1, const float beta2, const float eps, const float weight_decay, \ + const float beta1, const float beta2, const float eps, const float lasso, const float weight_decay, \ const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \ - { name##32bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ + { name##32bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, lasso, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ MAKE_CFUNC32(adam, float, fp32) MAKE_CFUNC32(adam, half, fp16) @@ -248,10 +248,10 @@ extern "C" float eps, int step, float lr, \ float* quantiles1, float* quantiles2, \ float* max1, float* max2, float* new_max1, float* new_max2, \ - float weight_decay, float gnorm_scale, int n) \ + float lasso, float weight_decay, float gnorm_scale, int n) \ { \ name##_static_8bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ - quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ + quantiles1, quantiles2, max1, max2, new_max1, new_max2, lasso, weight_decay, gnorm_scale, n); \ } \ MAKE_CFUNC8(adam, float, 32) @@ -266,8 +266,8 @@ extern "C" #define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \ void c##fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ - float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \ - { fname##_8bit_blockwise_grad_##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \ + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float lasso, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \ + { fname##_8bit_blockwise_grad_##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, lasso, weight_decay, gnorm_scale, skip_zeros, n); } \ MAKE_CBLOCKWISE8(adam, ADAM, half, fp16) MAKE_CBLOCKWISE8(adam, ADAM, float, fp32) From 4c5da80e362ff6af19fc7316ae39ec98efc59ca3 Mon Sep 17 00:00:00 2001 From: Vincenzo Scotti Date: Sat, 10 Aug 2024 14:25:34 +0200 Subject: [PATCH 6/8] Add temporary fix for operands precision issue in sign function computation --- csrc/kernels.cu | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 5fa6ca190..8a9c68209 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1067,9 +1067,9 @@ __global__ void kOptimizer32bit2State(T* g, T* p, if(lasso > 0.0f && weight_decay > 0.0f) - p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)) - ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*((float)(lr*lasso)); + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)) - ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*((float)(lr*lasso)); else if(lasso > 0.0f) - p_vals[j] = ((float)p_vals[j]) - ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*((float)(lr*lasso)); + p_vals[j] = ((float)p_vals[j]) - ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*((float)(lr*lasso)); else if(weight_decay > 0.0f) p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); } @@ -1222,9 +1222,9 @@ __global__ void kOptimizer32bit1State(T *g, T *p, { g_vals[j] = gnorm_scale*((float)g_vals[j]); if(lasso > 0.0f && weight_decay > 0.0f) - p_vals[j] = (float)g_vals[j] + (((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*lasso) + (((float)p_vals[j])*weight_decay); + p_vals[j] = (float)g_vals[j] + (((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*lasso) + (((float)p_vals[j])*weight_decay); else if(lasso > 0.0f) - g_vals[j] = (float)g_vals[j] + (((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*lasso); + g_vals[j] = (float)g_vals[j] + (((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*lasso); else if(weight_decay > 0.0f) g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay); } @@ -1498,9 +1498,9 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha { p_vals[j] = (T)(((float)p_vals[j]) + ((update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(correction2*eps)))))); if(lasso > 0.0f && weight_decay > 0.0f) - p_vals[j] = update_scale*(((float)p_vals[j])*(1.0f-(lr*weight_decay)) - ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*((float)(lr*lasso))); + p_vals[j] = update_scale*(((float)p_vals[j])*(1.0f-(lr*weight_decay)) - ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*((float)(lr*lasso))); else if(lasso > 0.0f) - p_vals[j] = update_scale*(((float)p_vals[j]) - ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*((float)(lr*lasso))); + p_vals[j] = update_scale*(((float)p_vals[j]) - ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*((float)(lr*lasso))); else if(weight_decay > 0.0f) p_vals[j] = update_scale*((float)p_vals[j])*(1.0f-(lr*weight_decay)); } @@ -1678,20 +1678,20 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, switch(OPTIMIZER) { case MOMENTUM: case RMSPROP: - g_val += ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*lasso + ((float)p_vals[j])*weight_decay; + g_val += ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*lasso + ((float)p_vals[j])*weight_decay; break; case LION: - p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay) - ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*((float)(lr*lasso)); + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay) - ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*((float)(lr*lasso)); break; } else if(lasso > 0.0f) switch(OPTIMIZER) { case MOMENTUM: case RMSPROP: - g_val += ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*lasso; + g_val += ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*lasso; break; case LION: - p_vals[j] = ((float)p_vals[j]) - ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*((float)(lr*lasso)); + p_vals[j] = ((float)p_vals[j]) - ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*((float)(lr*lasso)); break; } else if(weight_decay > 0.0f) { @@ -1947,9 +1947,9 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char { p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); if(lasso > 0.0f && weight_decay > 0.0f) - p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)) - ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*((float)(lr*lasso)); + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)) - ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*((float)(lr*lasso)); else if(lasso > 0.0f) - p_vals[j] = ((float)p_vals[j]) - ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*((float)(lr*lasso)); + p_vals[j] = ((float)p_vals[j]) - ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*((float)(lr*lasso)); else if(weight_decay > 0.0f) p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); } @@ -2070,20 +2070,20 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char switch(OPTIMIZER) { case MOMENTUM: case RMSPROP: - g_val += ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*lasso + ((float)p_vals[j])*weight_decay; + g_val += ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*lasso + ((float)p_vals[j])*weight_decay; break; case LION: - p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay) - ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*((float)(lr*lasso)); + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay) - ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*((float)(lr*lasso)); break; } else if(lasso > 0.0f) switch(OPTIMIZER) { case MOMENTUM: case RMSPROP: - g_val += ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*lasso; + g_val += ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*lasso; break; case LION: - p_vals[j] = ((float)p_vals[j]) - ((float)((p_vals[j] > 0) - (p_vals[j] < 0)))*((float)(lr*lasso)); + p_vals[j] = ((float)p_vals[j]) - ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*((float)(lr*lasso)); break; } else if(weight_decay > 0.0f) { From 1d4e114ad09b2fca9eb297e4888666f0bd8153b6 Mon Sep 17 00:00:00 2001 From: Vincenzo Scotti Date: Sun, 11 Aug 2024 14:57:30 +0200 Subject: [PATCH 7/8] Add threshold to lasso update to ensure convergence --- csrc/kernels.cu | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 8a9c68209..d3ec01dfe 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1067,9 +1067,9 @@ __global__ void kOptimizer32bit2State(T* g, T* p, if(lasso > 0.0f && weight_decay > 0.0f) - p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)) - ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*((float)(lr*lasso)); + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)) - ((float)(((float)p_vals[j] > ((float)(lr*lasso))) - ((float)p_vals[j] < -((float)(lr*lasso)))))*((float)(lr*lasso)); else if(lasso > 0.0f) - p_vals[j] = ((float)p_vals[j]) - ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*((float)(lr*lasso)); + p_vals[j] = ((float)p_vals[j]) - ((float)(((float)p_vals[j] > ((float)(lr*lasso))) - ((float)p_vals[j] < -((float)(lr*lasso)))))*((float)(lr*lasso)); else if(weight_decay > 0.0f) p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); } @@ -1222,9 +1222,9 @@ __global__ void kOptimizer32bit1State(T *g, T *p, { g_vals[j] = gnorm_scale*((float)g_vals[j]); if(lasso > 0.0f && weight_decay > 0.0f) - p_vals[j] = (float)g_vals[j] + (((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*lasso) + (((float)p_vals[j])*weight_decay); + p_vals[j] = (float)g_vals[j] + (((float)(((float)p_vals[j] > ((float)(lr*lasso))) - ((float)p_vals[j] < -((float)(lr*lasso)))))*lasso) + (((float)p_vals[j])*weight_decay); else if(lasso > 0.0f) - g_vals[j] = (float)g_vals[j] + (((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*lasso); + g_vals[j] = (float)g_vals[j] + (((float)(((float)p_vals[j] > ((float)(lr*lasso))) - ((float)p_vals[j] < -((float)(lr*lasso)))))*lasso); else if(weight_decay > 0.0f) g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay); } @@ -1498,9 +1498,9 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha { p_vals[j] = (T)(((float)p_vals[j]) + ((update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(correction2*eps)))))); if(lasso > 0.0f && weight_decay > 0.0f) - p_vals[j] = update_scale*(((float)p_vals[j])*(1.0f-(lr*weight_decay)) - ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*((float)(lr*lasso))); + p_vals[j] = update_scale*(((float)p_vals[j])*(1.0f-(lr*weight_decay)) - ((float)(((float)p_vals[j] > ((float)(lr*lasso))) - ((float)p_vals[j] < -((float)(lr*lasso)))))*((float)(lr*lasso))); else if(lasso > 0.0f) - p_vals[j] = update_scale*(((float)p_vals[j]) - ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*((float)(lr*lasso))); + p_vals[j] = update_scale*(((float)p_vals[j]) - ((float)(((float)p_vals[j] > ((float)(lr*lasso))) - ((float)p_vals[j] < -((float)(lr*lasso)))))*((float)(lr*lasso))); else if(weight_decay > 0.0f) p_vals[j] = update_scale*((float)p_vals[j])*(1.0f-(lr*weight_decay)); } @@ -1678,20 +1678,20 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, switch(OPTIMIZER) { case MOMENTUM: case RMSPROP: - g_val += ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*lasso + ((float)p_vals[j])*weight_decay; + g_val += ((float)(((float)p_vals[j] > ((float)(lr*lasso))) - ((float)p_vals[j] < -((float)(lr*lasso)))))*lasso + ((float)p_vals[j])*weight_decay; break; case LION: - p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay) - ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*((float)(lr*lasso)); + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay) - ((float)(((float)p_vals[j] > ((float)(lr*lasso))) - ((float)p_vals[j] < -((float)(lr*lasso)))))*((float)(lr*lasso)); break; } else if(lasso > 0.0f) switch(OPTIMIZER) { case MOMENTUM: case RMSPROP: - g_val += ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*lasso; + g_val += ((float)(((float)p_vals[j] > ((float)(lr*lasso))) - ((float)p_vals[j] < -((float)(lr*lasso)))))*lasso; break; case LION: - p_vals[j] = ((float)p_vals[j]) - ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*((float)(lr*lasso)); + p_vals[j] = ((float)p_vals[j]) - ((float)(((float)p_vals[j] > ((float)(lr*lasso))) - ((float)p_vals[j] < -((float)(lr*lasso)))))*((float)(lr*lasso)); break; } else if(weight_decay > 0.0f) { @@ -1947,9 +1947,9 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char { p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); if(lasso > 0.0f && weight_decay > 0.0f) - p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)) - ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*((float)(lr*lasso)); + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)) - ((float)(((float)p_vals[j] > ((float)(lr*lasso))) - ((float)p_vals[j] < -((float)(lr*lasso)))))*((float)(lr*lasso)); else if(lasso > 0.0f) - p_vals[j] = ((float)p_vals[j]) - ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*((float)(lr*lasso)); + p_vals[j] = ((float)p_vals[j]) - ((float)(((float)p_vals[j] > ((float)(lr*lasso))) - ((float)p_vals[j] < -((float)(lr*lasso)))))*((float)(lr*lasso)); else if(weight_decay > 0.0f) p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); } @@ -2070,20 +2070,20 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char switch(OPTIMIZER) { case MOMENTUM: case RMSPROP: - g_val += ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*lasso + ((float)p_vals[j])*weight_decay; + g_val += ((float)(((float)p_vals[j] > ((float)(lr*lasso))) - ((float)p_vals[j] < -((float)(lr*lasso)))))*lasso + ((float)p_vals[j])*weight_decay; break; case LION: - p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay) - ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*((float)(lr*lasso)); + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay) - ((float)(((float)p_vals[j] > ((float)(lr*lasso))) - ((float)p_vals[j] < -((float)(lr*lasso)))))*((float)(lr*lasso)); break; } else if(lasso > 0.0f) switch(OPTIMIZER) { case MOMENTUM: case RMSPROP: - g_val += ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*lasso; + g_val += ((float)(((float)p_vals[j] > ((float)(lr*lasso))) - ((float)p_vals[j] < -((float)(lr*lasso)))))*lasso; break; case LION: - p_vals[j] = ((float)p_vals[j]) - ((float)(((float)p_vals[j] > 0) - ((float)p_vals[j] < 0)))*((float)(lr*lasso)); + p_vals[j] = ((float)p_vals[j]) - ((float)(((float)p_vals[j] > ((float)(lr*lasso))) - ((float)p_vals[j] < -((float)(lr*lasso)))))*((float)(lr*lasso)); break; } else if(weight_decay > 0.0f) { From f29c7e07eb3523828facba1b1865dc09f70edd7b Mon Sep 17 00:00:00 2001 From: Vincenzo Scotti Date: Sun, 11 Aug 2024 15:17:54 +0200 Subject: [PATCH 8/8] Refactor switch cases code --- csrc/kernels.cu | 116 +++++++++++++++++++++++------------------------- 1 file changed, 55 insertions(+), 61 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index d3ec01dfe..6bab7aeb1 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1674,36 +1674,33 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, g_val = float(g_vals[j]); g_val *= gnorm_scale; - if(lasso > 0.0f && weight_decay > 0.0f) - switch(OPTIMIZER) { - case MOMENTUM: - case RMSPROP: - g_val += ((float)(((float)p_vals[j] > ((float)(lr*lasso))) - ((float)p_vals[j] < -((float)(lr*lasso)))))*lasso + ((float)p_vals[j])*weight_decay; - break; - case LION: - p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay) - ((float)(((float)p_vals[j] > ((float)(lr*lasso))) - ((float)p_vals[j] < -((float)(lr*lasso)))))*((float)(lr*lasso)); - break; - } - else if(lasso > 0.0f) - switch(OPTIMIZER) { - case MOMENTUM: - case RMSPROP: - g_val += ((float)(((float)p_vals[j] > ((float)(lr*lasso))) - ((float)p_vals[j] < -((float)(lr*lasso)))))*lasso; - break; - case LION: - p_vals[j] = ((float)p_vals[j]) - ((float)(((float)p_vals[j] > ((float)(lr*lasso))) - ((float)p_vals[j] < -((float)(lr*lasso)))))*((float)(lr*lasso)); - break; - } - else if(weight_decay > 0.0f) { - switch(OPTIMIZER) { - case MOMENTUM: - case RMSPROP: - g_val += ((float)p_vals[j])*weight_decay; - break; - case LION: - p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); - break; - } + if(lasso > 0.0f || weight_decay > 0.0f) + { + switch (OPTIMIZER) + { + case MOMENTUM: + case ADAGRAD: + case RMSPROP: + { + if (lasso > 0.0f && weight_decay > 0.0f) + g_val += ((float) (((float) p_vals[j] > ((float) (lr * lasso))) - ((float) p_vals[j] < -((float) (lr * lasso))))) * lasso + ((float) p_vals[j]) * weight_decay; + else if (lasso > 0.0f) + g_val += ((float) (((float) p_vals[j] > ((float) (lr * lasso))) - ((float) p_vals[j] < -((float) (lr * lasso))))) * lasso; + else if (weight_decay > 0.0f) + g_val += ((float) p_vals[j]) * weight_decay; + break; + } + case LION: + { + if (lasso > 0.0f && weight_decay > 0.0f) + p_vals[j] = ((float) p_vals[j]) * (1.0f - lr * weight_decay) - ((float) (((float) p_vals[j] > ((float) (lr * lasso))) - ((float) p_vals[j] < -((float) (lr * lasso))))) * ((float) (lr * lasso)); + else if (lasso > 0.0f) + p_vals[j] = ((float) p_vals[j]) - ((float) (((float) p_vals[j] > ((float) (lr * lasso))) - ((float) p_vals[j] < -((float) (lr * lasso))))) * ((float) (lr * lasso)); + else if (weight_decay > 0.0f) + p_vals[j] = ((float) p_vals[j]) * (1.0f - lr * weight_decay); + break; + } + } } s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0]; @@ -2066,37 +2063,34 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char g_val *= gnorm_scale; if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) { - if(lasso > 0.0f && weight_decay > 0.0f) - switch(OPTIMIZER) { - case MOMENTUM: - case RMSPROP: - g_val += ((float)(((float)p_vals[j] > ((float)(lr*lasso))) - ((float)p_vals[j] < -((float)(lr*lasso)))))*lasso + ((float)p_vals[j])*weight_decay; - break; - case LION: - p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay) - ((float)(((float)p_vals[j] > ((float)(lr*lasso))) - ((float)p_vals[j] < -((float)(lr*lasso)))))*((float)(lr*lasso)); - break; - } - else if(lasso > 0.0f) - switch(OPTIMIZER) { - case MOMENTUM: - case RMSPROP: - g_val += ((float)(((float)p_vals[j] > ((float)(lr*lasso))) - ((float)p_vals[j] < -((float)(lr*lasso)))))*lasso; - break; - case LION: - p_vals[j] = ((float)p_vals[j]) - ((float)(((float)p_vals[j] > ((float)(lr*lasso))) - ((float)p_vals[j] < -((float)(lr*lasso)))))*((float)(lr*lasso)); - break; - } - else if(weight_decay > 0.0f) { - switch(OPTIMIZER) { - case MOMENTUM: - case RMSPROP: - g_val += ((float)p_vals[j])*weight_decay; - break; - case LION: - p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); - break; - } - } + if(lasso > 0.0f || weight_decay > 0.0f) + { + switch (OPTIMIZER) + { + case MOMENTUM: + case ADAGRAD: + case RMSPROP: + { + if (lasso > 0.0f && weight_decay > 0.0f) + g_val += ((float) (((float) p_vals[j] > ((float) (lr * lasso))) - ((float) p_vals[j] < -((float) (lr * lasso))))) * lasso + ((float) p_vals[j]) * weight_decay; + else if (lasso > 0.0f) + g_val += ((float) (((float) p_vals[j] > ((float) (lr * lasso))) - ((float) p_vals[j] < -((float) (lr * lasso))))) * lasso; + else if (weight_decay > 0.0f) + g_val += ((float) p_vals[j]) * weight_decay; + break; + } + case LION: + { + if (lasso > 0.0f && weight_decay > 0.0f) + p_vals[j] = ((float) p_vals[j]) * (1.0f - lr * weight_decay) - ((float) (((float) p_vals[j] > ((float) (lr * lasso))) - ((float) p_vals[j] < -((float) (lr * lasso))))) * ((float) (lr * lasso)); + else if (lasso > 0.0f) + p_vals[j] = ((float) p_vals[j]) - ((float) (((float) p_vals[j] > ((float) (lr * lasso))) - ((float) p_vals[j] < -((float) (lr * lasso))))) * ((float) (lr * lasso)); + else if (weight_decay > 0.0f) + p_vals[j] = ((float) p_vals[j]) * (1.0f - lr * weight_decay); + break; + } + } + } s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];