From 66b72aa067155840a762ef8155d9090a6c96d3c9 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 14 Oct 2024 11:47:31 +0000 Subject: [PATCH 1/6] support npu --- colossalai/legacy/communication/p2p.py | 2 +- colossalai/pipeline/p2p.py | 41 ++++++++++++------- .../pipeline/schedule/interleaved_pp.py | 3 +- colossalai/pipeline/schedule/one_f_one_b.py | 1 - colossalai/zero/gemini/gemini_ddp.py | 8 ++-- colossalai/zero/gemini/gemini_hook.py | 2 +- examples/language/bert/benchmark_utils.py | 19 +++++---- examples/language/bert/finetune.py | 5 ++- examples/language/llama/benchmark.py | 1 - .../flash_attention/flash_attention_npu.py | 3 ++ 10 files changed, 52 insertions(+), 33 deletions(-) diff --git a/colossalai/legacy/communication/p2p.py b/colossalai/legacy/communication/p2p.py index cf0bd4ba2437..089fcf23b884 100644 --- a/colossalai/legacy/communication/p2p.py +++ b/colossalai/legacy/communication/p2p.py @@ -171,7 +171,7 @@ def _communicate( for req in reqs: req.wait() # To protect against race condition when using batch_isend_irecv(). - torch.cuda.synchronize() + get_accelerator().synchronize() if recv_prev and recv_prev_split: if isinstance(tensor_recv_prev, torch.Tensor): diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index b7b2842136c5..841d2f161ce1 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -14,6 +14,8 @@ from torch.distributed import distributed_c10d as c10d from torch.utils._pytree import tree_flatten, tree_unflatten +from colossalai.accelerator import get_accelerator + from .stage_manager import PipelineStageManager @@ -31,7 +33,7 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) - buf = tensor.numpy().tobytes()[:tensor_size] if b"cuda" in buf: buf_array = bytearray(buf) - device_index = torch.cuda.current_device() + device_index = get_accelerator().current_device() # There might be more than one output tensors during forward for cuda_str in re.finditer(b"cuda", buf_array): pos = cuda_str.start() @@ -86,7 +88,7 @@ def _broadcast_object_list( else: current_device = torch.device("cpu") if is_nccl_backend: - current_device = torch.device("cuda", torch.cuda.current_device()) + current_device = torch.device("cuda", get_accelerator().current_device()) my_rank = dist.get_rank() # Serialize object_list elements to tensors on src rank. @@ -139,9 +141,9 @@ def _broadcast_object_list( # unconsistence in device if ( isinstance(unpickle_object, torch.Tensor) - and unpickle_object.device.index != torch.cuda.current_device() + and unpickle_object.device.index != get_accelerator().current_device() ): - unpickle_object = unpickle_object.cuda() + unpickle_object = unpickle_object.to(get_accelerator().current_device()) object_list[i] = unpickle_object @@ -157,12 +159,18 @@ def _check_for_nccl_backend(group): return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL +# def _check_device(group): +# is_nccl_backend = _check_for_nccl_backend(group) +# print("_check_device", is_nccl_backend) +# current_device = torch.device("cpu") +# if is_nccl_backend: +# current_device = torch.device(get_accelerator().current_device()) +# return current_device, is_nccl_backend + + def _check_device(group): - is_nccl_backend = _check_for_nccl_backend(group) - current_device = torch.device("cpu") - if is_nccl_backend: - current_device = torch.device("cuda", torch.cuda.current_device()) - return current_device, is_nccl_backend + current_device = torch.device(get_accelerator().current_device()) + return current_device, True TensorMetadata = namedtuple("TensorMetadata", ["shape", "dtype", "requires_grad"]) @@ -348,8 +356,11 @@ def _send_recv_serialization_object( unpickle_object = _cuda_safe_tensor_to_object(recv_object_tensor, recv_object_size_tensor.item()) - if isinstance(unpickle_object, torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device(): - unpickle_object = unpickle_object.cuda() + if ( + isinstance(unpickle_object, torch.Tensor) + and unpickle_object.device.index != get_accelerator().current_device() + ): + unpickle_object = unpickle_object.to(get_accelerator().current_device()) return unpickle_object @@ -475,9 +486,11 @@ def _p2p_comm( recv_prev_shape = None if tensor_send_next is not None: - send_next_shape = torch.tensor(tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64) + send_next_shape = torch.tensor( + tensor_send_next.size(), device=get_accelerator().current_device(), dtype=torch.int64 + ) if recv_prev: - recv_prev_shape = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64) + recv_prev_shape = torch.empty((3), device=get_accelerator().current_device(), dtype=torch.int64) ops = [] if send_next_shape is not None: @@ -502,7 +515,7 @@ def _p2p_comm( # send and recv data tensor_recv_prev = None if recv_prev: - tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_dtype) + tensor_recv_prev = torch.empty(recv_prev_shape, device=get_accelerator().current_device(), dtype=comm_dtype) ops = [] if tensor_send_next is not None: diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index c538ee0715b4..17ccb74cf19a 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -2,7 +2,6 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import torch -import torch.cuda import torch.distributed from torch.nn import Module, ModuleList from torch.utils._pytree import tree_map @@ -18,7 +17,7 @@ from .base import PipelineSchedule -def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None: +def _wait_p2p(wait_handles: List[get_accelerator().Event]) -> None: if wait_handles is not None: for req in wait_handles: req.wait() diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 0fc90995adcc..2b486d9bcf94 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -2,7 +2,6 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Union import torch -import torch.cuda from torch.nn import Module from torch.utils._pytree import tree_map diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index dbaae66108fe..280dfc123591 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -157,7 +157,7 @@ def __init__( self.enable_async_reduce = enable_async_reduce if enable_async_reduce: - self.async_reduce_stream = torch.cuda.Stream() + self.async_reduce_stream = get_accelerator().Stream() else: self.async_reduce_stream = None @@ -393,7 +393,7 @@ def grad_handle( master_weights: bool, enable_gradient_accumulation: bool, p: nn.Parameter, - async_reduce_stream: Optional[torch.cuda.Stream] = None, + async_reduce_stream=None, ): async_reduce_scatter = async_reduce_stream is not None setattr(p, "_gemini_reduced", True) @@ -432,9 +432,9 @@ def grad_handle( grad_chunk.add_tensor_to_chunk_slice(p, grad) if async_reduce_stream is not None: - async_reduce_stream.wait_stream(torch.cuda.current_stream()) + async_reduce_stream.wait_stream(get_accelerator().current_stream()) - with torch.cuda.stream(async_reduce_stream): + with get_accelerator().stream(async_reduce_stream): reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce_scatter) if reduced: grad_chunk.wait_async_reduce() diff --git a/colossalai/zero/gemini/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py index bf5faa0fe884..786b30c242a1 100644 --- a/colossalai/zero/gemini/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -62,7 +62,7 @@ def pre_op(self, params): # # Other than that, self._gemini_manager.wait_chunks will have synced with default stream # by calling dist.Work.wait() and this line makes no diff. - self._gemini_manager.chunk_manager._prefetch_stream.wait_stream(torch.cuda.current_stream()) + self._gemini_manager.chunk_manager._prefetch_stream.wait_stream(get_accelerator().current_stream()) with get_accelerator().stream(self._gemini_manager.chunk_manager._prefetch_stream): for chunk in chunks_fetch_async: diff --git a/examples/language/bert/benchmark_utils.py b/examples/language/bert/benchmark_utils.py index 04d55cb2e7b6..b70dc7496f05 100644 --- a/examples/language/bert/benchmark_utils.py +++ b/examples/language/bert/benchmark_utils.py @@ -9,6 +9,7 @@ from torch.utils.data import DataLoader from tqdm import tqdm +from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.cluster import DistCoordinator @@ -59,7 +60,9 @@ def warm_up( for i, data in enumerate(dataloader): if i > num_runs: break - inputs, labels = data[0].cuda(), data[1].cuda() + inputs, labels = data[0].to(get_accelerator().get_current_device()), data[1].to( + get_accelerator().get_current_device() + ) outputs = model(inputs, labels=labels) loss = criterion(outputs) booster.backward(loss, optimizer) @@ -85,7 +88,7 @@ def benchmark( warm_up_steps: int = 3, ): results = {} - model_device = torch.cuda.current_device() + model_device = get_accelerator().get_current_device() # Warm up warm_up_fn( @@ -106,8 +109,8 @@ def benchmark( # Measure Allocated Memory and Throughput memory = {} throughput = {} - torch.cuda.reset_peak_memory_stats(device=model_device) - pre_mem = torch.cuda.memory_allocated(device=model_device) + get_accelerator().reset_peak_memory_stats(device=model_device) + pre_mem = get_accelerator().memory_allocated(device=model_device) start_time = time() @@ -116,7 +119,9 @@ def benchmark( dataloader, desc=f"Epoch [{epoch + 1}/{epoch_num}]", disable=not DistCoordinator().is_master() ) as pbar: for data in pbar: - inputs, labels = data[0].cuda(), data[1].cuda() + inputs, labels = data[0].to(get_accelerator().get_current_device()), data[1].to( + get_accelerator().get_current_device() + ) outputs = model(inputs, labels=labels) loss = criterion(outputs) booster.backward(loss, optimizer) @@ -128,8 +133,8 @@ def benchmark( all_sample = epoch_num * len(dataloader) - post_mem = torch.cuda.memory_allocated(device=model_device) - max_mem = torch.cuda.max_memory_allocated(device=model_device) + post_mem = get_accelerator().memory_allocated(device=model_device) + max_mem = get_accelerator().max_memory_allocated(device=model_device) memory[f"batch_size_{batch_size}"] = { "cuda_pre_training_bytes": format_num(pre_mem, bytes=True), diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index f048abdd253a..96f1bece0591 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -38,7 +38,7 @@ def move_to_cuda(batch): - return {k: v.cuda() for k, v in batch.items()} + return {k: v.to(get_accelerator().get_current_device()) for k, v in batch.items()} @torch.no_grad() @@ -266,7 +266,8 @@ def main(): cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels) if model_name == "bert-base-uncased": - model = BertForSequenceClassification.from_pretrained(model_name, config=cfg).cuda() + model = BertForSequenceClassification.from_pretrained(model_name, config=cfg) + model = model.to(get_accelerator().get_current_device()) elif model_name == "albert-xxlarge-v2": model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg) else: diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 0e88fabf1eb0..4858ded4b6d9 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -283,7 +283,6 @@ def empty_init(): config, trust_remote_code=True, **init_kwargs, - attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, ) if args.grad_checkpoint: diff --git a/extensions/pybind/flash_attention/flash_attention_npu.py b/extensions/pybind/flash_attention/flash_attention_npu.py index 8a30972b6fba..f22bc4022757 100644 --- a/extensions/pybind/flash_attention/flash_attention_npu.py +++ b/extensions/pybind/flash_attention/flash_attention_npu.py @@ -27,6 +27,7 @@ def build_jit(self) -> None: ) def load(self): + import math from typing import Optional import torch @@ -47,6 +48,8 @@ def flash_attention( q_indices: Optional[torch.Tensor] = None, kv_indices: Optional[torch.Tensor] = None, ): + if scale is None: + scale = 1.0 / math.sqrt(q.size(-1)) num_heads = q.size(1) return torch_npu.npu_fusion_attention( q, From a6ac181d6e576a44dc3b16b81b8af8de1db1a1c0 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Tue, 15 Oct 2024 09:31:57 +0000 Subject: [PATCH 2/6] support pretrain support pretrain fix --- .../colossal_llama/dataset/conversation.py | 2 +- .../dataset/spliced_and_tokenized_dataset.py | 2 +- applications/Colossal-LLaMA/train.py | 88 +++++++++---------- colossalai/shardformer/layer/normalization.py | 26 +++++- examples/language/llama/benchmark.py | 16 ++-- 5 files changed, 72 insertions(+), 62 deletions(-) diff --git a/applications/Colossal-LLaMA/colossal_llama/dataset/conversation.py b/applications/Colossal-LLaMA/colossal_llama/dataset/conversation.py index 8ec9c848b2c8..3ffbb87db7fb 100644 --- a/applications/Colossal-LLaMA/colossal_llama/dataset/conversation.py +++ b/applications/Colossal-LLaMA/colossal_llama/dataset/conversation.py @@ -100,7 +100,7 @@ def dict(self): messages=[], offset=0, sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN, - seps=["<|begin_of_text|>", "<|end_of_text|>"], + seps=["<|begin_of_text|>", "<|eot_id|>"], ) default_conversation = LLaMA3_Conv diff --git a/applications/Colossal-LLaMA/colossal_llama/dataset/spliced_and_tokenized_dataset.py b/applications/Colossal-LLaMA/colossal_llama/dataset/spliced_and_tokenized_dataset.py index 30122d2838f9..15cb2987489b 100644 --- a/applications/Colossal-LLaMA/colossal_llama/dataset/spliced_and_tokenized_dataset.py +++ b/applications/Colossal-LLaMA/colossal_llama/dataset/spliced_and_tokenized_dataset.py @@ -88,7 +88,7 @@ def supervised_tokenize_sft( assert ( tokenizer.bos_token == conversation_template.seps[0] and tokenizer.eos_token == conversation_template.seps[1] - ), "`bos_token` and `eos_token` should be the same with `conversation_template.seps`." + ), f"`bos_token`{tokenizer.bos_token} and `eos_token`{tokenizer.eos_token} should be the same with `conversation_template.seps`{conversation_template.seps}." if ignore_index is None: ignore_index = IGNORE_INDEX diff --git a/applications/Colossal-LLaMA/train.py b/applications/Colossal-LLaMA/train.py index db23275e4e31..fb14a3a2832b 100644 --- a/applications/Colossal-LLaMA/train.py +++ b/applications/Colossal-LLaMA/train.py @@ -65,7 +65,7 @@ def train(args) -> None: initial_scale=2**16, max_norm=args.grad_clip, enable_gradient_accumulation=(args.accumulation_steps > 1), - enable_fused_normalization=torch.cuda.is_available(), + enable_fused_normalization=get_accelerator().is_available(), enable_flash_attention=args.use_flash_attn, ) elif args.plugin == "gemini_auto": @@ -75,7 +75,7 @@ def train(args) -> None: initial_scale=2**16, max_norm=args.grad_clip, enable_gradient_accumulation=(args.accumulation_steps > 1), - enable_fused_normalization=torch.cuda.is_available(), + enable_fused_normalization=get_accelerator().is_available(), enable_flash_attention=args.use_flash_attn, ) elif args.plugin == "zero2": @@ -101,7 +101,7 @@ def train(args) -> None: sequence_parallelism_mode=args.sp_mode, zero_stage=args.zero_stage, enable_flash_attention=args.use_flash_attn, - enable_fused_normalization=torch.cuda.is_available(), + enable_fused_normalization=get_accelerator().is_available(), enable_sequence_parallelism=args.enable_sequence_parallelism, cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False, parallel_output=False, @@ -170,19 +170,11 @@ def train(args) -> None: else nullcontext() ) with init_ctx: - if args.use_flash_attn: - model = AutoModelForCausalLM.from_pretrained( - args.pretrained, - attn_implementation="flash_attention_2", - torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, - trust_remote_code=True, - ) - else: - model = AutoModelForCausalLM.from_pretrained( - args.pretrained, - torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, - trust_remote_code=True, - ) + model = AutoModelForCausalLM.from_pretrained( + args.pretrained, + torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + trust_remote_code=True, + ) # Freeze part of parameters. if args.freeze_non_embeds_params: freeze_non_embeds_parameters(model=model) @@ -371,44 +363,44 @@ def train(args) -> None: total_loss.fill_(0.0) pbar.update() - # Save modeling. - save_model_condition = ( - args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0 - ) + # Save modeling. + save_model_condition = ( + args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0 + ) - if not args.skip_save_each_epoch: - save_model_condition = save_model_condition or (step + 1) == len(dataloader) + if not args.skip_save_each_epoch: + save_model_condition = save_model_condition or (step + 1) == len(dataloader) - if save_model_condition and not args.benchmark: - coordinator.print_on_master("\nStart saving model checkpoint with running states") + if save_model_condition and not args.benchmark: + coordinator.print_on_master("\nStart saving model checkpoint with running states") - if args.use_neft: - coordinator.print_on_master("Deactivate NEFTune before saving model.") - deactivate_neftune(model, handle) + if args.use_neft: + coordinator.print_on_master("Deactivate NEFTune before saving model.") + deactivate_neftune(model, handle) - accelerator.empty_cache() - save_checkpoint( - save_dir=args.save_dir, - booster=booster, - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - epoch=epoch, - step=step + 1, - batch_size=args.batch_size, - coordinator=coordinator, - ) - coordinator.print_on_master( - f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}" - ) + accelerator.empty_cache() + save_checkpoint( + save_dir=args.save_dir, + booster=booster, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + epoch=epoch, + step=step + 1, + batch_size=args.batch_size, + coordinator=coordinator, + ) + coordinator.print_on_master( + f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}" + ) - if args.use_neft: - coordinator.print_on_master("Activate NEFTune.") - model, handle = activate_neftune(model) + if args.use_neft: + coordinator.print_on_master("Activate NEFTune.") + model, handle = activate_neftune(model) - # Delete cache. - # del batch, batch_labels, batch_output, loss - accelerator.empty_cache() + # Delete cache. + # del batch, batch_labels, batch_output, loss + accelerator.empty_cache() # the continue epochs are not resumed, so we need to reset the sampler start index and start step dataloader.sampler.set_start_index(start_index=0) diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 043bf6aeb4cd..05bf9ead1f04 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -1,9 +1,14 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import numbers import warnings from abc import ABC, abstractmethod +import torch import torch.nn as nn +import torch_npu +from torch.nn import init +from torch.nn.parameter import Parameter from colossalai.lazy import LazyInitContext @@ -21,7 +26,6 @@ try: from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm - from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm class FusedLayerNormWithHook(ApexFusedLayerNorm): def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True): @@ -32,12 +36,26 @@ def forward(self, input): output = hook_parameter_in_backward(output, self.weight, self.bias) return output - class FusedRMSNormWithHook(ApexFusedRMSNorm): + class FusedRMSNormWithHook(nn.Module): def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True): - super().__init__(normalized_shape, eps, elementwise_affine) + super().__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = torch.Size(normalized_shape) + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = Parameter(torch.empty(*normalized_shape)) + else: + self.register_parameter("weight", None) + self.reset_parameters() + + def reset_parameters(self): + if self.elementwise_affine: + init.ones_(self.weight) def forward(self, input): - output = super().forward(input) + output, _ = torch_npu.npu_rms_norm(input, self.weight, self.eps) output = hook_parameter_in_backward(output, self.weight) return output diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 4858ded4b6d9..d6b009724bf4 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -146,7 +146,7 @@ def empty_init(): offload_param_frac=args.offload_param_frac, tp_size=args.tp, extra_dp_size=args.extra_dp, - enable_fused_normalization=torch.cuda.is_available(), + enable_fused_normalization=get_accelerator().is_available(), enable_flash_attention=args.xformers, max_prefetch=args.prefetch_num, enable_async_reduce=not args.disable_async_reduce, @@ -160,7 +160,7 @@ def empty_init(): warmup_non_model_data_ratio=args.warmup_ratio, tp_size=args.tp, extra_dp_size=args.extra_dp, - enable_fused_normalization=torch.cuda.is_available(), + enable_fused_normalization=get_accelerator().is_available(), max_prefetch=args.prefetch_num, enable_async_reduce=not args.disable_async_reduce, enable_flash_attention=args.xformers, @@ -219,7 +219,7 @@ def empty_init(): sp_size=args.sp, sequence_parallelism_mode=args.sp_mode, enable_sequence_parallelism=args.sp > 1, - enable_fused_normalization=torch.cuda.is_available(), + enable_fused_normalization=get_accelerator().is_available(), enable_flash_attention=args.xformers, microbatch_size=args.mbs, precision="bf16", @@ -237,7 +237,7 @@ def empty_init(): num_model_chunks=args.n_chunks, zero_stage=args.zero, cpu_offload=True, - enable_fused_normalization=torch.cuda.is_available(), + enable_fused_normalization=get_accelerator().is_available(), enable_flash_attention=args.xformers, microbatch_size=args.mbs, initial_scale=2**8, @@ -260,7 +260,7 @@ def empty_init(): config = MODEL_CONFIGS[args.config] else: config = AutoConfig.from_pretrained(args.config, trust_remote_code=True) - torch.cuda.manual_seed(42) + get_accelerator().manual_seed(42) dataset = RandomDataset( num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size ) @@ -308,7 +308,7 @@ def empty_init(): torch.set_default_dtype(torch.float) coordinator.print_on_master( - f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB" + f"Booster init max NPU memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB" ) coordinator.print_on_master( f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" @@ -319,7 +319,7 @@ def empty_init(): args.ignore_steps, 1, # avoid creating massive log files save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}", - nsys=args.nsys, + nsys=False, ) as prof: if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: data_iter = iter(dataloader) @@ -356,7 +356,7 @@ def empty_init(): performance_evaluator.on_step_end(**batch) prof.step() performance_evaluator.on_fit_end() - coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB") + coordinator.print_on_master(f"Max NPU memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB") if __name__ == "__main__": From fae90d681ba6ae7ea5debabdbed0b8da57e3834d Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Mon, 21 Oct 2024 12:28:19 +0000 Subject: [PATCH 3/6] support lora fix fix --- .../Colossal-LLaMA/colossal_llama/utils/ckpt_io.py | 6 +++++- applications/Colossal-LLaMA/train.py | 13 +++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/applications/Colossal-LLaMA/colossal_llama/utils/ckpt_io.py b/applications/Colossal-LLaMA/colossal_llama/utils/ckpt_io.py index 05342ce41a60..2d712f41605e 100644 --- a/applications/Colossal-LLaMA/colossal_llama/utils/ckpt_io.py +++ b/applications/Colossal-LLaMA/colossal_llama/utils/ckpt_io.py @@ -43,6 +43,7 @@ def save_checkpoint( step: int, batch_size: int, coordinator: DistCoordinator, + use_lora: bool = False, ) -> None: """ Save model checkpoint, optimizer, LR scheduler and intermedidate running states. @@ -51,7 +52,10 @@ def save_checkpoint( save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}") os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True) - booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True) + if use_lora: + booster.save_lora_as_pretrained(model, os.path.join(save_dir, "modeling")) + else: + booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True) booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True) booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) diff --git a/applications/Colossal-LLaMA/train.py b/applications/Colossal-LLaMA/train.py index fb14a3a2832b..9571171b5c2b 100644 --- a/applications/Colossal-LLaMA/train.py +++ b/applications/Colossal-LLaMA/train.py @@ -21,6 +21,7 @@ from colossal_llama.utils.froze import freeze_non_embeds_parameters from colossal_llama.utils.neftune_patch import activate_neftune, deactivate_neftune from colossal_llama.utils.utils import all_reduce_mean, format_numel_str, get_model_numel +from peft import LoraConfig from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer @@ -166,7 +167,7 @@ def train(args) -> None: # ====================================================== init_ctx = ( LazyInitContext(default_device=get_current_device()) - if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) + if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) and args.lora_rank == 0 else nullcontext() ) with init_ctx: @@ -178,11 +179,16 @@ def train(args) -> None: # Freeze part of parameters. if args.freeze_non_embeds_params: freeze_non_embeds_parameters(model=model) + + if args.lora_rank > 0: + lora_config = LoraConfig(task_type="CAUSAL_LM", r=args.lora_rank, lora_alpha=32, lora_dropout=0.1) + model = booster.enable_lora(model, lora_config=lora_config) + # this is essential, otherwise the grad checkpoint will not work. model.train() if args.use_grad_checkpoint: - model.gradient_checkpointing_enable() + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") model_numel = get_model_numel(model) @@ -319,6 +325,7 @@ def train(args) -> None: step=step + 1, batch_size=args.batch_size, coordinator=coordinator, + use_lora=(args.lora_rank > 0), ) coordinator.print_on_master( f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}" @@ -389,6 +396,7 @@ def train(args) -> None: step=step + 1, batch_size=args.batch_size, coordinator=coordinator, + use_lora=(args.lora_rank > 0), ) coordinator.print_on_master( f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}" @@ -514,6 +522,7 @@ def train(args) -> None: parser.add_argument( "--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin." ) + parser.add_argument("--lora_rank", type=int, default=0, help="lora rank when using lora to train.") # Additional arguments for benchmark. parser.add_argument("--num_samples", type=int, default=500, help="Number of samples for benchmarking.") From 8e8c0812a1f34fbf97553c4736a7a095576075e0 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Thu, 7 Nov 2024 19:19:51 +0800 Subject: [PATCH 4/6] support chatglm fix fxi fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix fix fix --- applications/Colossal-LLaMA/train.py | 14 +- colossalai/lazy/lazy_init.py | 6 +- colossalai/pipeline/p2p.py | 20 +-- colossalai/shardformer/layer/normalization.py | 36 ++++- colossalai/shardformer/modeling/chatglm2.py | 151 ++++++++++++------ colossalai/shardformer/policies/chatglm2.py | 8 + examples/language/llama/benchmark.py | 6 +- 7 files changed, 171 insertions(+), 70 deletions(-) diff --git a/applications/Colossal-LLaMA/train.py b/applications/Colossal-LLaMA/train.py index 9571171b5c2b..17c5abd53236 100644 --- a/applications/Colossal-LLaMA/train.py +++ b/applications/Colossal-LLaMA/train.py @@ -105,7 +105,6 @@ def train(args) -> None: enable_fused_normalization=get_accelerator().is_available(), enable_sequence_parallelism=args.enable_sequence_parallelism, cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False, - parallel_output=False, max_norm=args.grad_clip, precision=args.mixed_precision, microbatch_size=args.microbatch_size, @@ -118,11 +117,17 @@ def train(args) -> None: # ====================================================== # Initialize Tokenizer, Dataset, Collator and Dataloader # ====================================================== - tokenizer = AutoTokenizer.from_pretrained(args.pretrained) + tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True) if args.pad_token == "eos": - tokenizer.pad_token = tokenizer.eos_token + try: + tokenizer.pad_token = tokenizer.eos_token + except AttributeError: + coordinator.print_on_master(f"pad_token can't be set") elif args.pad_token == "unk": - tokenizer.pad_token = tokenizer.unk_token + try: + tokenizer.pad_token = tokenizer.unk_token + except AttributeError: + coordinator.print_on_master(f"pad_token can't be set") tokenizer.add_bos_token = False tokenizer.add_eos_token = False @@ -165,6 +170,7 @@ def train(args) -> None: # ====================================================== # Initialize Model, Objective, Optimizer and LR Scheduler # ====================================================== + # TODO chatglm doesn't support lora now init_ctx = ( LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) and args.lora_rank == 0 diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py index b130111ba3d9..4072bb1974c2 100644 --- a/colossalai/lazy/lazy_init.py +++ b/colossalai/lazy/lazy_init.py @@ -509,9 +509,9 @@ def wrap_factory_like_method(orig_target, target): # factory_like functions (eg. torch.empty_like()) def wrapper(*args, **kwargs): orig_t = args[0] - return self.tensor_cls( - orig_target, *orig_t.shape, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs - ) + device = kwargs.pop("device", orig_t.device) + dtype = kwargs.pop("dtype", orig_t.dtype) + return self.tensor_cls(orig_target, *orig_t.shape, *args[1:], device=device, dtype=dtype, **kwargs) return wrapper, target diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 841d2f161ce1..b9a324f709d2 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -148,7 +148,7 @@ def _broadcast_object_list( object_list[i] = unpickle_object -def _check_for_nccl_backend(group): +def _check_for_nccl_hccl_backend(group): pg = group or c10d._get_default_group() # Gate PG wrapper check on Gloo availability. if c10d._GLOO_AVAILABLE: @@ -156,21 +156,15 @@ def _check_for_nccl_backend(group): while isinstance(pg, c10d._ProcessGroupWrapper): pg = pg.wrapped_pg - return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL - - -# def _check_device(group): -# is_nccl_backend = _check_for_nccl_backend(group) -# print("_check_device", is_nccl_backend) -# current_device = torch.device("cpu") -# if is_nccl_backend: -# current_device = torch.device(get_accelerator().current_device()) -# return current_device, is_nccl_backend + return (c10d.is_nccl_available() or torch.distributed.is_hccl_available()) and pg.name() == c10d.Backend.NCCL def _check_device(group): - current_device = torch.device(get_accelerator().current_device()) - return current_device, True + is_nccl_backend = _check_for_nccl_hccl_backend(group) + current_device = torch.device("cpu") + if is_nccl_backend: + current_device = torch.device(get_accelerator().current_device()) + return current_device, is_nccl_backend TensorMetadata = namedtuple("TensorMetadata", ["shape", "dtype", "requires_grad"]) diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 05bf9ead1f04..aef8796acc21 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -6,7 +6,6 @@ import torch import torch.nn as nn -import torch_npu from torch.nn import init from torch.nn.parameter import Parameter @@ -15,6 +14,16 @@ from ._operation import hook_parameter_in_backward from .utils import SeqParallelUtils +SUPPORT_NPU = False +try: + import torch_npu + + SUPPORT_NPU = True + warnings.warn("support npu") +except Exception: + warnings.warn("support gpu") + + __all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"] try: @@ -36,7 +45,13 @@ def forward(self, input): output = hook_parameter_in_backward(output, self.weight, self.bias) return output - class FusedRMSNormWithHook(nn.Module): +except ImportError: + warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel") + +FusedRMSNormWithHook = None +if SUPPORT_NPU: + + class NPUFusedRMSNormWithHook(nn.Module): def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True): super().__init__() if isinstance(normalized_shape, numbers.Integral): @@ -55,12 +70,25 @@ def reset_parameters(self): init.ones_(self.weight) def forward(self, input): + output, _ = torch_npu.npu_rms_norm(input, self.weight, self.eps) output = hook_parameter_in_backward(output, self.weight) return output -except ImportError: - warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMSNorm kernel") + FusedRMSNormWithHook = NPUFusedRMSNormWithHook +else: + from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm + + class CUDAFusedRMSNormWithHook(ApexFusedRMSNorm): + def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True): + super().__init__(normalized_shape, eps, elementwise_affine) + + def forward(self, input): + output = super().forward(input) + output = hook_parameter_in_backward(output, self.weight) + return output + + FusedRMSNormWithHook = CUDAFusedRMSNormWithHook FAST_LAYERNORM_SUPPORTED_SIZE = [ 1024, diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index a9be5c74dba8..be13200b5c4f 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -9,7 +9,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig -from colossalai.shardformer.layer import AttnMaskType, ColoAttention +from colossalai.shardformer.layer import ColoAttention from colossalai.shardformer.layer._operation import ( all_to_all_comm, gather_sp_output, @@ -25,42 +25,7 @@ def get_flash_core_attention_forward(): def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_mask): query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] - if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: - attention_mask_type = AttnMaskType.CAUSAL - attn_bias = torch.zeros( - query_layer.shape[0], - 1, - query_layer.shape[2], - key_layer.shape[2], - dtype=query_layer.dtype, - device=query_layer.device, - ) - temp_mask = ( - torch.ones( - query_layer.shape[2], - key_layer.shape[2], - dtype=torch.bool, - device=query_layer.device, - ) - .tril(diagonal=0) - .expand(query_layer.shape[0], 1, -1, -1) - ) - attn_bias.masked_fill_(temp_mask.logical_not(), torch.finfo(query_layer.dtype).min) - else: - attention_mask_type = AttnMaskType.CUSTOM - if attention_mask is not None: - attn_bias = torch.zeros_like(attention_mask, dtype=query_layer.dtype) - attn_bias.masked_fill_(attention_mask, torch.finfo(query_layer.dtype).min) - dropout_p = self.attention_dropout.p if self.training else 0.0 - context_layer = ColoAttention.attention( - query_layer, - key_layer, - value_layer, - attention_mask=attn_bias, - attention_mask_type=attention_mask_type, - dropout_p=dropout_p, - scale=1.0 / self.norm_factor, - ) + context_layer = ColoAttention.attention(query_layer, key_layer, value_layer, **attention_mask) context_layer = context_layer.permute(2, 0, 1, 3) new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.reshape(*new_context_layer_shape) @@ -180,9 +145,20 @@ def chatglm_model_forward( ], dim=-1, ) - if full_attention_mask is None: - if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): - full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + + if shard_config.enable_flash_attention: + mask_shape = (batch_size, 1, seq_length, seq_length) + full_attention_mask: dict = ColoAttention.prepare_attn_kwargs( + mask_shape, + hidden_states.dtype, + hidden_states.device, + q_padding_mask=attention_mask, + is_causal=True, + ) + else: + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) # Support SP + PP sp_size = shard_config.sequence_parallel_size @@ -237,7 +213,7 @@ def chatglm_model_forward( layer_ret = torch.utils.checkpoint.checkpoint( layer, hidden_states, - attention_mask, + full_attention_mask, rotary_pos_emb, past_key_values[idx], use_cache, @@ -402,10 +378,19 @@ def forward( ], dim=-1, ) - - if full_attention_mask is None: - if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): - full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + if shard_config.enable_flash_attention: + mask_shape = (batch_size, 1, seq_length, seq_length) + full_attention_mask: dict = ColoAttention.prepare_attn_kwargs( + mask_shape, + hidden_states.dtype, + hidden_states.device, + q_padding_mask=attention_mask, + is_causal=True, + ) + else: + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) # Rotary positional embeddings rotary_pos_emb = self.rotary_pos_emb(self.seq_length) @@ -652,3 +637,79 @@ def forward( return output, kv_cache return forward + + +def get_flash_attention_forward_for_chat_glm_model(): + from .chatglm2_6b.modeling_chatglm import ChatGLMModel + + def forward( + self: ChatGLMModel, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt( + batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype + ) + if attention_mask is not None: + attention_mask = torch.cat( + [attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1 + ) + + mask_shape = (batch_size, 1, seq_length, seq_length) + full_attention_mask: dict = ColoAttention.prepare_attn_kwargs( + mask_shape, + inputs_embeds.dtype, + inputs_embeds.device, + q_padding_mask=attention_mask, + is_causal=True, + ) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, + full_attention_mask, + rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + ) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + return forward diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index c003570a0582..4ddcf8bfce6b 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -11,6 +11,7 @@ from ..modeling.chatglm2 import ( get_chatglm_sequence_parallel_attention_forward, get_chatglm_sequence_parallel_forward_fn, + get_flash_attention_forward_for_chat_glm_model, get_flash_core_attention_forward, get_jit_fused_glm_block_forward, ) @@ -203,6 +204,13 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy=policy, target_key="CoreAttention", ) + self.append_or_create_method_replacement( + description={ + "forward": get_flash_attention_forward_for_chat_glm_model(), + }, + policy=policy, + target_key="ChatGLMModel", + ) # use sequence parallel if self.shard_config.enable_sequence_parallelism: diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index d6b009724bf4..ad3b9a34bc0d 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -292,9 +292,13 @@ def empty_init(): model_numel = get_model_numel(model) coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}") + if config.model_type == "chatglm": + num_layers = model.config.num_layers + else: + num_layers = model.config.num_hidden_layers performance_evaluator = PerformanceEvaluator( model_numel, - model.config.num_hidden_layers, + num_layers, model.config.hidden_size, model.config.vocab_size, args.grad_checkpoint, From 85be43f4f4ab3130ede98f4d0f61160c30d43e5e Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 20 Nov 2024 15:27:04 +0800 Subject: [PATCH 5/6] Update train.py --- applications/Colossal-LLaMA/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/Colossal-LLaMA/train.py b/applications/Colossal-LLaMA/train.py index 17c5abd53236..cb39c9121924 100644 --- a/applications/Colossal-LLaMA/train.py +++ b/applications/Colossal-LLaMA/train.py @@ -170,7 +170,7 @@ def train(args) -> None: # ====================================================== # Initialize Model, Objective, Optimizer and LR Scheduler # ====================================================== - # TODO chatglm doesn't support lora now + # When training of the ChatGLM model, LoRA and gradient checkpointing are incompatible. init_ctx = ( LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) and args.lora_rank == 0 From beffe3ab0aa23b3683dab2f0ce712c3ad0a8f7a6 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 20 Nov 2024 15:27:55 +0800 Subject: [PATCH 6/6] Update train.py --- applications/Colossal-LLaMA/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/Colossal-LLaMA/train.py b/applications/Colossal-LLaMA/train.py index cb39c9121924..6650469f30c2 100644 --- a/applications/Colossal-LLaMA/train.py +++ b/applications/Colossal-LLaMA/train.py @@ -170,7 +170,7 @@ def train(args) -> None: # ====================================================== # Initialize Model, Objective, Optimizer and LR Scheduler # ====================================================== - # When training of the ChatGLM model, LoRA and gradient checkpointing are incompatible. + # When training the ChatGLM model, LoRA and gradient checkpointing are incompatible. init_ctx = ( LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) and args.lora_rank == 0