diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index e9f4756c21..5ad6b76ed5 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -666,7 +666,7 @@ def forward( ) out = outputs.last_hidden_state.to(self.transformer.wte.weight.device) - if self.unembed is not None: + if self.transformer.unembed is not None: logits = self.transformer.unembed(out) else: # move outputs to same device as weights for token embedding