diff --git a/experiments/demo_magtrain_llm_sft.sh b/experiments/demo_magtrain_llm_sft.sh index 68af07ba..521a4f51 100644 --- a/experiments/demo_magtrain_llm_sft.sh +++ b/experiments/demo_magtrain_llm_sft.sh @@ -14,9 +14,9 @@ source ${SCRIPTPATH}/wandb.sh echo $SLURM_TMPDIR export TMPDIR="/cache" -export TRITON_CACHE_DIR=${HOME}/.cache/triton -export HF_DATASETS_CACHE=${HOME}/.cache/dataset -export HF_HOME=${HOME}/.cache/huggingface +export TRITON_CACHE_DIR=${HOME}/project/cache/triton +export HF_DATASETS_CACHE=${HOME}/project/cache/dataset +export HF_HOME=${HOME}/project/cache/huggingface # TORCH and NCCL export CUDA_LAUNCH_BLOCKING=1 diff --git a/experiments/demo_magtrain_slurm.sh b/experiments/demo_magtrain_slurm.sh index 3b6a9505..1e2ebf55 100644 --- a/experiments/demo_magtrain_slurm.sh +++ b/experiments/demo_magtrain_slurm.sh @@ -24,8 +24,18 @@ source ${SCRIPTPATH}/util.sh CONTAINER=${HOME}/container/pt2402.sif -srun --jobid $SLURM_JOB_ID apptainer exec -B $SLURM_TMPDIR:/cache --nv $CONTAINER bash ${SCRIPTPATH}/demo_magtrain_llm_sft.sh +# srun --jobid $SLURM_JOB_ID apptainer exec -B $SLURM_TMPDIR:/cache --nv $CONTAINER bash ${SCRIPTPATH}/demo_magtrain_llm_sft.sh # use nsys to profile training process -# srun --jobid $SLURM_JOB_ID apptainer exec -B $SLURM_TMPDIR:/cache --nv --fakeroot $CONTAINER nsys profile -t cuda,nvtx -o /cache/nsys_${SLURM_JOB_ID} bash ${SCRIPTPATH}/demo_magtrain_llm_sft.sh -# cp $SLURM_TMPDIR/nsys_${SLURM_JOB_ID}.nsys-rep ${HOME}/project/log/nsys/ +srun --jobid $SLURM_JOB_ID \ + apptainer exec -B $SLURM_TMPDIR:/cache --nv --fakeroot $CONTAINER \ + nsys profile -s none -t cuda,nvtx \ + --gpu-metrics-device=all \ + --gpu-metrics-frequency=100 \ + --nic-metrics=true \ + --capture-range=cudaProfilerApi \ + --capture-range-end=stop \ + -o /cache/nsys_${SLURM_JOB_ID} \ + bash ${SCRIPTPATH}/demo_magtrain_llm_sft.sh + +cp $SLURM_TMPDIR/nsys_${SLURM_JOB_ID}.nsys-rep ${HOME}/project/log/nsys/ diff --git a/recipes/llama3-8b/sft/config_full.yaml b/recipes/llama3-8b/sft/config_full.yaml index 0efa95bb..6076304f 100644 --- a/recipes/llama3-8b/sft/config_full.yaml +++ b/recipes/llama3-8b/sft/config_full.yaml @@ -32,13 +32,15 @@ optim_target_modules: all-linear weight_decay: 0.01 lr_scheduler_type: cosine max_seq_length: 8192 +packing: false +dataset_num_proc: 16 max_steps: -1 num_train_epochs: 3 output_dir: /home/l069561/project/alignment-handbook/experiments/models/llama-3-full-ultrachat overwrite_output_dir: true per_device_eval_batch_size: 1 -per_device_train_batch_size: 1 -gradient_accumulation_steps: 64 +per_device_train_batch_size: 1 # this is per device, you need to manual calculate global batch by per device * gas * gpu * node +gradient_accumulation_steps: 4 push_to_hub: false remove_unused_columns: true report_to: diff --git a/scripts/run_sft.py b/scripts/run_sft.py index 3efb3013..b795e9ab 100644 --- a/scripts/run_sft.py +++ b/scripts/run_sft.py @@ -219,9 +219,7 @@ def main(): eval_dataset=eval_dataset, dataset_text_field="text", dataset_num_proc=data_args.preprocessing_num_workers, - max_seq_length=training_args.max_seq_length, tokenizer=tokenizer, - packing=True, dataset_kwargs=training_args.dataset_kwargs, callbacks=[GpuUtilPrintCallBack()], ) @@ -234,9 +232,7 @@ def main(): eval_dataset=eval_dataset, dataset_text_field="text", dataset_num_proc=data_args.preprocessing_num_workers, - max_seq_length=training_args.max_seq_length, tokenizer=tokenizer, - packing=True, peft_config=get_peft_config(model_args), dataset_kwargs=training_args.dataset_kwargs, callbacks=[GpuUtilPrintCallBack()], diff --git a/src/alignment/configs.py b/src/alignment/configs.py index e85c4ec6..a0e59270 100644 --- a/src/alignment/configs.py +++ b/src/alignment/configs.py @@ -18,9 +18,11 @@ from dataclasses import dataclass, field from typing import Any, Dict, List, NewType, Optional, Union -import trl from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser +import trl + + MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) @@ -214,7 +216,7 @@ class ModelArguments: default="uint8", metadata={"help": "storage type to pack the quanitzed 4-bit prarams."}, ) - use_flash_attention_2: bool = field(default=False) + # use_flash_attention_2: bool = field(default=False) def __post_init__(self): if self.load_in_8bit and self.load_in_4bit: @@ -292,6 +294,10 @@ class SFTConfig(trl.SFTConfig): "help": ("Whether to log and evaluate the first global_step or not.") }, ) + # max_seq_length: Optional[int] = field( + # default=None, + # ) + # packing: Optional[bool] = field(default=False) @dataclass