From 717925eb31cbc674b16ea60b890108ca73b333fd Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Mon, 6 May 2024 12:32:16 -0700 Subject: [PATCH] rename to regularize_embeddings --- olmo/config.py | 2 +- olmo/optim.py | 6 +++--- olmo/train.py | 2 +- test_fixtures/reverse_wd.yaml | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/olmo/config.py b/olmo/config.py index bc1819d50..6ddf55142 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -482,7 +482,7 @@ class OptimizerConfig(BaseConfig): If not set, defaults to the wandb `log_interval`. """ - reverse_embedding_decay: bool = False + regularize_embeddings: bool = False """ Applying weight decay to embeddings may make them too small, potentially causing spikes. Setting this parameter to true is a way of applying "reverse weight decay" to embeddings. diff --git a/olmo/optim.py b/olmo/optim.py index 9a25b5cfb..f9e73fa3e 100644 --- a/olmo/optim.py +++ b/olmo/optim.py @@ -43,7 +43,7 @@ def clip_grads_and_collect_metrics( global_step: int, collect_param_metrics: bool = True, process_group: Optional[dist.ProcessGroup] = None, - reverse_embedding_decay: bool = False, + regularize_embeddings: bool = False, ) -> Dict[str, torch.Tensor]: """ Clips gradients for every group that has the field `max_grad_norm`. @@ -91,7 +91,7 @@ def clip_grads_and_collect_metrics( # other metrics. tensors: List[Optional[torch.Tensor]] = [p.grad] prefixes: List[str] = [f"grad/{name}"] - if collect_param_metrics or (reverse_embedding_decay and is_embedding_group): + if collect_param_metrics or (regularize_embeddings and is_embedding_group): state = self.get_state_for_param(p) sorted_state_keys = sorted([k for k in state.keys()]) tensors.extend([p] + [state[key] for key in sorted_state_keys]) @@ -647,7 +647,7 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]] elif pn.endswith("weight") and isinstance(m, nn.Embedding): if cfg.optimizer.decay_embeddings: decay.add(fpn) - elif cfg.optimizer.reverse_embedding_decay: + elif cfg.optimizer.regularize_embeddings: embeddings_decay.add(fpn) else: no_decay.add(fpn) diff --git a/olmo/train.py b/olmo/train.py index cd4e625fb..ec1570667 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -719,7 +719,7 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> # passing this process group here ensures metrics are reduced correctly when we're using # HYBRID sharding. process_group=self.fsdp_model.process_group, - reverse_embedding_decay=self.cfg.optimizer.reverse_embedding_decay, + regularize_embeddings=self.cfg.optimizer.regularize_embeddings, ) emb_norm = optim_metrics["param/transformer.wte.weight.norm"] diff --git a/test_fixtures/reverse_wd.yaml b/test_fixtures/reverse_wd.yaml index 5d208542c..c75da3274 100644 --- a/test_fixtures/reverse_wd.yaml +++ b/test_fixtures/reverse_wd.yaml @@ -22,7 +22,7 @@ model: init_std: 0.02 optimizer: learning_rate: 0.001 - reverse_embedding_decay: true + regularize_embeddings: true metrics_log_interval: 100 scheduler: name: "cosine_with_warmup"