-
Notifications
You must be signed in to change notification settings - Fork 1
/
openai_utils_curl.py
510 lines (434 loc) · 20.8 KB
/
openai_utils_curl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
import os
import copy
import functools
import json
import logging
import math
import multiprocessing
import random
import time
import hashlib
from typing import Optional, Sequence
from pathlib import Path
import numpy as np
import openai
from openai import OpenAI
import tiktoken
import tqdm
import requests
__all__ = ["openai_completions"]
tiktoken.model.MODEL_TO_ENCODING['ChatGPT'] = tiktoken.model.MODEL_TO_ENCODING['gpt-3.5-turbo']
# API specific
DEFAULT_OPENAI_API_BASE = openai.base_url
OPENAI_API_KEYS = os.environ.get("OPENAI_API_KEYS", os.environ.get("OPENAI_API_KEY", None))
if isinstance(OPENAI_API_KEYS, str):
OPENAI_API_KEYS = OPENAI_API_KEYS.split(",")
OPENAI_ORGANIZATION_IDS = os.environ.get("OPENAI_ORGANIZATION_IDS", None)
if isinstance(OPENAI_ORGANIZATION_IDS, str):
OPENAI_ORGANIZATION_IDS = OPENAI_ORGANIZATION_IDS.split(",")
OPENAI_MAX_CONCURRENCY = int(os.environ.get("OPENAI_MAX_CONCURRENCY", 5))
cache_dir = Path(os.path.abspath(__file__)).parent / "cache"
cache_dir.mkdir(exist_ok=True)
cache_base_path = None
cache_base = None
def get_prompt_uids(prompt: str) -> str:
return hashlib.sha256(prompt.encode()).hexdigest()
def openai_completions(
prompts: Sequence[str],
model_name: str,
tokens_to_favor: Optional[Sequence[str]] = None,
tokens_to_avoid: Optional[Sequence[str]] = None,
is_skip_multi_tokens_to_avoid: bool = True,
is_strip: bool = True,
num_procs: Optional[int] = OPENAI_MAX_CONCURRENCY,
batch_size: Optional[int] = None,
use_cache: bool = True,
rpm: int = 50,
**decoding_kwargs,
) -> dict[str, list]:
r"""Get openai completions for the given prompts. Allows additional parameters such as tokens to avoid and
tokens to favor.
Parameters
----------
prompts : list of str
Prompts to get completions for.
model_name : str
Name of the model to use for decoding.
tokens_to_favor : list of str, optional
Substrings to favor in the completions. We will add a positive bias to the logits of the tokens constituting
the substrings.
tokens_to_avoid : list of str, optional
Substrings to avoid in the completions. We will add a large negative bias to the logits of the tokens
constituting the substrings.
is_skip_multi_tokens_to_avoid : bool, optional
Whether to skip substrings from tokens_to_avoid that are constituted by more than one token => avoid undesired
side effects on other tokens.
is_strip : bool, optional
Whether to strip trailing and leading spaces from the prompts.
use_cache : bool, optional
Whether to use cache to save the query results in case of multiple queries.
decoding_kwargs :
Additional kwargs to pass to `openai.Completion` or `openai.ChatCompletion`.
Example
-------
>>> prompts = ["Respond with one digit: 1+1=", "Respond with one digit: 2+2="]
>>> openai_completions(prompts, model_name="text-davinci-003", tokens_to_avoid=["2"," 2"])['completions']
['\n\nAnswer: \n\nTwo (or, alternatively, the number "two" or the numeral "two").', '\n\n4']
>>> openai_completions(prompts, model_name="text-davinci-003", tokens_to_favor=["2"])['completions']
['2\n\n2', '\n\n4']
>>> openai_completions(prompts, model_name="text-davinci-003",
... tokens_to_avoid=["2 a long sentence that is not a token"])['completions']
['\n\n2', '\n\n4']
>>> chat_prompt = ["<|im_start|>user\n1+1=<|im_end|>", "<|im_start|>user\nRespond with one digit: 2+2=<|im_end|>"]
>>> openai_completions(chat_prompt, "gpt-3.5-turbo", tokens_to_avoid=["2"," 2"])['completions']
['As an AI language model, I can confirm that 1+1 equals 02 in octal numeral system, 10 in decimal numeral
system, and 02 in hexadecimal numeral system.', '4']
"""
# add cache support for query
num_procs = num_procs or OPENAI_MAX_CONCURRENCY
if use_cache:
global cache_base
global cache_base_path
cache_base_path = cache_dir / f"{model_name}.jsonl"
if cache_base is None:
if not cache_base_path.exists():
cache_base = {}
logging.warning(
f"Cache file {cache_base_path} does not exist. Creating new cache.")
else:
with open(cache_base_path, "r") as f:
cache_base = [json.loads(line) for line in f.readlines()]
cache_base = {item['uid']: item for item in cache_base}
logging.warning(f"Loaded cache base from {cache_base_path}.")
n_examples = len(prompts)
if n_examples == 0:
logging.warning("No samples to annotate.")
return []
else:
logging.warning(
f"Using `openai_completions` on {n_examples} prompts using {model_name}.")
if tokens_to_avoid or tokens_to_favor:
tokenizer = tiktoken.encoding_for_model(model_name)
logit_bias = decoding_kwargs.get("logit_bias", {})
if tokens_to_avoid is not None:
for t in tokens_to_avoid:
curr_tokens = tokenizer.encode(t)
if len(curr_tokens) != 1 and is_skip_multi_tokens_to_avoid:
logging.warning(
f"'{t}' has more than one token, skipping because `is_skip_multi_tokens_to_avoid`.")
continue
for tok_id in curr_tokens:
logit_bias[tok_id] = -100 # avoids certain tokens
if tokens_to_favor is not None:
for t in tokens_to_favor:
curr_tokens = tokenizer.encode(t)
for tok_id in curr_tokens:
# increase log prob of tokens to match
logit_bias[tok_id] = 7
decoding_kwargs["logit_bias"] = logit_bias
if is_strip:
prompts = [p.strip() for p in prompts]
is_chat = decoding_kwargs.get(
"requires_chatml", _requires_chatml(model_name))
if is_chat:
# prompts = [_prompt_to_chatml(prompt) for prompt in prompts]
num_procs = num_procs or 4
batch_size = batch_size or 1
if batch_size > 1:
logging.warning(
"batch_size > 1 is not supported yet for chat models. Setting to 1")
batch_size = 1
else:
num_procs = num_procs or 1
batch_size = batch_size or 10
logging.warning(f"Kwargs to completion: {decoding_kwargs}")
n_batches = int(math.ceil(n_examples / batch_size))
prompt_batches = [
prompts[batch_id * batch_size: (batch_id + 1) * batch_size] for batch_id in range(n_batches)]
if "azure" == openai.api_type:
# Azure API uses engine instead of model
kwargs = dict(n=1, model=model_name, is_chat=is_chat,
use_cache=use_cache, **decoding_kwargs)
else:
# OpenAI API uses model instead of engine
kwargs = dict(n=1, model=model_name, is_chat=is_chat,
use_cache=use_cache, **decoding_kwargs)
kwargs.update({"rpm": rpm / num_procs})
logging.warning(f"Kwargs to completion: {kwargs}")
with Timer() as t:
if num_procs == 1:
completions = [
_openai_completion_helper(prompt_batch, **kwargs)
for prompt_batch in tqdm.tqdm(prompt_batches, desc="prompt_batches")
]
else:
with multiprocessing.Pool(num_procs) as p:
partial_completion_helper = functools.partial(
_openai_completion_helper, **kwargs)
completions = list(
tqdm.tqdm(
p.imap(partial_completion_helper, prompt_batches),
desc="prompt_batches",
total=len(prompt_batches),
)
)
logging.warning(f"Completed {n_examples} examples in {t}.")
# flatten the list and select only the text
completions_text = [completion['content']
for completion_batch in completions for completion in completion_batch]
price = [
completion["total_tokens"] * _get_price_per_token(model_name)
if completion["total_tokens"] is not None else 0
for completion_batch in completions
for completion in completion_batch
]
avg_time = [t.duration / n_examples] * len(completions_text)
return dict(completions=completions_text, price_per_example=price, time_per_example=avg_time)
def _openai_completion_helper(
prompt_batch: Sequence[str],
is_chat: bool,
sleep_time: int = 2,
openai_organization_ids: Optional[Sequence[str]] = OPENAI_ORGANIZATION_IDS,
openai_api_keys: Optional[Sequence[str]] = OPENAI_API_KEYS,
openai_api_base: Optional[str] = None,
max_tokens: Optional[int] = 1000,
top_p: Optional[float] = 1.0,
temperature: Optional[float] = 0.7,
use_cache: bool = True,
rpm: int = 10,
**kwargs,
):
client_kwargs = dict()
# randomly select orgs
if openai_organization_ids is not None:
client_kwargs["organization"] = random.choice(openai_organization_ids)
openai_api_keys = openai_api_keys or OPENAI_API_KEYS
if openai_api_keys is not None:
client_kwargs["api_key"] = random.choice(openai_api_keys)
# set api base
client_kwargs["base_url"] = base_url = openai_api_base if openai_api_base is not None else DEFAULT_OPENAI_API_BASE
client = OpenAI(**client_kwargs)
# copy shared_kwargs to avoid modifying it
kwargs.update(dict(max_tokens=max_tokens,
top_p=top_p, temperature=temperature))
curr_kwargs = copy.deepcopy(kwargs)
if use_cache:
prompt_uids = [get_prompt_uids(prompt) for prompt in prompt_batch]
cache_completions = [cache_base[prompt_uid]['completion']
if prompt_uid in cache_base else None for prompt_uid in prompt_uids]
to_query_prompt_batch = [prompt for prompt, cache_completion in zip(
prompt_batch, cache_completions) if cache_completion is None]
else:
to_query_prompt_batch = prompt_batch
if is_chat:
to_query_prompt_batch = [_prompt_to_chatml(
prompt) for prompt in to_query_prompt_batch]
# now_cand = ""
# retry_times = 0
if len(to_query_prompt_batch) != 0:
while True:
try:
if is_chat:
# print(curr_kwargs)
# completion_batch = client.chat.completions.create(messages=to_query_prompt_batch[0], **curr_kwargs)
# curl
url = 'http://121.127.44.53:8102/v1/chat/completions'
headers = {
'Authorization': 'Bearer {}'.format(client_kwargs["api_key"]),
'Content-Type': 'application/json'
}
data = {
"messages": to_query_prompt_batch[0],
**curr_kwargs
}
start_time = time.time()
response = requests.post(url, headers=headers, data=json.dumps(data))
end_time = time.time()
total_time = end_time - start_time
sleep_time = max(0, 60 / rpm - total_time)
print(f"Sleeping {sleep_time} seconds...")
time.sleep(sleep_time)
completion_batch = json.loads(response.text)
if "choices" not in completion_batch:
raise Exception(response.text)
choices = completion_batch['choices']
for choice in choices:
assert choice['message']['role'] == "assistant"
# if choice.message.get("function_call"):
# # currently we only use function calls to get a JSON object
# # => overwrite text with the JSON object. In the future, we could
# # allow actual function calls
# all_args = json.loads(
# choice.message.function_call.arguments)
# assert len(all_args) == 1
# choice["text"] = all_args[list(all_args.keys())[0]]
else:
raise NotImplementedError
batch_avg_tokens = completion_batch['usage']['total_tokens'] / len(prompt_batch)
break
except Exception as e:
if "Please reduce your prompt" in str(e):
kwargs["max_tokens"] = int(kwargs["max_tokens"] * 0.8)
logging.warning(
f"Reducing target length to {kwargs['max_tokens']}, Retrying...")
if kwargs["max_tokens"] == 0:
logging.exception(
"Prompt is already longer than max context length. Error:")
raise e
else:
if "rate limit" in str(e).lower() or "model_cap_exceeded" in str(e).lower():
print(e)
else:
logging.warning(
f"Unknown error {e}. \n It's likely a rate limit so we are retrying...")
logging.warning(prompt_batch)
if openai_organization_ids is not None and len(openai_organization_ids) > 1:
client_kwargs["organization"] = organization = random.choice(
[o for o in openai_organization_ids if o != openai.organization]
)
client = OpenAI(**client_kwargs)
logging.info(f"Switching OAI organization.")
if openai_api_keys is not None and len(openai_api_keys) > 1:
client_kwargs["api_key"] = random.choice([o for o in openai_api_keys if o != openai.api_key])
client = OpenAI(**client_kwargs)
logging.info(f"Switching OAI API key.")
logging.info(f"Sleeping {sleep_time} before retrying to call openai API...")
time.sleep(sleep_time) # Annoying rate limit on requests.
if use_cache:
responses = []
to_cache_items = []
to_query_idx = 0
for i in range(len(prompt_batch)):
prompt_uid = prompt_uids[i]
if cache_completions[i] is None:
cache_base[prompt_uid] = dict(uid=prompt_uid,
prompt=prompt_batch[i], completion=choices[to_query_idx]['message']['content'],
top_p=top_p, temperature=temperature, max_tokens=max_tokens, total_tokens=batch_avg_tokens)
to_cache_items.append(cache_base[prompt_uid])
responses.append(dict(
content=choices[to_query_idx]['message']['content'], total_tokens=batch_avg_tokens))
to_query_idx += 1
else:
responses.append(
dict(content=cache_completions[i], total_tokens=None))
assert to_query_idx == len(to_query_prompt_batch)
# save cache items
with open(cache_base_path, "a+") as f:
for item in to_cache_items:
f.write(json.dumps(item, ensure_ascii=False) + "\n")
else:
responses = [dict(content=choice['message']['content'], total_tokens=batch_avg_tokens) for choice in choices]
return responses
def _requires_chatml(model: str) -> bool:
"""Whether a model requires the ChatML format."""
# TODO: this should ideally be an OpenAI function... Maybe it already exists?
return "turbo" in model or "gpt-4" in model or "chatgpt" in model.lower()
def _prompt_to_chatml(prompt: str, start_token: str = "<|im_start|>", end_token: str = "<|im_end|>"):
r"""Convert a text prompt to ChatML formal
Examples
--------
>>> prompt = (
... "<|im_start|>system\n"
... "You are a helpful assistant.\n<|im_end|>\n"
... "<|im_start|>system name=example_user\nKnock knock.\n<|im_end|>\n<|im_start|>system name=example_assistant\n"
... "Who's there?\n<|im_end|>\n<|im_start|>user\nOrange.\n<|im_end|>"
... )
>>> print(prompt)
<|im_start|>system
You are a helpful assistant.
<|im_end|>
<|im_start|>system name=example_user
Knock knock.
<|im_end|>
<|im_start|>system name=example_assistant
Who's there?
<|im_end|>
<|im_start|>user
Orange.
<|im_end|>
>>> _prompt_to_chatml(prompt)
[{'content': 'You are a helpful assistant.', 'role': 'system'},
{'content': 'Knock knock.', 'role': 'system', 'name': 'example_user'},
{'content': "Who's there?", 'role': 'system', 'name': 'example_assistant'},
{'content': 'Orange.', 'role': 'user'}]
"""
prompt = prompt.strip()
assert prompt.startswith(start_token)
assert prompt.endswith(end_token)
message = []
for p in prompt.split("<|im_start|>")[1:]:
newline_splitted = p.split("\n", 1)
role = newline_splitted[0].strip()
content = newline_splitted[1].split(end_token, 1)[0].strip()
if role.startswith("system") and role != "system":
# based on https://github.com/openai/openai-cookbook/blob/main/examples
# /How_to_format_inputs_to_ChatGPT_models.ipynb
# and https://github.com/openai/openai-python/blob/main/chatml.md it seems that system can specify a
# dictionary of other args
other_params = _string_to_dict(role.split("system", 1)[-1])
role = "system"
else:
other_params = dict()
message.append(dict(content=content, role=role, **other_params))
return message
def _chatml_to_prompt(message: Sequence[dict], start_token: str = "<|im_start|>", end_token: str = "<|im_end|>"):
r"""Convert a ChatML message to a text prompt
Examples
--------
>>> message = [
... {'content': 'You are a helpful assistant.', 'role': 'system'},
... {'content': 'Knock knock.', 'role': 'system', 'name': 'example_user'},
... {'content': "Who's there?", 'role': 'system', 'name': 'example_assistant'},
... {'content': 'Orange.', 'role': 'user'}
... ]
>>> _chatml_to_prompt(message)
'<|im_start|>system\nYou are a helpful assistant.\n<|im_end|>\n<|im_start|>system name=example_user\nKnock knock.\n<|im_end|>\n<|im_start|>system name=example_assistant\nWho\'s there?\n<|im_end|>\n<|im_start|>user\nOrange.\n<|im_end|>'
"""
prompt = ""
for m in message:
role = m["role"]
name = m.get("name", None)
if name is not None:
role += f" name={name}"
prompt += f"<|im_start|>{role}\n{m['content']}\n<|im_end|>\n"
return prompt
def _string_to_dict(to_convert):
r"""Converts a string with equal signs to dictionary. E.g.
>>> _string_to_dict(" name=user university=stanford")
{'name': 'user', 'university': 'stanford'}
"""
return {s.split("=", 1)[0]: s.split("=", 1)[1] for s in to_convert.split(" ") if len(s) > 0}
def _get_price_per_token(model):
"""Returns the price per token for a given model"""
if "gpt-4" in model:
return (
0.03 / 1000
) # that's not completely true because decoding is 0.06 but close enough given that most is context
elif "gpt-3.5-turbo" in model.lower() or 'chatgpt' in model.lower():
return 0.002 / 1000
elif "text-davinci-003" in model:
return 0.02 / 1000
else:
logging.warning(
f"Unknown model {model} for computing price per token.")
return np.nan
class Timer:
"""Timer context manager"""
def __enter__(self):
"""Start a new timer as a context manager"""
self.start = time.time()
return self
def __exit__(self, *args):
"""Stop the context manager timer"""
self.end = time.time()
self.duration = self.end - self.start
def __str__(self):
return f"{self.duration:.1f} seconds"
# # Example usage with ChatGPT Azure Model
# from openai_utils import openai_completions, _chatml_to_prompt
# prompts = ["Respond with one digit: 1+1=", "Respond with one digit: 2+2="]
# chatmls = [[{"role":"system","content":"You are an AI assistant that helps people find information."},
# {"role":"user","content": prompt}] for prompt in prompts]
# chatml_prompts = [_chatml_to_prompt(chatml) for chatml in chatmls]
# print(chatml_prompts)
# openai_completions(prompts, model_name="ChatGPT")['completions']