Skip to content

Commit

Permalink
update config
Browse files Browse the repository at this point in the history
  • Loading branch information
xiyang-aads-lilly committed Aug 15, 2024
1 parent 19a676f commit 4c579bc
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 14 deletions.
6 changes: 3 additions & 3 deletions experiments/demo_magtrain_llm_sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ source ${SCRIPTPATH}/wandb.sh
echo $SLURM_TMPDIR
export TMPDIR="/cache"

export TRITON_CACHE_DIR=${HOME}/.cache/triton
export HF_DATASETS_CACHE=${HOME}/.cache/dataset
export HF_HOME=${HOME}/.cache/huggingface
export TRITON_CACHE_DIR=${HOME}/project/cache/triton
export HF_DATASETS_CACHE=${HOME}/project/cache/dataset
export HF_HOME=${HOME}/project/cache/huggingface

# TORCH and NCCL
export CUDA_LAUNCH_BLOCKING=1
Expand Down
16 changes: 13 additions & 3 deletions experiments/demo_magtrain_slurm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,18 @@ source ${SCRIPTPATH}/util.sh

CONTAINER=${HOME}/container/pt2402.sif

srun --jobid $SLURM_JOB_ID apptainer exec -B $SLURM_TMPDIR:/cache --nv $CONTAINER bash ${SCRIPTPATH}/demo_magtrain_llm_sft.sh
# srun --jobid $SLURM_JOB_ID apptainer exec -B $SLURM_TMPDIR:/cache --nv $CONTAINER bash ${SCRIPTPATH}/demo_magtrain_llm_sft.sh

# use nsys to profile training process
# srun --jobid $SLURM_JOB_ID apptainer exec -B $SLURM_TMPDIR:/cache --nv --fakeroot $CONTAINER nsys profile -t cuda,nvtx -o /cache/nsys_${SLURM_JOB_ID} bash ${SCRIPTPATH}/demo_magtrain_llm_sft.sh
# cp $SLURM_TMPDIR/nsys_${SLURM_JOB_ID}.nsys-rep ${HOME}/project/log/nsys/
srun --jobid $SLURM_JOB_ID \
apptainer exec -B $SLURM_TMPDIR:/cache --nv --fakeroot $CONTAINER \
nsys profile -s none -t cuda,nvtx \
--gpu-metrics-device=all \
--gpu-metrics-frequency=100 \
--nic-metrics=true \
--capture-range=cudaProfilerApi \
--capture-range-end=stop \
-o /cache/nsys_${SLURM_JOB_ID} \
bash ${SCRIPTPATH}/demo_magtrain_llm_sft.sh

cp $SLURM_TMPDIR/nsys_${SLURM_JOB_ID}.nsys-rep ${HOME}/project/log/nsys/
6 changes: 4 additions & 2 deletions recipes/llama3-8b/sft/config_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@ optim_target_modules: all-linear
weight_decay: 0.01
lr_scheduler_type: cosine
max_seq_length: 8192
packing: false
dataset_num_proc: 16
max_steps: -1
num_train_epochs: 3
output_dir: /home/l069561/project/alignment-handbook/experiments/models/llama-3-full-ultrachat
overwrite_output_dir: true
per_device_eval_batch_size: 1
per_device_train_batch_size: 1
gradient_accumulation_steps: 64
per_device_train_batch_size: 1 # this is per device, you need to manual calculate global batch by per device * gas * gpu * node
gradient_accumulation_steps: 4
push_to_hub: false
remove_unused_columns: true
report_to:
Expand Down
4 changes: 0 additions & 4 deletions scripts/run_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,7 @@ def main():
eval_dataset=eval_dataset,
dataset_text_field="text",
dataset_num_proc=data_args.preprocessing_num_workers,
max_seq_length=training_args.max_seq_length,
tokenizer=tokenizer,
packing=True,
dataset_kwargs=training_args.dataset_kwargs,
callbacks=[GpuUtilPrintCallBack()],
)
Expand All @@ -234,9 +232,7 @@ def main():
eval_dataset=eval_dataset,
dataset_text_field="text",
dataset_num_proc=data_args.preprocessing_num_workers,
max_seq_length=training_args.max_seq_length,
tokenizer=tokenizer,
packing=True,
peft_config=get_peft_config(model_args),
dataset_kwargs=training_args.dataset_kwargs,
callbacks=[GpuUtilPrintCallBack()],
Expand Down
10 changes: 8 additions & 2 deletions src/alignment/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, NewType, Optional, Union

import trl
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser

import trl


MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)

Expand Down Expand Up @@ -214,7 +216,7 @@ class ModelArguments:
default="uint8",
metadata={"help": "storage type to pack the quanitzed 4-bit prarams."},
)
use_flash_attention_2: bool = field(default=False)
# use_flash_attention_2: bool = field(default=False)

def __post_init__(self):
if self.load_in_8bit and self.load_in_4bit:
Expand Down Expand Up @@ -292,6 +294,10 @@ class SFTConfig(trl.SFTConfig):
"help": ("Whether to log and evaluate the first global_step or not.")
},
)
# max_seq_length: Optional[int] = field(
# default=None,
# )
# packing: Optional[bool] = field(default=False)


@dataclass
Expand Down

0 comments on commit 4c579bc

Please sign in to comment.