Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NPU]support npu #6089

Merged
merged 6 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def dict(self):
messages=[],
offset=0,
sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN,
seps=["<|begin_of_text|>", "<|end_of_text|>"],
seps=["<|begin_of_text|>", "<|eot_id|>"],
)

default_conversation = LLaMA3_Conv
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def supervised_tokenize_sft(

assert (
tokenizer.bos_token == conversation_template.seps[0] and tokenizer.eos_token == conversation_template.seps[1]
), "`bos_token` and `eos_token` should be the same with `conversation_template.seps`."
), f"`bos_token`{tokenizer.bos_token} and `eos_token`{tokenizer.eos_token} should be the same with `conversation_template.seps`{conversation_template.seps}."

if ignore_index is None:
ignore_index = IGNORE_INDEX
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def save_checkpoint(
step: int,
batch_size: int,
coordinator: DistCoordinator,
use_lora: bool = False,
) -> None:
"""
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
Expand All @@ -51,7 +52,10 @@ def save_checkpoint(
save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)

booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
if use_lora:
booster.save_lora_as_pretrained(model, os.path.join(save_dir, "modeling"))
else:
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)

booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
Expand Down
115 changes: 61 additions & 54 deletions applications/Colossal-LLaMA/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from colossal_llama.utils.froze import freeze_non_embeds_parameters
from colossal_llama.utils.neftune_patch import activate_neftune, deactivate_neftune
from colossal_llama.utils.utils import all_reduce_mean, format_numel_str, get_model_numel
from peft import LoraConfig
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
Expand Down Expand Up @@ -65,7 +66,7 @@ def train(args) -> None:
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_fused_normalization=torch.cuda.is_available(),
enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "gemini_auto":
Expand All @@ -75,7 +76,7 @@ def train(args) -> None:
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_fused_normalization=torch.cuda.is_available(),
enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2":
Expand All @@ -101,10 +102,9 @@ def train(args) -> None:
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_fused_normalization=torch.cuda.is_available(),
enable_fused_normalization=get_accelerator().is_available(),
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
microbatch_size=args.microbatch_size,
Expand All @@ -117,11 +117,17 @@ def train(args) -> None:
# ======================================================
# Initialize Tokenizer, Dataset, Collator and Dataloader
# ======================================================
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
if args.pad_token == "eos":
tokenizer.pad_token = tokenizer.eos_token
try:
tokenizer.pad_token = tokenizer.eos_token
except AttributeError:
coordinator.print_on_master(f"pad_token can't be set")
elif args.pad_token == "unk":
tokenizer.pad_token = tokenizer.unk_token
try:
tokenizer.pad_token = tokenizer.unk_token
except AttributeError:
coordinator.print_on_master(f"pad_token can't be set")
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False

Expand Down Expand Up @@ -164,33 +170,31 @@ def train(args) -> None:
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
# When training the ChatGLM model, LoRA and gradient checkpointing are incompatible.
init_ctx = (
LazyInitContext(default_device=get_current_device())
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) and args.lora_rank == 0
else nullcontext()
)
with init_ctx:
if args.use_flash_attn:
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
else:
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
# Freeze part of parameters.
if args.freeze_non_embeds_params:
freeze_non_embeds_parameters(model=model)

if args.lora_rank > 0:
lora_config = LoraConfig(task_type="CAUSAL_LM", r=args.lora_rank, lora_alpha=32, lora_dropout=0.1)
model = booster.enable_lora(model, lora_config=lora_config)

# this is essential, otherwise the grad checkpoint will not work.
model.train()

if args.use_grad_checkpoint:
model.gradient_checkpointing_enable()
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")

model_numel = get_model_numel(model)
Expand Down Expand Up @@ -327,6 +331,7 @@ def train(args) -> None:
step=step + 1,
batch_size=args.batch_size,
coordinator=coordinator,
use_lora=(args.lora_rank > 0),
)
coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
Expand Down Expand Up @@ -371,44 +376,45 @@ def train(args) -> None:
total_loss.fill_(0.0)
pbar.update()

# Save modeling.
save_model_condition = (
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
)
# Save modeling.
save_model_condition = (
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
)

if not args.skip_save_each_epoch:
save_model_condition = save_model_condition or (step + 1) == len(dataloader)
if not args.skip_save_each_epoch:
save_model_condition = save_model_condition or (step + 1) == len(dataloader)

if save_model_condition and not args.benchmark:
coordinator.print_on_master("\nStart saving model checkpoint with running states")
if save_model_condition and not args.benchmark:
coordinator.print_on_master("\nStart saving model checkpoint with running states")

if args.use_neft:
coordinator.print_on_master("Deactivate NEFTune before saving model.")
deactivate_neftune(model, handle)
if args.use_neft:
coordinator.print_on_master("Deactivate NEFTune before saving model.")
deactivate_neftune(model, handle)

accelerator.empty_cache()
save_checkpoint(
save_dir=args.save_dir,
booster=booster,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
epoch=epoch,
step=step + 1,
batch_size=args.batch_size,
coordinator=coordinator,
)
coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
)
accelerator.empty_cache()
save_checkpoint(
save_dir=args.save_dir,
booster=booster,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
epoch=epoch,
step=step + 1,
batch_size=args.batch_size,
coordinator=coordinator,
use_lora=(args.lora_rank > 0),
)
coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
)

if args.use_neft:
coordinator.print_on_master("Activate NEFTune.")
model, handle = activate_neftune(model)
if args.use_neft:
coordinator.print_on_master("Activate NEFTune.")
model, handle = activate_neftune(model)

# Delete cache.
# del batch, batch_labels, batch_output, loss
accelerator.empty_cache()
# Delete cache.
# del batch, batch_labels, batch_output, loss
accelerator.empty_cache()

# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader.sampler.set_start_index(start_index=0)
Expand Down Expand Up @@ -522,6 +528,7 @@ def train(args) -> None:
parser.add_argument(
"--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin."
)
parser.add_argument("--lora_rank", type=int, default=0, help="lora rank when using lora to train.")

# Additional arguments for benchmark.
parser.add_argument("--num_samples", type=int, default=500, help="Number of samples for benchmarking.")
Expand Down
6 changes: 3 additions & 3 deletions colossalai/lazy/lazy_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,9 +509,9 @@ def wrap_factory_like_method(orig_target, target):
# factory_like functions (eg. torch.empty_like())
def wrapper(*args, **kwargs):
orig_t = args[0]
return self.tensor_cls(
orig_target, *orig_t.shape, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs
)
device = kwargs.pop("device", orig_t.device)
dtype = kwargs.pop("dtype", orig_t.dtype)
return self.tensor_cls(orig_target, *orig_t.shape, *args[1:], device=device, dtype=dtype, **kwargs)

return wrapper, target

Expand Down
2 changes: 1 addition & 1 deletion colossalai/legacy/communication/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _communicate(
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
get_accelerator().synchronize()

if recv_prev and recv_prev_split:
if isinstance(tensor_recv_prev, torch.Tensor):
Expand Down
33 changes: 20 additions & 13 deletions colossalai/pipeline/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from torch.distributed import distributed_c10d as c10d
from torch.utils._pytree import tree_flatten, tree_unflatten

from colossalai.accelerator import get_accelerator

from .stage_manager import PipelineStageManager


Expand All @@ -31,7 +33,7 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
buf = tensor.numpy().tobytes()[:tensor_size]
if b"cuda" in buf:
buf_array = bytearray(buf)
device_index = torch.cuda.current_device()
device_index = get_accelerator().current_device()
# There might be more than one output tensors during forward
for cuda_str in re.finditer(b"cuda", buf_array):
pos = cuda_str.start()
Expand Down Expand Up @@ -86,7 +88,7 @@ def _broadcast_object_list(
else:
current_device = torch.device("cpu")
if is_nccl_backend:
current_device = torch.device("cuda", torch.cuda.current_device())
current_device = torch.device("cuda", get_accelerator().current_device())

my_rank = dist.get_rank()
# Serialize object_list elements to tensors on src rank.
Expand Down Expand Up @@ -139,29 +141,29 @@ def _broadcast_object_list(
# unconsistence in device
if (
isinstance(unpickle_object, torch.Tensor)
and unpickle_object.device.index != torch.cuda.current_device()
and unpickle_object.device.index != get_accelerator().current_device()
):
unpickle_object = unpickle_object.cuda()
unpickle_object = unpickle_object.to(get_accelerator().current_device())

object_list[i] = unpickle_object


def _check_for_nccl_backend(group):
def _check_for_nccl_hccl_backend(group):
pg = group or c10d._get_default_group()
# Gate PG wrapper check on Gloo availability.
if c10d._GLOO_AVAILABLE:
# It is not expected for PG to be wrapped many times, but support it just in case
while isinstance(pg, c10d._ProcessGroupWrapper):
pg = pg.wrapped_pg

return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
return (c10d.is_nccl_available() or torch.distributed.is_hccl_available()) and pg.name() == c10d.Backend.NCCL


def _check_device(group):
is_nccl_backend = _check_for_nccl_backend(group)
is_nccl_backend = _check_for_nccl_hccl_backend(group)
current_device = torch.device("cpu")
if is_nccl_backend:
current_device = torch.device("cuda", torch.cuda.current_device())
current_device = torch.device(get_accelerator().current_device())
return current_device, is_nccl_backend


Expand Down Expand Up @@ -348,8 +350,11 @@ def _send_recv_serialization_object(

unpickle_object = _cuda_safe_tensor_to_object(recv_object_tensor, recv_object_size_tensor.item())

if isinstance(unpickle_object, torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device():
unpickle_object = unpickle_object.cuda()
if (
isinstance(unpickle_object, torch.Tensor)
and unpickle_object.device.index != get_accelerator().current_device()
):
unpickle_object = unpickle_object.to(get_accelerator().current_device())

return unpickle_object

Expand Down Expand Up @@ -475,9 +480,11 @@ def _p2p_comm(
recv_prev_shape = None

if tensor_send_next is not None:
send_next_shape = torch.tensor(tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64)
send_next_shape = torch.tensor(
tensor_send_next.size(), device=get_accelerator().current_device(), dtype=torch.int64
)
if recv_prev:
recv_prev_shape = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64)
recv_prev_shape = torch.empty((3), device=get_accelerator().current_device(), dtype=torch.int64)

ops = []
if send_next_shape is not None:
Expand All @@ -502,7 +509,7 @@ def _p2p_comm(
# send and recv data
tensor_recv_prev = None
if recv_prev:
tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_dtype)
tensor_recv_prev = torch.empty(recv_prev_shape, device=get_accelerator().current_device(), dtype=comm_dtype)

ops = []
if tensor_send_next is not None:
Expand Down
3 changes: 1 addition & 2 deletions colossalai/pipeline/schedule/interleaved_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
import torch.cuda
import torch.distributed
from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map
Expand All @@ -18,7 +17,7 @@
from .base import PipelineSchedule


def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
def _wait_p2p(wait_handles: List[get_accelerator().Event]) -> None:
if wait_handles is not None:
for req in wait_handles:
req.wait()
Expand Down
1 change: 0 additions & 1 deletion colossalai/pipeline/schedule/one_f_one_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Union

import torch
import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_map

Expand Down
Loading