Skip to content

Commit

Permalink
remove dup code; add unsloth support
Browse files Browse the repository at this point in the history
  • Loading branch information
xiyang-aads-lilly committed May 30, 2024
1 parent 2996eda commit 6dc510f
Show file tree
Hide file tree
Showing 12 changed files with 219 additions and 278 deletions.
3 changes: 2 additions & 1 deletion experiments/demo_dgx2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ export HF_DATASETS_CACHE="${ROOT}/project/.cache/dataset"
export HF_HOME="${ROOT}/project/.cache/"

# Wandb
export WANDB_API_KEY="<key>"
export WANDB_API_KEY="05411100e08ac02e3fcbdc821b4116cf1c066e99"
# export WANDB_API_KEY="<key>"
export WANDB_USERNAME="xi-yang5"
export WANDB_PROJECT="demo_dgx2"
export WANDB_LOG_MODEL="false"
Expand Down
2 changes: 1 addition & 1 deletion experiments/demo_dgx2_launch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ ROOT=$(realpath ~)
CONTAINER=${ROOT}/project/singularity_containers/py2402.sig

# CUDA
export CUDA_VISIBLE_DEVICES=0,1,2
export CUDA_VISIBLE_DEVICES=0,1

# PATH
DEMO_PATH=${ROOT}/project/alignment_handbook/experiments
Expand Down
16 changes: 9 additions & 7 deletions recipes/accelerate_configs/deepspeed_zero2.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
compute_environment: LOCAL_MACHINE
debug: true
deepspeed_config:
deepspeed_multinode_launcher: standard
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero3_save_16bit_model: false
zero_stage: 2
mixed_precision: bf16
deepspeed_config_file: /home/l069561/project/alignment-handbook/recipes/accelerate_configs/deepspeed_zs2.json
zero3_init_flag: true
# deepspeed_multinode_launcher: standard
# offload_optimizer_device: none
# offload_param_device: none
# zero3_init_flag: true
# zero3_save_16bit_model: false
# zero_stage: 2
# mixed_precision: bf16
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
Expand Down
9 changes: 8 additions & 1 deletion recipes/accelerate_configs/deepspeed_zs2.json
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
{
"fp16": {
"enabled": true,
"enabled": false,
"loss_scale": 0,
"auto_cast": false,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"consecutive_hysteresis": false,
"min_loss_scale": 1
},
"bf16": {
"enabled": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"weight_decay": "auto",
"betas": "auto",
"eps": "auto",
"torch_adam": true,
"adam_w_mode": true
}
Expand Down
38 changes: 38 additions & 0 deletions recipes/accelerate_configs/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
## deepspeed optimizers
- DeepSpeed natively supports Adam, AdamW, OneBitAdam, Lamb, OneBitLamb, FusedLamb, FusedAdam
- see for details on how to config https://deepspeed.readthedocs.io/en/latest/optimizers.html
-
```json
{
"optimizer": {
"type": "OneBitLamb",
"params": {
"lr": 1e-3,
"weight_decay": 0.01,
"bias_correction": false,
"max_coeff": 0.3,
"min_coeff": 0.01,
"freeze_step": 1000,
"cuda_aware": false,
"comm_backend_name": "nccl",
"coeff_beta": 0.9,
"factor_max": 4.0,
"factor_min": 0.5,
"factor_threshold": 0.1
}
},
}

{
"optimizer": {
"type": "Lamb",
"params": {
"lr": 1e-3,
"weight_decay": 0.01,
"bias_correction": false,
"max_coeff": 0.3,
"min_coeff": 0.01
}
},
}
```
12 changes: 7 additions & 5 deletions recipes/llama3-8b/sft/config_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,30 @@ preprocessing_num_workers: 8
bf16: true
do_eval: true
evaluation_strategy: epoch
max_grad_norm: 1.0
gradient_accumulation_steps: 1
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: False
hub_model_id: null
hub_strategy: every_save
learning_rate: 2.0e-05
log_level: info
logging_steps: 5
logging_strategy: steps
learning_rate: 2.0e-05
optim: galore_adamw # adamw_torch paged_adamw_32bit galore_adamw lion_32bit
weight_decay: 0.01
lr_scheduler_type: cosine
max_seq_length: 2048
max_seq_length: 4096
max_steps: -1
num_train_epochs: 1
output_dir: /home/l069561/project/models/fine-tuned/demo-llama-3-full-ultrachat
output_dir: /home/l069561/project/alignment_handbook/experiments/models/llama-3-full-ultrachat
overwrite_output_dir: true
per_device_eval_batch_size: 8
per_device_train_batch_size: 16
push_to_hub: false
remove_unused_columns: true
report_to:
- tensorboard
- wandb
save_strategy: "steps"
save_steps: 100
save_total_limit: 1
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ jinja2>=3.0.0
tqdm>=4.64.1
flash-attn>=2.1.0
pynvml>=11.4.0

# optional
galore-torch
85 changes: 40 additions & 45 deletions scripts/run_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
get_peft_config,
get_quantization_config,
get_tokenizer,
tokenizer_and_embedding_resize,
)
from trl import SFTTrainer, setup_chat_format

