diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 614f56f0f5..2f0dcb890d 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -684,12 +684,13 @@ def forward( use_cache=use_cache, ) - out = outputs.last_hidden_state.to(self.transformer.wte.weight.device) if self.transformer.lm_head is not None: - logits = self.transformer.lm_head(out) + logits = self.transformer.lm_head(outputs.last_hidden_state) else: # move outputs to same device as weights for token embedding # needed to support HF `device_map` + out = outputs.last_hidden_state + out = out.to(self.transformer.wte.weight.device) logits = self.transformer.wte(out, True) if self.logit_scale is not None: diff --git a/mcli/mcli-1b-max-seq-len-8k.yaml b/mcli/mcli-1b-max-seq-len-8k.yaml index 24af39234c..c804eb10e1 100644 --- a/mcli/mcli-1b-max-seq-len-8k.yaml +++ b/mcli/mcli-1b-max-seq-len-8k.yaml @@ -1,10 +1,13 @@ integrations: - integration_type: git_repo - git_repo: mosaicml/llm-foundry - git_branch: v0.3.0 + git_repo: vchiley/llm-foundry + git_branch: notie_embd # git_commit: # OR use your commit hash pip_install: -e .[gpu] ssh_clone: false # Should be true if using a private repo +- integration_type: wandb + entity: mosaic-ml + project: notie_embd_test # We are fetching, converting, and training on the 'val' split # as it is small and quick to get going for this demo. @@ -18,10 +21,12 @@ command: | --concat_tokens 8192 --tokenizer EleutherAI/gpt-neox-20b --eos_text '<|endoftext|>' composer train/train.py /mnt/config/parameters.yaml image: mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04 -name: mpt-1b-ctx-8k-gpus-8 + +name: mpt-1b-ctx-8k-gpus-8-notieembd compute: gpus: 8 # Number of GPUs to use + cluster: r1z1 ## These configurations are optional # cluster: TODO # Name of the cluster to use for this run @@ -48,6 +53,7 @@ parameters: expansion_ratio: 4 max_seq_len: ${max_seq_len} vocab_size: 50368 + tie_word_embeddings: false attn_config: attn_impl: triton @@ -102,7 +108,7 @@ parameters: clipping_type: norm clipping_threshold: 1.0 - max_duration: 24800ba # ~ 26B tokens + max_duration: 500ba # ~ 26B tokens eval_interval: 2000ba eval_first: false eval_subset_num_batches: -1 @@ -111,7 +117,7 @@ parameters: # System seed: 17 device_eval_batch_size: 1 - device_train_microbatch_size: 1 + device_train_microbatch_size: 4 # device_train_microbatch_size: auto precision: amp_bf16 @@ -136,8 +142,8 @@ parameters: lr_monitor: {} memory_monitor: {} runtime_estimator: {} -# loggers: -# wandb: {} + loggers: + wandb: {} # Checkpoint to local filesystem or remote object store # save_interval: 2000ba