Skip to content

Commit

Permalink
Add: attention-buckets
Browse files Browse the repository at this point in the history
  • Loading branch information
言斯 committed Feb 29, 2024
1 parent 3c1890d commit ef1f2e1
Show file tree
Hide file tree
Showing 16 changed files with 3,859 additions and 0 deletions.
73 changes: 73 additions & 0 deletions attention-buckets/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Fortify the Shortest Stave in Attention: Enhancing Context Awareness of Large Language Models for Effective Tool Use

The codes is implemented based on pytorch and codes for tool-use is base on [ToolBench](https://github.com/OpenBMB/ToolBench). We appreciate these open-source codes.

## Install
```bash
https://github.com/AlibabaResearch/DAMO-ConvAI.git
cd attention-buckets
conda create -n your_envs python=3.9
conda activate your_envs
pip install -r requirment.txt
# repalce "your_env_path/lib/site-packages/transformers/models/llama/modeling-llama.py" with our 'modeling-llama.py'
```
## Code for tool-use
```bash
git clone [email protected]:OpenBMB/ToolBench.git
cd ToolBench
```
We only present information about our work, for more information about toolbench please refer to [ToolBench](https://github.com/OpenBMB/ToolBench)

### Data
Put all dataset in ToolBench/data
- Original data of ToolBench: Download the dataset using the following link: [Google Drive](https://drive.google.com/drive/folders/1yBUQ732mPu-KclJnuQELEhtKakdXFc3J) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/c9e50625743b40bfbe10/).
- results of our method: Download the dataset using the following link: [Google Drive](https://alibaba-research.oss-cn-beijing.aliyuncs.com/attention-buckets/all_data.zip)


### core codes and how to run and how to eval
1.replace "ToolBench/toolbench/inference/utils.py" with our "inference/utils.py"
2.move our "inference/config.py" to "ToolBench/toolbench/inference"
3.replace "ToolBench/toolbench/utils.py" with our "utils.py"
```bash
# run
bash scripts/inference_toolllama_pipeline.sh
# eval, the same to ToolBench
cd tooleval
bash run_convert_answer.sh
bash run_pass_rate.sh
# get pass rate of chatgpt_cot and your method and then run to get preference.
bash run_preference.sh
```

## Code for rag
```bash
cd base_rag
```
### Data
Put the data in ../qa_dataset
Download the dataset using the following link: [Google Drive](https://alibaba-research.oss-cn-beijing.aliyuncs.com/attention-buckets/all_data.zip)

### how to run and eval
```bash
# run
# bsz >= total bases num
CUDA_VISIBLE_DEVICES=i python test_nq_kl.py --flag i --bsz 8 --num_doc $num_doc --ngpu $n_gpu --data_name $data_name

# eval
python merge_result.py --ngpu $n_gpu --data_name $data_name --num_doc $num_doc
```


## Citation
Feel free to cite us if you like our work.
```bibtex
@article{Chen2023FortifyTS,
title={Fortify the Shortest Stave in Attention: Enhancing Context Awareness of Large Language Models for Effective Tool Use},
author={Yuhan Chen and Ang Lv and Ting-En Lin and Chang Heng Chen and Yuchuan Wu and Fei Huang and Yongbin Li and Rui Yan},
journal={ArXiv},
year={2023},
volume={abs/2312.04455},
url={https://api.semanticscholar.org/CorpusID:266053571}
}
```

4 changes: 4 additions & 0 deletions attention-buckets/base_rag/download_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import huggingface_hub
huggingface_hub.login("")
from huggingface_hub import snapshot_download
snapshot_download(repo_id="meta-llama/Llama-2-7b")
76 changes: 76 additions & 0 deletions attention-buckets/base_rag/merge_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import os
import json
import argparse
import regex
import unicodedata
import string


def normalize_answer(s):
def remove_articles(text):
return regex.sub(r'\b(a|an|the)\b', ' ', text)

def white_space_fix(text):
return ' '.join(text.split())

def lower(text):
return text.lower()
return white_space_fix(remove_articles(lower(s)))



def main(args):
data_file = args.data_dir + args.data_name +'-test.jsonl'
# answer_file = args.answer_file #args.result_dir + args.data_name + '_doc_num' + str(args.num_doc)
answer_file = 'answer/nq_doc_num10(10000, 13000, 16000, 19000, 22000, 25000, 28000)'
base_list_file = answer_file +'/base_list.json'
if args.ngpu > 1:
f2 = open(base_list_file, 'w')
for i in range(args.ngpu):
base_list_file_tmp = answer_file +'/base_list'+'_'+str(i) + '.json'
f2_tmp = open(base_list_file_tmp).read()
f2.write(f2_tmp)
f2.close()

# base_list_file = 'answer/webq_doc_num10/base_(10000,17000,18000,19000,20000,23000,25000).json'
f2 = open(base_list_file, 'r').readlines()
total = len(f2)

true = 0
with open(data_file) as fin:
data = json.load(fin)
# assert total == len(data)
for idx, input_example in enumerate(data):
gold_answer = [x.strip() for x in input_example["answers"]]
answer = json.loads(f2[idx])['answer']
# print(answer)
# print(gold_answer)


for x in gold_answer:
if normalize_answer(x).lower() in normalize_answer(answer).lower():
true += 1
break


print(true)
print(total)
print(true/total)

if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument('--ckpt_dir', type=str, default='meta-llama/Llama-2-7b-chat-hf')
parser.add_argument('--model_type', type=str, default='llama')
parser.add_argument('--data_dir', type=str, default='../qa_dataset/')
parser.add_argument('--prompt_dir', type=str, default='prompts/qa.prompt')
parser.add_argument('--result_dir', type=str, default='answer/')
parser.add_argument('--answer_file', type=str, default='answer/')
parser.add_argument('--num_doc', type=int, default=10)
parser.add_argument('--bsz', type=int, default=1)
parser.add_argument('--data_name', type=str, default='nq')
parser.add_argument('--ngpu', type=int, default=1)
parser.add_argument('--chosen_base', type=int, default=10000)
args = parser.parse_args()

main(args)
2 changes: 2 additions & 0 deletions attention-buckets/base_rag/prompts/closedbook_qa.prompt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Question: {question}
Answer:
7 changes: 7 additions & 0 deletions attention-buckets/base_rag/prompts/kv_retrieval.prompt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Extract the value corresponding to the specified key in the JSON object below.

JSON data:
{formatted_kv_records}

Key: "{key}"
Corresponding value:
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Extract the value corresponding to the specified key in the JSON object below.

Key: "{key}"

JSON data:
{formatted_kv_records}

Key: "{key}"
Corresponding value:
6 changes: 6 additions & 0 deletions attention-buckets/base_rag/prompts/qa.prompt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Write a high-quality answer for the given question using only the provided search results (some of which might be irrelevant).

{search_results}

Question: {question}
Answer:
6 changes: 6 additions & 0 deletions attention-buckets/base_rag/prompts/qa_ordered_randomly.prompt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Write a high-quality answer for the given question using only the provided search results (some of which might be irrelevant). The search results are ordered randomly.

{search_results}

Question: {question}
Answer:
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Write a high-quality answer for the given question using only the provided search results (some of which might be irrelevant).

Question: {question}

{search_results}

Question: {question}
Answer:
168 changes: 168 additions & 0 deletions attention-buckets/base_rag/test_nq_kl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import argparse
import json
import os
import time

import numpy as np
import tensor_parallel as tp
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from typing import List, Tuple
import torch.nn.functional as F




def load(ckpt_dir, model_type):
hub_token = ""
n_gpus = torch.cuda.device_count()

if model_type == 'llama':
tokenizer = AutoTokenizer.from_pretrained(ckpt_dir, use_fast=False, padding_side="left", use_auth_token=hub_token) #'meta-llama/Llama-2-7b-chat-hf'

model = AutoModelForCausalLM.from_pretrained(ckpt_dir, low_cpu_mem_usage = True, torch_dtype=torch.float16, use_auth_token=hub_token)

tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
tokenizer.bos_token_id = 1
else:
# however, tensor parallel for running falcon will occur bugs
tokenizer = AutoTokenizer.from_pretrained(ckpt_dir, use_fast=False, padding_side="left")
model = AutoModelForCausalLM.from_pretrained(ckpt_dir, device_map = 'balanced_low_0', torch_dtype=torch.bfloat16, trust_remote_code=True)
if tokenizer.pad_token_id is None:
if tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.pad_token_id = 0

model.eval()
model.cuda()

return model, tokenizer

def get_nq_retrieval_prompt(
data: List[Tuple[str, str]],
key: str,
prompt_dir:str,
query_aware_contextualization: bool = False,

):
if not data:
raise ValueError(f"Provided `data` must be truthy, got: {data}")
if not key:
raise ValueError(f"Provided `key` must be truthy, got: {key}")
# if len(data) != len(set([x["text"] for x in data])):
# raise ValueError(f"`data` has duplicate keys: {data}")
if len(data) < 2:
raise ValueError(f"Must have at least 2 items in data: {data}")

with open(prompt_dir) as f:
prompt_template = f.read().rstrip("\n")

# Format the KV data into a string
formatted_kv_records = ""
for index, record in enumerate(data):
# start_character = "" if index == 0 else " "
data_string = f'"Document (Title: {record["title"]})": "{record["text"]}"'
end_character = "\n" if index != len(data) - 1 else ""
formatted_kv_records += data_string + end_character #start_character +

return prompt_template.format(search_results=formatted_kv_records, question=key)

def batch_infer(model, tokenizer, args):
from xopen import xopen
from copy import deepcopy

data_file = args.data_dir + args.data_name +'-test.jsonl'
answer_file = args.result_dir + args.data_name + '_doc_num' + str(args.num_doc) + '(10000, 13000, 16000, 19000, 22000, 25000, 28000)'
os.makedirs(answer_file, exist_ok=True)
np.random.seed(0)
torch.manual_seed(0)
base_list = [10000, 13000, 16000, 19000, 22000, 25000, 28000]

with xopen(data_file) as fin:
# with open('ori_result')
num_doc = args.num_doc
data = json.load(fin)

pre_true = 0
true = 0
model.eval()
if args.ngpu == 1:
base_list_file = answer_file +'/base_list.json'
sub_data = data
left = 0
else:
base_list_file = answer_file +'/base_list'+'_'+str(args.flag) + '.json'
data_split = int(len(data)/args.ngpu)
left = args.flag * data_split

if args.flag == args.ngpu-1:
sub_data = data[left:]
else:
sub_data = data[left:left+data_split]

# chosen_data = []
base_list_data = []
# if os.path.exists(chosen_answer_file):
# f1 = open(chosen_answer_file, 'r')
# chosen_data = f1.readlines()

if os.path.exists(base_list_file):
f2 = open(base_list_file, 'r')
base_list_data = f2.readlines()

for idx, input_example in enumerate(tqdm(sub_data)):
left_docs = input_example["ctxs"]
question = input_example["question"]
gold_answer = [x.strip() for x in input_example["answers"]]


kv_prompt = get_nq_retrieval_prompt(
data=left_docs[:num_doc], key=question, prompt_dir = args.prompt_dir
)

inputs = tokenizer.encode(kv_prompt, return_tensors="pt", padding=True).cuda()
prompt_length = inputs.shape[1]


bsz = args.bsz
# model._set_best_base(base) # for single base
model.set_base_mean(base_list, bsz)
if idx < len(base_list_data):
answer = json.loads(base_list_data[idx])['answer']
else:
with torch.no_grad():
outputs = model.generate(inputs, max_new_tokens=100, do_sample=False, num_beams=1)
answer = tokenizer.batch_decode(outputs[:, prompt_length:], skip_special_tokens=True)[0]
with open(base_list_file, 'a') as f2:
f2.write(json.dumps({'id':idx+left, 'answer': answer})+'\n')





def main(args):
model, tokenizer = load(args.ckpt_dir, args.model_type)
start_time = time.time()
batch_infer(model, tokenizer, args)
end_time = time.time()
print("total run time %.2f" % (end_time - start_time))


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--ckpt_dir', type=str, default='meta-llama/Llama-2-7b-chat-hf')
parser.add_argument('--model_type', type=str, default='llama')
parser.add_argument('--data_dir', type=str, default='../qa_dataset/')
parser.add_argument('--prompt_dir', type=str, default='prompts/qa.prompt')
parser.add_argument('--result_dir', type=str, default='answer/')
parser.add_argument('--num_doc', type=int, default=None)
parser.add_argument('--bsz', type=int, default=1)
parser.add_argument('--data_name', type=str, default='nq')
parser.add_argument('--ngpu', type=int, default=1)
parser.add_argument('--flag', type=int, default=0)
parser.add_argument('--chosen_base', type=int, default=10000)
args = parser.parse_args()

main(args)
Loading

0 comments on commit ef1f2e1

Please sign in to comment.