Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NPU]support npu #6089

Merged
merged 6 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def dict(self):
messages=[],
offset=0,
sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN,
seps=["<|begin_of_text|>", "<|end_of_text|>"],
seps=["<|begin_of_text|>", "<|eot_id|>"],
)

default_conversation = LLaMA3_Conv
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def supervised_tokenize_sft(

assert (
tokenizer.bos_token == conversation_template.seps[0] and tokenizer.eos_token == conversation_template.seps[1]
), "`bos_token` and `eos_token` should be the same with `conversation_template.seps`."
), f"`bos_token`{tokenizer.bos_token} and `eos_token`{tokenizer.eos_token} should be the same with `conversation_template.seps`{conversation_template.seps}."

if ignore_index is None:
ignore_index = IGNORE_INDEX
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def save_checkpoint(
step: int,
batch_size: int,
coordinator: DistCoordinator,
use_lora: bool = False,
) -> None:
"""
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
Expand All @@ -51,7 +52,10 @@ def save_checkpoint(
save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)

booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
if use_lora:
booster.save_lora_as_pretrained(model, os.path.join(save_dir, "modeling"))
else:
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)

booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
Expand Down
101 changes: 51 additions & 50 deletions applications/Colossal-LLaMA/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from colossal_llama.utils.froze import freeze_non_embeds_parameters
from colossal_llama.utils.neftune_patch import activate_neftune, deactivate_neftune
from colossal_llama.utils.utils import all_reduce_mean, format_numel_str, get_model_numel
from peft import LoraConfig
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
Expand Down Expand Up @@ -65,7 +66,7 @@ def train(args) -> None:
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_fused_normalization=torch.cuda.is_available(),
enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "gemini_auto":
Expand All @@ -75,7 +76,7 @@ def train(args) -> None:
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_fused_normalization=torch.cuda.is_available(),
enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2":
Expand All @@ -101,7 +102,7 @@ def train(args) -> None:
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_fused_normalization=torch.cuda.is_available(),
enable_fused_normalization=get_accelerator().is_available(),
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
Expand Down Expand Up @@ -166,31 +167,28 @@ def train(args) -> None:
# ======================================================
init_ctx = (
LazyInitContext(default_device=get_current_device())
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) and args.lora_rank == 0
else nullcontext()
)
with init_ctx:
if args.use_flash_attn:
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
else:
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
# Freeze part of parameters.
if args.freeze_non_embeds_params:
freeze_non_embeds_parameters(model=model)

if args.lora_rank > 0:
lora_config = LoraConfig(task_type="CAUSAL_LM", r=args.lora_rank, lora_alpha=32, lora_dropout=0.1)
model = booster.enable_lora(model, lora_config=lora_config)

# this is essential, otherwise the grad checkpoint will not work.
model.train()

if args.use_grad_checkpoint:
model.gradient_checkpointing_enable()
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")

model_numel = get_model_numel(model)
Expand Down Expand Up @@ -327,6 +325,7 @@ def train(args) -> None:
step=step + 1,
batch_size=args.batch_size,
coordinator=coordinator,
use_lora=(args.lora_rank > 0),
)
coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
Expand Down Expand Up @@ -371,44 +370,45 @@ def train(args) -> None:
total_loss.fill_(0.0)
pbar.update()

# Save modeling.
save_model_condition = (
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
)
# Save modeling.
save_model_condition = (
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
)

if not args.skip_save_each_epoch:
save_model_condition = save_model_condition or (step + 1) == len(dataloader)
if not args.skip_save_each_epoch:
save_model_condition = save_model_condition or (step + 1) == len(dataloader)

if save_model_condition and not args.benchmark:
coordinator.print_on_master("\nStart saving model checkpoint with running states")
if save_model_condition and not args.benchmark:
coordinator.print_on_master("\nStart saving model checkpoint with running states")

if args.use_neft:
coordinator.print_on_master("Deactivate NEFTune before saving model.")
deactivate_neftune(model, handle)
if args.use_neft:
coordinator.print_on_master("Deactivate NEFTune before saving model.")
deactivate_neftune(model, handle)

accelerator.empty_cache()
save_checkpoint(
save_dir=args.save_dir,
booster=booster,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
epoch=epoch,
step=step + 1,
batch_size=args.batch_size,
coordinator=coordinator,
)
coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
)
accelerator.empty_cache()
save_checkpoint(
save_dir=args.save_dir,
booster=booster,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
epoch=epoch,
step=step + 1,
batch_size=args.batch_size,
coordinator=coordinator,
use_lora=(args.lora_rank > 0),
)
coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
)

if args.use_neft:
coordinator.print_on_master("Activate NEFTune.")
model, handle = activate_neftune(model)
if args.use_neft:
coordinator.print_on_master("Activate NEFTune.")
model, handle = activate_neftune(model)

# Delete cache.
# del batch, batch_labels, batch_output, loss
accelerator.empty_cache()
# Delete cache.
# del batch, batch_labels, batch_output, loss
accelerator.empty_cache()

# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader.sampler.set_start_index(start_index=0)
Expand Down Expand Up @@ -522,6 +522,7 @@ def train(args) -> None:
parser.add_argument(
"--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin."
)
parser.add_argument("--lora_rank", type=int, default=0, help="lora rank when using lora to train.")

# Additional arguments for benchmark.
parser.add_argument("--num_samples", type=int, default=500, help="Number of samples for benchmarking.")
Expand Down
2 changes: 1 addition & 1 deletion colossalai/legacy/communication/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _communicate(
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
get_accelerator().synchronize()

if recv_prev and recv_prev_split:
if isinstance(tensor_recv_prev, torch.Tensor):
Expand Down
41 changes: 27 additions & 14 deletions colossalai/pipeline/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from torch.distributed import distributed_c10d as c10d
from torch.utils._pytree import tree_flatten, tree_unflatten

from colossalai.accelerator import get_accelerator

from .stage_manager import PipelineStageManager


Expand All @@ -31,7 +33,7 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
buf = tensor.numpy().tobytes()[:tensor_size]
if b"cuda" in buf:
buf_array = bytearray(buf)
device_index = torch.cuda.current_device()
device_index = get_accelerator().current_device()
# There might be more than one output tensors during forward
for cuda_str in re.finditer(b"cuda", buf_array):
pos = cuda_str.start()
Expand Down Expand Up @@ -86,7 +88,7 @@ def _broadcast_object_list(
else:
current_device = torch.device("cpu")
if is_nccl_backend:
current_device = torch.device("cuda", torch.cuda.current_device())
current_device = torch.device("cuda", get_accelerator().current_device())

my_rank = dist.get_rank()
# Serialize object_list elements to tensors on src rank.
Expand Down Expand Up @@ -139,9 +141,9 @@ def _broadcast_object_list(
# unconsistence in device
if (
isinstance(unpickle_object, torch.Tensor)
and unpickle_object.device.index != torch.cuda.current_device()
and unpickle_object.device.index != get_accelerator().current_device()
):
unpickle_object = unpickle_object.cuda()
unpickle_object = unpickle_object.to(get_accelerator().current_device())

object_list[i] = unpickle_object

Expand All @@ -157,12 +159,18 @@ def _check_for_nccl_backend(group):
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL


# def _check_device(group):
# is_nccl_backend = _check_for_nccl_backend(group)
# print("_check_device", is_nccl_backend)
# current_device = torch.device("cpu")
# if is_nccl_backend:
# current_device = torch.device(get_accelerator().current_device())
# return current_device, is_nccl_backend


def _check_device(group):
is_nccl_backend = _check_for_nccl_backend(group)
current_device = torch.device("cpu")
if is_nccl_backend:
current_device = torch.device("cuda", torch.cuda.current_device())
return current_device, is_nccl_backend
current_device = torch.device(get_accelerator().current_device())
return current_device, True


TensorMetadata = namedtuple("TensorMetadata", ["shape", "dtype", "requires_grad"])
Expand Down Expand Up @@ -348,8 +356,11 @@ def _send_recv_serialization_object(

unpickle_object = _cuda_safe_tensor_to_object(recv_object_tensor, recv_object_size_tensor.item())

if isinstance(unpickle_object, torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device():
unpickle_object = unpickle_object.cuda()
if (
isinstance(unpickle_object, torch.Tensor)
and unpickle_object.device.index != get_accelerator().current_device()
):
unpickle_object = unpickle_object.to(get_accelerator().current_device())

return unpickle_object

Expand Down Expand Up @@ -475,9 +486,11 @@ def _p2p_comm(
recv_prev_shape = None

if tensor_send_next is not None:
send_next_shape = torch.tensor(tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64)
send_next_shape = torch.tensor(
tensor_send_next.size(), device=get_accelerator().current_device(), dtype=torch.int64
)
if recv_prev:
recv_prev_shape = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64)
recv_prev_shape = torch.empty((3), device=get_accelerator().current_device(), dtype=torch.int64)

ops = []
if send_next_shape is not None:
Expand All @@ -502,7 +515,7 @@ def _p2p_comm(
# send and recv data
tensor_recv_prev = None
if recv_prev:
tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_dtype)
tensor_recv_prev = torch.empty(recv_prev_shape, device=get_accelerator().current_device(), dtype=comm_dtype)

ops = []
if tensor_send_next is not None:
Expand Down
3 changes: 1 addition & 2 deletions colossalai/pipeline/schedule/interleaved_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
import torch.cuda
import torch.distributed
from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map
Expand All @@ -18,7 +17,7 @@
from .base import PipelineSchedule


def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
def _wait_p2p(wait_handles: List[get_accelerator().Event]) -> None:
if wait_handles is not None:
for req in wait_handles:
req.wait()
Expand Down
1 change: 0 additions & 1 deletion colossalai/pipeline/schedule/one_f_one_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Union

import torch
import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_map

Expand Down
26 changes: 22 additions & 4 deletions colossalai/shardformer/layer/normalization.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import numbers
import warnings
from abc import ABC, abstractmethod

import torch
import torch.nn as nn
import torch_npu
from torch.nn import init
from torch.nn.parameter import Parameter

from colossalai.lazy import LazyInitContext

Expand All @@ -21,7 +26,6 @@

try:
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm

class FusedLayerNormWithHook(ApexFusedLayerNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
Expand All @@ -32,12 +36,26 @@ def forward(self, input):
output = hook_parameter_in_backward(output, self.weight, self.bias)
return output

class FusedRMSNormWithHook(ApexFusedRMSNorm):
class FusedRMSNormWithHook(nn.Module):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine)
super().__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = Parameter(torch.empty(*normalized_shape))
else:
self.register_parameter("weight", None)
self.reset_parameters()

def reset_parameters(self):
if self.elementwise_affine:
init.ones_(self.weight)
flybird11111 marked this conversation as resolved.
Show resolved Hide resolved

def forward(self, input):
output = super().forward(input)
output, _ = torch_npu.npu_rms_norm(input, self.weight, self.eps)
output = hook_parameter_in_backward(output, self.weight)
return output

Expand Down
Loading