-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
263 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
data/ | ||
.ipynb_checkpoints/ | ||
wandb/ | ||
toy/ | ||
llama2_qlora.ipynb | ||
__pycache__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import datetime | ||
|
||
import torch | ||
from datasets import load_dataset | ||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training | ||
from transformers import (AutoModelForCausalLM, AutoTokenizer, | ||
BitsAndBytesConfig, TrainingArguments) | ||
|
||
import wandb | ||
from mandrill_utils.loggin_utils import generate_random_string | ||
from preprocess.chat import llama_get_input_with_labels | ||
from train.trainers import MandrillTrainer | ||
from train.utils import print_trainable_parameters | ||
|
||
HUGGINGFACE_API_TOKEN = "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" | ||
|
||
TOY = True | ||
BATCH_SIZE = 2 | ||
BASE_RUN_NAME = "llama2-7b-qlora" | ||
SAVE_DATA_POINTS = 2000 | ||
HF_CACHE_DIR = "/notebooks/.cache/huggingface" | ||
WANDB_PROJECT = "mandrill" | ||
WANDB_TEAM = "yieldinc" | ||
""" | ||
https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.set_training.gradient_accumulation_steps | ||
When using gradient accumulation, one step is counted as one step with backward pass. | ||
Therefore, logging, evaluation, save will be conducted every gradient_accumulation_steps * xxx_step training examples. | ||
""" | ||
GRADIENT_ACCUMULATION_STEPS = 4 | ||
|
||
if TOY: | ||
SAVE_DATA_POINTS = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS | ||
WANDB_PROJECT += "-toy" | ||
|
||
model_id = "meta-llama/Llama-2-7b-chat-hf" | ||
# model_id = "codellama/CodeLlama-7b-Instruct-hf" | ||
bnb_config = BitsAndBytesConfig( | ||
load_in_4bit=True, | ||
bnb_4bit_use_double_quant=True, | ||
bnb_4bit_quant_type="nf4", | ||
bnb_4bit_compute_dtype=torch.bfloat16, | ||
) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained( | ||
model_id, cache_dir=HF_CACHE_DIR, token=HUGGINGFACE_API_TOKEN | ||
) | ||
tokenizer.pad_token = tokenizer.eos_token | ||
model = AutoModelForCausalLM.from_pretrained( | ||
model_id, | ||
quantization_config=bnb_config, | ||
device_map={"": 0}, | ||
cache_dir=HF_CACHE_DIR, | ||
token=HUGGINGFACE_API_TOKEN, | ||
) | ||
|
||
model.gradient_checkpointing_enable() | ||
model = prepare_model_for_kbit_training(model) | ||
|
||
config = LoraConfig( | ||
r=8, | ||
lora_alpha=32, | ||
target_modules=["q_proj", "k_proj", "v_proj"], | ||
lora_dropout=0.05, | ||
bias="none", | ||
task_type="CAUSAL_LM", | ||
) | ||
|
||
model = get_peft_model(model, config) | ||
print_trainable_parameters(model) | ||
|
||
data = load_dataset("json", data_files="data/instructions.jsonl") | ||
data = data.map(llama_get_input_with_labels) | ||
|
||
output_root = "outputs/toy" if TOY else "outputs" | ||
run_name = ( | ||
f"{datetime.today().date()}_{BASE_RUN_NAME}_{generate_random_string(5).lower()}" | ||
) | ||
output_dir = f"{output_root}/{run_name}" | ||
print("output_dir:", output_dir) | ||
save_steps = SAVE_DATA_POINTS // (BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS) | ||
|
||
wandb.init(entity=WANDB_TEAM, project=WANDB_PROJECT, name=run_name) | ||
|
||
trainer = MandrillTrainer( | ||
model=model, | ||
train_dataset=data["train"], | ||
args=TrainingArguments( | ||
per_device_train_batch_size=BATCH_SIZE, | ||
gradient_accumulation_steps=4, | ||
warmup_steps=2, | ||
save_steps=save_steps, | ||
learning_rate=2e-4, | ||
fp16=True, | ||
logging_steps=1, | ||
output_dir=output_dir, | ||
optim="paged_adamw_8bit", | ||
), | ||
data_collator=transformers.DataCollatorForSeq2Seq(tokenizer), | ||
) | ||
model.config.use_cache = False # silence the warnings. Please re-enable for inference! | ||
trainer.train() | ||
wandb.finish() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from datetime import datetime | ||
import random | ||
import string | ||
import hashlib | ||
|
||
# Get the current datetime as a string | ||
current_datetime_str = str(datetime.now()) | ||
|
||
# Hash the datetime string to generate a seed | ||
seed = int(hashlib.sha256(current_datetime_str.encode()).hexdigest(), 16) | ||
|
||
# Create a random number generator instance with the seed | ||
str_random = random.Random(seed) | ||
|
||
def generate_random_string(length): | ||
# Define the set of characters you want in the random string | ||
characters = string.ascii_letters + string.digits # You can include other characters if needed | ||
|
||
# Generate the random string by selecting characters randomly | ||
random_string = ''.join(str_random.choice(characters) for _ in range(length)) | ||
|
||
return random_string |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
from llama.generation import Message, Dialog, B_INST, E_INST, B_SYS, E_SYS | ||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | ||
from preprocess.prompts import SYSTEM_PROMPT | ||
import os | ||
|
||
HF_CACHE_DIR = os.environ.get('HUGGINGFACE_CACHE_DIR', '/notebooks/.cache/huggingface') | ||
HF_API_TOKEN = os.environ.get('HUGGINGFACE_API_TOKEN', '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') | ||
|
||
LLAMA_HF_ID = "meta-llama/Llama-2-7b-chat-hf" | ||
llama_tokenizer = AutoTokenizer.from_pretrained(LLAMA_HF_ID, cache_dir=HF_CACHE_DIR, token=HF_API_TOKEN) | ||
llama_tokenizer.pad_token = llama_tokenizer.eos_token | ||
|
||
def llama_dialog2tokens(dialog: Dialog, tokenizer=llama_tokenizer, verbose=False): | ||
# copied / adapted from https://github.com/facebookresearch/llama/blob/d58f9ae95c299fe6388ee2da2c87fd90cd360d41/llama/generation.py#L284 | ||
if dialog[0]["role"] == "system": | ||
dialog = [ | ||
{ | ||
"role": dialog[1]["role"], | ||
"content": B_SYS | ||
+ dialog[0]["content"] | ||
+ E_SYS | ||
+ dialog[1]["content"], | ||
} | ||
] + dialog[2:] | ||
assert all([msg["role"] == "user" for msg in dialog[::2]]) and all( | ||
[msg["role"] == "assistant" for msg in dialog[1::2]] | ||
), ( | ||
"model only supports 'system', 'user' and 'assistant' roles, " | ||
"starting with 'system', then 'user' and alternating (u/a/u/a/u...)" | ||
) | ||
dialog_tokens: List[int] = sum( | ||
[ | ||
tokenizer.encode( | ||
f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} {tokenizer.eos_token}", | ||
) | ||
for prompt, answer in zip( | ||
dialog[::2], | ||
dialog[1::2], | ||
) | ||
], | ||
[], | ||
) | ||
if verbose: | ||
messages = [ | ||
f"{tokenizer.bos_token} {B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} {tokenizer.eos_token}" | ||
for prompt, answer in zip( | ||
dialog[::2], | ||
dialog[1::2],)] | ||
assert ( | ||
dialog[-1]["role"] == "user" | ||
), f"Last message must be from user, got {dialog[-1]['role']}" | ||
dialog_tokens += tokenizer.encode( | ||
f"{tokenizer.bos_token} {B_INST} {(dialog[-1]['content']).strip()} {E_INST}", | ||
) | ||
if verbose: | ||
messages.append(f"{tokenizer.bos_token} {B_INST} {(dialog[-1]['content']).strip()} {E_INST}") | ||
display(messages) | ||
return dialog_tokens | ||
|
||
def llama_get_prompt_tokens(jsonl_row, system_message=SYSTEM_PROMPT): | ||
# TODO: make a dataclass or pydantic type for jsonl_row | ||
SYSTEM_MESSAGE = Message(role='system', content=SYSTEM_PROMPT) | ||
dialog = [ | ||
SYSTEM_MESSAGE, | ||
Message(role='user', content=jsonl_row['instruction']), | ||
] | ||
return llama_dialog2tokens(dialog) | ||
|
||
def llama_get_input_with_labels(row, tokenizer=llama_tokenizer): | ||
prompt_tokens = llama_get_prompt_tokens(row) | ||
response_tokens = tokenizer( | ||
f"{row['response']} {tokenizer.eos_token}" | ||
)['input_ids'] | ||
input_ids = prompt_tokens + response_tokens | ||
attention_mask = [1] * len(input_ids) | ||
labels = [-100]*len(prompt_tokens) + response_tokens | ||
return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels} | ||
|
||
if __name__ == '__main__': | ||
sample_dialog: Dialog = [ | ||
{"role": "system", "content": "Welcome to the virtual assistant."}, | ||
{"role": "user", "content": "Hello!"}, | ||
{"role": "assistant", "content": "Hi there! How can I assist you today?"}, | ||
{"role": "user", "content": "I need some help with Python programming."}, | ||
{"role": "assistant", "content": "Sure, I can help with Python. What do you need assistance with?"}, | ||
{"role": "user", "content": "I'm having trouble with a Python script."}, | ||
{"role": "assistant", "content": "Could you please provide the script or describe the issue you're facing?"}, | ||
{"role": "user", "content": "```python\n" | ||
"def calculate_square(x):\n" | ||
" return x ** 2\n" | ||
"```"}, | ||
{"role": "assistant", "content": "Thank you for sharing the script. What seems to be the problem with it?"}, | ||
{"role": "user", "content": "I'm getting a 'NameError' for 'x' when I run it."}, | ||
{"role": "assistant", "content": "The 'NameError' indicates that 'x' is not defined. You should provide a value for 'x' when calling the function."}, | ||
{"role": "user", "content": "```python\n" | ||
"def calculate_square(x):\n" | ||
" x = 5 # Assign a value to x\n" | ||
" return x ** 2\n" | ||
"```"}, | ||
{"role": "assistant", "content": "Great! Now the 'x' variable is defined. Is there anything else you need help with?"}, | ||
{"role": "user", "content": "No, that's all for now. Thanks for your assistance!"} | ||
] | ||
|
||
print(dialog2tokens(sample_dialog, verbose=True)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
SYSTEM_PROMPT = f""" | ||
You are a superintelligent reasoning agent. You think logically and carefully, and verbalize your reasoning process when thinking. | ||
""".strip() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from transformers import Trainer | ||
|
||
class MandrillTrainer(Trainer): | ||
""" | ||
avoid setting label to None: https://github.com/huggingface/transformers/blob/5a4f340df74b42b594aedf60199eea95cdb9bed0/src/transformers/trainer.py#L2703C26-L2703C26 | ||
""" | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
def compute_loss(self, model, inputs): | ||
outputs = model(**inputs) | ||
return outputs.loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
def print_trainable_parameters(model): | ||
""" | ||
Prints the number of trainable parameters in the model. | ||
""" | ||
trainable_params = 0 | ||
all_param = 0 | ||
for _, param in model.named_parameters(): | ||
all_param += param.numel() | ||
if param.requires_grad: | ||
trainable_params += param.numel() | ||
print( | ||
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param: .4f}%" | ||
) |