Skip to content

Commit

Permalink
fix: support e2e val for AGIEval
Browse files Browse the repository at this point in the history
  • Loading branch information
harshraj172 committed Sep 18, 2023
1 parent 2f5278d commit 41782a2
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 38 deletions.
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
## Setup

* Initialize the submodules
```bash
git submodule init
git submodule update
```

* Install the requirements
```bash
conda create -n venv
conda activate venv
pip install -r requirements.txt
```

* Finetune with QLoRA quantization
```bash
python llama2_qlora.py
```
16 changes: 10 additions & 6 deletions eval_args.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List
from dataclasses import asdict, dataclass, field, fields

@dataclass
Expand All @@ -11,10 +12,10 @@ class EvaluationArguments:
Parameters:
system_prompt (`str`, *optional*):
The system prompt to use while evaluating
temperature (`float`, *optional*, defaults to 0.0):
tasks_list (`List[str]`, *optional*):
The tasks to evaluate the model on.
temperature (`float`, *optional*, defaults to 0.2):
The model's temperature for evaluating.
per_device_eval_batch_size (`int`, *optional*, defaults to 16):
The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for evaluation.
max_new_tokens (`int`, *optional*, defaults to 50):
Maximum number of tokens to generate in evaluation.
top_p (`float`, *optional*):
Expand All @@ -25,9 +26,12 @@ class EvaluationArguments:
default="You are a helpful AI assistant",
metadata={"help": " The system prompt to use while evaluating."},
)
tasks_list: List[str] = field(
default_factory=["agieval"],
metadata={"help": " The system prompt to use while evaluating."},
)
temperature: float = field(
default=0.0,
default=0.2,
metadata={"help": ("The model's temperature for evaluating.")},)
per_device_eval_batch_size: bool = field(default=16, metadata={"help": "The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for evaluation."})
max_new_tokens: bool = field(default=False, metadata={"help": "Maximum number of tokens to generate in evaluation."})
top_p: bool = field(default=False, metadata={"help": "Parameter for nucleus sampling in decoding."})
top_p: float = field(default=0.2, metadata={"help": "Parameter for nucleus sampling in decoding."})
10 changes: 5 additions & 5 deletions evaluator/agentbench/hf_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@


class HuggingFaceChatAgent(Agent):
def __init__(self, model, model_id, system_prompt, temperature=0, max_new_tokens=32, top_p=0, **kwargs) -> None:
def __init__(self, model, model_id, system_prompt, hf_api_token, temperature=0, max_new_tokens=32, top_p=0, **kwargs) -> None:
self.model = model
self.hf_api_token = hf_api_token
self.temperature = temperature
self.max_new_tokens = max_new_tokens
self.top_p = top_p
Expand Down Expand Up @@ -49,11 +50,10 @@ def conv2prompt(self, history: List[dict]) -> str:
elif idx==1:
prompt += f"<s>[INST] {prompt.strip()} [/INST] {history[idx]['content'].strip()} </s>"
elif idx==len(history)-1:
prompt += f"<s>[INST] {history[idx]['user'].strip()} [/INST]"
prompt += f"<s>[INST] {history[idx]['content'].strip()} [/INST]"
else:
if history[idx]['role']=='user':
prompt += f"<s>[INST] {history[idx]['user'].strip()} [/INST] "
prompt += f"<s>[INST] {history[idx]['content'].strip()} [/INST] "
else:
prompt += f"{history[idx]['agent'].strip()} </s>"
prompt += f"<s>[INST] {history[-1]['user'].strip()} [/INST]"
prompt += f"{history[idx]['content'].strip()} </s>"
return prompt
1 change: 1 addition & 0 deletions evaluator/agentbench/hfchat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module: "hfchat.HuggingFaceChatAgent"
parameters:
model_id: "codellama/CodeLlama-7b-hf"
system_prompt: "You are a helpful AI assistant"
hf_api_token: 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
max_new_tokens: 128
temperature: 0
top_p: 0
29 changes: 17 additions & 12 deletions evaluator/agieval/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
import os
from typing import List, Callable
from transformers import AutoTokenizer
from peft import PeftModelForCausalLM

from .AGIEval.src import utils, dataset_loader
from .AGIEval.src import post_process, utils, dataset_loader
from .AGIEval.src import evaluation


