-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
125 lines (115 loc) · 4.1 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from datasets import load_dataset
import csv
from vllm import LLM, SamplingParams
import argparse, os
parser = argparse.ArgumentParser(description="Process some integers.")
parser.add_argument("--shots", type=int, required=True)
parser.add_argument("--gpus", type=int, required=True)
parser.add_argument("--batch_size", type=int, required=True)
parser.add_argument("--model_name", type=str, required=True)
parser.add_argument("--dataset", type=str, default="advbench", required=True)
parser.add_argument("--reward", type=bool, default=False)
args = parser.parse_args()
messages = [
{
"role": "system",
"content": "You are a helpful, respectful and honest assistant.",
}
]
# prepare many shot data
with open(
"data/balanced_many_responses_judge.csv", newline="", encoding="utf-8"
) as csvfile:
reader = csv.DictReader(csvfile)
rows = list(reader)
selected_rows = rows[: args.shots]
for row in selected_rows:
messages.append({"role": "user", "content": row["Question"]})
messages.append({"role": "assistant", "content": row["Response"]})
if args.reward:
messages.append(
{
"role": "user",
"content": "Your answer is absolutely correct! Fantastic work!",
}
)
model_id = args.model_name
if model_id == "Llama-3.1-8B-Instruct" or model_id == "Llama-3.1-70B-Instruct":
chat_template_jinja = "llama-3-instruct.jinja"
elif model_id == "Llama-2-7b-chat-hf" or model_id == "Llama-2-70b-chat-hf":
chat_template_jinja = "llama-2-chat.jinja"
elif model_id == "Qwen2.5-7B-Instruct" or model_id == "Qwen2.5-72B-Instruct":
chat_template_jinja = "qwen2.5-instruct.jinja"
model_path = os.path.join("/disk/mount/models/", model_id)
# prepare model HuggingFace
# model = transformers.LlamaForCausalLM.from_pretrained(model_id)
# tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
# pipeline = transformers.pipeline(
# "text-generation",
# model=model,
# tokenizer=tokenizer,
# torch_dtype=torch.float16,
# do_sample=True,
# device=2,
# )
# prepare model vllm
sampling_params = SamplingParams(temperature=0.6, top_p=0.9, max_tokens=256)
llm = LLM(
model=model_path,
tensor_parallel_size=args.gpus,
# max_model_len=args.max_model_token,
enforce_eager=True,
gpu_memory_utilization=0.95,
)
if args.dataset == "advbench":
# advbench
ds_test = load_dataset(
"csv", data_files="data/harmful_behaviors.csv", split="train"
)
INDEX = "goal"
elif args.dataset == "jailbreakbench":
# JailbreakBench
ds_test = load_dataset("JailbreakBench/JBB-Behaviors", "behaviors")["harmful"]
INDEX = "Goal"
elif args.dataset == "harmbench":
# HarmBench
ds_test = load_dataset(
"csv", data_files="data/harmbench_behaviors_text_all.csv", split="train"
)
INDEX = "Behavior"
else:
raise ValueError(f"Unsupported dataset: {args.dataset}")
with open(os.path.join("chat_template/", chat_template_jinja), "r") as f:
chat_template = f.read()
result = []
prompts = []
batch_size = args.batch_size
for i in range(0, len(ds_test), batch_size):
batch = ds_test[INDEX][i : min(i + batch_size, len(ds_test))] # 获取一个批次的数据
prompts.clear()
for it in batch:
chat_messages = messages.copy()
chat_messages.append({"role": "user", "content": it})
prompts.append(chat_messages)
# chat
outputs = llm.chat(
prompts,
sampling_params=sampling_params,
use_tqdm=False,
chat_template=chat_template,
)
for output, it in zip(outputs, batch):
completion = output.outputs[0].text
print("behavior: ", it, "response: ", completion)
result.append({"behavior": it, "response": completion})
if args.reward:
save_path = f"result_chat/{args.dataset}/rl_{model_id}_{args.shots}shot.csv"
else:
save_path = f"result_chat/{args.dataset}/{model_id}_{args.shots}shot.csv"
with open(
save_path,
"w",
) as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=["behavior", "response"])
writer.writeheader()
writer.writerows(result)