forked from huggingface/alignment-handbook
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add demo on dgx02 for multi GPU training
- Loading branch information
1 parent
190935b
commit a5f0909
Showing
15 changed files
with
305 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -164,4 +164,8 @@ data/ | |
wandb/ | ||
|
||
.DS_Store | ||
.vscode | ||
.vscode | ||
|
||
experiments/* | ||
!experiments/.gitkeep | ||
!experiments/demo* |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
#!/usr/bin/bash | ||
|
||
ROOT=$(realpath ~) | ||
|
||
# location | ||
echo activate virtual ENV | ||
PYTHON_ENV=${ROOT}/project/scripts/v2306.sh | ||
source $PYTHON_ENV | ||
|
||
# number of GPUs; here we use all GPUs for demo | ||
WORLD_SIZE=3 | ||
|
||
# HF cache | ||
export TMPDIR="${ROOT}/project/.cache/" | ||
export HF_DATASETS_CACHE="${ROOT}/project/.cache/dataset" | ||
export HF_HOME="${ROOT}/project/.cache/" | ||
|
||
# Wandb | ||
export WANDB_API_KEY="<key>" | ||
export WANDB_USERNAME="xi-yang5" | ||
export WANDB_PROJECT="demo_dgx2" | ||
export WANDB_LOG_MODEL="false" | ||
export WANDB_WATCH="false" | ||
|
||
# TORCH and NCCL | ||
export TORCH_DISTRIBUTED_DEBUG=INFO | ||
export NCCL_DEBUG=INFO | ||
# export NCCL_SOCKET_NTHREADS=16 | ||
|
||
export ACCELERATE_LOG_LEVEL=debug | ||
export ACCELERATE_DEBUG_MODE="1" | ||
export DEEPSPEED_TIMEOUT=120 | ||
|
||
accelerate launch \ | ||
--config_file ${ROOT}/project/alignment_handbook/recipes/accelerate_configs/deepspeed_zero2.yaml \ | ||
--num_processes $WORLD_SIZE \ | ||
--tee 3 \ | ||
${ROOT}/project/alignment_handbook/scripts/run_sft.py \ | ||
${ROOT}/project/alignment_handbook/recipes/llama3-8b/sft/config_qlora.yaml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#!/usr/bin/bash | ||
|
||
ROOT=$(realpath ~) | ||
|
||
# singularity container | ||
CONTAINER=${ROOT}/project/singularity_containers/py2402.sig | ||
|
||
# CUDA | ||
export CUDA_VISIBLE_DEVICES=0,1,2 | ||
|
||
# PATH | ||
DEMO_PATH=${ROOT}/project/alignment_handbook/experiments | ||
|
||
# launch | ||
singularity exec --nv $CONTAINER bash ${DEMO_PATH}/demo_dgx2.sh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Model arguments | ||
model_name_or_path: | ||
torch_dtype: null | ||
|
||
# Data training arguments | ||
# For definitions, see: src/h4/training/config.py | ||
dataset_mixer: | ||
HuggingFaceH4/ultrafeedback_binarized: 1.0 | ||
dataset_splits: | ||
- train_prefs | ||
- test_prefs | ||
preprocessing_num_workers: 12 | ||
|
||
# DPOTrainer arguments | ||
bf16: true | ||
beta: 0.01 | ||
do_eval: true | ||
evaluation_strategy: steps | ||
eval_steps: 100 | ||
gradient_accumulation_steps: 2 | ||
gradient_checkpointing: true | ||
gradient_checkpointing_kwargs: | ||
use_reentrant: False | ||
learning_rate: 5.0e-7 | ||
log_level: info | ||
logging_steps: 10 | ||
lr_scheduler_type: cosine | ||
max_length: 1024 | ||
max_prompt_length: 512 | ||
num_train_epochs: 1 | ||
optim: adamw_torch | ||
output_dir: | ||
per_device_train_batch_size: 8 | ||
per_device_eval_batch_size: 8 | ||
save_strategy: "steps" | ||
save_steps: 100 | ||
save_total_limit: 1 | ||
seed: 42 | ||
warmup_ratio: 0.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Model arguments | ||
model_name_or_path: | ||
torch_dtype: bfloat16 | ||
use_flash_attention_2: true | ||
|
||
# LoRA arguments | ||
use_peft: true | ||
load_in_4bit: true | ||
lora_r: 128 | ||
lora_alpha: 128 | ||
lora_dropout: 0.05 | ||
lora_target_modules: all | ||
# - q_proj | ||
# - k_proj | ||
# - v_proj | ||
# - o_proj | ||
# - gate_proj | ||
# - up_proj | ||
# - down_proj | ||
|
||
# Data training arguments | ||
|
||
dataset_mixer: | ||
HuggingFaceH4/ultrafeedback_binarized: 1.0 | ||
dataset_splits: | ||
- train_prefs | ||
- test_prefs | ||
preprocessing_num_workers: 12 | ||
|
||
# DPOTrainer arguments | ||
bf16: true | ||
beta: 0.01 | ||
do_eval: true | ||
evaluation_strategy: steps | ||
eval_steps: 100 | ||
gradient_accumulation_steps: 4 | ||
gradient_checkpointing: true | ||
gradient_checkpointing_kwargs: | ||
use_reentrant: false | ||
learning_rate: 5.0e-6 | ||
log_level: info | ||
logging_steps: 10 | ||
lr_scheduler_type: cosine | ||
max_length: 1024 | ||
max_prompt_length: 512 | ||
num_train_epochs: 1 | ||
optim: paged_adamw_32bit | ||
output_dir: | ||
per_device_train_batch_size: 4 | ||
per_device_eval_batch_size: 8 | ||
save_strategy: "steps" | ||
save_steps: 100 | ||
save_total_limit: 1 | ||
seed: 42 | ||
warmup_ratio: 0.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Model arguments | ||
model_name_or_path: /home/l069561/project/models/Meta-Llama-3-8B | ||
model_revision: main | ||
torch_dtype: bfloat16 | ||
use_flash_attention_2: true | ||
|
||
# Data training arguments | ||
chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" | ||
dataset_mixer: | ||
HuggingFaceH4/ultrachat_200k: 1.0 | ||
dataset_splits: | ||
- train_sft | ||
- test_sft | ||
preprocessing_num_workers: 8 | ||
|
||
# SFT trainer config | ||
bf16: true | ||
do_eval: true | ||
evaluation_strategy: epoch | ||
gradient_accumulation_steps: 1 | ||
gradient_checkpointing: true | ||
gradient_checkpointing_kwargs: | ||
use_reentrant: False | ||
hub_model_id: null | ||
hub_strategy: every_save | ||
learning_rate: 2.0e-05 | ||
log_level: info | ||
logging_steps: 5 | ||
logging_strategy: steps | ||
lr_scheduler_type: cosine | ||
max_seq_length: 2048 | ||
max_steps: -1 | ||
num_train_epochs: 1 | ||
output_dir: /home/l069561/project/models/fine-tuned/demo-llama-3-full-ultrachat | ||
overwrite_output_dir: true | ||
per_device_eval_batch_size: 8 | ||
per_device_train_batch_size: 16 | ||
push_to_hub: false | ||
remove_unused_columns: true | ||
report_to: | ||
- tensorboard | ||
save_strategy: "steps" | ||
save_steps: 100 | ||
save_total_limit: 1 | ||
seed: 42 | ||
warmup_ratio: 0.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Model arguments | ||
model_name_or_path: /home/l069561/project/models/Meta-Llama-3-8B # no chat template | ||
model_revision: main | ||
torch_dtype: bfloat16 | ||
use_flash_attention_2: true | ||
|
||
# LoRA arguments | ||
load_in_4bit: true | ||
use_peft: true | ||
lora_r: 32 | ||
lora_alpha: 32 | ||
lora_dropout: 0.05 | ||
lora_target_modules: all | ||
# - q_proj | ||
# - k_proj | ||
# - v_proj | ||
# - o_proj | ||
# - gate_proj | ||
# - up_proj | ||
# - down_proj | ||
|
||
# Data training arguments | ||
chat_template: "{% if messages[0]['role'] == 'system' %}{% set offset = 1 %}{% else %}{% set offset = 0 %}{% endif %}{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == offset) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{{ '<|im_start|>' + message['role'] + '\\n' + message['content'].strip() + '<|im_end|>\\n' }}{% if loop.last and message['role'] == 'user' and add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}{% endfor %}" | ||
dataset_mixer: | ||
HuggingFaceH4/ultrachat_200k: 1.0 | ||
dataset_splits: | ||
- train_sft | ||
- test_sft | ||
preprocessing_num_workers: 16 | ||
auto_insert_empty_system_msg: true | ||
|
||
# SFT trainer config | ||
bf16: true | ||
do_eval: true | ||
evaluation_strategy: epoch | ||
gradient_accumulation_steps: 16 | ||
gradient_checkpointing: true | ||
gradient_checkpointing_kwargs: | ||
use_reentrant: false | ||
learning_rate: 1.0e-04 | ||
log_level: info | ||
logging_steps: 5 | ||
logging_strategy: steps | ||
lr_scheduler_type: cosine | ||
max_seq_length: 4096 | ||
max_steps: -1 | ||
num_train_epochs: 1 | ||
output_dir: /home/l069561/project/models/fine-tuned/demo-llama-3-8b-lora-ultrachat | ||
overwrite_output_dir: true | ||
per_device_eval_batch_size: 8 | ||
per_device_train_batch_size: 4 | ||
push_to_hub: false | ||
report_to: | ||
- tensorboard | ||
# - wandb | ||
save_strategy: "steps" | ||
save_steps: 100 | ||
save_total_limit: 1 | ||
seed: 42 | ||
warmup_ratio: 0.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
accelerate>=0.23.0 | ||
datasets>=2.14.6 | ||
deepspeed>=0.12.2 | ||
einops>=0.6.1 | ||
evaluate==0.4.0 | ||
huggingface-hub>=0.14.1<1.0 | ||
ninja>=1.11.1 | ||
packaging>=23.0 | ||
parameterized>=0.9.0 | ||
peft>=0.6.1 | ||
protobuf<=3.20.2 | ||
safetensors>=0.3.3 | ||
tensorboard | ||
transformers>=4.35.0 | ||
trl>=0.7.4 | ||
jinja2>=3.0.0 | ||
tqdm>=4.64.1 | ||
flash-attn>=2.1.0 | ||
pynvml>=11.4.0 |
Oops, something went wrong.