You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@jit_fuser
def flash_attn_fwd_out_correction(out, out_per_step, seq_dim, softmax_lse, softmax_lse_per_step):
"""Merge partial outputs of each step in Attention with context parallelism"""
softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim)
softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
out_corrected = out_per_step * softmax_lse_corrected_exp
out.add_(out_corrected)
The final tracing info is:
torch._dynamo.exc.TorchRuntimeError: Failed running call_method movedim(*(FakeTensor(..., device='cuda:4', size=(36, 13056)), 2, 1), **{}):
The text was updated successfully, but these errors were encountered:
Version
Name: flash-attn
Version: 2.6.3
Name: transformer-engine
Version: 1.11.0+c27ee60
Name: flashattn-hopper
Version: 3.0.0b1
Bug report
The bug occurs in this function:
@jit_fuser
def flash_attn_fwd_out_correction(out, out_per_step, seq_dim, softmax_lse, softmax_lse_per_step):
"""Merge partial outputs of each step in Attention with context parallelism"""
softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim)
softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1)
out_corrected = out_per_step * softmax_lse_corrected_exp
out.add_(out_corrected)
The final tracing info is:
torch._dynamo.exc.TorchRuntimeError: Failed running call_method movedim(*(FakeTensor(..., device='cuda:4', size=(36, 13056)), 2, 1), **{}):
The text was updated successfully, but these errors were encountered: