-
Notifications
You must be signed in to change notification settings - Fork 192
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
言斯
committed
Feb 29, 2024
1 parent
3c1890d
commit ef1f2e1
Showing
16 changed files
with
3,859 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# Fortify the Shortest Stave in Attention: Enhancing Context Awareness of Large Language Models for Effective Tool Use | ||
|
||
The codes is implemented based on pytorch and codes for tool-use is base on [ToolBench](https://github.com/OpenBMB/ToolBench). We appreciate these open-source codes. | ||
|
||
## Install | ||
```bash | ||
https://github.com/AlibabaResearch/DAMO-ConvAI.git | ||
cd attention-buckets | ||
conda create -n your_envs python=3.9 | ||
conda activate your_envs | ||
pip install -r requirment.txt | ||
# repalce "your_env_path/lib/site-packages/transformers/models/llama/modeling-llama.py" with our 'modeling-llama.py' | ||
``` | ||
## Code for tool-use | ||
```bash | ||
git clone [email protected]:OpenBMB/ToolBench.git | ||
cd ToolBench | ||
``` | ||
We only present information about our work, for more information about toolbench please refer to [ToolBench](https://github.com/OpenBMB/ToolBench) | ||
|
||
### Data | ||
Put all dataset in ToolBench/data | ||
- Original data of ToolBench: Download the dataset using the following link: [Google Drive](https://drive.google.com/drive/folders/1yBUQ732mPu-KclJnuQELEhtKakdXFc3J) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/c9e50625743b40bfbe10/). | ||
- results of our method: Download the dataset using the following link: [Google Drive](https://alibaba-research.oss-cn-beijing.aliyuncs.com/attention-buckets/all_data.zip) | ||
|
||
|
||
### core codes and how to run and how to eval | ||
1.replace "ToolBench/toolbench/inference/utils.py" with our "inference/utils.py" | ||
2.move our "inference/config.py" to "ToolBench/toolbench/inference" | ||
3.replace "ToolBench/toolbench/utils.py" with our "utils.py" | ||
```bash | ||
# run | ||
bash scripts/inference_toolllama_pipeline.sh | ||
# eval, the same to ToolBench | ||
cd tooleval | ||
bash run_convert_answer.sh | ||
bash run_pass_rate.sh | ||
# get pass rate of chatgpt_cot and your method and then run to get preference. | ||
bash run_preference.sh | ||
``` | ||
|
||
## Code for rag | ||
```bash | ||
cd base_rag | ||
``` | ||
### Data | ||
Put the data in ../qa_dataset | ||
Download the dataset using the following link: [Google Drive](https://alibaba-research.oss-cn-beijing.aliyuncs.com/attention-buckets/all_data.zip) | ||
|
||
### how to run and eval | ||
```bash | ||
# run | ||
# bsz >= total bases num | ||
CUDA_VISIBLE_DEVICES=i python test_nq_kl.py --flag i --bsz 8 --num_doc $num_doc --ngpu $n_gpu --data_name $data_name | ||
|
||
# eval | ||
python merge_result.py --ngpu $n_gpu --data_name $data_name --num_doc $num_doc | ||
``` | ||
|
||
|
||
## Citation | ||
Feel free to cite us if you like our work. | ||
```bibtex | ||
@article{Chen2023FortifyTS, | ||
title={Fortify the Shortest Stave in Attention: Enhancing Context Awareness of Large Language Models for Effective Tool Use}, | ||
author={Yuhan Chen and Ang Lv and Ting-En Lin and Chang Heng Chen and Yuchuan Wu and Fei Huang and Yongbin Li and Rui Yan}, | ||
journal={ArXiv}, | ||
year={2023}, | ||
volume={abs/2312.04455}, | ||
url={https://api.semanticscholar.org/CorpusID:266053571} | ||
} | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
import huggingface_hub | ||
huggingface_hub.login("") | ||
from huggingface_hub import snapshot_download | ||
snapshot_download(repo_id="meta-llama/Llama-2-7b") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import os | ||
import json | ||
import argparse | ||
import regex | ||
import unicodedata | ||
import string | ||
|
||
|
||
def normalize_answer(s): | ||
def remove_articles(text): | ||
return regex.sub(r'\b(a|an|the)\b', ' ', text) | ||
|
||
def white_space_fix(text): | ||
return ' '.join(text.split()) | ||
|
||
def lower(text): | ||
return text.lower() | ||
return white_space_fix(remove_articles(lower(s))) | ||
|
||
|
||
|
||
def main(args): | ||
data_file = args.data_dir + args.data_name +'-test.jsonl' | ||
# answer_file = args.answer_file #args.result_dir + args.data_name + '_doc_num' + str(args.num_doc) | ||
answer_file = 'answer/nq_doc_num10(10000, 13000, 16000, 19000, 22000, 25000, 28000)' | ||
base_list_file = answer_file +'/base_list.json' | ||
if args.ngpu > 1: | ||
f2 = open(base_list_file, 'w') | ||
for i in range(args.ngpu): | ||
base_list_file_tmp = answer_file +'/base_list'+'_'+str(i) + '.json' | ||
f2_tmp = open(base_list_file_tmp).read() | ||
f2.write(f2_tmp) | ||
f2.close() | ||
|
||
# base_list_file = 'answer/webq_doc_num10/base_(10000,17000,18000,19000,20000,23000,25000).json' | ||
f2 = open(base_list_file, 'r').readlines() | ||
total = len(f2) | ||
|
||
true = 0 | ||
with open(data_file) as fin: | ||
data = json.load(fin) | ||
# assert total == len(data) | ||
for idx, input_example in enumerate(data): | ||
gold_answer = [x.strip() for x in input_example["answers"]] | ||
answer = json.loads(f2[idx])['answer'] | ||
# print(answer) | ||
# print(gold_answer) | ||
|
||
|
||
for x in gold_answer: | ||
if normalize_answer(x).lower() in normalize_answer(answer).lower(): | ||
true += 1 | ||
break | ||
|
||
|
||
print(true) | ||
print(total) | ||
print(true/total) | ||
|
||
if __name__ == "__main__": | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--ckpt_dir', type=str, default='meta-llama/Llama-2-7b-chat-hf') | ||
parser.add_argument('--model_type', type=str, default='llama') | ||
parser.add_argument('--data_dir', type=str, default='../qa_dataset/') | ||
parser.add_argument('--prompt_dir', type=str, default='prompts/qa.prompt') | ||
parser.add_argument('--result_dir', type=str, default='answer/') | ||
parser.add_argument('--answer_file', type=str, default='answer/') | ||
parser.add_argument('--num_doc', type=int, default=10) | ||
parser.add_argument('--bsz', type=int, default=1) | ||
parser.add_argument('--data_name', type=str, default='nq') | ||
parser.add_argument('--ngpu', type=int, default=1) | ||
parser.add_argument('--chosen_base', type=int, default=10000) | ||
args = parser.parse_args() | ||
|
||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
Question: {question} | ||
Answer: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Extract the value corresponding to the specified key in the JSON object below. | ||
|
||
JSON data: | ||
{formatted_kv_records} | ||
|
||
Key: "{key}" | ||
Corresponding value: |
9 changes: 9 additions & 0 deletions
9
attention-buckets/base_rag/prompts/kv_retrieval_with_query_aware_contextualization.prompt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
Extract the value corresponding to the specified key in the JSON object below. | ||
|
||
Key: "{key}" | ||
|
||
JSON data: | ||
{formatted_kv_records} | ||
|
||
Key: "{key}" | ||
Corresponding value: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
Write a high-quality answer for the given question using only the provided search results (some of which might be irrelevant). | ||
|
||
{search_results} | ||
|
||
Question: {question} | ||
Answer: |
6 changes: 6 additions & 0 deletions
6
attention-buckets/base_rag/prompts/qa_ordered_randomly.prompt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
Write a high-quality answer for the given question using only the provided search results (some of which might be irrelevant). The search results are ordered randomly. | ||
|
||
{search_results} | ||
|
||
Question: {question} | ||
Answer: |
8 changes: 8 additions & 0 deletions
8
attention-buckets/base_rag/prompts/qa_with_query_aware_contextualization.prompt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
Write a high-quality answer for the given question using only the provided search results (some of which might be irrelevant). | ||
|
||
Question: {question} | ||
|
||
{search_results} | ||
|
||
Question: {question} | ||
Answer: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
import argparse | ||
import json | ||
import os | ||
import time | ||
|
||
import numpy as np | ||
import tensor_parallel as tp | ||
import torch | ||
from tqdm import tqdm | ||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig | ||
from typing import List, Tuple | ||
import torch.nn.functional as F | ||
|
||
|
||
|
||
|
||
def load(ckpt_dir, model_type): | ||
hub_token = "" | ||
n_gpus = torch.cuda.device_count() | ||
|
||
if model_type == 'llama': | ||
tokenizer = AutoTokenizer.from_pretrained(ckpt_dir, use_fast=False, padding_side="left", use_auth_token=hub_token) #'meta-llama/Llama-2-7b-chat-hf' | ||
|
||
model = AutoModelForCausalLM.from_pretrained(ckpt_dir, low_cpu_mem_usage = True, torch_dtype=torch.float16, use_auth_token=hub_token) | ||
|
||
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id | ||
tokenizer.bos_token_id = 1 | ||
else: | ||
# however, tensor parallel for running falcon will occur bugs | ||
tokenizer = AutoTokenizer.from_pretrained(ckpt_dir, use_fast=False, padding_side="left") | ||
model = AutoModelForCausalLM.from_pretrained(ckpt_dir, device_map = 'balanced_low_0', torch_dtype=torch.bfloat16, trust_remote_code=True) | ||
if tokenizer.pad_token_id is None: | ||
if tokenizer.eos_token_id is not None: | ||
tokenizer.pad_token_id = tokenizer.eos_token_id | ||
else: | ||
tokenizer.pad_token_id = 0 | ||
|
||
model.eval() | ||
model.cuda() | ||
|
||
return model, tokenizer | ||
|
||
def get_nq_retrieval_prompt( | ||
data: List[Tuple[str, str]], | ||
key: str, | ||
prompt_dir:str, | ||
query_aware_contextualization: bool = False, | ||
|
||
): | ||
if not data: | ||
raise ValueError(f"Provided `data` must be truthy, got: {data}") | ||
if not key: | ||
raise ValueError(f"Provided `key` must be truthy, got: {key}") | ||
# if len(data) != len(set([x["text"] for x in data])): | ||
# raise ValueError(f"`data` has duplicate keys: {data}") | ||
if len(data) < 2: | ||
raise ValueError(f"Must have at least 2 items in data: {data}") | ||
|
||
with open(prompt_dir) as f: | ||
prompt_template = f.read().rstrip("\n") | ||
|
||
# Format the KV data into a string | ||
formatted_kv_records = "" | ||
for index, record in enumerate(data): | ||
# start_character = "" if index == 0 else " " | ||
data_string = f'"Document (Title: {record["title"]})": "{record["text"]}"' | ||
end_character = "\n" if index != len(data) - 1 else "" | ||
formatted_kv_records += data_string + end_character #start_character + | ||
|
||
return prompt_template.format(search_results=formatted_kv_records, question=key) | ||
|
||
def batch_infer(model, tokenizer, args): | ||
from xopen import xopen | ||
from copy import deepcopy | ||
|
||
data_file = args.data_dir + args.data_name +'-test.jsonl' | ||
answer_file = args.result_dir + args.data_name + '_doc_num' + str(args.num_doc) + '(10000, 13000, 16000, 19000, 22000, 25000, 28000)' | ||
os.makedirs(answer_file, exist_ok=True) | ||
np.random.seed(0) | ||
torch.manual_seed(0) | ||
base_list = [10000, 13000, 16000, 19000, 22000, 25000, 28000] | ||
|
||
with xopen(data_file) as fin: | ||
# with open('ori_result') | ||
num_doc = args.num_doc | ||
data = json.load(fin) | ||
|
||
pre_true = 0 | ||
true = 0 | ||
model.eval() | ||
if args.ngpu == 1: | ||
base_list_file = answer_file +'/base_list.json' | ||
sub_data = data | ||
left = 0 | ||
else: | ||
base_list_file = answer_file +'/base_list'+'_'+str(args.flag) + '.json' | ||
data_split = int(len(data)/args.ngpu) | ||
left = args.flag * data_split | ||
|
||
if args.flag == args.ngpu-1: | ||
sub_data = data[left:] | ||
else: | ||
sub_data = data[left:left+data_split] | ||
|
||
# chosen_data = [] | ||
base_list_data = [] | ||
# if os.path.exists(chosen_answer_file): | ||
# f1 = open(chosen_answer_file, 'r') | ||
# chosen_data = f1.readlines() | ||
|
||
if os.path.exists(base_list_file): | ||
f2 = open(base_list_file, 'r') | ||
base_list_data = f2.readlines() | ||
|
||
for idx, input_example in enumerate(tqdm(sub_data)): | ||
left_docs = input_example["ctxs"] | ||
question = input_example["question"] | ||
gold_answer = [x.strip() for x in input_example["answers"]] | ||
|
||
|
||
kv_prompt = get_nq_retrieval_prompt( | ||
data=left_docs[:num_doc], key=question, prompt_dir = args.prompt_dir | ||
) | ||
|
||
inputs = tokenizer.encode(kv_prompt, return_tensors="pt", padding=True).cuda() | ||
prompt_length = inputs.shape[1] | ||
|
||
|
||
bsz = args.bsz | ||
# model._set_best_base(base) # for single base | ||
model.set_base_mean(base_list, bsz) | ||
if idx < len(base_list_data): | ||
answer = json.loads(base_list_data[idx])['answer'] | ||
else: | ||
with torch.no_grad(): | ||
outputs = model.generate(inputs, max_new_tokens=100, do_sample=False, num_beams=1) | ||
answer = tokenizer.batch_decode(outputs[:, prompt_length:], skip_special_tokens=True)[0] | ||
with open(base_list_file, 'a') as f2: | ||
f2.write(json.dumps({'id':idx+left, 'answer': answer})+'\n') | ||
|
||
|
||
|
||
|
||
|
||
def main(args): | ||
model, tokenizer = load(args.ckpt_dir, args.model_type) | ||
start_time = time.time() | ||
batch_infer(model, tokenizer, args) | ||
end_time = time.time() | ||
print("total run time %.2f" % (end_time - start_time)) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--ckpt_dir', type=str, default='meta-llama/Llama-2-7b-chat-hf') | ||
parser.add_argument('--model_type', type=str, default='llama') | ||
parser.add_argument('--data_dir', type=str, default='../qa_dataset/') | ||
parser.add_argument('--prompt_dir', type=str, default='prompts/qa.prompt') | ||
parser.add_argument('--result_dir', type=str, default='answer/') | ||
parser.add_argument('--num_doc', type=int, default=None) | ||
parser.add_argument('--bsz', type=int, default=1) | ||
parser.add_argument('--data_name', type=str, default='nq') | ||
parser.add_argument('--ngpu', type=int, default=1) | ||
parser.add_argument('--flag', type=int, default=0) | ||
parser.add_argument('--chosen_base', type=int, default=10000) | ||
args = parser.parse_args() | ||
|
||
main(args) |
Oops, something went wrong.