Skip to content

Commit

Permalink
overlap kv comm with output rescale (#6017)
Browse files Browse the repository at this point in the history
Co-authored-by: Edenzzzz <[email protected]>
  • Loading branch information
Edenzzzz and Edenzzzz authored Aug 19, 2024
1 parent 26493b9 commit f1c3266
Showing 1 changed file with 19 additions and 11 deletions.
30 changes: 19 additions & 11 deletions colossalai/shardformer/layer/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,13 @@ def _forward(q, k, v, causal):
)
return out, softmax_lse, rng_state

def _kv_comm(i):
# Avoid overwriting attn input when it shares mem with buffer
if not RingAttention.ATTN_DONE.query():
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
if i < local_sp_size - 1:
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])

def _local_ring_forward():
# (Hopefully) overlap output correction with next flash attn
for i in range(local_sp_size):
Expand All @@ -698,12 +705,8 @@ def _local_ring_forward():
# NOTE: waiting outside the current stream will NOT correctly synchronize.
if i > 0:
local_kv_comms[(i + 1) % 2].wait()

# Avoid overwriting attn input when it shares mem with buffer
if not RingAttention.ATTN_DONE.query():
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
if i < local_sp_size - 1:
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])
if i == 0:
_kv_comm(i)

if i == 0:
# Compute with local KV; no mask
Expand Down Expand Up @@ -734,6 +737,9 @@ def _local_ring_forward():
rng_states[i],
) = _forward(q_block, kv_block[0], kv_block[1], causal=False)
RingAttention.ATTN_DONE.record()
# Pipeline the next KV comm with output correction instead of the next flash attn
# to minimize idle time when comm takes longer than attn.
_kv_comm(i + 1)

block_softmax_lse[i % 2] = (
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
Expand Down Expand Up @@ -761,15 +767,13 @@ def _other_ring_forward(ring_num_idx, out, softmax_lse):
# all new KVs from the previous inner ring
for i in range(local_sp_size):
with torch.cuda.stream(sp_streams[i % 2]):
if not RingAttention.ATTN_DONE.query():
kv_buffers[(i + 1) % 2] = torch.empty_like(kv_buffers[i % 2])
if i < local_sp_size - 1:
local_kv_comms[i % 2].send_recv(kv_buffers[i % 2], kv_buffers[(i + 1) % 2])

# Send & recv KV
if i > 0:
local_kv_comms[(i + 1) % 2].wait()

if i == 0:
_kv_comm(i)

if ring_num_idx > inter_ring_rank:
kv_block = kv_buffers[i % 2]
(
Expand All @@ -778,6 +782,8 @@ def _other_ring_forward(ring_num_idx, out, softmax_lse):
rng_states[i + local_sp_size * ring_num_idx],
) = _forward(q1, kv_block[0], kv_block[1], causal=False)
RingAttention.ATTN_DONE.record()

_kv_comm(i + 1)
block_softmax_lse[i % 2] = (
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
)
Expand All @@ -792,6 +798,8 @@ def _other_ring_forward(ring_num_idx, out, softmax_lse):
rng_states[i + local_sp_size * ring_num_idx],
) = _forward(q, kv_block[0], kv_block[1], causal=False)
RingAttention.ATTN_DONE.record()

_kv_comm(i + 1)
block_softmax_lse[i % 2] = (
block_softmax_lse[i % 2].transpose(0, 1).unsqueeze(-1).contiguous().float()
)
Expand Down

0 comments on commit f1c3266

Please sign in to comment.