Skip to content

Commit

Permalink
bad code
Browse files Browse the repository at this point in the history
  • Loading branch information
hjc-puro committed Sep 17, 2023
1 parent 5a54a9d commit ee194fa
Show file tree
Hide file tree
Showing 7 changed files with 263 additions and 0 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
data/
.ipynb_checkpoints/
wandb/
toy/
llama2_qlora.ipynb
__pycache__
102 changes: 102 additions & 0 deletions llama2_qlora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import datetime

import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig, TrainingArguments)

import wandb
from mandrill_utils.loggin_utils import generate_random_string
from preprocess.chat import llama_get_input_with_labels
from train.trainers import MandrillTrainer
from train.utils import print_trainable_parameters

HUGGINGFACE_API_TOKEN = "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"

TOY = True
BATCH_SIZE = 2
BASE_RUN_NAME = "llama2-7b-qlora"
SAVE_DATA_POINTS = 2000
HF_CACHE_DIR = "/notebooks/.cache/huggingface"
WANDB_PROJECT = "mandrill"
WANDB_TEAM = "yieldinc"
"""
https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.set_training.gradient_accumulation_steps
When using gradient accumulation, one step is counted as one step with backward pass.
Therefore, logging, evaluation, save will be conducted every gradient_accumulation_steps * xxx_step training examples.
"""
GRADIENT_ACCUMULATION_STEPS = 4

if TOY:
SAVE_DATA_POINTS = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS
WANDB_PROJECT += "-toy"

model_id = "meta-llama/Llama-2-7b-chat-hf"
# model_id = "codellama/CodeLlama-7b-Instruct-hf"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(
model_id, cache_dir=HF_CACHE_DIR, token=HUGGINGFACE_API_TOKEN
)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map={"": 0},
cache_dir=HF_CACHE_DIR,
token=HUGGINGFACE_API_TOKEN,
)

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)

model = get_peft_model(model, config)
print_trainable_parameters(model)

data = load_dataset("json", data_files="data/instructions.jsonl")
data = data.map(llama_get_input_with_labels)

output_root = "outputs/toy" if TOY else "outputs"
run_name = (
f"{datetime.today().date()}_{BASE_RUN_NAME}_{generate_random_string(5).lower()}"
)
output_dir = f"{output_root}/{run_name}"
print("output_dir:", output_dir)
save_steps = SAVE_DATA_POINTS // (BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS)

wandb.init(entity=WANDB_TEAM, project=WANDB_PROJECT, name=run_name)

trainer = MandrillTrainer(
model=model,
train_dataset=data["train"],
args=TrainingArguments(
per_device_train_batch_size=BATCH_SIZE,
gradient_accumulation_steps=4,
warmup_steps=2,
save_steps=save_steps,
learning_rate=2e-4,
fp16=True,
logging_steps=1,
output_dir=output_dir,
optim="paged_adamw_8bit",
),
data_collator=transformers.DataCollatorForSeq2Seq(tokenizer),
)
model.config.use_cache = False # silence the warnings. Please re-enable for inference!
trainer.train()
wandb.finish()
22 changes: 22 additions & 0 deletions mandrill_utils/logging_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from datetime import datetime
import random
import string
import hashlib

# Get the current datetime as a string
current_datetime_str = str(datetime.now())

# Hash the datetime string to generate a seed
seed = int(hashlib.sha256(current_datetime_str.encode()).hexdigest(), 16)

# Create a random number generator instance with the seed
str_random = random.Random(seed)

def generate_random_string(length):
# Define the set of characters you want in the random string
characters = string.ascii_letters + string.digits # You can include other characters if needed

# Generate the random string by selecting characters randomly
random_string = ''.join(str_random.choice(characters) for _ in range(length))

return random_string
104 changes: 104 additions & 0 deletions preprocess/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from llama.generation import Message, Dialog, B_INST, E_INST, B_SYS, E_SYS
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from preprocess.prompts import SYSTEM_PROMPT
import os

HF_CACHE_DIR = os.environ.get('HUGGINGFACE_CACHE_DIR', '/notebooks/.cache/huggingface')
HF_API_TOKEN = os.environ.get('HUGGINGFACE_API_TOKEN', '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')

LLAMA_HF_ID = "meta-llama/Llama-2-7b-chat-hf"
llama_tokenizer = AutoTokenizer.from_pretrained(LLAMA_HF_ID, cache_dir=HF_CACHE_DIR, token=HF_API_TOKEN)
llama_tokenizer.pad_token = llama_tokenizer.eos_token

