-
Notifications
You must be signed in to change notification settings - Fork 36
/
reproduce_bbh.py
118 lines (100 loc) · 4.99 KB
/
reproduce_bbh.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from lorahub.algorithm import lorahub_inference
import os
import json
from lorahub.algorithm import lorahub_learning, lorahub_inference
from lorahub.constant import LORA_MODULE_NAMES
import random
from random import shuffle
def evaluate_flan_results_zero_shot(folder, flan_model_name):
sub_dirs = os.listdir(folder)
for sub_dir in sub_dirs:
test_file_path = os.path.join(folder, sub_dir, "zero_shot.jsonl")
task_inputs, task_outputs = [], []
for line in open(test_file_path, "r", encoding="utf-8"):
example = json.loads(line)
task_inputs.append(example["context"])
task_outputs.append(example["completion"])
print("Evaluating on task (zero shot): ", sub_dir)
lorahub_inference(task_inputs,
flan_model_name,
flan_model_name,
16,
task_outputs)
def evaluate_flan_results_few_shot(folder, flan_model_name):
sub_dirs = os.listdir(folder)
for sub_dir in sub_dirs:
test_file_path = os.path.join(folder, sub_dir, "few_shot.jsonl")
task_inputs, task_outputs = [], []
for line in open(test_file_path, "r", encoding="utf-8"):
example = json.loads(line)
task_inputs.append(example["context"])
task_outputs.append(example["completion"])
print("Evaluating on task (five shot): ", sub_dir)
lorahub_inference(task_inputs,
flan_model_name,
flan_model_name,
16,
task_outputs)
def evaluate_lorahub_results_few_shot(folder, flan_model_name):
sub_dirs = os.listdir(folder)
# 5 seeds used in our experiments
for sub_dir in sub_dirs:
# construct the few-shot examples for lorahub learning
example_inputs, examples_outputs = [], []
example_file_path = os.path.join(folder, sub_dir, "example.jsonl")
for line in open(example_file_path, "r", encoding="utf-8"):
example = json.loads(line)
example_inputs.append(example["context"])
examples_outputs.append(example["completion"])
# random select 5 examples for each task
random.seed(42)
shuffled_set = list(zip(example_inputs, examples_outputs))
random.shuffle(shuffled_set)
example_inputs, examples_outputs = zip(*shuffled_set)
# take the first 5 examples
example_inputs, examples_outputs = example_inputs[:5], examples_outputs[:5]
# load the zero-shot examples for evaluation
test_file_path = os.path.join(folder, sub_dir, "zero_shot.jsonl")
task_inputs, task_outputs = [], []
for line in open(test_file_path, "r", encoding="utf-8"):
example = json.loads(line)
task_inputs.append(example["context"])
task_outputs.append(example["completion"])
task_perf_list = []
for seed in range(1, 6):
random.seed(seed)
def get_lora_module_list():
return random.sample(LORA_MODULE_NAMES, 20)
# get a list of modules to be used in the composition
modules = get_lora_module_list()
# perform LoRAHub learning
module_weights, model, tokenizer = lorahub_learning(lora_module_list=modules,
example_inputs=example_inputs,
example_outputs=examples_outputs,
max_inference_step=40,
batch_size=5)
print("module_weights:", module_weights)
"""
Perform inference to get predictions
"""
_, task_acc = lorahub_inference(example_inputs=task_inputs,
model_or_name_path=model,
tokenizer_or_tokenizer_path=tokenizer,
batch_size=10,
# can set as None if you do not have the ground truth
example_outputs=task_outputs)
task_perf_list.append(task_acc)
avg_perf, max_perf = sum(task_perf_list) / len(task_perf_list), max(task_perf_list)
print("average perf:", avg_perf, "best perf:", max_perf)
if __name__ == "__main__":
if not os.path.exists("data_bbh"):
# download dataset
os.system("wget https://github.com/sail-sg/lorahub/releases/download/0.1/data_bbh.zip")
# unzip
os.system("unzip data_bbh.zip")
# evaluate the model
evaluate_flan_results_zero_shot("data_bbh", "google/flan-t5-large")
# five shot for flan models
evaluate_flan_results_few_shot("data_bbh", "google/flan-t5-large")
# five shot for lorahub models
evaluate_lorahub_results_few_shot("data_bbh", "google/flan-t5-large")