Skip to content

Commit

Permalink
add gqa support and benchmark results in readme (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear authored Apr 11, 2024
1 parent e3b8cfd commit 80c8b46
Show file tree
Hide file tree
Showing 7 changed files with 275 additions and 35 deletions.
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,16 @@ In the figure presented below, we contrast the performance of LongContextAttenti

The best throughput is achieved when `ulysses_degree`=2 and ring_attn_impl as `zigzag`. We observed 18% and 31% throughput improvement for FWD+BWD and FWD-only.

![head=8](./media/long_ctx_h2.png)
![head=2](./media/long_ctx_h2.png)


- GQA, head_num=64, group_num=8, seqlen=4K. Reproduce by running [./scripts/run_gqa.sh](./scripts/run_gqa.sh)


The best throughput is achieved when `ulysses_degree`=8 and ring_attn_impl as `zigzag`. We observed 15% and 11% throughput improvement for FWD+BWD and FWD-only.


![gqa](./media/gqa.png)

## Ulysses Attention
This repository re-implements the all-to-all communication functions and supports QKV packed together, following the principles of [DeepSpeed-Ulysses](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-ulysses/README.md).
Expand Down
86 changes: 60 additions & 26 deletions benchmark/benchmark_longctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
help="ring attn implementation type",
)
parser.add_argument("--nheads", type=int, default=2, help="head number")
parser.add_argument("--head_size", type=int, default=128, help="head size")
parser.add_argument("--seq_len", type=int, default=4 * 1024, help="sequence length")
parser.add_argument("--group_num", type=int, default=1, help="group number")
parser.add_argument("--batch_size", type=int, default=2, help="batch size")
parser.add_argument(
"--fwd_only", action="store_true", help="benchmark forward pass only"
Expand All @@ -34,7 +37,6 @@
args = parser.parse_args()



def color_print(text):
print("\033[91m {}\033[00m".format(text))

Expand All @@ -46,24 +48,44 @@ def benchmark(num_iter=100, forward_only=True, log=True):
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)

batch_size = 1
seqlen = 1024 * 8
batch_size = args.batch_size
seqlen = args.seq_len
nheads = args.nheads
d = 128
group_num = args.group_num
d = args.head_size

dropout_p = 0
causal = True
deterministic = False

assert seqlen % (2 * world_size) == 0
assert seqlen % (2 * world_size) == 0, f"seqlen {seqlen} world_size {world_size}"
assert d % 8 == 0
assert nheads % group_num == 0, f"nheads {nheads} group_num {group_num}"
assert (
nheads // group_num % args.ulysses_degree == 0
), f"nheads {nheads}, group_num {group_num}, ulysses_degree {args.ulysses_degree}"

q, k, v = torch.randn(
3, batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True
).chunk(3, dim=0)
q = q.squeeze(0)
k = k.squeeze(0)
v = v.squeeze(0)

q = torch.randn(
batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True
)
k = torch.randn(
batch_size,
seqlen,
nheads // group_num,
d,
device=device,
dtype=dtype,
requires_grad=True,
)
v = torch.randn(
batch_size,
seqlen,
nheads // group_num,
d,
device=device,
dtype=dtype,
requires_grad=True,
)
dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype)

sp_ulysses_degree = min(args.ulysses_degree, world_size)
Expand All @@ -72,21 +94,34 @@ def benchmark(num_iter=100, forward_only=True, log=True):
set_seq_parallel_pg(
sp_ulysses_degree, sp_ring_degree, rank, world_size, args.use_ulysses_lowdim
)
longctx_attn = LongContextAttention()
longctx_attn = LongContextAttention(ring_impl_type=args.ring_impl_type)

out = longctx_attn(
q,
k,
v,
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
)
q,
k,
v,
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
)
out.backward(dout)

out = longctx_attn(
q,
k,
v,
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
)
out.backward(dout)

begin = torch.cuda.Event(enable_timing=True)
begin.record()

Expand Down Expand Up @@ -122,7 +157,6 @@ def benchmark(num_iter=100, forward_only=True, log=True):
return_attn_probs=False,
)
out.backward(dout)

end = torch.cuda.Event(enable_timing=True)
end.record()

