diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 86e49c315d..8caed031b0 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -627,14 +627,12 @@ def forward( value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim) elif rotary_emb_w_meta_info['impl'] == 'hf': (cos, sin) = rotary_emb(value, seq_len) - # The following two transposes should be removed once the transformers library allows for the specification of the dimension for heads in the call to apply_rotary_pos_emb - query = query.transpose(1, 2) - key = key.transpose(1, 2) - query, key = apply_rotary_pos_emb(query, key, cos, sin, - offset_info) - # The following two transposes should be removed once the transformers library allows for the specification of the dimension for heads in the call to apply_rotary_pos_emb - query = query.transpose(1, 2) - key = key.transpose(1, 2) + query, key = apply_rotary_pos_emb(query, + key, + cos, + sin, + offset_info, + unsqueeze_dim=2) query = query.view(bsz, seqlen, self.d_model) key = key.view(bsz, seqlen, self.kv_n_heads * self.head_dim)