Skip to content

Commit

Permalink
use extract_local for test_hybrid_attn.py (#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear authored Sep 14, 2024
1 parent 567ca32 commit 5ac27d8
Show file tree
Hide file tree
Showing 11 changed files with 179 additions and 136 deletions.
83 changes: 15 additions & 68 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The project is built on [zhuzilin/ring-flash-attention](https://github.com/zhuzi



## What's wrong with Ulysses and Ring?
## Why not apply Ulysses and Ring Attention Individually?

- Ulysses is sensitive to the number of attention heads.
The parallelism degree in Ulysses cannot exceed the number of heads.
Expand All @@ -25,89 +25,37 @@ Even with the communication and computation processes fully overlapped, the tota
Furthermore, Ring-Attention utilizes asynchronous peer-to-peer communication, which not only has a lower bandwidth utilization compared to collective communication methods but also poses the risk of potential communication deadlocks in large-scale deployments.


## LongContextAttention, a.k.a Unified Sequence Parallelism and Hybrid Sequence Parallelism
## LongContextAttention, also known as Unified Sequence Parallelism and Hybrid Sequence Parallelism

`LongContextAttention` is a **unified sequence parallel** , also known as **hybrid sequence parallel** ,that hybrid DeepSpeed-Ulysses-Attention and Ring-Attention therefore addressing the limitations of both methods.

<p align="center">
<img src="./media/hybrid_seqparallel.png">
<img src="./media/usp.png">
</p>

### Usage

Please refer to [test/test_hybrid_qkvpacked_attn.py](./test/test_hybrid_qkvpacked_attn.py) and [test/test_hybrid_attn.py](./test/test_hybrid_attn.py) for usage.

In short, we take the `zigzag` ring attention implementation as an example:

1. apply `set_seq_parallel_pg` to set the process group
2. extract local tensors with `zigzag_extract_local`. We need reorder the input tokens or input tensors for load balance ring attention.
3. then apply `LongContextAttention(ring_impl_type="zigzag")` as a drop-in replacement for Attention implementation.

### Install

Option 1: pip install from pypi.

`pip install yunchang==0.3` (flash_attn >= 2.6.0)
`pip install yunchang` (flash_attn >= 2.6.0)

`pip install yunchang==0.2` (flash_attn < 2.6.0)

Option 2: build from local.

`pip install .`

### Install for AMD GPU

Supported GPU : MI300X, MI308X

GPU arch : gfx942

Step 1: prepare docker envrionment

Tow recommended docker container to start with

- rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0 : hosted in dockerhub, no conda
- [dockerhub repo](https://github.com/yiakwy-xpu-ml-framework-team/Tools-dockerhub/blob/main/rocm/Dockerfile.rocm62.ubuntu-22.04) : Customerized Dockerfile with conda virtual env and develop kit support

An example to create an docker container :

```bash
# create docker container
IMG=rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0
tag=py310-rocm6.2-distattn-dev

docker_args=$(echo -it --privileged \
--name $tag \
--ulimit memlock=-1:-1 --net=host --cap-add=IPC_LOCK \
--device=/dev/kfd --device=/dev/dri \
--ipc=host \
--security-opt seccomp=unconfined \
--shm-size 16G \
--group-add video \
-v $(readlink -f `pwd`):/workspace \
--workdir /workspace \
--cpus=$((`nproc` / 2 - 1)) \
$IMG
)

docker_args=($docker_args)

docker container create "${docker_args[@]}"

# start it
docker start -a -i $tag
```

Update ROCM SDK using this [script](https://github.com/yiakwy-xpu-ml-framework-team/Tools-dockerhub/blob/main/rocm/update_sdk.sh):

```bash
# e.g.:
ROCM_VERSION=6.3 bash rocm/update_sdk.sh
```

Step 2 : build from local.

> MAX_JOBS=$(nproc) pip install .[amd] --verbose
**Features:**

1. No Limitation on the Number of Heads: Our approach does not impose a restriction on the number of heads, providing greater flexibility for various attention mechanisms.

2. Cover the Capability of either Ulysses and Ring: By setting the ulysses_degree to the sequence parallel degree, the system operates identically to Ulysses. Conversely, setting the ulysses_degree to 1 mirrors the functionality of Ring.

3. Enhanced Performance: We achieve superior performance benchmarks over both Ulysses and Ring, offering a more efficient solution for attention mechanism computations.

4. Compatibility with Advanced Parallel Strategies: LongContextAttention is fully compatible with other sophisticated parallelization techniques, including Tensor Parallelism, ZeRO, and Pipeline Parallelism, ensuring seamless integration with the latest advancements in parallel computing.
Install for AMD GPU: [install_amd.md](./docs/install_amd.md)

### Verified in Megatron-LM
The loss curves for Data Parallel (DP) and Unified Sequence Parallel (ulysses=2+ring=2) are closely aligned, as illustrated in the figure. This alignment confirms the accuracy of the unified sequence parallel.
Expand All @@ -121,7 +69,6 @@ In the Megatron-LM, you can reorder the input tokens before feed them into the m

## Best Practice for 4D Parallelism


We analyze the impact of introducing Sequnce Parallelism to Data/ZeRO/Tensor/Pipeline Parallelism in a technique report, which can be found at [here](https://arxiv.org/abs/2405.07719).

Some best practices are listed here:
Expand Down Expand Up @@ -183,7 +130,7 @@ I am honored that this repository has contributed to the following projects:
6. [FlagOpen/FlagScale](https://github.com/FlagOpen/FlagScale/commit/f98ee1e293bd906cc77f512f7a884b2030c10a12)
7. [zhiyuanhubj/LongRecipe](https://github.com/zhiyuanhubj/LongRecipe)

## Citation
## Cite Us
```
@article{fang2024unified,
title={USP: A Unified Sequence Parallelism Approach for Long Context Generative AI},
Expand Down
62 changes: 62 additions & 0 deletions docs/install_amd.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
## Install for AMD GPU

Supported GPU : MI300X, MI308X

GPU arch : gfx942

Step 1: prepare docker envrionment

Tow recommended docker container to start with

- rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0 : hosted in dockerhub, no conda
- [dockerhub repo](https://github.com/yiakwy-xpu-ml-framework-team/Tools-dockerhub/blob/main/rocm/Dockerfile.rocm62.ubuntu-22.04) : Customerized Dockerfile with conda virtual env and develop kit support

An example to create an docker container :

```bash
# create docker container
IMG=rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0
tag=py310-rocm6.2-distattn-dev

docker_args=$(echo -it --privileged \
--name $tag \
--ulimit memlock=-1:-1 --net=host --cap-add=IPC_LOCK \
--device=/dev/kfd --device=/dev/dri \
--ipc=host \
--security-opt seccomp=unconfined \
--shm-size 16G \
--group-add video \
-v $(readlink -f `pwd`):/workspace \
--workdir /workspace \
--cpus=$((`nproc` / 2 - 1)) \
$IMG
)

docker_args=($docker_args)

docker container create "${docker_args[@]}"

# start it
docker start -a -i $tag
```

Update ROCM SDK using this [script](https://github.com/yiakwy-xpu-ml-framework-team/Tools-dockerhub/blob/main/rocm/update_sdk.sh):

```bash
# e.g.:
ROCM_VERSION=6.3 bash rocm/update_sdk.sh
```

Step 2 : build from local.

> MAX_JOBS=$(nproc) pip install .[amd] --verbose
**Features:**

1. No Limitation on the Number of Heads: Our approach does not impose a restriction on the number of heads, providing greater flexibility for various attention mechanisms.

2. Cover the Capability of either Ulysses and Ring: By setting the ulysses_degree to the sequence parallel degree, the system operates identically to Ulysses. Conversely, setting the ulysses_degree to 1 mirrors the functionality of Ring.

3. Enhanced Performance: We achieve superior performance benchmarks over both Ulysses and Ring, offering a more efficient solution for attention mechanism computations.

4. Compatibility with Advanced Parallel Strategies: LongContextAttention is fully compatible with other sophisticated parallelization techniques, including Tensor Parallelism, ZeRO, and Pipeline Parallelism, ensuring seamless integration with the latest advancements in parallel computing.
Binary file removed media/hybrid_seqparallel.png
Binary file not shown.
Binary file added media/usp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 8 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from setuptools import setup, find_packages
import os

# 读取版本信息
version_file = os.path.join(os.path.dirname(__file__), 'yunchang', '__version__.py')
with open(version_file, 'r') as f:
exec(f.read())

setup(
name="yunchang",
version="0.3",
author="Jiarui Fang, Zilin Zhu, Yang Yu",
version=__version__,
author="[email protected]",
url="https://github.com/feifeibear/long-context-attention",
packages=find_packages(exclude=['test', 'benchmark']),
install_requires=[
Expand Down
93 changes: 62 additions & 31 deletions test/test_hybrid_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
AsyncLongContextAttention,
LongContextAttention,
set_seq_parallel_pg,
EXTRACT_FUNC_DICT
)
import torch
import torch.distributed as dist
Expand Down Expand Up @@ -34,31 +35,35 @@ def log(msg, a, rank0_only=False):
)
dist.barrier()


# test it with:
# torchrun --nproc_per_node=4 test/test_hybrid_attn_v2.py
if __name__ == "__main__":
torch.random.manual_seed(0)

use_bwd = False
use_bwd = True
dist.init_process_group("nccl")

rank = dist.get_rank()
world_size = dist.get_world_size()

assert world_size == 4, f"torchrun --nproc_per_node=4 test/test_hybrid_attn_v2.py"
# Inference mainly uses fp16; ROCM flash attention with bf16 precision is slightly larger, will be fixed soon
dtype = torch.float16
dtype = torch.bfloat16
device = torch.device(f"cuda:{rank}")

batch_size = 2
seqlen = 3816
nheads = 2
seqlen = 1024
nheads = 4
d = 128
dropout_p = 0
causal = True

deterministic = False

use_async_all_to_all = True
assert seqlen % world_size == 0
assert d % 8 == 0
# assert batch_size == 1

ring_impl_type = "zigzag" # You can change this to "basic" or "zigzag" if needed

# Prepare inputs
q = torch.randn(
Expand All @@ -77,30 +82,34 @@ def log(msg, a, rank0_only=False):
dist.broadcast(v, src=0)
dist.broadcast(dout, src=0)

local_q = q.chunk(world_size, dim=1)[rank].detach().clone()
local_q.requires_grad = True
local_k = k.chunk(world_size, dim=1)[rank].detach().clone()
local_k.requires_grad = True
local_v = v.chunk(world_size, dim=1)[rank].detach().clone()
local_v.requires_grad = True

local_dout = dout.chunk(world_size, dim=1)[rank].detach().clone()

# prepare process group for hybrid sequence parallelism
use_ring_low_dim = True

sp_ulysses_degree = min(nheads, world_size)
sp_ulysses_degree = 2
sp_ring_degree = world_size // sp_ulysses_degree

print(
f"rank {rank}, sp_ulysses_degree: {sp_ulysses_degree}, sp_ring_degree: {sp_ring_degree}"
)

set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size)

if use_async_all_to_all:
hybrid_seq_parallel_attn = AsyncLongContextAttention()
else:
hybrid_seq_parallel_attn = LongContextAttention()
# Use EXTRACT_FUNC_DICT to shard the tensors
local_q = EXTRACT_FUNC_DICT[ring_impl_type](
q, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
).detach().clone()
local_q.requires_grad = True

local_k = EXTRACT_FUNC_DICT[ring_impl_type](
k, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
).detach().clone()
local_k.requires_grad = True

local_v = EXTRACT_FUNC_DICT[ring_impl_type](
v, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
).detach().clone()
local_v.requires_grad = True
usp_attn = LongContextAttention(ring_impl_type=ring_impl_type)

if rank == 0:
print("#" * 30)
Expand All @@ -112,7 +121,9 @@ def log(msg, a, rank0_only=False):
alibi_slopes, attn_bias = None, None
dropout_mask = None

local_out = hybrid_seq_parallel_attn(
print(f"before usp attn forward: {local_q.shape} {local_k.shape} {local_v.shape}")
# usp attn forward
local_out = usp_attn(
local_q,
local_k,
local_v,
Expand All @@ -125,11 +136,17 @@ def log(msg, a, rank0_only=False):
return_attn_probs=True,
)

# extract local dout
local_dout = EXTRACT_FUNC_DICT[ring_impl_type](
dout, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
).detach().clone()

if rank == 0:
print("#" * 30)
print("# ds-ulysses backward:")
print("#" * 30)

# usp attn backward
if use_bwd:
local_out.backward(local_dout)

Expand Down Expand Up @@ -177,26 +194,40 @@ def log(msg, a, rank0_only=False):
dist.barrier()

# check correctness

local_out_ref = out_ref.chunk(world_size, dim=1)[rank]
local_out_pt_ref = out_ref.chunk(world_size, dim=1)[rank]
# When checking correctness, use EXTRACT_FUNC_DICT for reference outputs
local_out_ref = EXTRACT_FUNC_DICT[ring_impl_type](
out_ref, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
)
local_out_pt_ref = EXTRACT_FUNC_DICT[ring_impl_type](
out_pt_ref, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
)

log("local (rank) out", local_out, rank0_only=True)
log("out (distributed) - out_ref (non-distributed) diff", local_out_ref - local_out)
log("out_ref (non-distributed) - out_pt_ref (gpu) diff", local_out_ref - local_out_pt_ref)

# log("out_ref (non-distributed) - out_pt_ref (gpu) diff", local_out_ref - local_out_pt_ref)

torch.testing.assert_close(local_out, local_out_ref, atol=1e-2, rtol=0)
torch.testing.assert_close(out_ref, out_pt_ref, atol=1e-2, rtol=0)
# torch.testing.assert_close(out_ref, out_pt_ref, atol=1e-2, rtol=0)

if use_bwd:
local_dq_ref = q.grad.chunk(world_size, dim=1)[rank]
local_dq_ref = EXTRACT_FUNC_DICT[ring_impl_type](
q.grad, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
)
log("load_dq", local_q.grad)
log("dq diff", local_dq_ref - local_q.grad)

local_dk_ref = k.grad.chunk(world_size, dim=1)[rank]
local_dk_ref = EXTRACT_FUNC_DICT[ring_impl_type](
k.grad, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
)
log("load_dk", local_k.grad)
log("dk diff", local_dk_ref - local_k.grad)

local_dv_ref = v.grad.chunk(world_size, dim=1)[rank]
log("load_dk", local_v.grad)
local_dv_ref = EXTRACT_FUNC_DICT[ring_impl_type](
v.grad, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
)
log("load_dv", local_v.grad)
log("dv diff", local_dv_ref - local_v.grad)

if dist.is_initialized():
dist.destroy_process_group()
Loading

0 comments on commit 5ac27d8

Please sign in to comment.