-
Notifications
You must be signed in to change notification settings - Fork 1
/
lemma_component.py
81 lines (74 loc) · 3.41 KB
/
lemma_component.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import json
from configs import prompts_root, imgbed_root, cache_root
from utils import onlineImg_process, offlineImg_process, gpt_no_image
class LemmaComponent:
def __init__(self, prompt, name, model='gpt4v', using_cache=False, cache_name='', online_image=True, max_retry=5,
max_tokens=1000, temperature=0.1, post_process=None):
self.name = name
self.model = model
self.using_cache = using_cache
self.online_image = online_image
self.max_retry = max_retry
self.max_tokens = max_tokens
self.temperature = temperature
self.post_process = post_process
if cache_name != '':
self.cache_name = cache_name
else:
self.cache_name = self.name + '.json'
if type(prompt) == str:
with open(prompts_root + prompt, 'r', encoding='utf-8') as f:
self.prompt = f.read()
else:
self.prompt = prompt
if using_cache:
try:
with open(cache_root + self.cache_name, 'r', encoding='utf-8') as f:
self.cache = json.load(f)
except FileNotFoundError:
self.cache = {}
def __call__(self, *args, **kwargs):
print(f'Lemma Component {self.name}: Starting...')
if 'image' in kwargs:
image_path = kwargs['image']
del kwargs['image']
else:
image_path = ''
prompt = self.prompt.format(**kwargs)
if self.using_cache:
if prompt in self.cache:
print(f'Lemma Component {self.name}: retrieve from cache')
return self.cache[prompt]
for i in range(self.max_retry):
try:
if self.model == 'gpt4v':
if self.online_image:
if 'http' not in image_path:
image_path = imgbed_root + image_path
result = onlineImg_process(prompt=prompt, url=image_path,
max_tokens=self.max_tokens, temperature=self.temperature)
else:
result = offlineImg_process(prompt=prompt, image_path=image_path,
max_tokens=self.max_tokens, temperature=self.temperature)
elif self.model == 'gpt3.5':
result = gpt_no_image(prompt, model='gpt-3.5-turbo', max_tokens=self.max_tokens,
temperature=self.temperature)
elif self.model == 'gpt4':
result = gpt_no_image(prompt, model='gpt-4-1106-preview', max_tokens=self.max_tokens,
temperature=self.temperature)
else:
raise ValueError(f'Unknown model {self.model}')
if self.post_process is not None:
result = self.post_process(result)
break
except Exception as e:
print(f'Lemma Component {self.name}: {e}, retrying...')
continue
else:
print(f'Lemma Component {self.name}: Max retry exceeded')
return None
if self.using_cache and result is not None:
self.cache[prompt] = result
with open(cache_root + self.cache_name, 'w', encoding='utf-8') as f:
json.dump(self.cache, f)
return result