diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 90e07a687..2ff520dd2 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1011,7 +1011,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, template __launch_bounds__(TH, 1) -__global__ void kOptimizer32bit1State(T *g, T *p, +__global__ void kOptimizer32bit1State(T *g, T *p, T *return_updates, float *state1, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) @@ -1057,13 +1057,13 @@ __global__ void kOptimizer32bit1State(T *g, T *p, __syncthreads(); LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); __syncthreads(); - Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); + Load(temp_storage.load).Load(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items); # pragma unroll 4 for(unsigned int j = 0; j < NUM_PER_THREAD; j++) { g_vals[j] = gnorm_scale*((float)g_vals[j]); - if(weight_decay > 0.0f) + if(weight_decay > 0.0f && return_updates == nullptr) g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay); } @@ -1080,26 +1080,26 @@ __global__ void kOptimizer32bit1State(T *g, T *p, else s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); - p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j])); + p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) + update_scale*(-lr*(s1_vals[j])); break; case LION: - p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j])))); + p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j])))); s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j])); break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); - p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps)); + p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps)); break; case ADAGRAD: s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); - p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps); + p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps); break; } } } __syncthreads(); - Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + Store(temp_storage.store).Store(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items); __syncthreads(); StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); } @@ -1447,7 +1447,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c template __global__ void __launch_bounds__(1024, 1) -kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, +kOptimizerStatic8bit1State(T* p, T* const g, T* return_updates, unsigned char* state1, const float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr, @@ -1503,7 +1503,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, __syncthreads(); LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); __syncthreads(); - LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); + LoadT(temp_storage.loadh).Load(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items); if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } @@ -1513,7 +1513,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, g_val = float(g_vals[j]); g_val *= gnorm_scale; - if(weight_decay > 0.0f) { + if(weight_decay > 0.0f && return_updates == nullptr) { switch(OPTIMIZER) { case ADAGRAD: case MOMENTUM: @@ -1536,15 +1536,15 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, else s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); - p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j])); + p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) + (-lr*update_scale*(s1_vals[j])); break; case LION: - p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val)))); + p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val)))); s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); break; case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); - p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps)); + p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps)); break; } @@ -1560,7 +1560,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, } } - StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + StoreT(temp_storage.storeh).Store(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items); __syncthreads(); StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); __syncthreads(); @@ -1893,7 +1893,7 @@ kOptimizerStatic8bit2StateBlockwise( template __launch_bounds__(256, 3) __global__ void -kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, +kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, T* return_updates, unsigned char* state1, const float beta1, const float beta2, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, @@ -1957,7 +1957,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char __syncthreads(); LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); __syncthreads(); - LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); + LoadT(temp_storage.loadh).Load(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items, (T)0.0f); new_local_abs_max1 = -FLT_MAX; @@ -1969,7 +1969,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char g_val *= gnorm_scale; if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) { - if(weight_decay > 0.0f) { + if(weight_decay > 0.0f && return_updates == nullptr) { switch(OPTIMIZER) { case MOMENTUM: case ADAGRAD: @@ -2032,18 +2032,18 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char switch(OPTIMIZER) { case MOMENTUM: - p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); + p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - lr*(s1_vals[j]); break; case LION: - p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]); + p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - ((float)g_vals[j]); break; case RMSPROP: g_val = g_vals[j]; - p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); + p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); break; case ADAGRAD: g_val = g_vals[j]; - p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); + p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); break; } } @@ -3782,7 +3782,7 @@ MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, __nv_bfloat16) #define MAKE_Optimizer32bit1State(oname, gtype) \ -template __global__ void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ +template __global__ void kOptimizer32bit1State(gtype* g, gtype* p, gtype* return_updates, float* state1, float *unorm, const float max_unorm, const float param_norm, \ const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \ MAKE_Optimizer32bit1State(MOMENTUM, half) @@ -3847,7 +3847,7 @@ MAKE_PreconditionStatic8bit1State(ADAGRAD, half) MAKE_PreconditionStatic8bit1State(ADAGRAD, float) #define MAKE_optimizerStatic8bit1State(oname, gtype) \ -template __global__ void kOptimizerStatic8bit1State(gtype* p, gtype* const g, unsigned char* state1, \ +template __global__ void kOptimizerStatic8bit1State(gtype* p, gtype* const g, gtype* return_updates, unsigned char* state1, \ const float *unorm, const float max_unorm, const float param_norm, \ const float beta1, \ const float beta2, \ @@ -4002,7 +4002,7 @@ MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, __nv_bfloat16, 256, 1) #define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \ template __global__ void kOptimizerStatic8bit1StateBlockwise( \ - gtype* p, gtype* __restrict__ const g, unsigned char* state1, \ + gtype* p, gtype* __restrict__ const g, gtype* return_updates, unsigned char* state1, \ const float beta1, const float beta2, \ const float eps, const int step, const float lr, \ float* __restrict__ const quantiles1, \ diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index e0213d997..376639993 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -38,7 +38,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, const int step, const float lr, const float gnorm_scale, const int n); template -__global__ void kOptimizer32bit1State(T* g, T* p, +__global__ void kOptimizer32bit1State(T* g, T* p, T* return_updates, float* state1, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); @@ -57,7 +57,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c template __global__ void -kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, +kOptimizerStatic8bit1State(T* p, T* const g, T* return_updates, unsigned char* state1, const float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr, @@ -96,7 +96,7 @@ template __global__ voi float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void kOptimizerStatic8bit1StateBlockwise( - T* p, T* __restrict__ const g, unsigned char* state1, + T* p, T* __restrict__ const g, T* return_updates, unsigned char* state1, const float beta1, const float beta2, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, diff --git a/csrc/ops.cu b/csrc/ops.cu index ebe5b953f..e3c99a875 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -122,12 +122,12 @@ template void optimizer32bit(T* g, T* p, T* return_up CUDA_CHECK_RETURN(cudaPeekAtLastError()); } - kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + kOptimizer32bit1State<<>>(g, p, return_updates, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; case LION: // in lion, the momentum update after the parameter update - kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + kOptimizer32bit1State<<>>(g, p, return_updates, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); if(max_unorm > 0.0f) @@ -172,13 +172,13 @@ template void optimizerStatic8bit(T* p, T* g, T* retu CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); - kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + kOptimizerStatic8bit1State<<>>(p, g, return_updates, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; case LION: // in lion, the momentum update happens after the parameter update - kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + kOptimizerStatic8bit1State<<>>(p, g, return_updates, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); @@ -239,7 +239,7 @@ template void optimizerStatic8bitBlockwise( case LION: num_blocks = n/BLOCKSIZE_1STATE; num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; - kOptimizerStatic8bit1StateBlockwise<<>>(p, g, state1, beta1, beta2, eps, step, lr, + kOptimizerStatic8bit1StateBlockwise<<>>(p, g, return_updates, state1, beta1, beta2, eps, step, lr, quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break;