def llama_dialog2tokens(dialog: Dialog, tokenizer=llama_tokenizer, verbose=False):
# copied / adapted from https://github.com/facebookresearch/llama/blob/d58f9ae95c299fe6388ee2da2c87fd90cd360d41/llama/generation.py#L284
if dialog[0]["role"] == "system":
dialog = [
{
"role": dialog[1]["role"],
"content": B_SYS
+ dialog[0]["content"]
+ E_SYS
+ dialog[1]["content"],
}
] + dialog[2:]
assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
[msg["role"] == "assistant" for msg in dialog[1::2]]
), (
"model only supports 'system', 'user' and 'assistant' roles, "
"starting with 'system', then 'user' and alternating (u/a/u/a/u...)"
)
dialog_tokens: List[int] = sum(
[
tokenizer.encode(
f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} {tokenizer.eos_token}",
)
for prompt, answer in zip(
dialog[::2],
dialog[1::2],
)
],
[],
)
if verbose:
messages = [
f"{tokenizer.bos_token} {B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} {tokenizer.eos_token}"
for prompt, answer in zip(
dialog[::2],
dialog[1::2],)]
assert (
dialog[-1]["role"] == "user"
), f"Last message must be from user, got {dialog[-1]['role']}"
dialog_tokens += tokenizer.encode(
f"{tokenizer.bos_token} {B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
)
if verbose:
messages.append(f"{tokenizer.bos_token} {B_INST} {(dialog[-1]['content']).strip()} {E_INST}")
display(messages)
return dialog_tokens

def llama_get_prompt_tokens(jsonl_row, system_message=SYSTEM_PROMPT):
# TODO: make a dataclass or pydantic type for jsonl_row
SYSTEM_MESSAGE = Message(role='system', content=SYSTEM_PROMPT)
dialog = [
SYSTEM_MESSAGE,
Message(role='user', content=jsonl_row['instruction']),
]
return llama_dialog2tokens(dialog)

def llama_get_input_with_labels(row, tokenizer=llama_tokenizer):
prompt_tokens = llama_get_prompt_tokens(row)
response_tokens = tokenizer(
f"{row['response']} {tokenizer.eos_token}"
)['input_ids']
input_ids = prompt_tokens + response_tokens
attention_mask = [1] * len(input_ids)
labels = [-100]*len(prompt_tokens) + response_tokens
return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}

if __name__ == '__main__':
sample_dialog: Dialog = [
{"role": "system", "content": "Welcome to the virtual assistant."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi there! How can I assist you today?"},
{"role": "user", "content": "I need some help with Python programming."},
{"role": "assistant", "content": "Sure, I can help with Python. What do you need assistance with?"},
{"role": "user", "content": "I'm having trouble with a Python script."},
{"role": "assistant", "content": "Could you please provide the script or describe the issue you're facing?"},
{"role": "user", "content": "```python\n"
"def calculate_square(x):\n"
" return x ** 2\n"
"```"},
{"role": "assistant", "content": "Thank you for sharing the script. What seems to be the problem with it?"},
{"role": "user", "content": "I'm getting a 'NameError' for 'x' when I run it."},
{"role": "assistant", "content": "The 'NameError' indicates that 'x' is not defined. You should provide a value for 'x' when calling the function."},
{"role": "user", "content": "```python\n"
"def calculate_square(x):\n"
" x = 5 # Assign a value to x\n"
" return x ** 2\n"
"```"},
{"role": "assistant", "content": "Great! Now the 'x' variable is defined. Is there anything else you need help with?"},
{"role": "user", "content": "No, that's all for now. Thanks for your assistance!"}
]

print(dialog2tokens(sample_dialog, verbose=True))
4 changes: 4 additions & 0 deletions preprocess/prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
SYSTEM_PROMPT = f"""
You are a superintelligent reasoning agent. You think logically and carefully, and verbalize your reasoning process when thinking.
""".strip()

12 changes: 12 additions & 0 deletions train/trainers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from transformers import Trainer

class MandrillTrainer(Trainer):
"""
avoid setting label to None: https://github.com/huggingface/transformers/blob/5a4f340df74b42b594aedf60199eea95cdb9bed0/src/transformers/trainer.py#L2703C26-L2703C26
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def compute_loss(self, model, inputs):
outputs = model(**inputs)
return outputs.loss
13 changes: 13 additions & 0 deletions train/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
def print_trainable_parameters(model):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param: .4f}%"
)

0 comments on commit ee194fa

Please sign in to comment.