-
Notifications
You must be signed in to change notification settings - Fork 492
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reverse weight decay #567
base: main
Are you sure you want to change the base?
Reverse weight decay #567
Changes from 11 commits
6240dc9
0f5e28f
b7dc57e
1fc07cd
4c5c4b1
49a6f83
9fae31a
d6d5345
70d12b8
2f8beef
d2f6ea2
717925e
962b983
465d143
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,6 +43,7 @@ def clip_grads_and_collect_metrics( | |
global_step: int, | ||
collect_param_metrics: bool = True, | ||
process_group: Optional[dist.ProcessGroup] = None, | ||
reverse_embedding_decay: bool = False, | ||
dirkgr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> Dict[str, torch.Tensor]: | ||
""" | ||
Clips gradients for every group that has the field `max_grad_norm`. | ||
|
@@ -83,13 +84,14 @@ def clip_grads_and_collect_metrics( | |
# with ReLoRa, for example. | ||
assert group.get("sharded", True) is True | ||
|
||
is_embedding_group = group["name"] == "embedding_decay_group" | ||
for name, p in zip(group["param_names"], group["params"]): | ||
name = self._clean_param_name(name) | ||
# Always need to collect the norm of gradients for clipping, even if we're not collecting | ||
# Always need to collect the norm of gradients and parameters for clipping, even if we're not collecting | ||
# other metrics. | ||
tensors: List[Optional[torch.Tensor]] = [p.grad] | ||
prefixes: List[str] = [f"grad/{name}"] | ||
if collect_param_metrics: | ||
if collect_param_metrics or (reverse_embedding_decay and is_embedding_group): | ||
state = self.get_state_for_param(p) | ||
sorted_state_keys = sorted([k for k in state.keys()]) | ||
tensors.extend([p] + [state[key] for key in sorted_state_keys]) | ||
|
@@ -232,7 +234,7 @@ def is_grad_norm_metric(metric_name: str) -> bool: | |
all_metrics["clipping_rate"] = clipping_rate | ||
return all_metrics | ||
else: | ||
return {} | ||
return all_metrics | ||
|
||
@torch.no_grad() | ||
def _do_adaptive_clipping( | ||
|
@@ -617,6 +619,7 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]] | |
# Separate out parameters that we don't want to apply weight decay to, like norms and biases. | ||
decay = set() | ||
no_decay = set() | ||
embeddings_decay = set() | ||
all_params = {} | ||
for mn, m in model.named_modules(): | ||
for pn, p in m.named_parameters(): | ||
|
@@ -644,12 +647,14 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]] | |
elif pn.endswith("weight") and isinstance(m, nn.Embedding): | ||
if cfg.optimizer.decay_embeddings: | ||
decay.add(fpn) | ||
elif cfg.optimizer.reverse_embedding_decay: | ||
embeddings_decay.add(fpn) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens if these are both set? We should check against that somewhere. |
||
else: | ||
no_decay.add(fpn) | ||
|
||
# Validate that we've considered every parameter | ||
inter_params = decay & no_decay | ||
union_params = decay | no_decay | ||
inter_params = decay & no_decay & embeddings_decay | ||
union_params = decay | no_decay | embeddings_decay | ||
assert len(inter_params) == 0, f"parameters {inter_params} made it into both decay/no_decay sets!" | ||
assert ( | ||
len(all_params.keys() - union_params) == 0 | ||
|
@@ -658,12 +663,15 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]] | |
# Create the pytorch optimizer groups. | ||
decay_sorted = sorted(list(decay)) | ||
no_decay_sorted = sorted(list(no_decay)) | ||
embeddings_decay_sorted = sorted(list(embeddings_decay)) | ||
|
||
param_groups = [] | ||
if len(decay_sorted) > 0: | ||
param_groups.append( | ||
{ | ||
"params": [all_params[pn] for pn in decay_sorted], | ||
"param_names": decay_sorted, | ||
"name": "decay_group", | ||
**param_group_defaults, | ||
} | ||
) | ||
|
@@ -673,6 +681,17 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]] | |
"params": [all_params[pn] for pn in no_decay_sorted], | ||
"param_names": no_decay_sorted, | ||
"weight_decay": 0.0, | ||
"name": "no_decay_group", | ||
**param_group_defaults, | ||
} | ||
) | ||
if len(embeddings_decay_sorted) > 0: | ||
# the weight_decay value will be multiplied by emb_decay_factor in olmo/train.py | ||
param_groups.append( | ||
{ | ||
"params": [all_params[pn] for pn in embeddings_decay_sorted], | ||
"param_names": embeddings_decay_sorted, | ||
"name": "embedding_decay_group", | ||
**param_group_defaults, | ||
} | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,11 +19,10 @@ | |
import torch | ||
import torch.distributed as dist | ||
import torch.nn.functional as F | ||
import wandb | ||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
from torch.utils.data import DataLoader | ||
|
||
import wandb | ||
|
||
from .aliases import PathOrStr | ||
from .checkpoint import Checkpointer, FullCheckpointer, build_sharded_checkpointer | ||
from .config import ( | ||
|
@@ -720,8 +719,14 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> | |
# passing this process group here ensures metrics are reduced correctly when we're using | ||
# HYBRID sharding. | ||
process_group=self.fsdp_model.process_group, | ||
reverse_embedding_decay=self.cfg.optimizer.reverse_embedding_decay, | ||
) | ||
|
||
emb_norm = optim_metrics["param/transformer.wte.weight.norm"] | ||
emb_size = self.cfg.model.embedding_size or self.cfg.model.vocab_size | ||
emb_std = math.sqrt(math.pow(emb_norm, 2) / float(emb_size * self.cfg.model.vocab_size)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe the denominator should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. update: @AkshitaB and I discussed this, we think we need to calculate this metric separately in We also talked about how this standard deviation will be a little biased since it will include parts of the embedding that never are never used, since we inflate the embedding size beyond vocab size to be a multiple of 128. But this is probably okay since that's only a small part of the embeddings. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, I think this is a big problem. Embeddings will want to be small, so this will push them up. Unused, or rarely used embeddings will never get updated, so they will get bigger and bigger, skewing the calculation of the stddev more and more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Figuring out which embeddings to exclude from the stddev computation is going to be tricky in the distributed setting though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thinking out loud here... what if we force the unused params to be zero from the beginning? They would still bias standard deviation by as much as they are different from the mean, but they would always be zero.. I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That would work if we were starting with this from scratch, but what about the case when we want to use this to "rescue" a run? Can we explicitly make the unused embeddings zero when we load the model? And will it matter if we do so halfway through training? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think that's our best bet. I can't think of any issues that would introduce in the middle of training. I suspect those parameters are 0 anyway due to weight decay and zero gradients. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rare tokens would still be an issue, but not any more than they always are. |
||
emb_decay_factor = 1.0 - emb_std | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we're using this to plug into the value for WD, that means it needs to be negative when we want to pull up the values. So then it would be |
||
|
||
# Adjust the learning rate. | ||
for group in self.optim.param_groups: | ||
# TODO (epwalsh): if we want to enable different LRs or gradient clipping settings per group | ||
|
@@ -737,6 +742,9 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> | |
self.cfg.max_grad_norm_ratio, self.scheduler_current, self.scheduler_max | ||
) | ||
|
||
if group["name"] == "embedding_decay_group": | ||
group["weight_decay"] *= emb_decay_factor | ||
Comment on lines
+745
to
+746
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does't this multiply up |
||
|
||
# Optimizer step. | ||
self.optim.step() | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
run_name: "reverse_test" | ||
save_folder: "/tmp/olmo-train-tiny" | ||
wandb: | ||
name: ${run_name} | ||
project: reverse-test | ||
model: | ||
d_model: 128 | ||
n_heads: 4 | ||
n_layers: 4 | ||
mlp_ratio: 4 | ||
alibi: false | ||
alibi_bias_max: 8.0 | ||
attention_dropout: 0.1 | ||
attention_layer_norm: false | ||
residual_dropout: 0.1 | ||
embedding_dropout: 0.1 | ||
max_sequence_length: 512 | ||
vocab_size: 50257 | ||
eos_token_id: 50256 | ||
pad_token_id: 50256 | ||
init_device: null | ||
init_std: 0.02 | ||
optimizer: | ||
learning_rate: 0.001 | ||
reverse_embedding_decay: true | ||
metrics_log_interval: 100 | ||
scheduler: | ||
name: "cosine_with_warmup" | ||
t_warmup: 10 | ||
data: | ||
paths: | ||
- "/net/nfs.cirrascale/allennlp/llm-data/c4/en/c4-train.00000-00099.npy" | ||
persistent_workers: false | ||
num_workers: 0 | ||
prefetch_factor: null | ||
tokenizer: | ||
identifier: "gpt2" | ||
save_overwrite: true | ||
max_duration: 16 | ||
stop_at: ${max_duration} | ||
global_train_batch_size: 8 | ||
device_train_microbatch_size: 8 | ||
precision: "fp32" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This name also needs to be updated.