Skip to content

Commit

Permalink
rename to regularize_embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
AkshitaB committed May 6, 2024
1 parent d2f6ea2 commit 717925e
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ class OptimizerConfig(BaseConfig):
If not set, defaults to the wandb `log_interval`.
"""

reverse_embedding_decay: bool = False
regularize_embeddings: bool = False
"""
Applying weight decay to embeddings may make them too small, potentially causing spikes.
Setting this parameter to true is a way of applying "reverse weight decay" to embeddings.
Expand Down
6 changes: 3 additions & 3 deletions olmo/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def clip_grads_and_collect_metrics(
global_step: int,
collect_param_metrics: bool = True,
process_group: Optional[dist.ProcessGroup] = None,
reverse_embedding_decay: bool = False,
regularize_embeddings: bool = False,
) -> Dict[str, torch.Tensor]:
"""
Clips gradients for every group that has the field `max_grad_norm`.
Expand Down Expand Up @@ -91,7 +91,7 @@ def clip_grads_and_collect_metrics(
# other metrics.
tensors: List[Optional[torch.Tensor]] = [p.grad]
prefixes: List[str] = [f"grad/{name}"]
if collect_param_metrics or (reverse_embedding_decay and is_embedding_group):
if collect_param_metrics or (regularize_embeddings and is_embedding_group):
state = self.get_state_for_param(p)
sorted_state_keys = sorted([k for k in state.keys()])
tensors.extend([p] + [state[key] for key in sorted_state_keys])
Expand Down Expand Up @@ -647,7 +647,7 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]]
elif pn.endswith("weight") and isinstance(m, nn.Embedding):
if cfg.optimizer.decay_embeddings:
decay.add(fpn)
elif cfg.optimizer.reverse_embedding_decay:
elif cfg.optimizer.regularize_embeddings:
embeddings_decay.add(fpn)
else:
no_decay.add(fpn)
Expand Down
2 changes: 1 addition & 1 deletion olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) ->
# passing this process group here ensures metrics are reduced correctly when we're using
# HYBRID sharding.
process_group=self.fsdp_model.process_group,
reverse_embedding_decay=self.cfg.optimizer.reverse_embedding_decay,
regularize_embeddings=self.cfg.optimizer.regularize_embeddings,
)

emb_norm = optim_metrics["param/transformer.wte.weight.norm"]
Expand Down
2 changes: 1 addition & 1 deletion test_fixtures/reverse_wd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ model:
init_std: 0.02
optimizer:
learning_rate: 0.001
reverse_embedding_decay: true
regularize_embeddings: true
metrics_log_interval: 100
scheduler:
name: "cosine_with_warmup"
Expand Down

0 comments on commit 717925e

Please sign in to comment.