diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 7503ad73c..1132380be 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1555,9 +1555,9 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = def optimizer_update_32bit( optimizer_name: str, - g: Tensor, - p: Tensor, - state1: Tensor, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, beta1: float, eps: float, step: int, @@ -1571,6 +1571,7 @@ def optimizer_update_32bit( unorm_vec: Optional[torch.Tensor] = None, max_unorm: float = 0.0, skip_zeros=False, + return_updates: Optional[torch.Tensor] = None, ) -> None: """ Performs an inplace optimizer update with one or two optimizer states. @@ -1613,6 +1614,8 @@ def optimizer_update_32bit( The maximum update norm relative to the weight norm. skip_zeros : bool Whether to skip zero-valued gradients or not (default: False). + return_updates: Optional[torch.Tensor] + When provided, updates are written to this tensor and not applied directly to `p`. (default: None) """ param_norm = 0.0 @@ -1636,6 +1639,7 @@ def optimizer_update_32bit( optim_func( get_ptr(g), get_ptr(p), + get_ptr(return_updates), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), @@ -1658,25 +1662,26 @@ def optimizer_update_32bit( def optimizer_update_8bit( optimizer_name: str, - g: Tensor, - p: Tensor, - state1: Tensor, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, state2: Optional[torch.Tensor], beta1: float, beta2: float, eps: float, step: int, lr: float, - qmap1: Tensor, + qmap1: torch.Tensor, qmap2: Optional[torch.Tensor], - max1: Tensor, + max1: torch.Tensor, max2: Optional[torch.Tensor], - new_max1: Tensor, + new_max1: torch.Tensor, new_max2: Optional[torch.Tensor], weight_decay: float = 0.0, gnorm_scale: float = 1.0, unorm_vec: Optional[torch.Tensor] = None, max_unorm: float = 0.0, + return_updates: Optional[torch.Tensor] = None, ) -> None: """ Performs an inplace Adam update. @@ -1726,6 +1731,8 @@ def optimizer_update_8bit( The tensor for the update norm. max_unorm : float The maximum update norm relative to the weight norm. + return_updates: Optional[torch.Tensor] + When provided, updates are written to this tensor and not applied directly to `p`. (default: None) """ param_norm = 0.0 @@ -1738,6 +1745,7 @@ def optimizer_update_8bit( str2optimizer8bit[optimizer_name][0]( get_ptr(p), get_ptr(g), + get_ptr(return_updates), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), @@ -1762,6 +1770,7 @@ def optimizer_update_8bit( str2optimizer8bit[optimizer_name][1]( get_ptr(p), get_ptr(g), + get_ptr(return_updates), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), @@ -1809,6 +1818,7 @@ def optimizer_update_8bit_blockwise( weight_decay: float = 0.0, gnorm_scale: float = 1.0, skip_zeros=False, + return_updates: Optional[torch.Tensor] = None, ) -> None: optim_func = None prev_device = pre_call(g.device) @@ -1835,6 +1845,7 @@ def optimizer_update_8bit_blockwise( optim_func( get_ptr(p), get_ptr(g), + get_ptr(return_updates), get_ptr(state1), get_ptr(state2), ct.c_float(beta1), diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py index 07174c38d..05252246b 100644 --- a/bitsandbytes/optim/__init__.py +++ b/bitsandbytes/optim/__init__.py @@ -9,6 +9,7 @@ AdamW, AdamW8bit, AdamW32bit, + GaLoreAdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py index 4bf3f6436..a3c6402da 100644 --- a/bitsandbytes/optim/adamw.py +++ b/bitsandbytes/optim/adamw.py @@ -2,7 +2,17 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from bitsandbytes.optim.optimizer import Optimizer2State +import torch + +from bitsandbytes.optim.optimizer import GaLoreWrappedParameter, Optimizer2State + +_galore_available = False +try: + from galore_torch.galore_projector import GaLoreProjector + + _galore_available = True +except ImportError: + pass class AdamW(Optimizer2State): @@ -127,6 +137,117 @@ def __init__( ) +class GaLoreAdamW8bit(Optimizer2State): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + optim_bits=8, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + is_paged=False, + ): + if not _galore_available: + raise RuntimeError("The galore_torch package must be installed to use GaLoreAdamW8bit.") + super().__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + is_paged=is_paged, + ) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + overflows = [] + + if not self.initialized: + self.check_overrides() + self.to_gpu() # needed for fairseq pure fp16 training + self.initialized = True + + # if self.is_paged: self.page_mng.prefetch_all() + for gindex, group in enumerate(self.param_groups): + for pindex, p in enumerate(group["params"]): + if p.grad is None: + continue + state = self.state[p] + + if "step" not in state: + state["step"] = 0 + + # GaLore Projection + if "rank" in group: + if "projector" not in state: + state["projector"] = GaLoreProjector( + group["rank"], + update_proj_gap=group["update_proj_gap"], + scale=group["scale"], + proj_type=group["proj_type"], + ) + + grad = state["projector"].project(p.grad, state["step"]) + + # suboptimal implementation + # p.saved_data = p.data.clone() + # p.data = grad.clone().to(p.data.dtype).to(p.data.device) + # p.data.zero_() + # p.grad = grad + lor_update = torch.zeros_like( + grad, dtype=p.data.dtype, device=p.data.device, requires_grad=grad.requires_grad + ) + + if "state1" not in state: + self.init_state(group, p, gindex, pindex) + + self.prefetch_state(p) + + if "rank" in group: + galore_p = GaLoreWrappedParameter(p=p, grad=grad) + self.update_step(group, galore_p, gindex, pindex, return_updates=lor_update) + + # GaLore Projection Back + p.data.add_(state["projector"].project_back(lor_update)) + + if "weight_decay" in group and group["weight_decay"] > 0: + p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"]) + else: + self.update_step(group, p, gindex, pindex) + + torch.cuda.synchronize() + + if self.is_paged: + # all paged operation are asynchronous, we need + # to sync to make sure all tensors are in the right state + torch.cuda.synchronize() + + return loss + + class AdamW32bit(Optimizer2State): def __init__( self, diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index 03e0e01d7..2c0f77295 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -4,8 +4,9 @@ # LICENSE file in the root directory of this source tree. from collections import abc as container_abcs, defaultdict from copy import deepcopy +from dataclasses import dataclass from itertools import chain -from typing import Optional +from typing import Any, Dict, Optional, Union import torch @@ -18,6 +19,12 @@ def __init__(self, initial_data): setattr(self, key, initial_data[key]) +@dataclass +class GaLoreWrappedParameter: + p: torch.Tensor + grad: torch.Tensor + + class GlobalOptimManager: """ A global optimizer manager for enabling custom optimizer configs. @@ -320,7 +327,7 @@ def get_config(self, gindex, pindex, group): def init_state(self, group, p, gindex, pindex): raise NotImplementedError("init_state method needs to be overridden") - def update_step(self, group, p, gindex, pindex): + def update_step(self, group, p, gindex, pindex, return_updates): raise NotImplementedError("The update_step method needs to be overridden") def get_state_buffer(self, p, dtype=torch.float32): @@ -494,13 +501,25 @@ def init_state(self, group, p, gindex, pindex): state["unorm_vec"] = torch.zeros((1,), device=p.device) @torch.no_grad() - def update_step(self, group, p, gindex, pindex): - # avoid update error from non-contiguous memory layout - p.data = p.data.contiguous() - p.grad = p.grad.contiguous() + def update_step( + self, + group: Dict[str, Any], + p: Union[torch.Tensor, GaLoreWrappedParameter], + gindex: int, + pindex: int, + return_updates: Optional[torch.Tensor] = None, + ): + if isinstance(p, GaLoreWrappedParameter): + # Unwrap for GaLore + param_to_optimize = p.p + else: + param_to_optimize = p - state = self.state[p] - grad = p.grad + state = self.state[param_to_optimize] + + # avoid update error from non-contiguous memory layout + param_to_optimize.data = param_to_optimize.data.contiguous() + grad = p.grad.contiguous() config = self.get_config(gindex, pindex, group) @@ -521,7 +540,7 @@ def update_step(self, group, p, gindex, pindex): F.optimizer_update_32bit( self.optimizer_name, grad, - p, + param_to_optimize, state["state1"], config["betas"][0], config["eps"], @@ -536,13 +555,14 @@ def update_step(self, group, p, gindex, pindex): state["unorm_vec"] if config["max_unorm"] > 0.0 else None, max_unorm=config["max_unorm"], skip_zeros=config["skip_zeros"], + return_updates=return_updates, ) elif state["state1"].dtype == torch.uint8 and not config["block_wise"]: F.optimizer_update_8bit( self.optimizer_name, grad, - p, + param_to_optimize, state["state1"], state["state2"], config["betas"][0], @@ -560,6 +580,7 @@ def update_step(self, group, p, gindex, pindex): gnorm_scale=gnorm_scale, unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None, max_unorm=config["max_unorm"], + return_updates=return_updates, ) # swap maxes @@ -569,7 +590,7 @@ def update_step(self, group, p, gindex, pindex): F.optimizer_update_8bit_blockwise( self.optimizer_name, grad, - p, + param_to_optimize, state["state1"], state["state2"], config["betas"][0], @@ -586,6 +607,7 @@ def update_step(self, group, p, gindex, pindex): config["weight_decay"], gnorm_scale=gnorm_scale, skip_zeros=config["skip_zeros"], + return_updates=return_updates, ) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 867390f2c..2ff520dd2 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -68,27 +68,6 @@ __device__ float dDequantizeFP4(unsigned char val, float absmax) } } -__device__ float d2DequantizeFP4(unsigned char val) -{ - float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; - if((val & 0b0110) == 0) - { - // subnormal - if((val & 0b0001) == 0) - return 0.0f; - else - return sign*0.0625f; - } - else - { - // normal - float exponent = ((val & 0b0100) == 4 ? 2.0f : 8.0f) + ((val & 0b0010) == 2 ? 0.0f : 2.0f); - float fraction = (val & 0b0001) == 1 ? 1.5f : 1.0f; - - return sign*exponent*fraction; - } -} - __device__ float dDequantizeFP4Tree(unsigned char val, float absmax) { float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; @@ -165,60 +144,6 @@ __device__ unsigned char dQuantizeFP4(float x) return 0b0000+sign; } -__device__ half dhDequantizeNF4(unsigned char val) -{ - // the values for this tree was generated by test_normal_map_tree - // in the file tests/test_functional.py - if((val & 0b1000) == 8) - if((val & 0b0100) == 4) // 1 - if((val & 0b0010) == 2) // 11 - if((val & 0b0001) == 1) // 111 - return 1.0f; - else - return 0.7229568362236023f; - else - if((val & 0b0001) == 1) // 110 - return 0.5626170039176941f; - else - return 0.44070982933044434f; - else - if((val & 0b0010) == 2) //10 - if((val & 0b0001) == 1) // 101 - return 0.33791524171829224f; - else - return 0.24611230194568634f; - else - if((val & 0b0001) == 1) // 100 - return 0.16093020141124725f; - else - return 0.07958029955625534f; - - else - if((val & 0b0100) == 4) // 0 - if((val & 0b0010) == 2) //01 - if((val & 0b0001) == 1) // 011 - return 0.0f; - else - return -0.09105003625154495f; - else - if((val & 0b0001) == 1) // 010 - return -0.18477343022823334f; - else - return -0.28444138169288635f; - else - if((val & 0b0010) == 2) //00 - if((val & 0b0001) == 1) // 001 - return -0.39491748809814453f; - else - return -0.5250730514526367f; - else - if((val & 0b0001) == 1) // 000 - return -0.6961928009986877f; - else - return -1.0f; - -} - __device__ float dDequantizeNF4(unsigned char val) { @@ -872,7 +797,7 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p, template __launch_bounds__(TH, 1) -__global__ void kOptimizer32bit2State(T* g, T* p, +__global__ void kOptimizer32bit2State(T* g, T* p, T* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) @@ -931,7 +856,7 @@ __global__ void kOptimizer32bit2State(T* g, T* p, __syncthreads(); LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items); __syncthreads(); - Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); + Load(temp_storage.load).Load(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items); // Load additional state1 data for AdEMAMix // TODO: Make constexpr after updating min compiler @@ -975,17 +900,22 @@ __global__ void kOptimizer32bit2State(T* g, T* p, { s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); - p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2)))); - if(weight_decay > 0.0f) - p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); + if (return_updates == nullptr) { + p_vals[j] = (T)(((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2))))); + + if(weight_decay > 0.0f) + p_vals[j] = (T)(((float)p_vals[j])*(1.0f-(lr*weight_decay))); + } else { + p_vals[j] = (T)(update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2)))); + } } break; } } __syncthreads(); - Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + Store(temp_storage.store).Store(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items); __syncthreads(); StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); __syncthreads(); @@ -1081,7 +1011,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, template __launch_bounds__(TH, 1) -__global__ void kOptimizer32bit1State(T *g, T *p, +__global__ void kOptimizer32bit1State(T *g, T *p, T *return_updates, float *state1, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) @@ -1127,13 +1057,13 @@ __global__ void kOptimizer32bit1State(T *g, T *p, __syncthreads(); LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); __syncthreads(); - Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); + Load(temp_storage.load).Load(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items); # pragma unroll 4 for(unsigned int j = 0; j < NUM_PER_THREAD; j++) { g_vals[j] = gnorm_scale*((float)g_vals[j]); - if(weight_decay > 0.0f) + if(weight_decay > 0.0f && return_updates == nullptr) g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay); } @@ -1150,26 +1080,26 @@ __global__ void kOptimizer32bit1State(T *g, T *p, else s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); - p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j])); + p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) + update_scale*(-lr*(s1_vals[j])); break; case LION: - p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j])))); + p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j])))); s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j])); break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); - p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps)); + p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps)); break; case ADAGRAD: s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); - p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps); + p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps); break; } } } __syncthreads(); - Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + Store(temp_storage.store).Store(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items); __syncthreads(); StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); } @@ -1298,7 +1228,7 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c template __global__ void __launch_bounds__(NUM_THREADS2, 1) -kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, +kOptimizerStatic8bit2State(T* p, T* const g, T* return_updates, unsigned char* state1, unsigned char* state2, const float *unorm, const float max_unorm, const float param_norm, \ const float beta1, const float beta2, const float eps, const int step, const float lr, @@ -1369,7 +1299,7 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha __syncthreads(); LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); __syncthreads(); - LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); + LoadT(temp_storage.loadh).Load(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items); if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } @@ -1404,12 +1334,16 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha # pragma unroll 4 for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) { - p_vals[j] = (T)(((float)p_vals[j]) + ((update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(correction2*eps)))))); - if(weight_decay > 0.0f) - p_vals[j] = update_scale*((float)p_vals[j])*(1.0f-(lr*weight_decay)); + if (return_updates == nullptr) { + p_vals[j] = (T)(((float)p_vals[j]) + ((update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(correction2*eps)))))); + if(weight_decay > 0.0f) + p_vals[j] = (T)(update_scale*((float)p_vals[j])*(1.0f-(lr*weight_decay))); + } else { + p_vals[j] = (T)((update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(correction2*eps))))); + } } - StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + StoreT(temp_storage.storeh).Store(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items); __syncthreads(); StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); __syncthreads(); @@ -1513,7 +1447,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c template __global__ void __launch_bounds__(1024, 1) -kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, +kOptimizerStatic8bit1State(T* p, T* const g, T* return_updates, unsigned char* state1, const float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr, @@ -1569,7 +1503,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, __syncthreads(); LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); __syncthreads(); - LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); + LoadT(temp_storage.loadh).Load(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items); if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } @@ -1579,7 +1513,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, g_val = float(g_vals[j]); g_val *= gnorm_scale; - if(weight_decay > 0.0f) { + if(weight_decay > 0.0f && return_updates == nullptr) { switch(OPTIMIZER) { case ADAGRAD: case MOMENTUM: @@ -1602,15 +1536,15 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, else s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); - p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j])); + p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) + (-lr*update_scale*(s1_vals[j])); break; case LION: - p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val)))); + p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val)))); s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); - p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps)); + p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps)); break; } @@ -1626,7 +1560,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, } } - StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + StoreT(temp_storage.storeh).Store(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items); __syncthreads(); StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); __syncthreads(); @@ -1687,6 +1621,7 @@ __global__ void kOptimizerStatic8bit2StateBlockwise( T* p, T* __restrict__ const g, + T* __restrict__ return_updates, unsigned char* state1, unsigned char* state2, const float beta1, @@ -1881,7 +1816,7 @@ kOptimizerStatic8bit2StateBlockwise( } __syncthreads(); - LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); + LoadT(temp_storage.loadh).Load(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items, (T)0.0f); // reduce: 2.67/1.69 -> 2.67/1.70 # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) @@ -1895,18 +1830,24 @@ kOptimizerStatic8bit2StateBlockwise( (sqrtf(s2_vals[j]) / correction2) + eps ) )); + + if (weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); } else { - p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); + if (return_updates == nullptr) { + p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); + if(weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } else { + p_vals[j] = (T)(step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))); + } } - - if(weight_decay > 0.0f) - p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); - } + } } // store: 0.85/1.44 -> 2.48/1.57 __syncthreads(); - StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + StoreT(temp_storage.storeh).Store(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items); // quantizaztion: 2.67/1.70 -> 3.4/3.3 # pragma unroll N_PER_TH @@ -1952,7 +1893,7 @@ kOptimizerStatic8bit2StateBlockwise( template __launch_bounds__(256, 3) __global__ void -kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, +kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, T* return_updates, unsigned char* state1, const float beta1, const float beta2, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, @@ -2016,7 +1957,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char __syncthreads(); LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); __syncthreads(); - LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); + LoadT(temp_storage.loadh).Load(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items, (T)0.0f); new_local_abs_max1 = -FLT_MAX; @@ -2028,7 +1969,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char g_val *= gnorm_scale; if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) { - if(weight_decay > 0.0f) { + if(weight_decay > 0.0f && return_updates == nullptr) { switch(OPTIMIZER) { case MOMENTUM: case ADAGRAD: @@ -2091,18 +2032,18 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char switch(OPTIMIZER) { case MOMENTUM: - p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); + p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - lr*(s1_vals[j]); break; case LION: - p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]); + p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - ((float)g_vals[j]); break; case RMSPROP: g_val = g_vals[j]; - p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); + p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); break; case ADAGRAD: g_val = g_vals[j]; - p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); + p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); break; } } @@ -3841,7 +3782,7 @@ MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, __nv_bfloat16) #define MAKE_Optimizer32bit1State(oname, gtype) \ -template __global__ void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ +template __global__ void kOptimizer32bit1State(gtype* g, gtype* p, gtype* return_updates, float* state1, float *unorm, const float max_unorm, const float param_norm, \ const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \ MAKE_Optimizer32bit1State(MOMENTUM, half) @@ -3870,17 +3811,17 @@ MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, float) MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, half) MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, __nv_bfloat16) -template __global__ void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, +template __global__ void kOptimizer32bit2State(float* g, float* p, float* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, +template __global__ void kOptimizer32bit2State(half* g, half* p, half* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, +template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, __nv_bfloat16* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, +template __global__ void kOptimizer32bit2State(float* g, float* p, float* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, +template __global__ void kOptimizer32bit2State(half* g, half* p, half* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADEMAMIX>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, +template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADEMAMIX>(__nv_bfloat16* g, __nv_bfloat16* p, __nv_bfloat16* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); @@ -3906,7 +3847,7 @@ MAKE_PreconditionStatic8bit1State(ADAGRAD, half) MAKE_PreconditionStatic8bit1State(ADAGRAD, float) #define MAKE_optimizerStatic8bit1State(oname, gtype) \ -template __global__ void kOptimizerStatic8bit1State(gtype* p, gtype* const g, unsigned char* state1, \ +template __global__ void kOptimizerStatic8bit1State(gtype* p, gtype* const g, gtype* return_updates, unsigned char* state1, \ const float *unorm, const float max_unorm, const float param_norm, \ const float beta1, \ const float beta2, \ @@ -3941,7 +3882,8 @@ MAKE_PreconditionStatic8bit2State(ADAM, half) MAKE_PreconditionStatic8bit2State(ADAM, float) #define MAKE_optimizerStatic8bit2State(oname, gtype) \ -template __global__ void kOptimizerStatic8bit2State(gtype* p, gtype* const g, unsigned char* state1, unsigned char* state2, \ +template __global__ void kOptimizerStatic8bit2State(gtype* p, gtype* const g, gtype* return_updates, \ + unsigned char* state1, unsigned char* state2, \ const float *unorm, const float max_unorm, const float param_norm, \ const float beta1, const float beta2, \ const float eps, const int step, const float lr, \ @@ -4041,7 +3983,9 @@ template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, General template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n); #define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ -template __global__ void kOptimizerStatic8bit2StateBlockwise(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \ +template __global__ void kOptimizerStatic8bit2StateBlockwise( \ + gtype* p, gtype* __restrict__ const g, gtype* __restrict__ return_updates, \ + unsigned char* state1, unsigned char* state2, \ const float beta1, const float beta2, const float beta3, const float alpha, \ const float eps, const int step, const float lr, \ float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ @@ -4058,7 +4002,7 @@ MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, __nv_bfloat16, 256, 1) #define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \ template __global__ void kOptimizerStatic8bit1StateBlockwise( \ - gtype* p, gtype* __restrict__ const g, unsigned char* state1, \ + gtype* p, gtype* __restrict__ const g, gtype* return_updates, unsigned char* state1, \ const float beta1, const float beta2, \ const float eps, const int step, const float lr, \ float* __restrict__ const quantiles1, \ diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index ec6daebe5..376639993 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -25,7 +25,7 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p, const int step, const float lr, const float gnorm_scale, const int n); template -__global__ void kOptimizer32bit2State(T* g, T* p, +__global__ void kOptimizer32bit2State(T* g, T* p, T* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, @@ -38,7 +38,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, const int step, const float lr, const float gnorm_scale, const int n); template -__global__ void kOptimizer32bit1State(T* g, T* p, +__global__ void kOptimizer32bit1State(T* g, T* p, T* return_updates, float* state1, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); @@ -57,7 +57,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c template __global__ void -kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, +kOptimizerStatic8bit1State(T* p, T* const g, T* return_updates, unsigned char* state1, const float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr, @@ -80,7 +80,7 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c template __global__ void -kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, +kOptimizerStatic8bit2State(T* p, T* const g, T* return_updates, unsigned char* state1, unsigned char* state2, const float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr, @@ -89,13 +89,14 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha float weight_decay, const float gnorm_scale, const int n); template __global__ void kOptimizerStatic8bit2StateBlockwise( - T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, + T* p, T* __restrict__ const g, T* __restrict__ return_updates, + unsigned char* state1, unsigned char* state2, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void kOptimizerStatic8bit1StateBlockwise( - T* p, T* __restrict__ const g, unsigned char* state1, + T* p, T* __restrict__ const g, T* return_updates, unsigned char* state1, const float beta1, const float beta2, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, diff --git a/csrc/ops.cu b/csrc/ops.cu index 7ca854baf..e3c99a875 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -92,7 +92,7 @@ template void dequantizeBlockwise(float *code, unsign -template void optimizer32bit(T* g, T* p, +template void optimizer32bit(T* g, T* p, T* return_updates, float* state1, float* state2, float *unorm, float max_unorm, float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) @@ -109,7 +109,7 @@ template void optimizer32bit(T* g, T* p, kPreconditionOptimizer32bit2State<<>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } - kOptimizer32bit2State<<>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + kOptimizer32bit2State<<>>(g, p, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; case MOMENTUM: @@ -122,12 +122,12 @@ template void optimizer32bit(T* g, T* p, CUDA_CHECK_RETURN(cudaPeekAtLastError()); } - kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + kOptimizer32bit1State<<>>(g, p, return_updates, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; case LION: // in lion, the momentum update after the parameter update - kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + kOptimizer32bit1State<<>>(g, p, return_updates, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); if(max_unorm > 0.0f) @@ -140,7 +140,7 @@ template void optimizer32bit(T* g, T* p, } } -template void optimizerStatic8bit(T* p, T* g, +template void optimizerStatic8bit(T* p, T* g, T* return_updates, unsigned char* state1, unsigned char* state2, float *unorm, float max_unorm, float param_norm, float beta1, float beta2, @@ -162,7 +162,7 @@ template void optimizerStatic8bit(T* p, T* g, CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1*sizeof(float))); kPreconditionOptimizerStatic8bit2State<<>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); - kOptimizerStatic8bit2State<<>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + kOptimizerStatic8bit2State<<>>(p, g, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; @@ -172,13 +172,13 @@ template void optimizerStatic8bit(T* p, T* g, CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); - kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + kOptimizerStatic8bit1State<<>>(p, g, return_updates, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; case LION: // in lion, the momentum update happens after the parameter update - kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + kOptimizerStatic8bit1State<<>>(p, g, return_updates, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); @@ -199,6 +199,7 @@ template void optimizerStatic8bit(T* p, T* g, template void optimizerStatic8bitBlockwise( T* p, T* g, + T* return_updates, unsigned char* state1, unsigned char* state2, float beta1, @@ -226,7 +227,7 @@ template void optimizerStatic8bitBlockwise( num_blocks = n/BLOCKSIZE_2STATE; num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1; kOptimizerStatic8bit2StateBlockwise<<>>( - p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, + p, g, return_updates, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n ); @@ -238,7 +239,7 @@ template void optimizerStatic8bitBlockwise( case LION: num_blocks = n/BLOCKSIZE_1STATE; num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; - kOptimizerStatic8bit1StateBlockwise<<>>(p, g, state1, beta1, beta2, eps, step, lr, + kOptimizerStatic8bit1StateBlockwise<<>>(p, g, return_updates, state1, beta1, beta2, eps, step, lr, quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; @@ -807,7 +808,7 @@ template void dequantizeBlockwise<__nv_bfloat16, FP4>(float *code, unsigned char template void dequantizeBlockwise<__nv_bfloat16, NF4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream); #define MAKE_optimizer32bit(name, gtype) \ -template void optimizer32bit(gtype* g, gtype* p, \ +template void optimizer32bit(gtype* g, gtype* p, gtype* return_updates, \ float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \ const float beta1, const float beta2, const float beta3, const float alpha, \ const float eps, const float weight_decay, \ @@ -833,7 +834,8 @@ MAKE_optimizer32bit(ADEMAMIX, __nv_bfloat16) MAKE_optimizer32bit(ADEMAMIX, float) #define MAKE_optimizerStatic8bit(name, gtype) \ -template void optimizerStatic8bit(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ +template void optimizerStatic8bit(gtype* p, gtype* g, gtype* return_updates, \ + unsigned char* state1, unsigned char* state2, \ float *unorm, float max_unorm, float param_norm, \ float beta1, float beta2, \ float eps, int step, float lr, \ @@ -855,7 +857,7 @@ MAKE_optimizerStatic8bit(ADAGRAD, float) #define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ -template void optimizerStatic8bitBlockwise(gtype* p, gtype* g, \ +template void optimizerStatic8bitBlockwise(gtype* p, gtype* g, gtype* return_updates, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \ diff --git a/csrc/ops.cuh b/csrc/ops.cuh index b0ecc4622..f61de4095 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -148,12 +148,13 @@ void dequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t s template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n, cudaStream_t stream); -template void optimizer32bit(T* g, T* p, +template void optimizer32bit(T* g, T* p, T* return_updates, float* state1, float* state2, float *unorm, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, const float gnorm_scale, bool skip_zeros, int n); -template void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2, +template void optimizerStatic8bit(T* p, T* g, T* return_updates, + unsigned char* state1, unsigned char* state2, float *unorm, float max_unorm, float param_norm, float beta1, float beta2, float eps, int step, float lr, @@ -162,10 +163,11 @@ template void optimizerStatic8bit(T* p, T* g, unsigne float weight_decay, const float gnorm_scale, int n); -template void optimizerStatic8bitBlockwise(T* p, T* g, + +template void optimizerStatic8bitBlockwise(T* p, T* g, T* return_updates, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, - bool skip_zeros, int n); + bool skip_zeros, int n); template void percentileClipping(T * g, float *gnorm_vec, int step, const int n); diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index f0ee84c29..14ceb17b8 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -50,12 +50,12 @@ MAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) #define MAKE_FUNC32(fname, oname, gtype, gbits) \ -void fname##32bit_grad_##gbits(gtype *g, gtype *p, \ +void fname##32bit_grad_##gbits(gtype *g, gtype *p, gtype *return_updates, \ float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ const float beta1, const float beta2, const float beta3, const float alpha, \ const float eps, const float weight_decay, \ const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n) \ -{ optimizer32bit(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ +{ optimizer32bit(g, p, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ MAKE_FUNC32(momentum, MOMENTUM, float, 32) MAKE_FUNC32(momentum, MOMENTUM, half, 16) @@ -75,7 +75,7 @@ MAKE_FUNC32(ademamix, ADEMAMIX, __nv_bfloat16, bf16) #define MAKE_FUNC8(fname, oname, gtype, gbits) \ -void fname##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ +void fname##_static_8bit_grad_##gbits(gtype* p, gtype* g, gtype* return_updates, unsigned char* state1, unsigned char* state2, \ float *unorm, float max_unorm, float param_norm, \ float beta1, float beta2, \ float eps, int step, float lr, \ @@ -83,7 +83,7 @@ void fname##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, float* max1, float* max2, float* new_max1, float* new_max2, \ float weight_decay, float gnorm_scale, int n) \ { \ - optimizerStatic8bit(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ + optimizerStatic8bit(g, p, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ } \ @@ -97,10 +97,10 @@ MAKE_FUNC8(lion, LION, float, 32) MAKE_FUNC8(lion, LION, half, 16) #define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \ -void fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \ +void fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, gtype* return_updates, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\ -{ optimizerStatic8bitBlockwise(p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\ +{ optimizerStatic8bitBlockwise(p, g, return_updates, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\ MAKE_BLOCKWISE8(adam, ADAM, half, fp16) MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) @@ -233,12 +233,13 @@ extern "C" void cdequantize_blockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream); } #define MAKE_CFUNC32(name, gtype, gbits) \ - void c##name##32bit_grad_##gbits(gtype *g, gtype *p, \ + void c##name##32bit_grad_##gbits(gtype *g, gtype *p, gtype *return_updates, \ float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ const float beta1, const float beta2, const float beta3, const float alpha, \ const float eps, const float weight_decay, \ const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \ - { name##32bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ + { name##32bit_grad_##gbits(g, p, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ + MAKE_CFUNC32(adam, float, fp32) MAKE_CFUNC32(adam, half, fp16) @@ -257,7 +258,8 @@ extern "C" MAKE_CFUNC32(ademamix, __nv_bfloat16, bf16) #define MAKE_CFUNC8(name, gtype, gbits) \ - void c##name##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ + void c##name##_static_8bit_grad_##gbits(gtype* p, gtype* g, gtype* return_updates, \ + unsigned char* state1, unsigned char* state2, \ float *unorm, float max_unorm, float param_norm, \ float beta1, float beta2, \ float eps, int step, float lr, \ @@ -265,7 +267,7 @@ extern "C" float* max1, float* max2, float* new_max1, float* new_max2, \ float weight_decay, float gnorm_scale, int n) \ { \ - name##_static_8bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ + name##_static_8bit_grad_##gbits(g, p, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ } \ @@ -279,10 +281,11 @@ extern "C" MAKE_CFUNC8(lion, half, 16) #define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \ - void c##fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \ + void c##fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, gtype* return_updates, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \ - { fname##_8bit_blockwise_grad_##gbits(p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \ + { fname##_8bit_blockwise_grad_##gbits(p, g, return_updates, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \ + MAKE_CBLOCKWISE8(adam, ADAM, half, fp16) MAKE_CBLOCKWISE8(adam, ADAM, float, fp32) diff --git a/setup.py b/setup.py index 3a1bcb574..434a2eaf4 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ def read(fname): - return open(os.path.join(os.path.dirname(__file__), fname)).read() + return open(os.path.join(os.path.dirname(__file__), fname), encoding="utf8").read() # Tested with wheel v0.29.0