diff --git a/.gitignore b/.gitignore index df27b497..8c29b4eb 100644 --- a/.gitignore +++ b/.gitignore @@ -169,4 +169,5 @@ wandb/ experiments/* !experiments/.gitkeep !experiments/demo* -!experiments/README.md \ No newline at end of file +!experiments/README.md +!experiments/util.sh diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 95d2556e..7f4f6038 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,10 +9,25 @@ repos: # pre-commit's default_language_version, see # https://pre-commit.com/#top_level-default_language_version # we do not set python version so it will use default - + - id: black-jupyter # # It is recommended to specify the latest version of Python # # supported by your project here, or alternatively use # # pre-commit's default_language_version, see # # https://pre-commit.com/#top_level-default_language_version # language_version: python3.11 + + # - repo: https://github.com/gitleaks/gitleaks + # rev: v8.18.2 # Specify the desired version of Gitleaks + # hooks: + # - id: gitleaks + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: check-added-large-files + - id: check-merge-conflict + - id: detect-private-key # if this work well we can avoid using gitleaks + - id: end-of-file-fixer + - id: requirements-txt-fixer diff --git a/experiments/demo_dgx2.sh b/experiments/demo_dgx2.sh index 22b9b357..084bc434 100644 --- a/experiments/demo_dgx2.sh +++ b/experiments/demo_dgx2.sh @@ -1,4 +1,4 @@ -#!/usr/bin/bash +#!/usr/bin/bash ROOT=$(realpath ~) @@ -7,6 +7,7 @@ echo activate virtual ENV PYTHON_ENV=${ROOT}/project/scripts/v2306.sh source $PYTHON_ENV +# pip freeze # CUDA export CUDA_VISIBLE_DEVICES=0,1 @@ -21,10 +22,12 @@ export HF_DATASETS_CACHE="${ROOT}/project/.cache/dataset" export HF_HOME="${ROOT}/project/.cache/" # Wandb -export WANDB_API_KEY="" -# export WANDB_API_KEY="" +export WANDB_API_KEY="05411100e08ac02e3fcbdc821b4116cf1c066e99" export WANDB_USERNAME="xi-yang5" export WANDB_PROJECT="demo_dgx2" +# export WANDB_API_KEY="" +# export WANDB_USERNAME="" +# export WANDB_PROJECT="" export WANDB_LOG_MODEL="false" export WANDB_WATCH="false" @@ -33,21 +36,38 @@ export TORCH_DISTRIBUTED_DEBUG=INFO export NCCL_DEBUG=INFO # export NCCL_SOCKET_NTHREADS=16 -export ACCELERATE_LOG_LEVEL=debug +export ACCELERATE_LOG_LEVEL=debug export ACCELERATE_DEBUG_MODE="1" export DEEPSPEED_TIMEOUT=120 +# get this script location +SCRIPT=$(readlink -f "$0") +SCRIPTPATH=$(dirname "$SCRIPT") + # accelerate launch # accelerate launch \ # --config_file ${ROOT}/project/alignment_handbook/recipes/accelerate_configs/deepspeed_zero2.yaml \ # --num_processes $WORLD_SIZE \ # --tee 3 \ # ${ROOT}/project/alignment_handbook/scripts/run_sft.py \ -# ${ROOT}/project/alignment_handbook/recipes/llama3-8b/sft/config_qlora.yaml -# ${ROOT}/project/alignment_handbook/recipes/llama3-8b/sft/config_full.yaml +# ${ROOT}/project/alignment_handbook/recipes/llama3-8b/sft/config_qlora.yaml -# deepspeed launch +# torch launch +# source ${SCRIPTPATH}/util.sh +# --master_addr=$PRIMARY --master_port=$PRIMARY_PORT +# python -m torch.distributed.run +# need to add virtual env package path as PYTHONPATH +export PYTHONPATH=${ROOT}/project/pyenv/2306/lib/python3.10/site-packages +torchrun --nproc_per_node=$WORLD_SIZE --nnode=1 --node_rank=0 \ + ${ROOT}/project/alignment_handbook/scripts/run_sft.py \ + ${ROOT}/project/alignment_handbook/recipes/llama3-8b/sft/config_qlora.yaml \ + --deepspeed=${ROOT}/project/alignment_handbook/recipes/accelerate_configs/deepspeed_zs2.json \ + --tee=2 >> ${SCRIPTPATH}/log.txt -# torch launch \ No newline at end of file +# python -m torch.distributed.run --nproc_per_node=$WORLD_SIZE --nnode=1 --node_rank=0 \ +# ${ROOT}/project/alignment_handbook/scripts/run_sft.py \ +# ${ROOT}/project/alignment_handbook/recipes/llama3-8b/sft/config_qlora.yaml \ +# --deepspeed=${ROOT}/project/alignment_handbook/recipes/accelerate_configs/deepspeed_zs2.json \ +# --tee=2 diff --git a/experiments/util.sh b/experiments/util.sh new file mode 100644 index 00000000..ae3b5fb3 --- /dev/null +++ b/experiments/util.sh @@ -0,0 +1,34 @@ +# for slurm use +get_unused_port() { + # Well-known ports end at 1023. On Linux, dynamic ports start at 32768 + # (see /proc/sys/net/ipv4/ip_local_port_range). + local MIN_PORT=10001 + local MAX_PORT=32767 + + local USED_PORTS=$(netstat -a -n -t | tail -n +3 | tr -s ' ' | \ + cut -d ' ' -f 4 | sed 's/.*:\([0-9]\+\)$/\1/' | sort -n | uniq) + + # Generate random port numbers within the search range (inclusive) until we + # find one that isn't in use. + local RAN_PORT + while + RAN_PORT=$(shuf -i 10001-32767 -n 1) + [[ "$USED_PORTS" =~ $RAN_PORT ]] + do + continue + done + + echo $RAN_PORT +} + +init_node_info() { + export PRIMARY=$(hostname -s) + SECONDARIES=$(scontrol show hostnames $SLURM_JOB_NODELIST | \ + grep -v $PRIMARY) + + ALL_NODES="$PRIMARY $SECONDARIES" + export PRIMARY_PORT=$(get_unused_port) + echo $PRIMARY $SECONDARIES $PRIMARY_PORT +} + +init_node_info diff --git a/recipes/accelerate_configs/deepspeed_zs2.json b/recipes/accelerate_configs/deepspeed_zs2.json index b3347327..215f8469 100644 --- a/recipes/accelerate_configs/deepspeed_zs2.json +++ b/recipes/accelerate_configs/deepspeed_zs2.json @@ -9,9 +9,11 @@ "consecutive_hysteresis": false, "min_loss_scale": 1 }, + "bf16": { "enabled": true }, + "optimizer": { "type": "AdamW", "params": { @@ -23,6 +25,7 @@ "adam_w_mode": true } }, + "scheduler": { "type": "WarmupDecayLR", "params": { @@ -32,19 +35,22 @@ "total_num_steps": "auto" } }, + "zero_optimization": { "stage": 2, "allgather_partitions": true, - "allgather_bucket_size": 2e8, + "allgather_bucket_size": 5e8, "overlap_comm": true, "reduce_scatter": true, "reduce_bucket_size": "auto", - "contiguous_gradients": true + "contiguous_gradients": true, + "round_robin_gradients": true }, + "gradient_accumulation_steps": "auto", "gradient_clipping": "auto", - "steps_per_print": 2000, + "steps_per_print": 20000000, "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", "wall_clock_breakdown": false -} \ No newline at end of file +} diff --git a/recipes/accelerate_configs/readme.md b/recipes/accelerate_configs/readme.md index 3eaefe10..c6e371bc 100644 --- a/recipes/accelerate_configs/readme.md +++ b/recipes/accelerate_configs/readme.md @@ -1,10 +1,26 @@ ## Accelerate launch only support partial parameters in deepspeed - to avoid, we need to launch with deepspeed not accelerate +## more info HF-deepspeed integration +- https://huggingface.co/docs/transformers/deepspeed?zero-config=ZeRO-2 + ## deepspeed optimizers - DeepSpeed natively supports Adam, AdamW, OneBitAdam, Lamb, OneBitLamb, FusedLamb, FusedAdam - see for details on how to config https://deepspeed.readthedocs.io/en/latest/optimizers.html ```json +// You can set the parameters to "auto" or manually input your own desired values. +{ + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + } +} + { "optimizer": { "type": "OneBitLamb", @@ -56,4 +72,81 @@ "enabled": true } } -``` \ No newline at end of file +``` + +- offload +```json +{ + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + }, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 5e8, + "contiguous_gradients": true, + "round_robin_gradients": true + } +} + +{ + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "nvme", + "nvme_path": "/local_nvme", + "pin_memory": true, + "buffer_count": 4, + "fast_init": false + }, + "offload_param": { + "device": "nvme", + "nvme_path": "/local_nvme", + "pin_memory": true, + "buffer_count": 5, + "buffer_size": 1e8, + "max_in_cpu": 1e9 + }, + "aio": { + "block_size": 262144, + "queue_depth": 32, + "thread_count": 1, + "single_submit": false, + "overlap_events": true + }, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + }, +} +``` + +- communication data type +> Choosing fp32 adds a small amount of overhead but ensures the reduction operation is accumulated in fp32 and when it is ready, it is downcasted to whichever half-precision dtype you’re training in. +> Default is fp16 if you use AMP. +```json +{ "communication_data_type": "fp32"} +``` + +- launch + +```sh +deepspeed --num_gpus=2 examples/pytorch/translation/run_translation.py \ + --deepspeed tests/deepspeed/ds_config_zero3.json \ +... + + +torchrun --nproc_per_node=8 --nnode=2 --node_rank=0 --master_addr=hostname1 --master_port=9901 \ + your_program.py \ + --deepspeed ds_config.json +``` diff --git a/recipes/llama3-8b/sft/config_qlora.yaml b/recipes/llama3-8b/sft/config_qlora.yaml index c1465d93..92a8b795 100644 --- a/recipes/llama3-8b/sft/config_qlora.yaml +++ b/recipes/llama3-8b/sft/config_qlora.yaml @@ -39,26 +39,32 @@ gradient_checkpointing_kwargs: use_reentrant: false learning_rate: 1.0e-04 log_level: info -logging_steps: 5 +logging_steps: 5 logging_strategy: steps -optim: adamw_torch # adamw_torch paged_adamw_32bit galore_adamw lion_32bit +optim: adamw_torch # adamw_torch paged_adamw_32bit galore_adamw lion_32bit adamw_apex_fused # optim_target_modules: all-linear weight_decay: 0.01 lr_scheduler_type: cosine -max_seq_length: 2048 +max_seq_length: 4096 max_steps: -1 num_train_epochs: 1 -output_dir: /home/l069561/project/models/fine-tuned/demo-llama-3-8b-lora-ultrachat +output_dir: /home/l069561/project/alignment_handbook/experiments/models/demo-llama-3-8b-lora-ultrachat overwrite_output_dir: true -per_device_eval_batch_size: 2 -gradient_accumulation_steps: 32 per_device_train_batch_size: 4 +gradient_accumulation_steps: 4 +per_device_eval_batch_size: 4 push_to_hub: false report_to: -- tensorboard +- tensorboard - wandb save_strategy: "steps" save_steps: 100 save_total_limit: 1 seed: 42 -warmup_ratio: 0.1 \ No newline at end of file +warmup_ratio: 0.1 + +torch_compile: false +# https://pytorch.org/docs/stable/generated/torch.compile.html ('cudagraphs', 'inductor', 'onnxrt', 'openxla', 'openxla_eval', 'tvm']) +# https://huggingface.co/docs/transformers/perf_train_gpu_one#using-torchcompile +torch_compile_backend: "inductor" +torch_compile_mode: "default" # reduce-overhead max-autotune diff --git a/scripts/run_cpt.py b/scripts/run_cpt.py index 3db97eb7..273d9ebc 100644 --- a/scripts/run_cpt.py +++ b/scripts/run_cpt.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # diff --git a/scripts/run_orpo.py b/scripts/run_orpo.py index d3fe54e3..7ab8c947 100644 --- a/scripts/run_orpo.py +++ b/scripts/run_orpo.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # diff --git a/scripts/run_simpo.py b/scripts/run_simpo.py new file mode 100644 index 00000000..0e771e31 --- /dev/null +++ b/scripts/run_simpo.py @@ -0,0 +1,376 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import random +import sys +from dataclasses import dataclass, field +from typing import Literal, Optional + +import torch +import transformers +from transformers import AutoModelForCausalLM, set_seed + +from alignment import ( + DataArguments, + DPOConfig, + H4ArgumentParser, + ModelArguments, + SimPOTrainer, + get_checkpoint, + get_datasets, + get_kbit_device_map, + get_peft_config, + get_quantization_config, + get_tokenizer, + is_adapter_model, +) +from alignment.data import is_openai_format, maybe_insert_system_message +from peft import PeftConfig, PeftModel + + +logger = logging.getLogger(__name__) + +MISTRAL_CHAT_TEMPLATE = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'].strip() + '\n\n' %}{% else %}{% set loop_messages = messages %}{% set system_message = '' %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{% set content = system_message + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}" + + +@dataclass +class SimPOConfig(DPOConfig): + gamma: Optional[float] = field( + default=0.5, + metadata={"help": "The target reward margin term in SimPO loss."}, + ) + + +def apply_chat_template( + example, + tokenizer, + task: Literal["sft", "generation", "rm", "simpo"], + auto_insert_empty_system_msg: bool = True, + change_template=None, +): + if change_template == "mistral": + tokenizer.chat_template = MISTRAL_CHAT_TEMPLATE + if task in ["sft", "generation"]: + messages = example["messages"] + # We add an empty system message if there is none + if auto_insert_empty_system_msg: + maybe_insert_system_message(messages, tokenizer) + example["text"] = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True if task == "generation" else False, + ) + elif task == "rm": + if all(k in example.keys() for k in ("chosen", "rejected")): + chosen_messages = example["chosen"] + rejected_messages = example["rejected"] + # We add an empty system message if there is none + if auto_insert_empty_system_msg: + maybe_insert_system_message(chosen_messages, tokenizer) + maybe_insert_system_message(rejected_messages, tokenizer) + + example["text_chosen"] = tokenizer.apply_chat_template( + chosen_messages, tokenize=False + ) + example["text_rejected"] = tokenizer.apply_chat_template( + rejected_messages, tokenize=False + ) + else: + raise ValueError( + f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}" + ) + elif task == "simpo": + if all(k in example.keys() for k in ("chosen", "rejected")): + if not is_openai_format(example["chosen"]) or not is_openai_format( + example["rejected"] + ): + raise ValueError( + f"Could not format example as dialogue for `{task}` task! Require OpenAI format for all messages" + ) + + # For DPO/ORPO, the inputs are triples of (prompt, chosen, rejected), where `chosen` and `rejected` are the final turn of a dialogue + # We therefore need to extract the N-1 turns to form the prompt + if "prompt" in example and is_openai_format(example["prompt"]): + prompt_messages = example["prompt"] + chosen_messages = example["chosen"] + rejected_messages = example["rejected"] + else: + prompt_messages = example["chosen"][:-1] + # Now we extract the final turn to define chosen/rejected responses + chosen_messages = example["chosen"][-1:] + rejected_messages = example["rejected"][-1:] + + # Prepend a system message if the first message is not a system message + if auto_insert_empty_system_msg: + maybe_insert_system_message(prompt_messages, tokenizer) + + example["text_prompt"] = tokenizer.apply_chat_template( + prompt_messages, tokenize=False + ) + example["text_chosen"] = tokenizer.apply_chat_template( + chosen_messages, tokenize=False + ) + if example["text_chosen"].startswith(tokenizer.bos_token): + example["text_chosen"] = example["text_chosen"][ + len(tokenizer.bos_token) : + ] + example["text_rejected"] = tokenizer.apply_chat_template( + rejected_messages, tokenize=False + ) + if example["text_rejected"].startswith(tokenizer.bos_token): + example["text_rejected"] = example["text_rejected"][ + len(tokenizer.bos_token) : + ] + else: + raise ValueError( + f"Could not format example as dialogue for `{task}` task! Require either the " + f"`[chosen, rejected]` or `[prompt, chosen, rejected]` keys but found {list(example.keys())}" + ) + else: + raise ValueError( + f"Task {task} not supported, please ensure that the provided task is one of ['sft', 'generation', 'rm', 'dpo', 'orpo']" + ) + return example + + +def main(): + parser = H4ArgumentParser((ModelArguments, DataArguments, SimPOConfig)) + model_args, data_args, training_args = parser.parse() + + ####### + # Setup + ####### + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Log on each process the small summary: + logger.info(f"Model parameters {model_args}") + logger.info(f"Data parameters {data_args}") + logger.info(f"Training/evaluation parameters {training_args}") + + # Check for last checkpoint + last_checkpoint = get_checkpoint(training_args) + if last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.") + + # Set seed for reproducibility + set_seed(training_args.seed) + + ############### + # Load datasets + ############### + raw_datasets = get_datasets( + data_args, + splits=data_args.dataset_splits, + configs=data_args.dataset_configs, + columns_to_keep=[ + "messages", + "chosen", + "rejected", + "prompt", + "completion", + "label", + ], + # seed=training_args.seed, + ) + logger.info( + f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}" + ) + column_names = list(raw_datasets["train"].features) + + ##################################### + # Load tokenizer and process datasets + ##################################### + data_args.truncation_side = ( + "left" # Truncate from left to ensure we don't lose labels in final turn + ) + tokenizer = get_tokenizer(model_args, data_args) + + if "mistral" in model_args.model_name_or_path.lower(): + change_template = "mistral" + else: + change_template = None + ##################### + # Apply chat template + ##################### + raw_datasets = raw_datasets.map( + apply_chat_template, + fn_kwargs={ + "tokenizer": tokenizer, + "task": "simpo", + "auto_insert_empty_system_msg": data_args.auto_insert_empty_system_msg, + "change_template": change_template, + }, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + desc="Formatting comparisons with prompt template", + ) + + # Replace column names with what TRL needs, text_chosen -> chosen and text_rejected -> rejected + for split in ["train", "test"]: + raw_datasets[split] = raw_datasets[split].rename_columns( + { + "text_prompt": "prompt", + "text_chosen": "chosen", + "text_rejected": "rejected", + } + ) + + # Log a few random samples from the training set: + for index in random.sample(range(len(raw_datasets["train"])), 3): + logger.info( + f"Prompt sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['prompt']}" + ) + logger.info( + f"Chosen sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['chosen']}" + ) + logger.info( + f"Rejected sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['rejected']}" + ) + + torch_dtype = ( + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) + ) + quantization_config = get_quantization_config(model_args) + + model_kwargs = dict( + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + use_flash_attention_2=model_args.use_flash_attention_2, + torch_dtype=torch_dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + + model = model_args.model_name_or_path + if is_adapter_model(model, model_args.model_revision) is True: + logger.info(f"Loading SFT adapter for {model_args.model_name_or_path=}") + peft_config = PeftConfig.from_pretrained( + model_args.model_name_or_path, revision=model_args.model_revision + ) + model_kwargs = dict( + revision=model_args.base_model_revision, + trust_remote_code=model_args.trust_remote_code, + use_flash_attention_2=model_args.use_flash_attention_2, + torch_dtype=torch_dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=( + get_kbit_device_map() if quantization_config is not None else None + ), + quantization_config=quantization_config, + ) + base_model = AutoModelForCausalLM.from_pretrained( + peft_config.base_model_name_or_path, + **model_kwargs, + ) + model = PeftModel.from_pretrained( + base_model, + model_args.model_name_or_path, + revision=model_args.model_revision, + ) + model_kwargs = None + + ref_model = model + ref_model_kwargs = model_kwargs + + if model_args.use_peft is True: + ref_model = None + ref_model_kwargs = None + + ######################### + # Instantiate SimPO trainer + ######################### + trainer = SimPOTrainer( + model=model, + ref_model=ref_model, # pass in to bypass DPO Trainer check for ref model but is not actually used + model_init_kwargs=model_kwargs, + args=training_args, + beta=training_args.beta, + train_dataset=raw_datasets["train"], + eval_dataset=raw_datasets["test"], + tokenizer=tokenizer, + max_length=training_args.max_length, + max_prompt_length=training_args.max_prompt_length, + peft_config=get_peft_config(model_args), + loss_type=training_args.loss_type, + ) + + ############### + # Training loop + ############### + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + metrics = train_result.metrics + metrics["train_samples"] = len(raw_datasets["train"]) + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + logger.info("*** Training complete ***") + + ################################## + # Save model and create model card + ################################## + logger.info("*** Save model ***") + trainer.save_model(training_args.output_dir) + logger.info(f"Model saved to {training_args.output_dir}") + + # Save everything else on main process + kwargs = { + "finetuned_from": model_args.model_name_or_path, + "dataset": list(data_args.dataset_mixer.keys()), + "dataset_tags": list(data_args.dataset_mixer.keys()), + "tags": ["alignment-handbook"], + } + if trainer.accelerator.is_main_process: + trainer.create_model_card(**kwargs) + # Restore k,v cache for fast inference + trainer.model.config.use_cache = True + trainer.model.config.save_pretrained(training_args.output_dir) + + ########## + # Evaluate + ########## + if training_args.do_eval: + logger.info("*** Evaluate ***") + metrics = trainer.evaluate() + metrics["eval_samples"] = len(raw_datasets["test"]) + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + if training_args.push_to_hub is True: + logger.info("Pushing to hub...") + trainer.push_to_hub(**kwargs) + + logger.info("*** Training complete! ***") + + +if __name__ == "__main__": + main() diff --git a/src/alignment/__init__.py b/src/alignment/__init__.py index ff68f840..6afd54c4 100644 --- a/src/alignment/__init__.py +++ b/src/alignment/__init__.py @@ -17,6 +17,7 @@ is_adapter_model, tokenizer_and_embedding_resize, ) +from .simpo_trainer import SimPOTrainer from .utils import ( GpuUtilPrintCallBack, ProfCallback, diff --git a/src/alignment/simpo_trainer.py b/src/alignment/simpo_trainer.py new file mode 100644 index 00000000..1b7d19c2 --- /dev/null +++ b/src/alignment/simpo_trainer.py @@ -0,0 +1,146 @@ +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from trl import DPOTrainer + + +class SimPOTrainer(DPOTrainer): + + def __init__(self, **kwargs): + super().__init__(**kwargs) # Pass all other arguments using **kwargs + training_args = kwargs["args"] + self.gamma = training_args.gamma + + def simpo_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Compute the SimPO loss for a batch of policy model log probabilities. + + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + + Returns: + A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). + The losses tensor contains the SimPO loss for each example in the batch. + The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. + """ + pi_logratios = policy_chosen_logps - policy_rejected_logps + gamma_logratios = self.gamma / self.beta + pi_logratios = pi_logratios.to(self.accelerator.device) + logits = pi_logratios - gamma_logratios + + if self.loss_type == "sigmoid": + losses = ( + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) + elif self.loss_type == "hinge": + losses = torch.relu(1 - self.beta * logits) + else: + raise ValueError( + f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge']" + ) + + chosen_rewards = ( + self.beta * policy_chosen_logps.to(self.accelerator.device).detach() + ) + rejected_rewards = ( + self.beta * policy_rejected_logps.to(self.accelerator.device).detach() + ) + + return losses, chosen_rewards, rejected_rewards + + def concatenated_forward( + self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] + ) -> Tuple[ + torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor + ]: + """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + + We do this to avoid doing two forward passes, because it's faster for FSDP. + """ + concatenated_batch = self.concatenated_inputs( + batch, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + padding_value=self.padding_value, + device=self.accelerator.device, + ) + len_chosen = batch["chosen_labels"].shape[0] + + model_kwargs = ( + { + "labels": concatenated_batch["concatenated_labels"], + "decoder_input_ids": concatenated_batch.pop( + "concatenated_decoder_input_ids", None + ), + } + if self.is_encoder_decoder + else {} + ) + + all_logits = model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + use_cache=False, + **model_kwargs, + ).logits + + all_logps = self.get_batch_logps( + all_logits, + concatenated_batch["concatenated_labels"], + average_log_prob=True, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) + + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] + + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits) + + def get_batch_loss_metrics( + self, + model, + batch: Dict[str, Union[List, torch.LongTensor]], + train_eval: Literal["train", "eval"] = "train", + ): + """Compute the SimPO loss and other metrics for the given batch of inputs for train or test.""" + metrics = {} + + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + ) = self.concatenated_forward(model, batch) + + losses, chosen_rewards, rejected_rewards = self.simpo_loss( + policy_chosen_logps, policy_rejected_logps + ) + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + prefix = "eval_" if train_eval == "eval" else "" + metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu() + metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu() + metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu() + metrics[f"{prefix}rewards/margins"] = ( + (chosen_rewards - rejected_rewards).mean().cpu() + ) + metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu() + metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu() + metrics[f"{prefix}logits/rejected"] = ( + policy_rejected_logits.detach().mean().cpu() + ) + metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu() + + return losses.mean(), metrics