-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
113 lines (104 loc) · 3.43 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import io
import base64
from typing import (
Any,
Dict,
)
from PIL import Image
from openai import OpenAI
def image_to_base64(image: Image):
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
class Model:
@property
def _default_params(self) -> Dict[str, Any]:
params = {
"model": self.model,
# "max_tokens": self.max_tokens,
"stream": False,
}
if hasattr(self, "temperature"):
params["temperature"] = self.temperature
return params
def __init__(self, model_name, base_url, api_key):
self.model: str = model_name
self.max_tokens: int = 4096
self.client = OpenAI(
api_key=api_key,
base_url=base_url,
)
def post(self, request: Any) -> Any:
retries = 5
for _ in range(retries):
response = self.client.chat.completions.create(**request)
try:
choice = response.choices[0]
if choice.finish_reason != 'stop':
print(f"Finish reason: {choice.finish_reason}")
raise NotImplementedError
return response
except:
print(f"Response content: {response}")
raise NotImplementedError
def generate(self, prompt: str, model: str=None):
request = {
"messages": [
{
"role": "user",
"content": prompt,
}
]
}
request.update(self._default_params)
if model:
request['model'] = model
response = self.post(request)
return response.choices[0].message.content.rstrip()
def invoke(
self,
messages,
model=None,
**kwargs: Any,
) -> str:
request = kwargs
for i in range(len(messages)):
messages[i] = {key: messages[i][key] for key in ['role', 'content']}
request["messages"] = messages
request.update(self._default_params)
if model:
request['model'] = model
response = self.post(request)
return response.choices[0].message.content.rstrip()
class LocalModel(Model):
def __init__(self, model_name, base_url="http://localhost:8000/v1"):
super().__init__(model_name, base_url, api_key='EMPTY')
class GPT4O(Model):
def __init__(self, model_name="gpt-4o-2024-05-13", base_url="https://api.openai.com/v1"):
api_key=open('openai_key').read()
super().__init__(model_name, base_url, api_key=api_key)
def generate(self, prompt: str, model: str=None, base64_image: str = None):
messages = [{"role": "user"}]
if base64_image:
messages[0]['content'] = [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
else:
messages[0]['content'] = prompt
request = {
"messages": messages
}
request.update(self._default_params)
if model:
request['model'] = model
response = self.post(request)
return response.choices[0].message.content.rstrip()