run_experiment = True
dataset_dir = "data/v1"
raw_prompt_path = "./data/few_shot_prompts.csv"
dataset_dir = "evaluator/agieval/AGIEval/data/v1"
raw_prompt_path = "evaluator/agieval/AGIEval/data/few_shot_prompts.csv"

class HuggingFaceChat():
def __init__(self, model, model_id, system_prompt, temperature=0, max_new_tokens=32, top_p=0, batch_size=32, **kwargs) -> None:
def __init__(self, model, model_id, system_prompt, hf_api_token, temperature=0, max_new_tokens=32, top_p=0, batch_size=32, **kwargs) -> None:
if isinstance(model, PeftModelForCausalLM): model = model.merge_and_unload()
self.model = model
self.hf_api_token = hf_api_token
self.temperature = temperature
self.max_new_tokens = max_new_tokens
self.top_p = top_p
self.batch_size = batch_size
self.system_prompt = system_prompt
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_api_token)
self.tokenizer.eos_token = "<|im_end|>"
self.tokenizer.pad_token = self.tokenizer.eos_token
self.history = [{"role": "system", "content": self.system_prompt},]

super().__init__(**kwargs)
Expand Down Expand Up @@ -100,12 +105,13 @@ def query_model(self, query_list, setting_name='chat') -> str:
{"role": "user", "content": query},
)
elif isinstance(query, list):
messages += query
self.history += query
else:
raise ValueError("Unsupported query: {0}".format(query))
prompt = self.conv2prompt(self.history)
prompts.append(prompt)
input_ids_batch = self.tokenizer(prompts, return_tensors="pt")["input_ids"].to("cuda")

