Skip to content

Commit

Permalink
support pretrain
Browse files Browse the repository at this point in the history
support pretrain

fix
  • Loading branch information
flybird11111 committed Oct 15, 2024
1 parent 66b72aa commit a6ac181
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def dict(self):
messages=[],
offset=0,
sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN,
seps=["<|begin_of_text|>", "<|end_of_text|>"],
seps=["<|begin_of_text|>", "<|eot_id|>"],
)

default_conversation = LLaMA3_Conv
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def supervised_tokenize_sft(

assert (
tokenizer.bos_token == conversation_template.seps[0] and tokenizer.eos_token == conversation_template.seps[1]
), "`bos_token` and `eos_token` should be the same with `conversation_template.seps`."
), f"`bos_token`{tokenizer.bos_token} and `eos_token`{tokenizer.eos_token} should be the same with `conversation_template.seps`{conversation_template.seps}."

if ignore_index is None:
ignore_index = IGNORE_INDEX
Expand Down
88 changes: 40 additions & 48 deletions applications/Colossal-LLaMA/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def train(args) -> None:
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_fused_normalization=torch.cuda.is_available(),
enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "gemini_auto":
Expand All @@ -75,7 +75,7 @@ def train(args) -> None:
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_fused_normalization=torch.cuda.is_available(),
enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2":
Expand All @@ -101,7 +101,7 @@ def train(args) -> None:
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_fused_normalization=torch.cuda.is_available(),
enable_fused_normalization=get_accelerator().is_available(),
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
Expand Down Expand Up @@ -170,19 +170,11 @@ def train(args) -> None:
else nullcontext()
)
with init_ctx:
if args.use_flash_attn:
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
else:
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
# Freeze part of parameters.
if args.freeze_non_embeds_params:
freeze_non_embeds_parameters(model=model)
Expand Down Expand Up @@ -371,44 +363,44 @@ def train(args) -> None:
total_loss.fill_(0.0)
pbar.update()

# Save modeling.
save_model_condition = (
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
)
# Save modeling.
save_model_condition = (
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
)

if not args.skip_save_each_epoch:
save_model_condition = save_model_condition or (step + 1) == len(dataloader)
if not args.skip_save_each_epoch:
save_model_condition = save_model_condition or (step + 1) == len(dataloader)

if save_model_condition and not args.benchmark:
coordinator.print_on_master("\nStart saving model checkpoint with running states")
if save_model_condition and not args.benchmark:
coordinator.print_on_master("\nStart saving model checkpoint with running states")

if args.use_neft:
coordinator.print_on_master("Deactivate NEFTune before saving model.")
deactivate_neftune(model, handle)
if args.use_neft:
coordinator.print_on_master("Deactivate NEFTune before saving model.")
deactivate_neftune(model, handle)

accelerator.empty_cache()
save_checkpoint(
save_dir=args.save_dir,
booster=booster,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
epoch=epoch,
step=step + 1,
batch_size=args.batch_size,
coordinator=coordinator,
)
coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
)
accelerator.empty_cache()
save_checkpoint(
save_dir=args.save_dir,
booster=booster,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
epoch=epoch,
step=step + 1,
batch_size=args.batch_size,
coordinator=coordinator,
)
coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
)

if args.use_neft:
coordinator.print_on_master("Activate NEFTune.")
model, handle = activate_neftune(model)
if args.use_neft:
coordinator.print_on_master("Activate NEFTune.")
model, handle = activate_neftune(model)

# Delete cache.
# del batch, batch_labels, batch_output, loss
accelerator.empty_cache()
# Delete cache.
# del batch, batch_labels, batch_output, loss
accelerator.empty_cache()

# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader.sampler.set_start_index(start_index=0)
Expand Down
26 changes: 22 additions & 4 deletions colossalai/shardformer/layer/normalization.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import numbers
import warnings
from abc import ABC, abstractmethod

import torch
import torch.nn as nn
import torch_npu
from torch.nn import init
from torch.nn.parameter import Parameter

from colossalai.lazy import LazyInitContext

Expand All @@ -21,7 +26,6 @@

try:
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm

class FusedLayerNormWithHook(ApexFusedLayerNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
Expand All @@ -32,12 +36,26 @@ def forward(self, input):
output = hook_parameter_in_backward(output, self.weight, self.bias)
return output

class FusedRMSNormWithHook(ApexFusedRMSNorm):
class FusedRMSNormWithHook(nn.Module):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine)
super().__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = Parameter(torch.empty(*normalized_shape))
else:
self.register_parameter("weight", None)
self.reset_parameters()

def reset_parameters(self):
if self.elementwise_affine:
init.ones_(self.weight)

def forward(self, input):
output = super().forward(input)
output, _ = torch_npu.npu_rms_norm(input, self.weight, self.eps)
output = hook_parameter_in_backward(output, self.weight)
return output

Expand Down
16 changes: 8 additions & 8 deletions examples/language/llama/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def empty_init():
offload_param_frac=args.offload_param_frac,
tp_size=args.tp,
extra_dp_size=args.extra_dp,
enable_fused_normalization=torch.cuda.is_available(),
enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.xformers,
max_prefetch=args.prefetch_num,
enable_async_reduce=not args.disable_async_reduce,
Expand All @@ -160,7 +160,7 @@ def empty_init():
warmup_non_model_data_ratio=args.warmup_ratio,
tp_size=args.tp,
extra_dp_size=args.extra_dp,
enable_fused_normalization=torch.cuda.is_available(),
enable_fused_normalization=get_accelerator().is_available(),
max_prefetch=args.prefetch_num,
enable_async_reduce=not args.disable_async_reduce,
enable_flash_attention=args.xformers,
Expand Down Expand Up @@ -219,7 +219,7 @@ def empty_init():
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
enable_sequence_parallelism=args.sp > 1,
enable_fused_normalization=torch.cuda.is_available(),
enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.xformers,
microbatch_size=args.mbs,
precision="bf16",
Expand All @@ -237,7 +237,7 @@ def empty_init():
num_model_chunks=args.n_chunks,
zero_stage=args.zero,
cpu_offload=True,
enable_fused_normalization=torch.cuda.is_available(),
enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.xformers,
microbatch_size=args.mbs,
initial_scale=2**8,
Expand All @@ -260,7 +260,7 @@ def empty_init():
config = MODEL_CONFIGS[args.config]
else:
config = AutoConfig.from_pretrained(args.config, trust_remote_code=True)
torch.cuda.manual_seed(42)
get_accelerator().manual_seed(42)
dataset = RandomDataset(
num_samples=args.batch_size * args.num_steps * dp_size, max_length=args.max_length, vocab_size=config.vocab_size
)
Expand Down Expand Up @@ -308,7 +308,7 @@ def empty_init():

torch.set_default_dtype(torch.float)
coordinator.print_on_master(
f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB"
f"Booster init max NPU memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB"
)
coordinator.print_on_master(
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
Expand All @@ -319,7 +319,7 @@ def empty_init():
args.ignore_steps,
1, # avoid creating massive log files
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
nsys=args.nsys,
nsys=False,
) as prof:
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
data_iter = iter(dataloader)
Expand Down Expand Up @@ -356,7 +356,7 @@ def empty_init():
performance_evaluator.on_step_end(**batch)
prof.step()
performance_evaluator.on_fit_end()
coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")
coordinator.print_on_master(f"Max NPU memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")


if __name__ == "__main__":
Expand Down

0 comments on commit a6ac181

Please sign in to comment.