Skip to content

Commit

Permalink
Get rid of numpy dtypes in scheduler, prevent numpy dtypes ending up …
Browse files Browse the repository at this point in the history
…in train checkpoints
  • Loading branch information
rwightman committed Nov 24, 2024
1 parent ea524b9 commit 5f83962
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/open_clip_train/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import numpy as np
import math


def assign_learning_rate(optimizer, new_lr):
Expand All @@ -18,6 +18,7 @@ def _lr_adjuster(step):
lr = base_lr
assign_learning_rate(optimizer, lr)
return lr

return _lr_adjuster


Expand All @@ -33,10 +34,11 @@ def _lr_adjuster(step):
e = step - start_cooldown_step
es = steps - start_cooldown_step
# linear decay if power == 1; polynomial decay otherwise;
decay = (1 - (e/es)) ** cooldown_power
decay = (1 - (e / es)) ** cooldown_power
lr = decay * (base_lr - cooldown_end_lr) + cooldown_end_lr
assign_learning_rate(optimizer, lr)
return lr

return _lr_adjuster


Expand All @@ -47,7 +49,9 @@ def _lr_adjuster(step):
else:
e = step - warmup_length
es = steps - warmup_length
lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
lr = 0.5 * (1 + math.cos(math.pi * e / es)) * base_lr
assign_learning_rate(optimizer, lr)
return lr

return _lr_adjuster

0 comments on commit 5f83962

Please sign in to comment.