input_ids_batch = self.tokenizer(prompts, padding=True, return_tensors="pt")["input_ids"].to("cuda")
output_batch = self.model.generate(
input_ids_batch,
temperature=self.temperature,
Expand All @@ -126,20 +132,18 @@ def conv2prompt(self, history: List[dict]) -> str:
elif idx==1:
prompt += f"<s>[INST] {prompt.strip()} [/INST] {history[idx]['content'].strip()} </s>"
elif idx==len(history)-1:
prompt += f"<s>[INST] {history[idx]['user'].strip()} [/INST]"
prompt += f"<s>[INST] {history[idx]['content'].strip()} [/INST]"
else:
if history[idx]['role']=='user':
prompt += f"<s>[INST] {history[idx]['user'].strip()} [/INST] "
prompt += f"<s>[INST] {history[idx]['content'].strip()} [/INST] "
else:
prompt += f"{history[idx]['assistant'].strip()} </s>"
prompt += f"<s>[INST] {history[-1]['user'].strip()} [/INST]"
prompt += f"{history[idx]['content'].strip()} </s>"
return prompt


def evaluate(model, model_id, system_prompt,
def evaluate(model, model_id, system_prompt, hf_api_token,
temperature=0, max_new_tokens=32, top_p=0, batch_size=32,
dataset_name_list=[
"gaokao-chinese",
"gaokao-geography",
"gaokao-history",
"gaokao-biology",
Expand All @@ -166,6 +170,7 @@ def evaluate(model, model_id, system_prompt,
## Prediction
model = HuggingFaceChat(model=model,
model_id=model_id,
hf_api_token=hf_api_token,
system_prompt=system_prompt,
temperature=temperature,
max_new_tokens=max_new_tokens,
Expand Down
25 changes: 22 additions & 3 deletions llama2_qlora.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import os
from datetime import datetime

import sys
sys.path.append(f'{os.path.dirname(__file__)}/evaluator/agieval/AGIEval')
sys.path.append(f'{os.path.dirname(__file__)}/evaluator/agentbench/AgentBench')

import torch
import transformers
from datasets import load_dataset
Expand All @@ -8,9 +13,10 @@
BitsAndBytesConfig, TrainingArguments)

import wandb
from eval_args import EvaluationArguments
from mandrill_utils.logging_utils import generate_random_string
from preprocess.chat import llama_get_input_with_labels
from train.trainers import MandrillTrainer
from train.trainer import MandrillTrainer
from train.utils import print_trainable_parameters

HUGGINGFACE_API_TOKEN = "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"
Expand Down Expand Up @@ -83,19 +89,32 @@
wandb.init(entity=WANDB_TEAM, project=WANDB_PROJECT, name=run_name)

trainer = MandrillTrainer(
model=model,
model=model,
model_id=model_id,
hf_api_token=HUGGINGFACE_API_TOKEN,
train_dataset=data["train"],
eval_dataset=data["train"],
args=TrainingArguments(
num_train_epochs=3,
per_device_train_batch_size=BATCH_SIZE,
gradient_accumulation_steps=4,
gradient_accumulation_steps=2,
warmup_steps=2,
save_steps=save_steps,
evaluation_strategy='steps',
eval_steps=1,
learning_rate=2e-4,
fp16=True,
logging_steps=1,
output_dir=output_dir,
optim="paged_adamw_8bit",
),
eval_args=EvaluationArguments(
system_prompt="You are a helpful AI assistant",
tasks_list=["agieval"],
temperature=0.2,
max_new_tokens=32,
top_p=0.2,
),
data_collator=transformers.DataCollatorForSeq2Seq(tokenizer),
)
model.config.use_cache = False # silence the warnings. Please re-enable for inference!
Expand Down
18 changes: 18 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
accelerate @ git+https://github.com/huggingface/accelerate.git@cdb001ca5f9be875034ddd0aa86a542c182782fe
bitsandbytes==0.41.1
black==23.9.1
datasets==2.14.5
einops==0.6.1
evaluate==0.4.0
fairscale==0.4.13
fire==0.5.0
jsonlines==4.0.0
llama @ git+https://github.com/facebookresearch/llama@9f0e393991b45d320f5b4a287eaaeb8a7d2e6f8e
openai==0.28.0
pandas==2.1.0
peft @ git+https://github.com/huggingface/peft.git@0fa63fb4a21bf88777b2469892b76a6e096753e8
torch==2.0.1
transformers @ git+https://github.com/huggingface/transformers.git@95b374952dc27d8511541d6f5a4e22c9ec11fb24
wandb==0.15.10
scipy==1.11.2
tiktoken==0.5.1
44 changes: 44 additions & 0 deletions train/trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from transformers import Trainer, TrainingArguments
from transformers.utils import logging
from transformers.trainer_utils import EvalLoopOutput

from evaluator.agieval import wrapper as wrapper_agieval
from eval_args import EvaluationArguments
# from evaluator.agentbench import wrapper as wrapper_agentbench

logger = logging.get_logger(__name__)

class MandrillTrainer(Trainer):
"""
avoid setting label to None: https://github.com/huggingface/transformers/blob/5a4f340df74b42b594aedf60199eea95cdb9bed0/src/transformers/trainer.py#L2703C26-L2703C26
"""
def __init__(self, *args, **kwargs):
self.model_id = kwargs.pop('model_id')
self.eval_args = kwargs.pop('eval_args')
self.hf_api_token = kwargs.pop('hf_api_token')
super().__init__(*args, **kwargs)

def compute_loss(self, model, inputs):
outputs = model(**inputs)
return outputs.loss

def evaluation_loop(self, dataloader, description, prediction_loss_only=False, **kwargs) -> EvalLoopOutput:
'''
https://github.com/huggingface/transformers/blob/0a55d9f7376f72ad3ff296d4249840021b03bcc4/src/transformers/trainer_utils.py#L147
'''
model = self._wrap_model(self.model, training=False, dataloader=dataloader)
model.eval()

if 'agieval' in self.eval_args.tasks_list:
logger.info(f"***** Runnning Evaluation on AGIEval *****")
wrapper_agieval.evaluate(model=model, model_id=self.model_id, hf_api_token=self.hf_api_token,
system_prompt=self.eval_args.system_prompt, temperature=self.eval_args.temperature,
max_new_tokens=self.eval_args.max_new_tokens, top_p=self.eval_args.top_p,
batch_size=self.args.per_device_eval_batch_size,)
if 'agentbench' in self.eval_args.tasks_list:
logger.info(f"***** Runnning Evaluation on AgentBench *****")
wrapper_agentbench.evaluate(model=model, model_id=self.model_id, hf_api_token=self.hf_api_token,
system_prompt=self.eval_args.system_prompt, temperature=self.eval_args.temperature,
max_new_tokens=self.eval_args.max_new_tokens, top_p=self.eval_args.top_p,
batch_size=self.args.per_device_eval_batch_size,)
return EvalLoopOutput(predictions=None, label_ids=None, metrics={'fake_metric': 0.0}, num_samples=0)
12 changes: 0 additions & 12 deletions train/trainers.py

This file was deleted.

0 comments on commit 41782a2

Please sign in to comment.