diff --git a/README.md b/README.md index a9049d1..73a53c8 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,16 @@ In the figure presented below, we contrast the performance of LongContextAttenti The best throughput is achieved when `ulysses_degree`=2 and ring_attn_impl as `zigzag`. We observed 18% and 31% throughput improvement for FWD+BWD and FWD-only. -![head=8](./media/long_ctx_h2.png) +![head=2](./media/long_ctx_h2.png) + + +- GQA, head_num=64, group_num=8, seqlen=4K. Reproduce by running [./scripts/run_gqa.sh](./scripts/run_gqa.sh) + + +The best throughput is achieved when `ulysses_degree`=8 and ring_attn_impl as `zigzag`. We observed 15% and 11% throughput improvement for FWD+BWD and FWD-only. + + +![gqa](./media/gqa.png) ## Ulysses Attention This repository re-implements the all-to-all communication functions and supports QKV packed together, following the principles of [DeepSpeed-Ulysses](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-ulysses/README.md). diff --git a/benchmark/benchmark_longctx.py b/benchmark/benchmark_longctx.py index 9e089aa..3f395a2 100644 --- a/benchmark/benchmark_longctx.py +++ b/benchmark/benchmark_longctx.py @@ -14,6 +14,9 @@ help="ring attn implementation type", ) parser.add_argument("--nheads", type=int, default=2, help="head number") +parser.add_argument("--head_size", type=int, default=128, help="head size") +parser.add_argument("--seq_len", type=int, default=4 * 1024, help="sequence length") +parser.add_argument("--group_num", type=int, default=1, help="group number") parser.add_argument("--batch_size", type=int, default=2, help="batch size") parser.add_argument( "--fwd_only", action="store_true", help="benchmark forward pass only" @@ -34,7 +37,6 @@ args = parser.parse_args() - def color_print(text): print("\033[91m {}\033[00m".format(text)) @@ -46,24 +48,44 @@ def benchmark(num_iter=100, forward_only=True, log=True): device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) - batch_size = 1 - seqlen = 1024 * 8 + batch_size = args.batch_size + seqlen = args.seq_len nheads = args.nheads - d = 128 + group_num = args.group_num + d = args.head_size + dropout_p = 0 causal = True deterministic = False - assert seqlen % (2 * world_size) == 0 + assert seqlen % (2 * world_size) == 0, f"seqlen {seqlen} world_size {world_size}" assert d % 8 == 0 + assert nheads % group_num == 0, f"nheads {nheads} group_num {group_num}" + assert ( + nheads // group_num % args.ulysses_degree == 0 + ), f"nheads {nheads}, group_num {group_num}, ulysses_degree {args.ulysses_degree}" - q, k, v = torch.randn( - 3, batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True - ).chunk(3, dim=0) - q = q.squeeze(0) - k = k.squeeze(0) - v = v.squeeze(0) - + q = torch.randn( + batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True + ) + k = torch.randn( + batch_size, + seqlen, + nheads // group_num, + d, + device=device, + dtype=dtype, + requires_grad=True, + ) + v = torch.randn( + batch_size, + seqlen, + nheads // group_num, + d, + device=device, + dtype=dtype, + requires_grad=True, + ) dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) sp_ulysses_degree = min(args.ulysses_degree, world_size) @@ -72,21 +94,34 @@ def benchmark(num_iter=100, forward_only=True, log=True): set_seq_parallel_pg( sp_ulysses_degree, sp_ring_degree, rank, world_size, args.use_ulysses_lowdim ) - longctx_attn = LongContextAttention() + longctx_attn = LongContextAttention(ring_impl_type=args.ring_impl_type) out = longctx_attn( - q, - k, - v, - dropout_p=dropout_p, - causal=causal, - window_size=(-1, -1), - alibi_slopes=None, - deterministic=deterministic, - return_attn_probs=False, - ) + q, + k, + v, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=False, + ) + out.backward(dout) + + out = longctx_attn( + q, + k, + v, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=False, + ) out.backward(dout) - + begin = torch.cuda.Event(enable_timing=True) begin.record() @@ -122,7 +157,6 @@ def benchmark(num_iter=100, forward_only=True, log=True): return_attn_probs=False, ) out.backward(dout) - end = torch.cuda.Event(enable_timing=True) end.record() @@ -142,7 +176,7 @@ def benchmark(num_iter=100, forward_only=True, log=True): torch.cuda.empty_cache() if rank == 0: color_print( - f"# long context attention. ulysses_degree : {args.ulysses_degree} fwd_only {forward_only} use_ulysses_lowdim {args.use_ulysses_lowdim}" + f"# long context attention {args.ring_impl_type}. ulysses_degree : {args.ulysses_degree} fwd_only {forward_only} use_ulysses_lowdim {args.use_ulysses_lowdim}" ) torch.cuda.empty_cache() benchmark(forward_only=forward_only, log=False) diff --git a/benchmark/benchmark_longctx_qkvpacked.py b/benchmark/benchmark_longctx_qkvpacked.py index a320b5c..ff4c324 100644 --- a/benchmark/benchmark_longctx_qkvpacked.py +++ b/benchmark/benchmark_longctx_qkvpacked.py @@ -16,6 +16,8 @@ help="ring attn implementation type", ) parser.add_argument("--nheads", type=int, default=2, help="head number") +parser.add_argument("--head_size", type=int, default=128, help="head number") +parser.add_argument("--seq_len", type=int, default=4 * 1024, help="head number") parser.add_argument("--batch_size", type=int, default=2, help="batch size") parser.add_argument( "--fwd_only", action="store_true", help="benchmark forward pass only" @@ -47,14 +49,15 @@ def benchmark(num_iter=100, forward_only=True, log=True): torch.cuda.set_device(device) batch_size = args.batch_size - seqlen = 1024 * 8 + seqlen = args.seq_len nheads = args.nheads - d = 128 + d = args.head_size + dropout_p = 0 causal = True deterministic = False - assert seqlen % (2 * world_size) == 0 + assert seqlen % (2 * world_size) == 0, f"seqlen {seqlen} world_size {world_size}" assert d % 8 == 0 qkv = torch.randn( @@ -69,7 +72,7 @@ def benchmark(num_iter=100, forward_only=True, log=True): sp_ulysses_degree, sp_ring_degree, rank, world_size, args.use_ulysses_lowdim ) - longctx_attn = LongContextAttentionQKVPacked() + longctx_attn = LongContextAttentionQKVPacked(ring_impl_type=args.ring_impl_type) longctx_attn( qkv, diff --git a/benchmark/benchmark_qkvpacked_func.py b/benchmark/benchmark_qkvpacked_func.py index b2d7ecb..429a1c5 100644 --- a/benchmark/benchmark_qkvpacked_func.py +++ b/benchmark/benchmark_qkvpacked_func.py @@ -13,6 +13,8 @@ parser = argparse.ArgumentParser(description="Process some integers.") parser.add_argument("--nheads", type=int, default=2, help="head number") +parser.add_argument("--head_size", type=int, default=128, help="head number") +parser.add_argument("--seq_len", type=int, default=4 * 1024, help="head number") parser.add_argument("--batch_size", type=int, default=2, help="batch size") parser.add_argument( "--fwd_only", action="store_true", help="benchmark forward pass only" @@ -33,14 +35,15 @@ def benchmark(f, num_iter=100, forward_only=True, log=True): torch.cuda.set_device(device) batch_size = args.batch_size - seqlen = 1024 * 8 + seqlen = args.seq_len nheads = args.nheads - d = 128 + d = args.head_size + dropout_p = 0 causal = True deterministic = False - assert seqlen % (2 * world_size) == 0 + assert seqlen % (2 * world_size) == 0, f"seqlen {seqlen} world_size {world_size}" assert d % 8 == 0 qkv = torch.randn( @@ -113,7 +116,7 @@ def benchmark(f, num_iter=100, forward_only=True, log=True): forward_only = args.fwd_only for f in [ - # flash_attn_qkvpacked_func, + flash_attn_qkvpacked_func, ring_flash_attn_qkvpacked_func, zigzag_ring_flash_attn_qkvpacked_func, stripe_flash_attn_qkvpacked_func, diff --git a/benchmark/benchmark_ring_func.py b/benchmark/benchmark_ring_func.py new file mode 100644 index 0000000..cb0f88f --- /dev/null +++ b/benchmark/benchmark_ring_func.py @@ -0,0 +1,159 @@ +from flash_attn import flash_attn_func +import torch +import torch.distributed as dist +from yunchang import ( + ring_flash_attn_func, + zigzag_ring_flash_attn_func, + stripe_flash_attn_func, +) +import torch.cuda + +import argparse + +parser = argparse.ArgumentParser(description="Process some integers.") + +parser.add_argument("--nheads", type=int, default=2, help="head number") +parser.add_argument("--head_size", type=int, default=128, help="head number") +parser.add_argument("--seq_len", type=int, default=4 * 1024, help="head number") +parser.add_argument("--group_num", type=int, default=1, help="group number") +parser.add_argument("--batch_size", type=int, default=2, help="batch size") +parser.add_argument( + "--fwd_only", action="store_true", help="benchmark forward pass only" +) + +args = parser.parse_args() + + +def color_print(text): + print("\033[91m {}\033[00m".format(text)) + + +def benchmark(f, num_iter=100, forward_only=True, log=True): + dtype = torch.bfloat16 + rank = dist.get_rank() + world_size = dist.get_world_size() + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + + batch_size = args.batch_size + seqlen = args.seq_len + nheads = args.nheads + d = args.head_size + group_num = args.group_num + + dropout_p = 0 + causal = True + deterministic = False + + assert seqlen % (2 * world_size) == 0, f"seqlen {seqlen} world_size {world_size}" + assert d % 8 == 0 + assert nheads % group_num == 0, f"nheads {nheads} group_num {group_num}" + + q = torch.randn( + batch_size, seqlen, nheads, d, device=device, dtype=dtype, requires_grad=True + ) + k = torch.randn( + batch_size, + seqlen, + nheads // group_num, + d, + device=device, + dtype=dtype, + requires_grad=True, + ) + v = torch.randn( + batch_size, + seqlen, + nheads // group_num, + d, + device=device, + dtype=dtype, + requires_grad=True, + ) + dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) + + _ = f( + q, + k, + v, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=False, + ) + out = f( + q, + k, + v, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=False, + ) + out.backward(dout) + + begin = torch.cuda.Event(enable_timing=True) + begin.record() + + if forward_only: + with torch.no_grad(): + for _ in range(num_iter): + _ = f( + q, + k, + v, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=False, + ) + + else: + for _ in range(num_iter): + q.grad = None + k.grad = None + v.grad = None + out = f( + q, + k, + v, + dropout_p=dropout_p, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=deterministic, + return_attn_probs=False, + ) + out.backward(dout) + end = torch.cuda.Event(enable_timing=True) + end.record() + torch.cuda.synchronize(device=device) + time = begin.elapsed_time(end) / 1000.0 + + if rank == 0 and log: + color_print(f"{num_iter / time:.3f} iter/s, {time:.3f} sec") + + +if __name__ == "__main__": + dist.init_process_group("nccl") + rank = dist.get_rank() + + forward_only = args.fwd_only + + for f in [ + flash_attn_func, + ring_flash_attn_func, + zigzag_ring_flash_attn_func, + stripe_flash_attn_func, + ]: + torch.cuda.empty_cache() + if rank == 0: + color_print(f"# {f.__name__} fwd_only {forward_only}") + benchmark(f, forward_only=forward_only, log=False) + benchmark(f, forward_only=forward_only, log=True) diff --git a/media/gqa.png b/media/gqa.png new file mode 100644 index 0000000..e43985b Binary files /dev/null and b/media/gqa.png differ diff --git a/scripts/run_gqa.sh b/scripts/run_gqa.sh new file mode 100644 index 0000000..b103f81 --- /dev/null +++ b/scripts/run_gqa.sh @@ -0,0 +1,32 @@ +export PYTHONPATH=$PWD:$PYTHONPATH +# export NCCL_PXN_DISABLE=1 +# export NCCL_DEBUG=INFO +# export NCCL_SOCKET_IFNAME=eth0 +# export NCCL_IB_GID_INDEX=3 +# export NCCL_IB_DISABLE=0 +# export NCCL_NET_GDR_LEVEL=2 +# export NCCL_IB_QPS_PER_CONNECTION=4 +# export NCCL_IB_TC=160 +# export NCCL_IB_TIMEOUT=22 + +# comment this line fwd+bwd +# FWD_FLAG="--fwd_only" + +NHEADS=64 +SEQLEN=4096 #131072 +GROUP_NUM=8 +GPU_NUM=8 +ULYSSES_DEGREE=1 + + +# RING_IMPL_TYPE="zigzag" + +# make sure NHEADS // GROUP_NUM % ULYSSES_DEGREE == 0 +for ULYSSES_DEGREE in 8 4 2 1; do +for RING_IMPL_TYPE in "zigzag"; do +torchrun --nproc_per_node $GPU_NUM benchmark/benchmark_longctx.py --nheads $NHEADS --group_num $GROUP_NUM --batch_size 2 $FWD_FLAG --seq_len $SEQLEN --ulysses_degree $ULYSSES_DEGREE --ring_impl_type $RING_IMPL_TYPE +done +done + +torchrun --nproc_per_node $GPU_NUM benchmark/benchmark_ring_func.py --nheads $NHEADS --group_num $GROUP_NUM --batch_size 2 $FWD_FLAG --seq_len $SEQLEN +