Expand Down Expand Up @@ -110,31 +109,6 @@ def main():
)
column_names = list(raw_datasets["train"].features)

#######################
# Load pretrained model
#######################
logger.info("*** Load pretrained model ***")
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)

model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
use_flash_attention_2=model_args.use_flash_attention_2, # attn_implementation="flash_attention_2"
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
logger.info("*** Model loaded! ***")
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, **model_kwargs
)

################
# Load tokenizer
################
Expand Down Expand Up @@ -162,6 +136,7 @@ def main():
)

model = model_args.model_name_or_path

# For ChatML we need to add special tokens and resize the embedding layer
if (
"<|im_start|>" in tokenizer.chat_template
Expand All @@ -173,11 +148,6 @@ def main():
model, tokenizer = setup_chat_format(model, tokenizer)
model_kwargs = None

###############
# update new tokens added to tokenizer
###############
tokenizer_and_embedding_resize(data_args, tokenizer, model)

#####################
# Apply chat template
#####################
Expand Down Expand Up @@ -222,30 +192,55 @@ def main():
########################
# Initialize the Trainer
########################
trainer = SFTTrainer(
model=model,
model_init_kwargs=model_kwargs,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
dataset_text_field="text",
max_seq_length=training_args.max_seq_length,
tokenizer=tokenizer,
packing=True,
peft_config=get_peft_config(model_args),
dataset_kwargs=training_args.dataset_kwargs,
callbacks=[GpuUtilPrintCallBack()],
)
if model_args.use_unsloth:
from alignment.unsloth import get_unsloth_peft_model

peft_config = get_peft_config(model_args)
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, **model_kwargs
)
model, tokenizer = setup_chat_format(model, tokenizer)
model = get_unsloth_peft_model(model, training_args.max_seq_length, peft_config)

trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
dataset_text_field="text",
max_seq_length=training_args.max_seq_length,
tokenizer=tokenizer,
packing=True,
dataset_kwargs=training_args.dataset_kwargs,
callbacks=[GpuUtilPrintCallBack()],
)
else:
trainer = SFTTrainer(
model=model,
model_init_kwargs=model_kwargs,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
dataset_text_field="text",
max_seq_length=training_args.max_seq_length,
tokenizer=tokenizer,
packing=True,
peft_config=get_peft_config(model_args),
dataset_kwargs=training_args.dataset_kwargs,
callbacks=[GpuUtilPrintCallBack()],
)

###############
# Training loop
###############
logger.info("*** Train ***")

checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint

train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(train_dataset)
Expand Down
4 changes: 4 additions & 0 deletions src/alignment/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ class ModelArguments:
)
},
)
use_unsloth: bool = field(
default=False,
metadata={"help": ("Whether to use unsloth to accelerate lora.")},
)
use_peft: bool = field(
default=False,
metadata={"help": ("Whether to use PEFT or not for training.")},
Expand Down
98 changes: 0 additions & 98 deletions src/alignment/decontaminate.py

This file was deleted.

Loading

0 comments on commit 6dc510f

Please sign in to comment.