Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] support zbv in mixtral benchmark; #6083

Merged
merged 25 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3f5bec8
[feat] support zbv in mixtral benchmark;
duanjunwen Oct 9, 2024
9ee80fc
[fix] MixtralForCausalLMPolicy get_held_layer support zbv;
duanjunwen Oct 10, 2024
72b507a
[feat] update MixtralPipelineForwards --> mixtral_model_forward; supp…
duanjunwen Oct 10, 2024
e234dfa
[feat] support MixtralPipelineForwards--> mixtral_for_causal_lm_forwa…
duanjunwen Oct 10, 2024
0ca16d5
[fix] fix llama, mixtral benchmark zbv loss none bug; update mixtral …
duanjunwen Oct 11, 2024
cfade4c
[feat] Linear1D_COL/ROW support zbv WeightGradStore;
duanjunwen Oct 14, 2024
a11b4b5
[feat] support use_zbv in llama, mixtral modeling; only replace Linea…
duanjunwen Oct 14, 2024
abd4551
[fix] fix test case; moe error in second iter
duanjunwen Oct 14, 2024
160e9a4
[feat]EPMixtralSparseMoeBlock (op in MOE) support zbv;
duanjunwen Oct 14, 2024
9912cc8
[fix] fix bwd b; now bwd w only for Layer replaced by Linear1D_Col/Ro…
duanjunwen Oct 15, 2024
52dcc73
Merge branch 'feature/zerobubble' of github.com:hpcaitech/ColossalAI …
duanjunwen Oct 15, 2024
90939b7
[fix] debug zbv llama test;
duanjunwen Oct 15, 2024
e76308c
[fix] rm use_zbv flag in Shardconfig; rm debug info;
duanjunwen Oct 16, 2024
705b18e
[fix] add & fix llama test
duanjunwen Oct 16, 2024
2eca112
[feat] support meta cache, meta_grad_send, meta_tensor_send; fix runt…
duanjunwen Oct 24, 2024
d0ec221
[fix\ fix fail case test_shard_llama
duanjunwen Oct 25, 2024
cc0dfdd
[fix] fix test_shard_llama
duanjunwen Oct 25, 2024
03fa79a
[fix] fix llama modeling policy;
duanjunwen Oct 25, 2024
6377aa0
[fix] fix test_shard_llama ci;
duanjunwen Oct 28, 2024
5aee426
[fix] fix test zerobubble
duanjunwen Oct 28, 2024
fafe049
[fix] fix handle name; rm useless comments;
duanjunwen Oct 29, 2024
fa3ccda
[fix] fix send recv signature;
duanjunwen Oct 29, 2024
982e4ee
[fix] fix comment in llama & benchmark
duanjunwen Oct 29, 2024
d2e05a9
[feat] support no tensor parallel Linear in shardformer; Add test for…
duanjunwen Oct 30, 2024
5f09243
[fix] fix linear (no tp) ops func name;
duanjunwen Oct 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 105 additions & 57 deletions colossalai/pipeline/schedule/zero_bubble_pp.py

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion colossalai/pipeline/stage_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ def distribute_layers(

# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages * num_model_chunks

# deal with the rest layers
if remainder > 0:
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
Expand Down
32 changes: 32 additions & 0 deletions colossalai/pipeline/weight_grad_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import queue


class WeightGradStore:

cache = []
weight_grad_queue = [queue.Queue(), queue.Queue()]

@classmethod
def put(cls, total_input, grad_output, weight, func):
# func(total_input, grad_output, weight.main_grad)
cls.cache.append((total_input, grad_output, weight, func))

@classmethod
def flush(cls, chunk=0):
cls.weight_grad_queue[chunk].put(cls.cache)
cls.cache = []

@classmethod
def pop(cls, chunk=0):
# print(f"chunk id {chunk} queue size {cls.weight_grad_queue[chunk].qsize()}")
if cls.weight_grad_queue[chunk].qsize() > 0:
stored_grads = cls.weight_grad_queue[chunk].get()
for total_input, grad_output, weight, func in stored_grads:
if weight.grad is not None:
func(total_input, grad_output, weight.grad)
# for first bwd; weight.grad is None, assign grad_weight to weight.grad
else:
grad_weight = func(total_input, grad_output)
weight.grad = grad_weight
else:
raise Exception("Pop empty queue.")
3 changes: 2 additions & 1 deletion colossalai/shardformer/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .attn import AttnMaskType, ColoAttention, RingAttention, get_pad_info
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
from .linear import Linear1D, Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
from .loss import cross_entropy_1d, dist_cross_entropy
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
Expand All @@ -11,6 +11,7 @@
__all__ = [
"Embedding1D",
"VocabParallelEmbedding1D",
"Linear1D",
"Linear1D_Col",
"Linear1D_Row",
"GPT2FusedLinearConv1D_Col",
Expand Down
178 changes: 167 additions & 11 deletions colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import functools

import torch
import torch.distributed as dist
import torch.nn.functional as F

from colossalai.pipeline.weight_grad_store import WeightGradStore

from .utils import is_share_sp_tp

try:
Expand Down Expand Up @@ -125,12 +129,13 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
"""

@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce
ctx.fp8_communication = fp8_communication
ctx.use_zbv = use_zbv
if bias is not None:
output = F.linear(input_, weight, bias)
else:
Expand All @@ -143,6 +148,13 @@ def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
fp8_communication = ctx.fp8_communication
use_zbv = ctx.use_zbv

def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)

def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
return wgrad_gemm_func(_grad_output_.t(), _input_)

# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
if use_bias:
Expand All @@ -164,24 +176,162 @@ def backward(ctx, grad_output):
handle = dist.all_reduce(grad_input, group=ctx.process_group, async_op=True)
# Relay on CUDA_DEVICE_MAX_CONNECTIONS=1 to have
# all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py

if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
if use_zbv:
# TODO: append input, grad_output_, weight, grad func to WeightGradStore
if grad.dtype == torch.float32:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
),
)
grad_weight = None
elif grad.dtype in (torch.float16, torch.bfloat16):
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
),
)
ver217 marked this conversation as resolved.
Show resolved Hide resolved
grad_weight = None
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
else:
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
if use_zbv:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass,
wgrad_gemm_func=torch.matmul,
),
)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
grad_weight = grad_output.t().matmul(total_input)

grad_bias = grad_output.sum(dim=0) if use_bias else None

if ctx.async_grad_allreduce and not fp8_communication:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None


class LinearBase(torch.autograd.Function):
ver217 marked this conversation as resolved.
Show resolved Hide resolved
"""
Linear layer baseline (no tensor parallel version).
"""

@staticmethod
def forward(ctx, input_, weight, bias, async_grad_allreduce, fp8_communication=False, use_zbv=False):
ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.async_grad_allreduce = async_grad_allreduce
ctx.fp8_communication = fp8_communication
ctx.use_zbv = use_zbv
if bias is not None:
output = F.linear(input_, weight, bias)
else:
output = F.linear(input_, weight)

return output

@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
ctx.fp8_communication
use_zbv = ctx.use_zbv

def execute_w_pass_grad_accum(_input_, _grad_output_, _weight_main_grad_, wgrad_gemm_accum_func=None):
wgrad_gemm_accum_func(_input_, _grad_output_, _weight_main_grad_)

def execute_w_pass(_input_, _grad_output_, _weight_main_grad_=None, wgrad_gemm_func=None):
return wgrad_gemm_func(_grad_output_.t(), _input_)

# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
if use_bias:
bias.view(bias.shape)

total_input = input.contiguous()
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])

if _grad_accum_fusion_available and weight.grad is not None:
grad = weight.grad
if use_zbv:
# TODO: append input, grad_output_, weight, grad func to WeightGradStore
if grad.dtype == torch.float32:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32,
),
)
grad_weight = None
elif grad.dtype in (torch.float16, torch.bfloat16):
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass_grad_accum,
wgrad_gemm_accum_func=fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16,
),
)
grad_weight = None
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
else:
if grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, grad)
grad_weight = None
elif grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
else:
if use_zbv:
WeightGradStore.put(
total_input,
grad_output,
weight,
functools.partial(
execute_w_pass,
wgrad_gemm_func=torch.matmul,
),
)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)

grad_bias = grad_output.sum(dim=0) if use_bias else None

return grad_input, grad_weight, grad_bias, None, None, None, None

Expand Down Expand Up @@ -1043,12 +1193,18 @@ def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allre
)


def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False):
def linear_with_async_comm(
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication=False, use_zbv=False
):
return LinearWithAsyncCommunication.apply(
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication
input_, weight, bias, process_group, async_grad_allreduce, fp8_communication, use_zbv
)


def linear_base(input_, weight, bias, async_grad_allreduce, fp8_communication=False, use_zbv=False):
ver217 marked this conversation as resolved.
Show resolved Hide resolved
return LinearBase.apply(input_, weight, bias, async_grad_allreduce, fp8_communication, use_zbv)


def linear_gather_forward_reducescatter_backward(
input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap, ring=False
):
Expand Down
Loading
Loading