forked from YangLing0818/RPG-DiffusionMaster
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mllm.py
105 lines (97 loc) · 4.26 KB
/
mllm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import requests
import json
import os
from transformers import AutoTokenizer
import transformers
import torch
import re
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
def extract_output(text):
# Find the output in the text
output_pattern = r'### Output:(.*?)(?=###|$)'
output_match = re.search(output_pattern, text, re.DOTALL)
return output_match.group(1).strip() if output_match else None
def GPT4(prompt,version,key):
url = "https://api.openai.com/v1/chat/completions"
api_key = key
with open('template/template.txt', 'r') as f:
template=f.readlines()
if version=='multi-attribute':
with open('template/human_multi_attribute_examples.txt', 'r') as f:
incontext_examples=f.readlines()
elif version=='complex-object':
with open('template/complex_multi_object_examples.txt', 'r') as f:
incontext_examples=f.readlines()
user_textprompt=f"Caption:{prompt} \n Let's think step by step:"
textprompt= f"{' '.join(template)} \n {' '.join(incontext_examples)} \n {user_textprompt}"
payload = json.dumps({
"model": "gpt-4-1106-preview", # we suggest to use the latest version of GPT, you can also use gpt-4-vision-preivew, see https://platform.openai.com/docs/models/ for details.
"messages": [
{
"role": "user",
"content": textprompt
}
]
})
headers = {
'Accept': 'application/json',
'Authorization': f'Bearer {api_key}',
'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',
'Content-Type': 'application/json'
}
print('waiting for GPT-4 response')
response = requests.request("POST", url, headers=headers, data=payload)
obj=response.json()
text=obj['choices'][0]['message']['content']
print(text)
# Extract the split ratio and regional prompt
return get_params_dict(text)
def local_llm(prompt,version,model_path=None):
if model_path==None:
model_id = "Llama-2-13b-chat-hf"
else:
model_id=model_path
print('Using model:',model_id)
tokenizer = LlamaTokenizer.from_pretrained(model_id)
model = LlamaForCausalLM.from_pretrained(model_id, load_in_8bit=False, device_map='auto', torch_dtype=torch.float16)
with open('template/template.txt', 'r') as f:
template=f.readlines()
if version=='multi-attribute':
with open('template/human_multi_attribute_examples.txt', 'r') as f:
incontext_examples=f.readlines()
elif version=='complex-object':
with open('template/complex_multi_object_examples.txt', 'r') as f:
incontext_examples=f.readlines()
user_textprompt=f"Caption:{prompt} \n Let's think step by step:"
textprompt= f"{' '.join(template)} \n {' '.join(incontext_examples)} \n {user_textprompt}"
model_input = tokenizer(textprompt, return_tensors="pt").to("cuda")
model.eval()
with torch.no_grad():
print('waiting for LLM response')
res = model.generate(**model_input, max_new_tokens=1024)[0]
output=tokenizer.decode(res, skip_special_tokens=True)
output = output.replace(textprompt,'')
return get_params_dict(output)
def get_params_dict(output_text):
split_ratio_marker = "Split ratio: "
regional_prompt_marker = "Regional Prompt: "
output_text=extract_output(output_text)
print(output_text)
# Find the start and end indices for the split ratio and regional prompt
split_ratio_start = output_text.find(split_ratio_marker) + len(split_ratio_marker)
split_ratio_end = output_text.find("\n", split_ratio_start)
regional_prompt_start = output_text.find(regional_prompt_marker) + len(regional_prompt_marker)
regional_prompt_end = len(output_text) # Assuming Regional Prompt is at the end
# Extract the split ratio and regional prompt from the text
split_ratio = output_text[split_ratio_start:split_ratio_end].strip()
regional_prompt = output_text[regional_prompt_start:regional_prompt_end].strip()
#Delete the possible "(" and ")" in the split ratio
split_ratio=split_ratio.replace('(','').replace(')','')
# Create the dictionary with the extracted information
image_region_dict = {
'split ratio': split_ratio,
'Regional Prompt': regional_prompt
}
print(image_region_dict)
return image_region_dict