You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
For a attn_qkv weight, the arg dim is 0. So when the reversed=False and n_fused>1, the tensor is chunked on the dim 0 and then concatenated on the dim 1. Which make its shape incorrect.
The text was updated successfully, but these errors were encountered:
When a attn_qkv
Layer
is set withn_fused>1
andreversed=False
, the shape of its sliced weight is incorrect.Seems that the root cause is here:
parallelformers/parallelformers/parallel/slicing.py
Lines 79 to 95 in 436573b
For a attn_qkv weight, the arg
dim
is 0. So when thereversed=False
andn_fused>1
, the tensor is chunked on the dim 0 and then concatenated on the dim 1. Which make its shape incorrect.The text was updated successfully, but these errors were encountered: