diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 00000000..1e2b2293 --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,29 @@ +cff-version: 1.2.0 +title: The Alignment Handbook +message: >- + Robust recipes to align language models with human and AI + preferences. +type: software +authors: + - given-names: Lewis + family-names: Tunstall + - given-names: Edward + family-names: Beeching + - given-names: Nathan + family-names: Lambert + - given-names: Nazneen + family-names: Rajani + - given-names: Shengyi + family-names: Huang + - given-names: Kashif + family-names: Rasul + - given-names: Alvaro + family-names: Bartolome + - given-names: Alexander + name-particle: M. + family-names: Rush + - given-names: Thomas + family-names: Wolf +repository-code: 'https://github.com/huggingface/alignment-handbook' +license: Apache-2.0 +version: 0.3.0.dev0 diff --git a/README.md b/README.md index bda60c86..86e30a99 100644 --- a/README.md +++ b/README.md @@ -49,8 +49,8 @@ If you would like to train chat models on your own datasets, we recommend follow The initial release of the handbook will focus on the following techniques: -* **Continued pretraining:** adapt language models to a new language or domain, or simply improve it by continue pretraning (causal language modeling) on a new dataset. -* **Supervised fine-tuning:** teach language models to follow instructions and tips on how to collect and curate your own training dataset. +* **Continued pretraining:** adapt language models to a new language or domain, or simply improve it by continued pretraining (causal language modeling) on a new dataset. +* **Supervised fine-tuning:** teach language models to follow instructions and tips on how to collect and curate your training dataset. * **Reward modeling:** teach language models to distinguish model responses according to human or AI preferences. * **Rejection sampling:** a simple, but powerful technique to boost the performance of your SFT model. * **Direct preference optimisation (DPO):** a powerful and promising alternative to PPO. @@ -115,15 +115,14 @@ You can now check out the `scripts` and `recipes` directories for instructions o ## Citation -If you find the content of this repo useful in your work, please cite it as follows: +If you find the content of this repo useful in your work, please cite it as follows via `\usepackage{biblatex}`: ```bibtex -@misc{alignment_handbook2023, - author = {Lewis Tunstall and Edward Beeching and Nathan Lambert and Nazneen Rajani and Shengyi Huang and Kashif Rasul and Alexander M. Rush and Thomas Wolf}, - title = {The Alignment Handbook}, - year = {2023}, - publisher = {GitHub}, - journal = {GitHub repository}, - howpublished = {\url{https://github.com/huggingface/alignment-handbook}} +@software{Tunstall_The_Alignment_Handbook, + author = {Tunstall, Lewis and Beeching, Edward and Lambert, Nathan and Rajani, Nazneen and Huang, Shengyi and Rasul, Kashif and Bartolome, Alvaro and M. Rush, Alexander and Wolf, Thomas}, + license = {Apache-2.0}, + title = {{The Alignment Handbook}}, + url = {https://github.com/huggingface/alignment-handbook}, + version = {0.3.0.dev0} } ``` diff --git a/recipes/constitutional-ai/README.md b/recipes/constitutional-ai/README.md index 71b073bc..08f4520a 100644 --- a/recipes/constitutional-ai/README.md +++ b/recipes/constitutional-ai/README.md @@ -21,4 +21,4 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con ## Advanced: generating you own dataset -To generate the constitutional AI dataset, see https://github.com/huggingface/llm-swarm/tree/main/examples/constitutional-ai for detailed instructions if you want build or customize the dataset. +To generate the constitutional AI dataset, see https://github.com/huggingface/llm-swarm/tree/main/examples/constitutional-ai for detailed instructions if you want to build or customize the dataset. diff --git a/recipes/constitutional-ai/dpo/config_anthropic.yaml b/recipes/constitutional-ai/dpo/config_anthropic.yaml index 0ef08018..48f57676 100644 --- a/recipes/constitutional-ai/dpo/config_anthropic.yaml +++ b/recipes/constitutional-ai/dpo/config_anthropic.yaml @@ -17,7 +17,7 @@ bf16: true beta: 0.1 do_eval: true do_train: true -evaluation_strategy: steps +eval_strategy: steps eval_steps: 1000 gradient_accumulation_steps: 1 gradient_checkpointing: true diff --git a/recipes/constitutional-ai/sft/config_anthropic.yaml b/recipes/constitutional-ai/sft/config_anthropic.yaml index 64145286..6724de0c 100644 --- a/recipes/constitutional-ai/sft/config_anthropic.yaml +++ b/recipes/constitutional-ai/sft/config_anthropic.yaml @@ -2,7 +2,7 @@ model_name_or_path: mistralai/Mistral-7B-v0.1 model_revision: main torch_dtype: bfloat16 -use_flash_attention_2: true +attn_implementation: flash_attention_2 # Data training arguments chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" @@ -18,7 +18,7 @@ preprocessing_num_workers: 12 bf16: true do_eval: true do_train: true -evaluation_strategy: epoch # One of ["no", "steps", "epoch"] +eval_strategy: epoch # One of ["no", "steps", "epoch"] gradient_accumulation_steps: 4 gradient_checkpointing: true gradient_checkpointing_kwargs: diff --git a/recipes/constitutional-ai/sft/config_grok.yaml b/recipes/constitutional-ai/sft/config_grok.yaml index 6740ac19..c79031dc 100644 --- a/recipes/constitutional-ai/sft/config_grok.yaml +++ b/recipes/constitutional-ai/sft/config_grok.yaml @@ -2,7 +2,7 @@ model_name_or_path: mistralai/Mistral-7B-v0.1 model_revision: main torch_dtype: bfloat16 -use_flash_attention_2: true +attn_implementation: flash_attention_2 # Data training arguments chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" @@ -18,7 +18,7 @@ preprocessing_num_workers: 12 bf16: true do_eval: true do_train: true -evaluation_strategy: epoch # One of ["no", "steps", "epoch"] +eval_strategy: epoch # One of ["no", "steps", "epoch"] gradient_accumulation_steps: 4 gradient_checkpointing: true gradient_checkpointing_kwargs: diff --git a/recipes/gpt2-nl/README.md b/recipes/gpt2-nl/README.md index 366ae926..68eccfc8 100644 --- a/recipes/gpt2-nl/README.md +++ b/recipes/gpt2-nl/README.md @@ -2,7 +2,7 @@ This directory shows a base example of how to use continued pretraining and further tuning to adapt a language model to new data (e.g. a new language or domain). -Three steps are needed: continued pretraining (`cpt`), supervised finetuning (`sft`), and direct preference optimisation (`dpo`). In this dummy example we'll continue pretraining gpt2 on Dutch raw data, then sft-tuning it, and finally aligning it with DPO. Note that no extensive hyperparameters were tested in this example and that the output models are bad - it is just to show you how you can use the scripts for LM adaptation. The scripts work on 4x 3090s (24GB VRAM). If you have less powerful hardware you may need to reduce the batch size. +Three steps are needed: continued pretraining (`cpt`), supervised finetuning (`sft`), and direct preference optimisation (`dpo`). In this dummy example, we'll continue pretraining gpt2 on Dutch raw data, then sft-tuning it, and finally aligning it with DPO. Note that no extensive hyperparameters were tested in this example and that the output models are bad - it is just to show you how you can use the scripts for LM adaptation. The scripts work on 4x 3090s (24GB VRAM). If you have less powerful hardware you may need to reduce the batch size. ## Continued pretraining @@ -18,7 +18,7 @@ ACCELERATE_LOG_LEVEL=info accelerate launch \ ## Supervised finetuning -As other recipes, such as the famous zephyr-7b-beta recipe, have shown, we can then teach our model how to hold a conversation by finetuning it on chat-formatted data. As a base model we'll make use of the output of the previous step. +As other recipes, such as the famous zephyr-7b-beta recipe, have shown, we can then teach our model how to hold a conversation by finetuning it on chat-formatted data. As a base model, we'll make use of the output of the previous step. ```shell ACCELERATE_LOG_LEVEL=info accelerate launch \ diff --git a/recipes/gpt2-nl/cpt/config_full.yaml b/recipes/gpt2-nl/cpt/config_full.yaml index 69d5437f..9c7056cf 100644 --- a/recipes/gpt2-nl/cpt/config_full.yaml +++ b/recipes/gpt2-nl/cpt/config_full.yaml @@ -15,7 +15,7 @@ preprocessing_num_workers: 12 # SFT trainer config bf16: true do_eval: False -evaluation_strategy: "no" +eval_strategy: "no" gradient_accumulation_steps: 1 gradient_checkpointing: true gradient_checkpointing_kwargs: diff --git a/recipes/gpt2-nl/dpo/config_full.yaml b/recipes/gpt2-nl/dpo/config_full.yaml index a2552f39..976c2537 100644 --- a/recipes/gpt2-nl/dpo/config_full.yaml +++ b/recipes/gpt2-nl/dpo/config_full.yaml @@ -16,7 +16,7 @@ preprocessing_num_workers: 12 bf16: true beta: 0.1 do_eval: true -evaluation_strategy: steps +eval_strategy: steps eval_steps: 100 gradient_accumulation_steps: 8 gradient_checkpointing: true diff --git a/recipes/gpt2-nl/sft/config_full.yaml b/recipes/gpt2-nl/sft/config_full.yaml index fef3d5ee..f80d8efc 100644 --- a/recipes/gpt2-nl/sft/config_full.yaml +++ b/recipes/gpt2-nl/sft/config_full.yaml @@ -15,7 +15,7 @@ preprocessing_num_workers: 12 # SFT trainer config bf16: true do_eval: true -evaluation_strategy: epoch +eval_strategy: epoch gradient_accumulation_steps: 1 gradient_checkpointing: true gradient_checkpointing_kwargs: diff --git a/recipes/pref_align_scan/README.md b/recipes/pref_align_scan/README.md index 767a7426..f9c81a51 100644 --- a/recipes/pref_align_scan/README.md +++ b/recipes/pref_align_scan/README.md @@ -5,13 +5,14 @@ This directory contains various comparisons for three algorithms: DPO, IPO, and - OpenHermes-2.5 and the OpenOrca datasets We release a collection containing the datasets and models used for these experiments, if you require the other trained models, we can release them on request. -You can find a longer decription of there results in our [blogpost](https://huggingface.co/blog/pref-tuning) +You can find a longer description of these results in our [blogpost](https://huggingface.co/blog/pref-tuning) + ## Comparisons For each algorithm, we aim to tune the beta parameter for a fixed learning rate. We vary beta from 0.1-0.9 in steps of 0.1, we have also found that in certain configurations a tiny value of beta, 0.01, can be effective. So we have included this smaller value in all our comparisons. ## Usage The experiments can be launched with the following bash script: -``` +```bash #!/bin/bash # Define an array containing the base configs we wish to fine tune diff --git a/recipes/pref_align_scan/dpo/config_openhermes.yaml b/recipes/pref_align_scan/dpo/config_openhermes.yaml index 93d9ef33..43e8a230 100644 --- a/recipes/pref_align_scan/dpo/config_openhermes.yaml +++ b/recipes/pref_align_scan/dpo/config_openhermes.yaml @@ -16,7 +16,7 @@ beta: 0.01 loss_type: sigmoid do_eval: true do_train: true -evaluation_strategy: steps +eval_strategy: steps eval_steps: 100 gradient_accumulation_steps: 2 gradient_checkpointing: true diff --git a/recipes/pref_align_scan/dpo/config_zephyr.yaml b/recipes/pref_align_scan/dpo/config_zephyr.yaml index 01899bda..0dd6d379 100644 --- a/recipes/pref_align_scan/dpo/config_zephyr.yaml +++ b/recipes/pref_align_scan/dpo/config_zephyr.yaml @@ -15,7 +15,7 @@ bf16: true beta: 0.01 loss_type: sigmoid do_eval: true -evaluation_strategy: steps +eval_strategy: steps eval_steps: 100 gradient_accumulation_steps: 2 gradient_checkpointing: true diff --git a/recipes/starchat2-15b/dpo/config_v0.1.yaml b/recipes/starchat2-15b/dpo/config_v0.1.yaml index d53c8121..cf0ddb3f 100644 --- a/recipes/starchat2-15b/dpo/config_v0.1.yaml +++ b/recipes/starchat2-15b/dpo/config_v0.1.yaml @@ -16,7 +16,7 @@ preprocessing_num_workers: 12 bf16: true beta: 0.05 do_eval: true -evaluation_strategy: steps +eval_strategy: steps eval_steps: 100 gradient_accumulation_steps: 8 gradient_checkpointing: true diff --git a/recipes/starchat2-15b/sft/config_v0.1.yaml b/recipes/starchat2-15b/sft/config_v0.1.yaml index bd65890a..f5892de5 100644 --- a/recipes/starchat2-15b/sft/config_v0.1.yaml +++ b/recipes/starchat2-15b/sft/config_v0.1.yaml @@ -2,7 +2,7 @@ model_name_or_path: bigcode/starcoder2-15b model_revision: main torch_dtype: bfloat16 -use_flash_attention_2: true +attn_implementation: flash_attention_2 # Data training arguments chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" @@ -20,7 +20,7 @@ preprocessing_num_workers: 24 # SFT trainer config bf16: true do_eval: true -evaluation_strategy: epoch +eval_strategy: epoch gradient_accumulation_steps: 2 gradient_checkpointing: true gradient_checkpointing_kwargs: diff --git a/recipes/zephyr-141b-A35b/orpo/config_full.yaml b/recipes/zephyr-141b-A35b/orpo/config_full.yaml index 57ae4393..b5210132 100644 --- a/recipes/zephyr-141b-A35b/orpo/config_full.yaml +++ b/recipes/zephyr-141b-A35b/orpo/config_full.yaml @@ -2,7 +2,7 @@ model_name_or_path: mistral-community/Mixtral-8x22B-v0.1 model_revision: main torch_dtype: bfloat16 -use_flash_attention_2: true +attn_implementation: flash_attention_2 # Data training arguments chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" diff --git a/recipes/zephyr-7b-beta/README.md b/recipes/zephyr-7b-beta/README.md index d27de43a..8c082f17 100644 --- a/recipes/zephyr-7b-beta/README.md +++ b/recipes/zephyr-7b-beta/README.md @@ -4,9 +4,9 @@ As described in the Zephyr [technical report](https://huggingface.co/papers/2310.16944), training this model proceeds in two steps: 1. Apply SFT to fine-tune Mistral 7B on a filtered version of the UltraChat dataset ([link](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)). The result is an SFT model like [`zephyr-7b-sft-full`](https://huggingface.co/alignment-handbook/zephyr-7b-sft-full) or [`zephyr-7b-sft-qlora`](https://huggingface.co/alignment-handbook/zephyr-7b-sft-qlora). -2. Align the SFT model to AI feedback via DPO on a preprocessed version of the UltraFeedback dataset ([link](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)). The result is an DPO model like [`zephyr-7b-dpo-full`](https://huggingface.co/alignment-handbook/zephyr-7b-dpo-full) or [`zephyr-7b-dpo-qlora`](https://huggingface.co/alignment-handbook/zephyr-7b-dpo-qlora). +2. Align the SFT model to AI feedback via DPO on a preprocessed version of the UltraFeedback dataset ([link](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized)). The result is a DPO model like [`zephyr-7b-dpo-full`](https://huggingface.co/alignment-handbook/zephyr-7b-dpo-full) or [`zephyr-7b-dpo-qlora`](https://huggingface.co/alignment-handbook/zephyr-7b-dpo-qlora). -**Note:** after the release of Zephyr, the team at [Argilla](https://argilla.io) found that the source UltraFeedback dataset had a few thousand incorrect preference labels from GPT-4. Additionally, TRL's `SFTTrainer` had a bug in the learning rate scheduler which terminated training early. Accounting for these changes led us to find a better set of hyperparameters from those described in the technical report. In particular, for DPO training we found that training for 1 epoch with `beta=0.01` was suffucient to achieve comparable performance to `zephyr-7b-beta` (vs. 3 epochs with `beta=0.1`). +**Note:** after the release of Zephyr, the team at [Argilla](https://argilla.io) found that the source UltraFeedback dataset had a few thousand incorrect preference labels from GPT-4. Additionally, TRL's `SFTTrainer` had a bug in the learning rate scheduler which terminated training early. Accounting for these changes led us to find a better set of hyperparameters from those described in the technical report. In particular, for DPO training we found that training for 1 epoch with `beta=0.01` was sufficient to achieve comparable performance to `zephyr-7b-beta` (vs. 3 epochs with `beta=0.1`). See below for commands to train these models using either DeepSpeed ZeRO-3 or LoRA. @@ -34,11 +34,11 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con P.S. Using Flash Attention also allows you to drastically increase the batch size (x2 in my case) -Train without flash-attention: +Train without flash-attention (i.e. via PyTorch's scaled dot product attention): ```````shell # Step 1 - SFT -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_sft.py recipes/zephyr-7b-beta/sft/config_qlora.yaml --load_in_4bit=true --use_flash_attention_2=false +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_sft.py recipes/zephyr-7b-beta/sft/config_qlora.yaml --load_in_4bit=true --attn_implementation=sdpa # Step 2 - DPO -ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_dpo.py recipes/zephyr-7b-beta/dpo/config_qlora.yaml --use_flash_attention_2=false +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/multi_gpu.yaml --num_processes=1 scripts/run_dpo.py recipes/zephyr-7b-beta/dpo/config_qlora.yaml --attn_implementation=sdpa ``````` \ No newline at end of file diff --git a/recipes/zephyr-7b-beta/dpo/config_full.yaml b/recipes/zephyr-7b-beta/dpo/config_full.yaml index 9ea336b6..12b47b18 100644 --- a/recipes/zephyr-7b-beta/dpo/config_full.yaml +++ b/recipes/zephyr-7b-beta/dpo/config_full.yaml @@ -15,7 +15,7 @@ preprocessing_num_workers: 12 bf16: true beta: 0.01 do_eval: true -evaluation_strategy: steps +eval_strategy: steps eval_steps: 100 gradient_accumulation_steps: 2 gradient_checkpointing: true diff --git a/recipes/zephyr-7b-beta/dpo/config_qlora.yaml b/recipes/zephyr-7b-beta/dpo/config_qlora.yaml index 10a147ce..04536672 100644 --- a/recipes/zephyr-7b-beta/dpo/config_qlora.yaml +++ b/recipes/zephyr-7b-beta/dpo/config_qlora.yaml @@ -1,7 +1,7 @@ # Model arguments model_name_or_path: alignment-handbook/zephyr-7b-sft-qlora torch_dtype: bfloat16 -use_flash_attention_2: true +attn_implementation: flash_attention_2 # LoRA arguments use_peft: true @@ -31,7 +31,7 @@ preprocessing_num_workers: 12 bf16: true beta: 0.01 do_eval: true -evaluation_strategy: steps +eval_strategy: steps eval_steps: 100 gradient_accumulation_steps: 4 gradient_checkpointing: true @@ -54,4 +54,4 @@ save_strategy: "steps" save_steps: 100 save_total_limit: 1 seed: 42 -warmup_ratio: 0.1 \ No newline at end of file +warmup_ratio: 0.1 diff --git a/recipes/zephyr-7b-beta/sft/config_full.yaml b/recipes/zephyr-7b-beta/sft/config_full.yaml index f5eb4405..f1e8457d 100644 --- a/recipes/zephyr-7b-beta/sft/config_full.yaml +++ b/recipes/zephyr-7b-beta/sft/config_full.yaml @@ -2,7 +2,7 @@ model_name_or_path: mistralai/Mistral-7B-v0.1 model_revision: main torch_dtype: bfloat16 -use_flash_attention_2: true +attn_implementation: flash_attention_2 # Data training arguments chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" @@ -16,7 +16,7 @@ preprocessing_num_workers: 12 # SFT trainer config bf16: true do_eval: true -evaluation_strategy: epoch +eval_strategy: epoch gradient_accumulation_steps: 1 gradient_checkpointing: true gradient_checkpointing_kwargs: diff --git a/recipes/zephyr-7b-beta/sft/config_qlora.yaml b/recipes/zephyr-7b-beta/sft/config_qlora.yaml index 8a753565..13376107 100644 --- a/recipes/zephyr-7b-beta/sft/config_qlora.yaml +++ b/recipes/zephyr-7b-beta/sft/config_qlora.yaml @@ -2,7 +2,7 @@ model_name_or_path: mistralai/Mistral-7B-v0.1 model_revision: main torch_dtype: bfloat16 -use_flash_attention_2: true +attn_implementation: flash_attention_2 # LoRA arguments load_in_4bit: true @@ -31,7 +31,7 @@ preprocessing_num_workers: 12 # SFT trainer config bf16: true do_eval: true -evaluation_strategy: epoch +eval_strategy: epoch gradient_accumulation_steps: 2 gradient_checkpointing: true gradient_checkpointing_kwargs: @@ -40,7 +40,7 @@ hub_model_id: zephyr-7b-sft-qlora hub_strategy: every_save learning_rate: 2.0e-04 log_level: info -logging_steps: 5 +logging_steps: 5 logging_strategy: steps lr_scheduler_type: cosine max_seq_length: 2048 @@ -57,4 +57,4 @@ save_strategy: "steps" save_steps: 100 save_total_limit: 1 seed: 42 -warmup_ratio: 0.1 \ No newline at end of file +warmup_ratio: 0.1 diff --git a/recipes/zephyr-7b-gemma/dpo/config_full.yaml b/recipes/zephyr-7b-gemma/dpo/config_full.yaml index d643b94a..f17ac683 100644 --- a/recipes/zephyr-7b-gemma/dpo/config_full.yaml +++ b/recipes/zephyr-7b-gemma/dpo/config_full.yaml @@ -15,7 +15,7 @@ preprocessing_num_workers: 12 bf16: true beta: 0.05 do_eval: true -evaluation_strategy: steps +eval_strategy: steps eval_steps: 100 gradient_accumulation_steps: 8 gradient_checkpointing: true diff --git a/recipes/zephyr-7b-gemma/sft/config_full.yaml b/recipes/zephyr-7b-gemma/sft/config_full.yaml index a28f0e46..03226ab3 100644 --- a/recipes/zephyr-7b-gemma/sft/config_full.yaml +++ b/recipes/zephyr-7b-gemma/sft/config_full.yaml @@ -3,7 +3,7 @@ model_name_or_path: google/gemma-7b model_revision: main tokenizer_name_or_path: philschmid/gemma-tokenizer-chatml # Custom tokenizer with <|im_start|> and <|im_end|> tokens torch_dtype: bfloat16 -use_flash_attention_2: true +attn_implementation: flash_attention_2 # Data training arguments dataset_mixer: @@ -19,7 +19,7 @@ dataset_kwargs: add_special_tokens: false # We already wrap and in the chat template append_concat_token: false # No need to add across samples do_eval: true -evaluation_strategy: epoch +eval_strategy: epoch gradient_accumulation_steps: 4 gradient_checkpointing: true gradient_checkpointing_kwargs: diff --git a/scripts/README.md b/scripts/README.md index 3860e41b..1613d8cd 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -28,7 +28,7 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/fsdp+qlora.yaml --num_processes={num_gpus} scripts/run_{task}.py recipes/{model_name}/{task}/config_qlora.yaml --torch_dtype=bfloat16 --bnb_4bit_quant_storage=bfloat16 ``` -Here `{task}` refers to the type of training you wish to run. Currently the following tasks are supported: +Here `{task}` refers to the type of training you wish to run. Currently, the following tasks are supported: * continued pretraining `cpt` (note that `cpt` is only present in the `gpt-nl` example recipe) * supervised finetuning `sft` * direct preference optimisation `dpo` @@ -54,8 +54,7 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con ``` ## Logging with Weights and Biases - -By default all training metrics are logged with TensorBoard. If you have a [Weights and Biases](https://wandb.ai/site) account and are logged in, you can view the training metrics by appending `--report_to=wandb`, e.g. +By default, all training metrics are logged with TensorBoard. If you have a [Weights and Biases](https://wandb.ai/site) account and are logged in, you can view the training metrics by appending `--report_to=wandb`, e.g. ```shell ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_{task}.py recipes/{model_name}/{task}/config_full.yaml --report_to=wandb @@ -120,7 +119,7 @@ If you format your dataset in the same way, our training scripts should work out We recommend benchmarking chat models on: * [MT-Bench](https://huggingface.co/spaces/lmsys/mt-bench): a multi-turn benchmark spanning 80 dialogues and 10 domains. -* [AlpacaEval](https://github.com/tatsu-lab/alpaca_eval): a single-turn benchmark which evaluates the helpfulness of chat and instruct models against `text-davinci-003`. +* [AlpacaEval](https://github.com/tatsu-lab/alpaca_eval): a single-turn benchmark that evaluates the helpfulness of chat and instruct models against `text-davinci-003`. For both benchmarks, we have added support for the [Zephyr chat template](https://huggingface.co/alignment-handbook/zephyr-7b-sft-full/blob/ac6e600eefcce74f5e8bae1035d4f66019e93190/tokenizer_config.json#L30) (which is the default produced by our scripts), so you can evaluate models produced by our scripts as follows: @@ -137,6 +136,6 @@ For both benchmarks, we have added support for the [Zephyr chat template](https: * Next, update the [config name](https://github.com/tatsu-lab/alpaca_eval/blob/2daa6e11b194653043ca74f735728dc068e04aae/src/alpaca_eval/models_configs/zephyr-7b-beta/configs.yaml#L1) and [Hub model ID](https://github.com/tatsu-lab/alpaca_eval/blob/2daa6e11b194653043ca74f735728dc068e04aae/src/alpaca_eval/models_configs/zephyr-7b-beta/configs.yaml#L5) to match your model name. * Follow the steps to evaluate your model [here](https://github.com/tatsu-lab/alpaca_eval/tree/main#evaluating-a-model). -Note that MT-Bench and AlpacaEval rely on LLMs like GPT-4 to judge the quality of the model responses, and thus the ranking exhibit various biases including a preference for models distilled from GPTs. For that reason, we also recommend submitting your best models for human evaluation in: +Note that MT-Bench and AlpacaEval rely on LLMs like GPT-4 to judge the quality of the model responses, and thus the ranking exhibits various biases including a preference for models distilled from GPTs. For that reason, we also recommend submitting your best models for human evaluation in: * [Chatbot Arena](https://chat.lmsys.org): a live, human evaluation of chat models in head-to-head comparisons. diff --git a/scripts/run_cpt.py b/scripts/run_cpt.py index 273d9ebc..d5e56e67 100644 --- a/scripts/run_cpt.py +++ b/scripts/run_cpt.py @@ -135,7 +135,7 @@ def main(): model_kwargs = dict( revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, - use_flash_attention_2=model_args.use_flash_attention_2, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, diff --git a/scripts/run_dpo.py b/scripts/run_dpo.py index f2a8f65c..af6d5a84 100644 --- a/scripts/run_dpo.py +++ b/scripts/run_dpo.py @@ -173,7 +173,7 @@ def main(): model_kwargs = dict( revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, - use_flash_attention_2=model_args.use_flash_attention_2, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, @@ -192,7 +192,7 @@ def main(): model_kwargs = dict( revision=model_args.base_model_revision, trust_remote_code=model_args.trust_remote_code, - use_flash_attention_2=model_args.use_flash_attention_2, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=( diff --git a/scripts/run_orpo.py b/scripts/run_orpo.py index 7ab8c947..0894f4d4 100644 --- a/scripts/run_orpo.py +++ b/scripts/run_orpo.py @@ -35,8 +35,7 @@ get_quantization_config, get_tokenizer, ) -from alignment.configs import ORPOConfig -from trl import ORPOTrainer, setup_chat_format +from trl import ORPOConfig, ORPOTrainer, setup_chat_format logger = logging.getLogger(__name__) @@ -110,7 +109,7 @@ def main(): model_args.model_name_or_path, revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, - use_flash_attention_2=model_args.use_flash_attention_2, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, diff --git a/scripts/run_sft.py b/scripts/run_sft.py index 3a7879b1..848e0b98 100644 --- a/scripts/run_sft.py +++ b/scripts/run_sft.py @@ -22,15 +22,12 @@ import sys from pathlib import Path - p = Path(__file__).parent.parent / "src" sys.path.append(p.as_posix()) import datasets import torch import transformers -from transformers import AutoModelForCausalLM, set_seed - from alignment import ( DataArguments, GpuUtilPrintCallBack, @@ -46,9 +43,9 @@ get_quantization_config, get_tokenizer, ) +from transformers import AutoModelForCausalLM, set_seed from trl import SFTTrainer, setup_chat_format - logger = logging.getLogger(__name__) @@ -128,7 +125,7 @@ def main(): model_kwargs = dict( revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, - use_flash_attention_2=model_args.use_flash_attention_2, + attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, diff --git a/setup.py b/setup.py index a9b1ce6e..66792a78 100644 --- a/setup.py +++ b/setup.py @@ -43,9 +43,9 @@ _deps = [ "accelerate>=0.29.2", "bitsandbytes>=0.43.0", - "black==23.1.0", + "black>=24.4.2", "datasets>=2.18.0", - "deepspeed==0.12.2", + "deepspeed>=0.14.4", "einops>=0.6.1", "evaluate==0.4.0", "flake8>=6.0.0", @@ -64,9 +64,9 @@ "sentencepiece>=0.1.99", "scipy", "tensorboard", - "torch==2.1.2", + "torch>=2.1.2", "transformers>=4.39.3", - "trl>=0.8.2", + "trl>=0.9.6", "jinja2>=3.0.0", "tqdm>=4.64.1", ] diff --git a/src/alignment/__init__.py b/src/alignment/__init__.py index 6afd54c4..5c92315c 100644 --- a/src/alignment/__init__.py +++ b/src/alignment/__init__.py @@ -24,3 +24,21 @@ print_gpu_utilization, print_summary, ) + + +__all__ = [ + "DataArguments", + "DPOConfig", + "H4ArgumentParser", + "ModelArguments", + "SFTConfig", + "apply_chat_template", + "get_datasets", + "decontaminate_humaneval", + "get_checkpoint", + "get_kbit_device_map", + "get_peft_config", + "get_quantization_config", + "get_tokenizer", + "is_adapter_model", +] diff --git a/src/alignment/configs.py b/src/alignment/configs.py index 466dc0e1..e85c4ec6 100644 --- a/src/alignment/configs.py +++ b/src/alignment/configs.py @@ -18,10 +18,9 @@ from dataclasses import dataclass, field from typing import Any, Dict, List, NewType, Optional, Union -import transformers +import trl from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser - MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) @@ -74,7 +73,7 @@ def parse_yaml_and_args( inputs[arg] = [str(v) for v in val.split(",")] # bool of a non-empty string is True, so we manually check for bools - if base_type == bool: + if base_type is bool: if val in ["true", "True"]: inputs[arg] = True else: @@ -161,14 +160,16 @@ class ModelArguments: ) }, ) + trust_remote_code: bool = field( default=False, metadata={"help": "Trust remote code when loading a model."} ) - use_flash_attention_2: bool = field( - default=False, + + attn_implementation: Optional[str] = field( + default=None, metadata={ "help": ( - "Whether to use flash attention 2. You must install this manually by running `pip install flash-attn --no-build-isolation`" + "Which attention implementation to use; you can use --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`" ) }, ) @@ -213,6 +214,7 @@ class ModelArguments: default="uint8", metadata={"help": "storage type to pack the quanitzed 4-bit prarams."}, ) + use_flash_attention_2: bool = field(default=False) def __post_init__(self): if self.load_in_8bit and self.load_in_4bit: @@ -274,22 +276,15 @@ class DataArguments: @dataclass -class SFTConfig(transformers.TrainingArguments): +class SFTConfig(trl.SFTConfig): """ - Arguments related to the training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments + Arguments related to the training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.39.3/en/main_classes/trainer#transformers.TrainingArguments Also used for the continued pretraining task. """ - dataset_kwargs: Optional[Dict[str, Any]] = field( - default=None, metadata={"help": "Dataset kwargs for the SFTTrainer"} - ) - max_seq_length: Optional[int] = field( - default=None, - metadata={ - "help": ( - "Used by TRL for reward model training, which tries to read this parameter in init." - ) - }, + hub_model_revision: Optional[str] = field( + default="main", + metadata={"help": ("The Hub model branch to push the model to.")}, ) logging_first_step: bool = field( default=True, @@ -297,21 +292,14 @@ class SFTConfig(transformers.TrainingArguments): "help": ("Whether to log and evaluate the first global_step or not.") }, ) - optim: Optional[str] = field(default="adamw_torch") @dataclass -class DPOConfig(transformers.TrainingArguments): +class DPOConfig(trl.DPOConfig): """ - Arguments related to the DPO training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments + Arguments related to the DPO training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.39.3/en/main_classes/trainer#transformers.TrainingArguments """ - beta: Optional[float] = field( - default=0.1, - metadata={ - "help": "The beta factor in DPO loss. Higher beta means less divergence from the initial policy." - }, - ) hub_model_revision: Optional[str] = field( default="main", metadata={"help": ("The Hub model branch to push the model to.")}, @@ -322,97 +310,5 @@ class DPOConfig(transformers.TrainingArguments): "help": ("Whether to log and evaluate the first global_step or not.") }, ) - max_prompt_length: Optional[int] = field( - default=None, - metadata={ - "help": ( - "For DPO, the maximum length of the prompt to use for conditioning the model." - ) - }, - ) - max_length: Optional[int] = field( - default=None, - metadata={ - "help": ( - "Used by TRL for reward model training, which tries to read this parameter in init." - ) - }, - ) optim: Optional[str] = field(default="rmsprop") remove_unused_columns: bool = field(default=False) - loss_type: Optional[str] = field( - default="sigmoid", metadata={"help": ("The loss type for DPO.")} - ) - - -@dataclass -class ORPOConfig(transformers.TrainingArguments): - max_length: Optional[int] = field( - default=None, - metadata={"help": "The maximum length of the sequences in the batch."}, - ) - max_prompt_length: Optional[int] = field( - default=None, - metadata={"help": "The maximum length of the prompt."}, - ) - max_completion_length: Optional[int] = field( - default=None, - metadata={"help": "The maximum length of the completions."}, - ) - - beta: float = field( - default=0.1, - metadata={ - "help": "The beta factor in ORPO loss (lambda/alpha in paper/code) that is the weight of the relative loss ratio in the SFT loss." - }, - ) - disable_dropout: bool = field( - default=True, - metadata={"help": "Whether or not to disable dropouts in `model`."}, - ) - - label_pad_token_id: int = field( - default=-100, - metadata={"help": "The label pad token id."}, - ) - padding_value: Optional[int] = field( - default=None, - metadata={ - "help": "The padding value if it is different to the tokenizer's pad_token_id." - }, - ) - truncation_mode: str = field( - default="keep_end", - metadata={ - "help": "The truncation mode to use, either `keep_end` or `keep_start`." - }, - ) - - generate_during_eval: bool = field( - default=False, - metadata={ - "help": "Whether to sample and log generations during evaluation step." - }, - ) - is_encoder_decoder: Optional[bool] = field( - default=None, - metadata={ - "help": ( - "If no model is provided, we need to know if the model_init returns an encoder-decoder." - ) - }, - ) - - model_init_kwargs: Optional[Dict] = field( - default=None, - metadata={ - "help": ( - "Dict of Optional kwargs to pass when instantiating the model from a string" - ) - }, - ) - - dataset_num_proc: Optional[int] = field( - default=None, - metadata={"help": ("The number of workers to use to tokenize the data.")}, - ) diff --git a/src/alignment/release.py b/src/alignment/release.py new file mode 100644 index 00000000..a733c481 --- /dev/null +++ b/src/alignment/release.py @@ -0,0 +1,129 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import re + +import packaging.version + + +REPLACE_PATTERNS = { + "init": ( + re.compile(r'^__version__\s+=\s+"([^"]+)"\s*$', re.MULTILINE), + '__version__ = "VERSION"\n', + ), + "setup": ( + re.compile(r'^(\s*)version\s*=\s*"[^"]+",', re.MULTILINE), + r'\1version="VERSION",', + ), + "citation": (re.compile(r"^version:\s+[^ ]+", re.MULTILINE), "version: VERSION"), + "readme": ( + re.compile(r"version\s+=\s+\{[^}]+\}", re.MULTILINE), + "version = {VERSION}", + ), +} + +README_FILE = "README.md" + +REPLACE_FILES = { + "init": "src/alignment/__init__.py", + "setup": "setup.py", + "citation": "CITATION.cff", + "readme": README_FILE, +} + + +def update_version_in_file(fname, version, pattern): + """Update the version in one file using a specific pattern.""" + with open(fname, "r", encoding="utf-8", newline="\n") as f: + code = f.read() + re_pattern, replace = REPLACE_PATTERNS[pattern] + replace = replace.replace("VERSION", version) + code = re_pattern.sub(replace, code) + with open(fname, "w", encoding="utf-8", newline="\n") as f: + f.write(code) + + +def global_version_update(version, patch=False): + """Update the version in all needed files.""" + for pattern, fname in REPLACE_FILES.items(): + update_version_in_file(fname, version, pattern) + + +def get_version(): + """Reads the current version in the __init__.""" + with open(REPLACE_FILES["init"], "r") as f: + code = f.read() + default_version = REPLACE_PATTERNS["init"][0].search(code).groups()[0] + return packaging.version.parse(default_version) + + +def pre_release_work(patch=False): + """Do all the necessary pre-release steps.""" + # First let's get the default version: base version if we are in dev, bump minor otherwise. + default_version = get_version() + if patch and default_version.is_devrelease: + raise ValueError( + "Can't create a patch version from the dev branch, checkout a released version!" + ) + if default_version.is_devrelease: + default_version = default_version.base_version + elif patch: + default_version = f"{default_version.major}.{default_version.minor}.{default_version.micro + 1}" + else: + default_version = f"{default_version.major}.{default_version.minor + 1}.0" + + # Now let's ask nicely if that's the right one. + version = input(f"Which version are you releasing? [{default_version}]") + if len(version) == 0: + version = default_version + + print(f"Updating version to {version}.") + global_version_update(version, patch=patch) + + +def post_release_work(): + """Do all the necessary post-release steps.""" + # First let's get the current version + current_version = get_version() + dev_version = f"{current_version.major}.{current_version.minor + 1}.0.dev0" + current_version = current_version.base_version + + # Check with the user we got that right. + version = input(f"Which version are we developing now? [{dev_version}]") + if len(version) == 0: + version = dev_version + + print(f"Updating version to {version}.") + global_version_update(version) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--post_release", + action="store_true", + help="Whether this is pre or post release.", + ) + parser.add_argument( + "--patch", action="store_true", help="Whether or not this is a patch release." + ) + args = parser.parse_args() + if not args.post_release: + pre_release_work(patch=args.patch) + elif args.patch: + print("Nothing to do after a patch :-)") + else: + post_release_work() diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/fixtures/config_dpo_full.yaml b/tests/fixtures/config_dpo_full.yaml deleted file mode 100644 index 5110f591..00000000 --- a/tests/fixtures/config_dpo_full.yaml +++ /dev/null @@ -1,37 +0,0 @@ -# Model arguments -model_name_or_path: alignment-handbook/zephyr-7b-sft-full - -# Data training arguments -# For definitions, see: src/h4/training/config.py -dataset_mixer: - HuggingFaceH4/ultrafeedback_binarized: 1.0 -dataset_splits: -- train_prefs -- test_prefs -preprocessing_num_workers: 12 - -# DPOTrainer arguments -bf16: true -beta: 0.1 -do_eval: true -evaluation_strategy: steps -eval_steps: 100 -gradient_accumulation_steps: 1 -gradient_checkpointing: true -hub_model_id: zephyr-7b-dpo-full -learning_rate: 5.0e-7 -log_level: info -logging_steps: 10 -lr_scheduler_type: linear -max_length: 1024 -max_prompt_length: 512 -num_train_epochs: 3 -optim: rmsprop -output_dir: data/zephyr-7b-dpo-full -per_device_train_batch_size: 8 -per_device_eval_batch_size: 4 -push_to_hub: true -save_strategy: "no" -save_total_limit: null -seed: 42 -warmup_ratio: 0.1 \ No newline at end of file diff --git a/tests/fixtures/config_sft_full.yaml b/tests/fixtures/config_sft_full.yaml deleted file mode 100644 index adf13dae..00000000 --- a/tests/fixtures/config_sft_full.yaml +++ /dev/null @@ -1,41 +0,0 @@ -# Model arguments -model_name_or_path: mistralai/Mistral-7B-v0.1 -model_revision: main -torch_dtype: bfloat16 -use_flash_attention_2: true - -# Data training arguments -dataset_mixer: - HuggingFaceH4/ultrachat_200k: 1.0 -dataset_splits: -- train_sft -- test_sft -preprocessing_num_workers: 12 - -# SFT trainer config -bf16: true -do_eval: true -evaluation_strategy: epoch -gradient_accumulation_steps: 2 -gradient_checkpointing: true -hub_model_id: zephyr-7b-sft-full -hub_strategy: every_save -learning_rate: 2.0e-05 -log_level: info -logging_steps: 5 -logging_strategy: steps -lr_scheduler_type: cosine -max_seq_length: 2048 -max_steps: -1 -num_train_epochs: 1 -output_dir: data/zephyr-7b-sft-full -overwrite_output_dir: true -per_device_eval_batch_size: 16 -per_device_train_batch_size: 32 -push_to_hub: true -remove_unused_columns: true -report_to: -- tensorboard -save_strategy: "no" -save_total_limit: null -seed: 42 \ No newline at end of file diff --git a/tests/test_configs.py b/tests/test_configs.py deleted file mode 100644 index f42348d5..00000000 --- a/tests/test_configs.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import unittest - -from alignment import DataArguments, H4ArgumentParser, ModelArguments, SFTConfig - - -class H4ArgumentParserTest(unittest.TestCase): - def setUp(self): - self.parser = H4ArgumentParser((ModelArguments, DataArguments, SFTConfig)) - self.yaml_file_path = "tests/fixtures/config_sft_full.yaml" - - def test_load_yaml(self): - model_args, data_args, training_args = self.parser.parse_yaml_file( - os.path.abspath(self.yaml_file_path) - ) - self.assertEqual(model_args.model_name_or_path, "mistralai/Mistral-7B-v0.1") - - def test_load_yaml_and_args(self): - command_line_args = [ - "--model_name_or_path=test", - "--use_peft=true", - "--lora_r=16", - "--lora_dropout=0.5", - ] - model_args, data_args, training_args = self.parser.parse_yaml_and_args( - os.path.abspath(self.yaml_file_path), command_line_args - ) - self.assertEqual(model_args.model_name_or_path, "test") - self.assertEqual(model_args.use_peft, True) - self.assertEqual(model_args.lora_r, 16) - self.assertEqual(model_args.lora_dropout, 0.5) diff --git a/tests/test_data.py b/tests/test_data.py deleted file mode 100644 index cd8021e9..00000000 --- a/tests/test_data.py +++ /dev/null @@ -1,209 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import unittest -from copy import deepcopy - -import pytest -from datasets import Dataset -from transformers import AutoTokenizer - -from alignment import ( - DataArguments, - ModelArguments, - apply_chat_template, - get_datasets, - get_tokenizer, -) -from alignment.data import maybe_insert_system_message - - -class GetDatasetsTest(unittest.TestCase): - """Each of these test datasets has 100 examples""" - - def test_loading_data_args(self): - dataset_mixer = { - "HuggingFaceH4/testing_alpaca_small": 0.5, - "HuggingFaceH4/testing_self_instruct_small": 0.3, - "HuggingFaceH4/testing_codealpaca_small": 0.2, - } - data_args = DataArguments(dataset_mixer=dataset_mixer) - datasets = get_datasets(data_args, columns_to_keep=["prompt", "completion"]) - self.assertEqual(len(datasets["train"]), 100) - self.assertEqual(len(datasets["test"]), 300) - - def test_loading_data_dict(self): - dataset_mixer = { - "HuggingFaceH4/testing_alpaca_small": 0.5, - "HuggingFaceH4/testing_self_instruct_small": 0.3, - "HuggingFaceH4/testing_codealpaca_small": 0.2, - } - datasets = get_datasets(dataset_mixer, columns_to_keep=["prompt", "completion"]) - self.assertEqual(len(datasets["train"]), 100) - self.assertEqual(len(datasets["test"]), 300) - - def test_loading_with_unit_fractions(self): - dataset_mixer = { - "HuggingFaceH4/testing_alpaca_small": 1.0, - "HuggingFaceH4/testing_self_instruct_small": 1.0, - "HuggingFaceH4/testing_codealpaca_small": 1.0, - } - datasets = get_datasets(dataset_mixer, columns_to_keep=["prompt", "completion"]) - self.assertEqual(len(datasets["train"]), 300) - self.assertEqual(len(datasets["test"]), 300) - - def test_loading_with_fractions_greater_than_unity(self): - dataset_mixer = { - "HuggingFaceH4/testing_alpaca_small": 0.7, - "HuggingFaceH4/testing_self_instruct_small": 0.4, - } - datasets = get_datasets(dataset_mixer, columns_to_keep=["prompt", "completion"]) - self.assertEqual(len(datasets["train"]), 70 + 40) - self.assertEqual(len(datasets["test"]), 200) - - def test_loading_fails_with_negative_fractions(self): - dataset_mixer = { - "HuggingFaceH4/testing_alpaca_small": 0.7, - "HuggingFaceH4/testing_self_instruct_small": -0.3, - } - with pytest.raises(ValueError, match=r"Dataset fractions cannot be negative."): - get_datasets(dataset_mixer, columns_to_keep=["prompt", "completion"]) - - def test_loading_single_split_with_unit_fractions(self): - dataset_mixer = { - "HuggingFaceH4/testing_alpaca_small": 1.0, - } - datasets = get_datasets( - dataset_mixer, splits=["test"], columns_to_keep=["prompt", "completion"] - ) - self.assertEqual(len(datasets["test"]), 100) - self.assertRaises(KeyError, lambda: datasets["train"]) - - -class ApplyChatTemplateTest(unittest.TestCase): - def setUp(self): - model_args = ModelArguments(model_name_or_path="HuggingFaceH4/zephyr-7b-alpha") - data_args = DataArguments() - self.tokenizer = get_tokenizer(model_args, data_args) - self.dataset = Dataset.from_dict( - { - "prompt": ["Hello!"], - "messages": [ - [ - {"role": "system", "content": "You are a happy chatbot"}, - {"role": "user", "content": "Hello!"}, - {"role": "assistant", "content": "Bonjour!"}, - {"role": "user", "content": "How are you?"}, - {"role": "assistant", "content": "I am doing well, thanks!"}, - ] - ], - "chosen": [ - [ - {"role": "system", "content": "You are a happy chatbot"}, - {"role": "user", "content": "Hello!"}, - {"role": "assistant", "content": "Bonjour!"}, - {"role": "user", "content": "How are you?"}, - {"role": "assistant", "content": "I am doing well, thanks!"}, - ] - ], - "rejected": [ - [ - {"role": "system", "content": "You are a happy chatbot"}, - {"role": "user", "content": "Hello!"}, - {"role": "assistant", "content": "Bonjour!"}, - {"role": "user", "content": "How are you?"}, - {"role": "assistant", "content": "Not so good tbh"}, - ] - ], - } - ) - - def test_maybe_insert_system_message(self): - # does not accept system prompt - mistral_tokenizer = AutoTokenizer.from_pretrained( - "mistralai/Mistral-7B-Instruct-v0.2" - ) - # accepts system prompt. use codellama since it has no HF token reqiurement - llama_tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-hf") - messages_sys_excl = [{"role": "user", "content": "Tell me a joke."}] - messages_sys_incl = [ - {"role": "system", "content": ""}, - {"role": "user", "content": "Tell me a joke."}, - ] - - mistral_messages = deepcopy(messages_sys_excl) - llama_messages = deepcopy(messages_sys_excl) - maybe_insert_system_message(mistral_messages, mistral_tokenizer) - maybe_insert_system_message(llama_messages, llama_tokenizer) - - # output from mistral should not have a system message, output from llama should - self.assertEqual(mistral_messages, messages_sys_excl) - self.assertEqual(llama_messages, messages_sys_incl) - - def test_sft(self): - dataset = self.dataset.map( - apply_chat_template, - fn_kwargs={"tokenizer": self.tokenizer, "task": "sft"}, - remove_columns=self.dataset.column_names, - ) - self.assertDictEqual( - dataset[0], - { - "text": "<|system|>\nYou are a happy chatbot\n<|user|>\nHello!\n<|assistant|>\nBonjour!\n<|user|>\nHow are you?\n<|assistant|>\nI am doing well, thanks!\n" - }, - ) - - def test_generation(self): - # Remove last turn from messages - dataset = self.dataset.map(lambda x: {"messages": x["messages"][:-1]}) - dataset = dataset.map( - apply_chat_template, - fn_kwargs={"tokenizer": self.tokenizer, "task": "generation"}, - remove_columns=self.dataset.column_names, - ) - self.assertDictEqual( - dataset[0], - { - "text": "<|system|>\nYou are a happy chatbot\n<|user|>\nHello!\n<|assistant|>\nBonjour!\n<|user|>\nHow are you?\n<|assistant|>\n" - }, - ) - - def test_rm(self): - dataset = self.dataset.map( - apply_chat_template, - fn_kwargs={"tokenizer": self.tokenizer, "task": "rm"}, - remove_columns=self.dataset.column_names, - ) - self.assertDictEqual( - dataset[0], - { - "text_chosen": "<|system|>\nYou are a happy chatbot\n<|user|>\nHello!\n<|assistant|>\nBonjour!\n<|user|>\nHow are you?\n<|assistant|>\nI am doing well, thanks!\n", - "text_rejected": "<|system|>\nYou are a happy chatbot\n<|user|>\nHello!\n<|assistant|>\nBonjour!\n<|user|>\nHow are you?\n<|assistant|>\nNot so good tbh\n", - }, - ) - - def test_dpo(self): - dataset = self.dataset.map( - apply_chat_template, - fn_kwargs={"tokenizer": self.tokenizer, "task": "dpo"}, - remove_columns=self.dataset.column_names, - ) - self.assertDictEqual( - dataset[0], - { - "text_prompt": "<|system|>\nYou are a happy chatbot\n<|user|>\nHello!\n<|assistant|>\nBonjour!\n<|user|>\nHow are you?\n", - "text_chosen": "<|assistant|>\nI am doing well, thanks!\n", - "text_rejected": "<|assistant|>\nNot so good tbh\n", - }, - ) diff --git a/tests/test_model_utils.py b/tests/test_model_utils.py deleted file mode 100644 index 612c97c1..00000000 --- a/tests/test_model_utils.py +++ /dev/null @@ -1,121 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import unittest - -import torch -from transformers import AutoTokenizer - -from alignment import ( - DataArguments, - ModelArguments, - get_peft_config, - get_quantization_config, - get_tokenizer, - is_adapter_model, -) -from alignment.data import DEFAULT_CHAT_TEMPLATE - - -class GetQuantizationConfigTest(unittest.TestCase): - def test_4bit(self): - model_args = ModelArguments(load_in_4bit=True) - quantization_config = get_quantization_config(model_args) - self.assertTrue(quantization_config.load_in_4bit) - self.assertEqual(quantization_config.bnb_4bit_compute_dtype, torch.float16) - self.assertEqual(quantization_config.bnb_4bit_quant_type, "nf4") - self.assertFalse(quantization_config.bnb_4bit_use_double_quant) - - def test_8bit(self): - model_args = ModelArguments(load_in_8bit=True) - quantization_config = get_quantization_config(model_args) - self.assertTrue(quantization_config.load_in_8bit) - - def test_no_quantization(self): - model_args = ModelArguments() - quantization_config = get_quantization_config(model_args) - self.assertIsNone(quantization_config) - - -class GetTokenizerTest(unittest.TestCase): - def setUp(self) -> None: - self.model_args = ModelArguments( - model_name_or_path="HuggingFaceH4/zephyr-7b-alpha" - ) - - def test_right_truncation_side(self): - tokenizer = get_tokenizer( - self.model_args, DataArguments(truncation_side="right") - ) - self.assertEqual(tokenizer.truncation_side, "right") - - def test_left_truncation_side(self): - tokenizer = get_tokenizer( - self.model_args, DataArguments(truncation_side="left") - ) - self.assertEqual(tokenizer.truncation_side, "left") - - def test_default_chat_template(self): - tokenizer = get_tokenizer(self.model_args, DataArguments()) - self.assertEqual(tokenizer.chat_template, DEFAULT_CHAT_TEMPLATE) - - def test_default_chat_template_no_overwrite(self): - """ - If no chat template is passed explicitly in the config, then for models with a - `default_chat_template` but no `chat_template` we do not set a `chat_template`, - and that we do not change `default_chat_template` - """ - model_args = ModelArguments( - model_name_or_path="m-a-p/OpenCodeInterpreter-SC2-7B" - ) - base_tokenizer = AutoTokenizer.from_pretrained( - "m-a-p/OpenCodeInterpreter-SC2-7B" - ) - processed_tokenizer = get_tokenizer(model_args, DataArguments()) - - assert getattr(processed_tokenizer, "chat_template") is None - self.assertEqual( - base_tokenizer.default_chat_template, - processed_tokenizer.default_chat_template, - ) - - def test_chatml_chat_template(self): - chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" - tokenizer = get_tokenizer( - self.model_args, DataArguments(chat_template=chat_template) - ) - self.assertEqual(tokenizer.chat_template, chat_template) - - -class GetPeftConfigTest(unittest.TestCase): - def test_peft_config(self): - model_args = ModelArguments( - use_peft=True, lora_r=42, lora_alpha=0.66, lora_dropout=0.99 - ) - peft_config = get_peft_config(model_args) - self.assertEqual(peft_config.r, 42) - self.assertEqual(peft_config.lora_alpha, 0.66) - self.assertEqual(peft_config.lora_dropout, 0.99) - - def test_no_peft_config(self): - model_args = ModelArguments(use_peft=False) - peft_config = get_peft_config(model_args) - self.assertIsNone(peft_config) - - -class IsAdapterModelTest(unittest.TestCase): - def test_is_adapter_model_calls_listdir(self): - # Assert that for an invalid repo name it gets to the point where it calls os.listdir, - # which is expected to raise a FileNotFoundError - self.assertRaises(FileNotFoundError, is_adapter_model, "nonexistent/model")