Skip to content

Commit

Permalink
sync
Browse files Browse the repository at this point in the history
  • Loading branch information
xiyang-aads-lilly committed Aug 19, 2024
2 parents 4f70851 + a8dcde2 commit fbebbc1
Show file tree
Hide file tree
Showing 7 changed files with 384 additions and 7 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ However, we know from the [InstructGPT](https://huggingface.co/papers/2203.02155
The Alignment Handbook aims to fill that gap by providing the community with a series of robust training recipes that span the whole pipeline.

## News 🗞️
* **August 18, 2024**: We release SmolLM-Instruct v0.2, along with the [recipe](recipes/smollm/README.md) to fine-tuning small LLMs 💻
* **April 12, 2024**: We release Zephyr 141B (A35B), in collaboration with Argilla and Kaist AI, along with the recipe to fine-tune Mixtral 8x22B with ORPO 🪁
* **March 12, 2024:** We release StarChat2 15B, along with the recipe to train capable coding assistants 🌟
* **March 1, 2024:** We release Zephyr 7B Gemma, which is a new recipe to align Gemma 7B with RLAIF 🔥
Expand Down
19 changes: 19 additions & 0 deletions recipes/smollm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@

# Instructions to train SmolLM-Instruct

We build the [SmolLM-Instruct](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966) (v0.2) models (135M, 360M and 1.7B) by doing SFT on a mix of these datasets:
- a dataset of 2k simple everyday conversations we generated by llama3.1-70B [everyday-conversations-llama3.1-2k](https://huggingface.co/datasets/HuggingFaceTB/everyday-conversations-llama3.1-2k/)
- [Magpie-Pro-300K-Filtered](https://huggingface.co/datasets/Magpie-Align/Magpie-Pro-300K-Filtered)
- [StarCoder2-Self-OSS-Instruct](https://huggingface.co/datasets/bigcode/self-oss-instruct-sc2-exec-filter-50k)
- A small subset of [OpenHermes-2.5](https://huggingface.co/datasets/teknium/OpenHermes-2.5)

## Setup

Follow the installation instructions in https://github.com/huggingface/alignment-handbook/tree/main?tab=readme-ov-file#installation-instructions

## Training
We train the models on 8 GPUs using the following command:

```shell
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/deepspeed_zero3.yaml scripts/run_sft.py recipes/smollm/sft/config.yaml
```
53 changes: 53 additions & 0 deletions recipes/smollm/sft/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Model arguments
model_name_or_path: HuggingFaceTB/SmolLM-360M
model_revision: main
tokenizer_name_or_path: HuggingFaceTB/SmolLM-360M-Instruct # Custom tokenizer with <|im_start|> and <|im_end|> tokens
torch_dtype: bfloat16
use_flash_attention_2: true

# Data training arguments
dataset_mixer:
HuggingFaceTB/Magpie-Pro-300K-Filtered-H4: 1.0
HuggingFaceTB/self-oss-instruct-sc2-H4: 1.0
HuggingFaceTB/OpenHermes-2.5-H4: 0.001
HuggingFaceTB/everyday-conversations-llama3.1-2k: 1.0
HuggingFaceTB/instruct-data-basics-smollm-H4: 1.0

dataset_splits:
- train_sft
- test_sft
preprocessing_num_workers: 36

# SFT trainer config
bf16: true
dataset_kwargs:
add_special_tokens: false # We already wrap <bos> and <eos> in the chat template
append_concat_token: false # No need to add <eos> across samples
do_eval: true
evaluation_strategy: epoch
gradient_accumulation_steps: 4
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: smollm-360M-instruct-new
hub_strategy: every_save
learning_rate: 1.0e-03 # 3e-4
log_level: info
logging_steps: 5
logging_strategy: steps
lr_scheduler_type: cosine
max_seq_length: 2048
max_steps: -1
num_train_epochs: 1
output_dir: data/smollm-360M-instruct-new
overwrite_output_dir: true
per_device_eval_batch_size: 4
per_device_train_batch_size: 4
push_to_hub: true
remove_unused_columns: true
report_to:
- tensorboard
- wandb
save_strategy: "no"
seed: 42
warmup_ratio: 0.1
2 changes: 1 addition & 1 deletion src/alignment/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def maybe_insert_system_message(messages, tokenizer):
# chat template can be one of two attributes, we check in order
chat_template = tokenizer.chat_template
if chat_template is None:
chat_template = tokenizer.default_chat_template
chat_template = tokenizer.get_chat_template()

# confirm the jinja template refers to a system message before inserting
if "system" in chat_template or "<|im_start|>" in chat_template:
Expand Down
7 changes: 1 addition & 6 deletions src/alignment/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,7 @@ def get_tokenizer(

if data_args.chat_template is not None:
tokenizer.chat_template = data_args.chat_template

elif (
auto_set_chat_template
and tokenizer.chat_template is None
and tokenizer.default_chat_template is None
):
elif auto_set_chat_template and tokenizer.get_chat_template() is None:
tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE

return tokenizer
Expand Down
209 changes: 209 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from copy import deepcopy

import pytest
from datasets import Dataset
from transformers import AutoTokenizer

from alignment import (
DataArguments,
ModelArguments,
apply_chat_template,
get_datasets,
get_tokenizer,
)
from alignment.data import maybe_insert_system_message


class GetDatasetsTest(unittest.TestCase):
"""Each of these test datasets has 100 examples"""

def test_loading_data_args(self):
dataset_mixer = {
"HuggingFaceH4/testing_alpaca_small": 0.5,
"HuggingFaceH4/testing_self_instruct_small": 0.3,
"HuggingFaceH4/testing_codealpaca_small": 0.2,
}
data_args = DataArguments(dataset_mixer=dataset_mixer)
datasets = get_datasets(data_args, columns_to_keep=["prompt", "completion"])
self.assertEqual(len(datasets["train"]), 100)
self.assertEqual(len(datasets["test"]), 300)

def test_loading_data_dict(self):
dataset_mixer = {
"HuggingFaceH4/testing_alpaca_small": 0.5,
"HuggingFaceH4/testing_self_instruct_small": 0.3,
"HuggingFaceH4/testing_codealpaca_small": 0.2,
}
datasets = get_datasets(dataset_mixer, columns_to_keep=["prompt", "completion"])
self.assertEqual(len(datasets["train"]), 100)
self.assertEqual(len(datasets["test"]), 300)

def test_loading_with_unit_fractions(self):
dataset_mixer = {
"HuggingFaceH4/testing_alpaca_small": 1.0,
"HuggingFaceH4/testing_self_instruct_small": 1.0,
"HuggingFaceH4/testing_codealpaca_small": 1.0,
}
datasets = get_datasets(dataset_mixer, columns_to_keep=["prompt", "completion"])
self.assertEqual(len(datasets["train"]), 300)
self.assertEqual(len(datasets["test"]), 300)

def test_loading_with_fractions_greater_than_unity(self):
dataset_mixer = {
"HuggingFaceH4/testing_alpaca_small": 0.7,
"HuggingFaceH4/testing_self_instruct_small": 0.4,
}
datasets = get_datasets(dataset_mixer, columns_to_keep=["prompt", "completion"])
self.assertEqual(len(datasets["train"]), 70 + 40)
self.assertEqual(len(datasets["test"]), 200)

def test_loading_fails_with_negative_fractions(self):
dataset_mixer = {
"HuggingFaceH4/testing_alpaca_small": 0.7,
"HuggingFaceH4/testing_self_instruct_small": -0.3,
}
with pytest.raises(ValueError, match=r"Dataset fractions cannot be negative."):
get_datasets(dataset_mixer, columns_to_keep=["prompt", "completion"])

def test_loading_single_split_with_unit_fractions(self):
dataset_mixer = {
"HuggingFaceH4/testing_alpaca_small": 1.0,
}
datasets = get_datasets(
dataset_mixer, splits=["test"], columns_to_keep=["prompt", "completion"]
)
self.assertEqual(len(datasets["test"]), 100)
self.assertRaises(KeyError, lambda: datasets["train"])


class ApplyChatTemplateTest(unittest.TestCase):
def setUp(self):
model_args = ModelArguments(model_name_or_path="HuggingFaceH4/zephyr-7b-alpha")
data_args = DataArguments()
self.tokenizer = get_tokenizer(model_args, data_args)
self.dataset = Dataset.from_dict(
{
"prompt": ["Hello!"],
"messages": [
[
{"role": "system", "content": "You are a happy chatbot"},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Bonjour!"},
{"role": "user", "content": "How are you?"},
{"role": "assistant", "content": "I am doing well, thanks!"},
]
],
"chosen": [
[
{"role": "system", "content": "You are a happy chatbot"},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Bonjour!"},
{"role": "user", "content": "How are you?"},
{"role": "assistant", "content": "I am doing well, thanks!"},
]
],
"rejected": [
[
{"role": "system", "content": "You are a happy chatbot"},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Bonjour!"},
{"role": "user", "content": "How are you?"},
{"role": "assistant", "content": "Not so good tbh"},
]
],
}
)

def test_maybe_insert_system_message(self):
# Chat template that does not accept system prompt. Use community checkpoint since it has no HF token requirement
tokenizer_sys_excl = AutoTokenizer.from_pretrained(
"mistral-community/Mistral-7B-Instruct-v0.3"
)
# Chat template that accepts system prompt
tokenizer_sys_incl = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct")
messages_sys_excl = [{"role": "user", "content": "Tell me a joke."}]
messages_sys_incl = [
{"role": "system", "content": ""},
{"role": "user", "content": "Tell me a joke."},
]

messages_proc_excl = deepcopy(messages_sys_excl)
message_proc_incl = deepcopy(messages_sys_excl)
maybe_insert_system_message(messages_proc_excl, tokenizer_sys_excl)
maybe_insert_system_message(message_proc_incl, tokenizer_sys_incl)

# output from mistral should not have a system message, output from llama should
self.assertEqual(messages_proc_excl, messages_sys_excl)
self.assertEqual(message_proc_incl, messages_sys_incl)

def test_sft(self):
dataset = self.dataset.map(
apply_chat_template,
fn_kwargs={"tokenizer": self.tokenizer, "task": "sft"},
remove_columns=self.dataset.column_names,
)
self.assertDictEqual(
dataset[0],
{
"text": "<|system|>\nYou are a happy chatbot</s>\n<|user|>\nHello!</s>\n<|assistant|>\nBonjour!</s>\n<|user|>\nHow are you?</s>\n<|assistant|>\nI am doing well, thanks!</s>\n"
},
)

def test_generation(self):
# Remove last turn from messages
dataset = self.dataset.map(lambda x: {"messages": x["messages"][:-1]})
dataset = dataset.map(
apply_chat_template,
fn_kwargs={"tokenizer": self.tokenizer, "task": "generation"},
remove_columns=self.dataset.column_names,
)
self.assertDictEqual(
dataset[0],
{
"text": "<|system|>\nYou are a happy chatbot</s>\n<|user|>\nHello!</s>\n<|assistant|>\nBonjour!</s>\n<|user|>\nHow are you?</s>\n<|assistant|>\n"
},
)

def test_rm(self):
dataset = self.dataset.map(
apply_chat_template,
fn_kwargs={"tokenizer": self.tokenizer, "task": "rm"},
remove_columns=self.dataset.column_names,
)
self.assertDictEqual(
dataset[0],
{
"text_chosen": "<|system|>\nYou are a happy chatbot</s>\n<|user|>\nHello!</s>\n<|assistant|>\nBonjour!</s>\n<|user|>\nHow are you?</s>\n<|assistant|>\nI am doing well, thanks!</s>\n",
"text_rejected": "<|system|>\nYou are a happy chatbot</s>\n<|user|>\nHello!</s>\n<|assistant|>\nBonjour!</s>\n<|user|>\nHow are you?</s>\n<|assistant|>\nNot so good tbh</s>\n",
},
)

def test_dpo(self):
dataset = self.dataset.map(
apply_chat_template,
fn_kwargs={"tokenizer": self.tokenizer, "task": "dpo"},
remove_columns=self.dataset.column_names,
)
self.assertDictEqual(
dataset[0],
{
"text_prompt": "<|system|>\nYou are a happy chatbot</s>\n<|user|>\nHello!</s>\n<|assistant|>\nBonjour!</s>\n<|user|>\nHow are you?</s>\n",
"text_chosen": "<|assistant|>\nI am doing well, thanks!</s>\n",
"text_rejected": "<|assistant|>\nNot so good tbh</s>\n",
},
)
Loading

0 comments on commit fbebbc1

Please sign in to comment.