Expand All @@ -142,7 +176,7 @@ def benchmark(num_iter=100, forward_only=True, log=True):
torch.cuda.empty_cache()
if rank == 0:
color_print(
f"# long context attention. ulysses_degree : {args.ulysses_degree} fwd_only {forward_only} use_ulysses_lowdim {args.use_ulysses_lowdim}"
f"# long context attention {args.ring_impl_type}. ulysses_degree : {args.ulysses_degree} fwd_only {forward_only} use_ulysses_lowdim {args.use_ulysses_lowdim}"
)
torch.cuda.empty_cache()
benchmark(forward_only=forward_only, log=False)
Expand Down
11 changes: 7 additions & 4 deletions benchmark/benchmark_longctx_qkvpacked.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
help="ring attn implementation type",
)
parser.add_argument("--nheads", type=int, default=2, help="head number")
parser.add_argument("--head_size", type=int, default=128, help="head number")
parser.add_argument("--seq_len", type=int, default=4 * 1024, help="head number")
parser.add_argument("--batch_size", type=int, default=2, help="batch size")
parser.add_argument(
"--fwd_only", action="store_true", help="benchmark forward pass only"
Expand Down Expand Up @@ -47,14 +49,15 @@ def benchmark(num_iter=100, forward_only=True, log=True):
torch.cuda.set_device(device)

batch_size = args.batch_size
seqlen = 1024 * 8
seqlen = args.seq_len
nheads = args.nheads
d = 128
d = args.head_size

dropout_p = 0
causal = True
deterministic = False

assert seqlen % (2 * world_size) == 0
assert seqlen % (2 * world_size) == 0, f"seqlen {seqlen} world_size {world_size}"
assert d % 8 == 0

qkv = torch.randn(
Expand All @@ -69,7 +72,7 @@ def benchmark(num_iter=100, forward_only=True, log=True):
sp_ulysses_degree, sp_ring_degree, rank, world_size, args.use_ulysses_lowdim
)

longctx_attn = LongContextAttentionQKVPacked()
longctx_attn = LongContextAttentionQKVPacked(ring_impl_type=args.ring_impl_type)

longctx_attn(
qkv,
Expand Down
11 changes: 7 additions & 4 deletions benchmark/benchmark_qkvpacked_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
parser = argparse.ArgumentParser(description="Process some integers.")

parser.add_argument("--nheads", type=int, default=2, help="head number")
parser.add_argument("--head_size", type=int, default=128, help="head number")
parser.add_argument("--seq_len", type=int, default=4 * 1024, help="head number")
parser.add_argument("--batch_size", type=int, default=2, help="batch size")
parser.add_argument(
"--fwd_only", action="store_true", help="benchmark forward pass only"
Expand All @@ -33,14 +35,15 @@ def benchmark(f, num_iter=100, forward_only=True, log=True):
torch.cuda.set_device(device)

batch_size = args.batch_size
seqlen = 1024 * 8
seqlen = args.seq_len
nheads = args.nheads
d = 128
d = args.head_size

dropout_p = 0
causal = True
deterministic = False

assert seqlen % (2 * world_size) == 0
assert seqlen % (2 * world_size) == 0, f"seqlen {seqlen} world_size {world_size}"
assert d % 8 == 0

qkv = torch.randn(
Expand Down Expand Up @@ -113,7 +116,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True):
forward_only = args.fwd_only

for f in [
# flash_attn_qkvpacked_func,
flash_attn_qkvpacked_func,
ring_flash_attn_qkvpacked_func,
zigzag_ring_flash_attn_qkvpacked_func,
stripe_flash_attn_qkvpacked_func,
Expand Down
159 changes: 159 additions & 0 deletions benchmark/benchmark_ring_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from flash_attn import flash_attn_func
import torch
import torch.distributed as dist
from yunchang import (
ring_flash_attn_func,
zigzag_ring_flash_attn_func,
stripe_flash_attn_func,
)
import torch.cuda

import argparse

parser = argparse.ArgumentParser(description="Process some integers.")

parser.add_argument("--nheads", type=int, default=2, help="head number")
parser.add_argument("--head_size", type=int, default=128, help="head number")
parser.add_argument("--seq_len", type=int, default=4 * 1024, help="head number")
parser.add_argument("--group_num", type=int, default=1, help="group number")
parser.add_argument("--batch_size", type=int, default=2, help="batch size")
parser.add_argument(
"--fwd_only", action="store_true", help="benchmark forward pass only"
)

args = parser.parse_args()


def color_print(text):
print("\033[91m {}\033[00m".format(text))


def benchmark(f, num_iter=100, forward_only=True, log=True):
dtype = torch.bfloat16
rank = dist.get_rank()
world_size = dist.get_world_size()
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)

batch_size = args.batch_size
seqlen = args.seq_len
nheads = args.nheads
d = args.head_size
group_num = args.group_num

dropout_p = 0
causal = True
deterministic = False

assert seqlen % (2 * world_size) == 0, f"seqlen {seqlen} world_size {world_size}"
assert d % 8 == 0
assert nheads % group_num == 0, f"nheads {nheads} group_num {group_num}"

q = torch.randn(
batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True
)
k = torch.randn(
batch_size,
seqlen,
nheads // group_num,
d,
device=device,
dtype=dtype,
requires_grad=True,
)
v = torch.randn(
batch_size,
seqlen,
nheads // group_num,
d,
device=device,
dtype=dtype,
requires_grad=True,
)
dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype)

_ = f(
q,
k,
v,
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
)
out = f(
q,
k,
v,
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
)
out.backward(dout)

begin = torch.cuda.Event(enable_timing=True)
begin.record()

if forward_only:
with torch.no_grad():
for _ in range(num_iter):
_ = f(
q,
k,
v,
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
)

else:
for _ in range(num_iter):
q.grad = None
k.grad = None
v.grad = None
out = f(
q,
k,
v,
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
)
out.backward(dout)
end = torch.cuda.Event(enable_timing=True)
end.record()
torch.cuda.synchronize(device=device)
time = begin.elapsed_time(end) / 1000.0

if rank == 0 and log:
color_print(f"{num_iter / time:.3f} iter/s, {time:.3f} sec")


if __name__ == "__main__":
dist.init_process_group("nccl")
rank = dist.get_rank()

forward_only = args.fwd_only

for f in [
flash_attn_func,
ring_flash_attn_func,
zigzag_ring_flash_attn_func,
stripe_flash_attn_func,
]:
torch.cuda.empty_cache()
if rank == 0:
color_print(f"# {f.__name__} fwd_only {forward_only}")
benchmark(f, forward_only=forward_only, log=False)
benchmark(f, forward_only=forward_only, log=True)
Binary file added media/gqa.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 80c8b46

Please sign in to comment.