diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index a10a9dd7b0..91f8473525 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -266,11 +266,13 @@ def flash_attn_fn( batch_size, seqlen = query.shape[:2] - indices_q = flash_attn_padding_info['indices_q'] - indices_k = flash_attn_padding_info['indices_k'] - indices_v = flash_attn_padding_info['indices_v'] - cu_seqlens_q = flash_attn_padding_info['cu_seqlens_q'] - cu_seqlens_k = flash_attn_padding_info['cu_seqlens_k'] + # In the following lines we move the tensors to the same devices as query, key, and value respectively. These operations should be no-ops during training. + # This is done to fix pipeline parallel generation using hf.generate. Please see this comment for details: https://github.com/mosaicml/llm-foundry/pull/1332#issue-2386827204 + indices_q = flash_attn_padding_info['indices_q'].to(query.device) + indices_k = flash_attn_padding_info['indices_k'].to(key.device) + indices_v = flash_attn_padding_info['indices_v'].to(value.device) + cu_seqlens_q = flash_attn_padding_info['cu_seqlens_q'].to(query.device) + cu_seqlens_k = flash_attn_padding_info['cu_seqlens_k'].to(key.device) max_seqlen_q = flash_attn_padding_info['max_seqlen_q'] max_seqlen_k = flash_attn_padding_info['max_seqlen_k'] @@ -667,6 +669,10 @@ def _apply_rotary_embeddings( else: (cos, sin) = rotary_emb(x=value, seq_len=seq_len) if is_transformers_version_gte('4.38'): + # In the following lines we move the cos and sin tensors to the same devices as query. These operations should be no-ops during training. + # This is done to fix pipeline parallel generation using hf.generate. Please see this comment for details: https://github.com/mosaicml/llm-foundry/pull/1332#issue-2386827204 + cos = cos.to(query.device) + sin = sin.to(query.device) query, key = apply_rotary_pos_emb( q=query, k=key, diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index 401acacfc6..c6988b7bd7 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -198,7 +198,9 @@ def forward( m = self.norm_2(x) n = self.apply_ffn(attention_mask, m) - x = x + self.resid_ffn_dropout(n) + # In the following line we move the `x` tensor to the same devices as the output of ffn layer. This operation should be a no-op during training. + # This is done to fix pipeline parallel generation using hf.generate. Please see this comment for details: https://github.com/mosaicml/llm-foundry/pull/1332#issue-2386827204 + x = x.to(device=n.device) + self.resid_ffn_dropout(n) return x, attn_weights, past_key_value def apply_ffn( diff --git a/llmfoundry/models/layers/dmoe.py b/llmfoundry/models/layers/dmoe.py index 6190dbc6ea..59508e0a50 100644 --- a/llmfoundry/models/layers/dmoe.py +++ b/llmfoundry/models/layers/dmoe.py @@ -280,9 +280,13 @@ def forward( expert_tokens = x[None, token_list].reshape(-1, hidden_size) mlp_output = self.mlp(expert_tokens, expert_idx) + # In the following lines we move tensors to the same devices as the output of mlp. These operations should be no-ops during training. + # This is done to fix pipeline parallel generation using hf.generate. Please see this comment for details: https://github.com/mosaicml/llm-foundry/pull/1332#issue-2386827204 + expert_weights = expert_weights.to(mlp_output.device) expert_out = mlp_output * expert_weights[token_list, topk_list, None] - + out = out.to(mlp_output.device) + token_idx = token_idx.to(mlp_output.device) out.index_add_(0, token_idx, expert_out) out = out.view(in_shape)