Skip to content

Commit

Permalink
add plw trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
xiyang-aads-lilly committed Oct 17, 2024
1 parent b7524c8 commit 9b2fde5
Show file tree
Hide file tree
Showing 13 changed files with 414 additions and 79 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,4 @@ experiments/*
!experiments/demo*
!experiments/README.md
!experiments/util.sh
*/sang_project/*
4 changes: 2 additions & 2 deletions experiments/demo_magtrain_llm_sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ echo $PRIMARY_PORT

# manually set
export WANDB_PROJECT="sang"
# TRAIN_CONF=${ROOT}/recipes/sang_project/config_full_1.yaml
TRAIN_CONF=${ROOT}/recipes/sang_project/config_full_2.yaml
TRAIN_CONF=${ROOT}/recipes/sang_project/config_full_1.yaml
# TRAIN_CONF=${ROOT}/recipes/sang_project/config_full_2.yaml

DEEPSPEED_CONF=${ROOT}/recipes/accelerate_configs/deepspeed_zs2.json

Expand Down
2 changes: 1 addition & 1 deletion experiments/demo_magtrain_slurm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#SBATCH --job-name=llm_sft
#SBATCH --mail-type=ALL
#SBATCH [email protected]
#SBATCH --nodes=4
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=1
#SBATCH --gpus-per-node=4
#SBATCH --gpus-per-task=4
Expand Down
2 changes: 1 addition & 1 deletion recipes/accelerate_configs/deepspeed_zs2.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_min_lr": 1e-8,
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
Expand Down
79 changes: 79 additions & 0 deletions recipes/accelerate_configs/deepspeed_zs3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
{
"fp16": {
"enabled": false,
"loss_scale": 0,
"auto_cast": false,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"consecutive_hysteresis": false,
"min_loss_scale": 1
},

"bf16": {
"enabled": true
},

"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"weight_decay": "auto",
"betas": "auto",
"eps": "auto",
"torch_adam": true,
"adam_w_mode": true
}
},

"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": 1e-8,
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},

"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu"
},
"offload_param": {
"device": "cpu"
},
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"reduce_bucket_size": "auto",
"overlap_comm": true,
"reduce_scatter": true,
"contiguous_gradients": true,
"round_robin_gradients": true
},

"aio": {
"block_size": 262144,
"queue_depth": 32,
"thread_count": 1,
"single_submit": false,
"overlap_events": true
},

"activation_checkpointing":{
"partition_activations": false,
"cpu_checkpointing": false,
"contiguous_memory_optimization": true,
"number_checkpoints": null,
"synchronize_checkpoint_boundary": false,
"profile": true
},

"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 20000000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
21 changes: 12 additions & 9 deletions recipes/sang_project/config_full_1.yaml
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
# Model arguments
model_name_or_path: /home/l069561/project/models/gemma-2-2b
model_name_or_path: /home/l069561/project/models/Meta-Llama-3-8B #togethercomputer/StripedHyena-Hessian-7B
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2

# Data training arguments
chat_template: "{% if messages[0]['role'] == 'system' %}{% set system_message = '### System Instruction: ' + messages[0]['content'] | trim + '' %}{% set messages = messages[1:] %}{% else %}{% set system_message = '' %}{% endif %}{{ bos_token + system_message }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '### Context: ' + message['content'] | trim + '' }}{% elif message['role'] == 'assistant' %}{{ '### Result: ' + message['content'] | trim + eos_token + '' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '### Result: ' }}{% endif %}"
dataset_mixer:
/home/l069561/project/data/processed_data_open_sourced_xml_to_text/merged_open_sourced_xml_to_text_dataset: 1.0
/home/l069561/project/data/test_8k_nlp: 1.0
# HuggingFaceH4/ultrachat_200k: 1.0
# /home/l069561/project/data/processed_data_open_sourced_xml_to_text/merged_open_sourced_xml_to_text_dataset: 1.0
# /home/l069561/project/data/sang_data_formatted: 1.0
dataset_splits:
- train_sft
- test_sft
preprocessing_num_workers: 4

# SFT trainer config
trust_remote_code: true
bf16: true
do_eval: true
do_eval: false
# evaluation_strategy: epoch
eval_strategy: epoch
max_grad_norm: 1.0
Expand All @@ -36,19 +39,19 @@ max_seq_length: 8192
packing: false
dataset_num_proc: 16
max_steps: -1
num_train_epochs: 2
output_dir: /home/l069561/project/alignment-handbook/experiments/models/sang_exp1_stage1_gemma-2-2b_full
num_train_epochs: 100
output_dir: /home/l069561/project/alignment-handbook/experiments/models/test_deepspeed
overwrite_output_dir: true
per_device_eval_batch_size: 1
per_device_train_batch_size: 1 # this is per device, you need to manual calculate global batch by per device * gas * gpu * node
gradient_accumulation_steps: 4
push_to_hub: false
remove_unused_columns: true
report_to:
- tensorboard
- wandb
save_strategy: "steps"
save_steps: 2000
save_total_limit: 10
# - tensorboard
save_strategy: "no"
save_steps: 2500
save_total_limit: 1
seed: 42
warmup_ratio: 0.1
2 changes: 1 addition & 1 deletion recipes/sang_project/config_full_2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ report_to:
- tensorboard
- wandb
save_strategy: "steps"
save_steps: 1500
save_steps: 2500
save_total_limit: 10
seed: 42
warmup_ratio: 0.1
112 changes: 75 additions & 37 deletions scripts/run_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
GpuUtilPrintCallBack,
H4ArgumentParser,
ModelArguments,
ProfCallback,
PLW_apply_chat_template,
PLWTrainer,
SFTConfig,
apply_chat_template,
get_checkpoint,
Expand Down Expand Up @@ -107,7 +108,9 @@ def main():
################
# Load tokenizer
################
tokenizer = get_tokenizer(model_args, data_args, training_args)
tokenizer = get_tokenizer(
model_args, data_args, training_args, auto_set_chat_template=True
)

#######################
# Load pretrained model
Expand Down Expand Up @@ -150,22 +153,35 @@ def main():
# Apply chat template
#####################
logger.info("*** apply chat template ***")
raw_datasets = raw_datasets.map(
apply_chat_template,
fn_kwargs={
"tokenizer": tokenizer,
"task": "sft",
"auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg,
},
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
desc="Applying chat template",
)

if training_args.use_plw:
raw_datasets = raw_datasets.map(
PLW_apply_chat_template,
fn_kwargs={
"tokenizer": tokenizer,
"use_sample_template": training_args.use_plw_sample_template,
},
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
desc="Applying chat template",
)
else:
raw_datasets = raw_datasets.map(
apply_chat_template,
fn_kwargs={
"tokenizer": tokenizer,
"task": "sft",
"auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg,
},
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
desc="Applying chat template",
)

train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]

# this is hard coded
# this is hard coded - move to config.yaml
training_args.dataset_text_field = "text"

# # no need for logging samples
Expand Down Expand Up @@ -213,26 +229,50 @@ def main():
):
model, tokenizer = setup_chat_format(model, tokenizer)

trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
dataset_kwargs=training_args.dataset_kwargs,
callbacks=[GpuUtilPrintCallBack()],
)
if training_args.use_plw:
trainer = PLWTrainer(
prompt_loss_weight=training_args.prompt_loss_weight,
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
dataset_kwargs=training_args.dataset_kwargs,
# callbacks=[GpuUtilPrintCallBack()],
)
else:
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
dataset_kwargs=training_args.dataset_kwargs,
# callbacks=[GpuUtilPrintCallBack()],
)
else:
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=get_peft_config(model_args),
dataset_kwargs=training_args.dataset_kwargs,
callbacks=[GpuUtilPrintCallBack()],
)
if training_args.use_plw:
trainer = PLWTrainer(
prompt_loss_weight=training_args.prompt_loss_weight,
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
dataset_kwargs=training_args.dataset_kwargs,
# callbacks=[GpuUtilPrintCallBack()],
)
else:
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=get_peft_config(model_args),
dataset_kwargs=training_args.dataset_kwargs,
# callbacks=[GpuUtilPrintCallBack()],
)

###############
# Training loop
Expand Down Expand Up @@ -290,13 +330,11 @@ def main():
# logger.info("Pushing to hub...")
# trainer.push_to_hub(**kwargs)

torch.cuda.memory._dump_snapshot(
Path(training_args.output_dir) / "GPU_RAM_PROFILE.pickle"
)
# torch.cuda.memory._dump_snapshot(Path(training_args.output_dir) / "GPU_RAM_PROFILE.pickle")
# prof.close()
logger.info("*** Training complete ***")


if __name__ == "__main__":
torch.cuda.memory._record_memory_history()
# torch.cuda.memory._record_memory_history()
main()
13 changes: 6 additions & 7 deletions src/alignment/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)
from .data import apply_chat_template, get_datasets
from .model_utils import (
add_new_special_token,
get_checkpoint,
get_kbit_device_map,
get_peft_config,
Expand All @@ -17,13 +18,9 @@
is_adapter_model,
tokenizer_and_embedding_resize,
)
from .plw_trainer import PLW_apply_chat_template, PLWTrainer
from .simpo_trainer import SimPOTrainer
from .utils import (
GpuUtilPrintCallBack,
ProfCallback,
print_gpu_utilization,
print_summary,
)
from .utils import GpuUtilPrintCallBack, ProfCallback


__all__ = [
Expand All @@ -34,11 +31,13 @@
"SFTConfig",
"apply_chat_template",
"get_datasets",
"decontaminate_humaneval",
"get_checkpoint",
"get_kbit_device_map",
"get_peft_config",
"get_quantization_config",
"get_tokenizer",
"is_adapter_model",
"PLW_apply_chat_template",
"PLWTrainer",
"SimPOTrainer",
]
7 changes: 3 additions & 4 deletions src/alignment/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,9 @@ class SFTConfig(trl.SFTConfig):
"help": ("Whether to log and evaluate the first global_step or not.")
},
)
# max_seq_length: Optional[int] = field(
# default=None,
# )
# packing: Optional[bool] = field(default=False)
prompt_loss_weight: float = field(default=0.1)
use_plw: bool = field(default=False)
use_plw_sample_template: bool = field(default=False)


@dataclass
Expand Down
Loading

0 comments on commit 9b2fde5

Please sign in to comment.