Skip to content

Commit

Permalink
restore openai
Browse files Browse the repository at this point in the history
  • Loading branch information
bmosaicml committed Dec 13, 2023
1 parent be604a8 commit aeda62c
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 102 deletions.
7 changes: 4 additions & 3 deletions llmfoundry/models/inference_api_wrapper/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer):

def get_metrics(self, is_train: bool = False):
if is_train:
metrics = None
raise NotImplementedError(
'You cannot use inference wrappers for training')
else:
metrics = self.eval_metrics

Expand All @@ -54,7 +55,6 @@ def rebatch(self, batch: Batch):
return batch

def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):
padding_tok = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else self.tokenizer.eos_token_id
# If the batch mode is generate, we will generate a requested number of tokens using the underlying
# model's generate function. Extra generation kwargs can be passed in via the batch. Strings will
# be returned from eval_forward
Expand All @@ -80,7 +80,8 @@ def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):
[output_logits,
next_logit_tensor.reshape(1, -1)])
padding = torch.nn.functional.one_hot(
torch.full((seqlen - output_logits.shape[0],), padding_tok),
torch.full((seqlen - output_logits.shape[0],),
self.tokenizer.pad_token_id),
num_classes=self.tokenizer.vocab_size)
output_logits = torch.cat([output_logits, padding])
output_logits_batch.append(output_logits)
Expand Down
76 changes: 39 additions & 37 deletions llmfoundry/models/inference_api_wrapper/openai_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import logging
import os
import random
from time import sleep
from typing import Any, Dict, List, Optional, Union

Expand All @@ -31,23 +30,20 @@ class OpenAIEvalInterface(InferenceAPIEvalWrapper):

def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
super().__init__(model_cfg, tokenizer)
assert os.getenv(
'OPENAI_API_KEY'
) is not None, 'No OpenAI API Key found. Ensure it is saved as an environmental variable called OPENAI_API_KEY.'
try:
import openai
except ImportError as e:
raise MissingConditionalImportError(
extra_deps_group='openai',
conda_package='openai',
conda_channel='conda-forge') from e
self.client = openai.OpenAI()
openai.api_key = os.getenv('OPENAI_API_KEY')
self.model_name = model_cfg['version']

def generate_completion(self, prompt: str, num_tokens: int):
raise NotImplementedError()

def process_result(self, completion): # pyright: ignore
def process_result(self, completion: Optional[dict]):
raise NotImplementedError()

def get_next_token_logit_tensor(self, prompt: str, num_tokens: int = 1):
Expand All @@ -56,30 +52,26 @@ def get_next_token_logit_tensor(self, prompt: str, num_tokens: int = 1):

def try_generate_completion(self, prompt: str, num_tokens: int):
try:
from openai import RateLimitError
from openai.error import RateLimitError
except ImportError as e:
raise MissingConditionalImportError(
extra_deps_group='openai',
conda_package='openai',
conda_channel='conda-forge') from e
tries = 0
completion = None
delay = 1
while tries < MAX_RETRIES:
tries += 1
try:

completion = self.generate_completion(prompt, num_tokens)
break
except RateLimitError as e:
if 'You exceeded your current quota' in str(
e._message): # pyright: ignore
if 'You exceeded your current quota' in str(e._message):
raise e
delay *= 2 * (1 + random.random())
sleep(delay)
sleep(60)
continue
except Exception as e:
print(f'Found Exception: {e}')
# TODO: Why continue on unspecified Exception?
except Exception:
continue
return completion

Expand All @@ -88,23 +80,23 @@ class OpenAIChatAPIEvalWrapper(OpenAIEvalInterface):

def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
super().__init__(model_cfg, tokenizer)
try:
import openai
except ImportError as e:
raise MissingConditionalImportError(
extra_deps_group='openai',
conda_package='openai',
conda_channel='conda-forge') from e

self.generate_completion = lambda prompt, num_tokens: self.client.chat.completions.create(
model=self.model_name,
self.generate_completion = lambda prompt, num_tokens: openai.ChatCompletion.create(
self.model_name,
messages=[{
'role':
'system',
'content':
model_cfg.get('sytsem_role_prompt',
'Please complete the following text: ')
}, {
'role': 'user',
'content': prompt
}],
max_tokens=num_tokens,
temperature=0.0)

# TODO: Do we still need retokenize, rebatch, and eval_forward?
def retokenize(self, tokens: List[int], cont_idxs: List[int]):
"""Chat API will never respond with a word-initial space.
Expand Down Expand Up @@ -170,7 +162,6 @@ def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):
# than what the continuation would expect.
# Get around this issue by retokenizing the batch to remove spacing from the continuation as well as
# decoding the whole continuation at once.
padding_tok = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else self.tokenizer.eos_token_id
output_logits_batch = []
batch = self.rebatch(batch)
for tokens, cont_idxs in zip(batch['input_ids'],
Expand All @@ -191,18 +182,20 @@ def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):
if next_logit_tensor is not None:
output_logits = torch.cat([output_logits, next_logit_tensor])
padding = torch.nn.functional.one_hot(
torch.full((seqlen - output_logits.shape[0],), padding_tok),
torch.full((seqlen - output_logits.shape[0],),
self.tokenizer.pad_token_id),
num_classes=self.tokenizer.vocab_size)
output_logits = torch.cat([output_logits, padding])
output_logits_batch.append(output_logits)

return torch.stack(output_logits_batch).to(batch['input_ids'].device)

def process_result(self, completion): # pyright: ignore
if len(completion.choices) > 0: # pyright: ignore
def process_result(self, completion: Optional[dict]):
assert isinstance(completion, dict)
if len(completion['choices']) > 0:
tensors = []
for t in self.tokenizer(completion.choices[0].message.content
)['input_ids']: # pyright: ignore
for t in self.tokenizer(completion['choices'][0]['message']
['content'])['input_ids']:
tensors.append(
self.tokenizer.construct_logit_tensor(
{self.tokenizer.decode([t]): 0.0}))
Expand All @@ -220,20 +213,29 @@ class OpenAICausalLMEvalWrapper(OpenAIEvalInterface):

def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
super().__init__(model_cfg, tokenizer)
# TODO: this will be deprecated
self.generate_completion = lambda prompt, num_tokens: self.client.completions.create(
model=self.model_name,
try:
import openai
except ImportError as e:
raise MissingConditionalImportError(
extra_deps_group='openai',
conda_package='openai',
conda_channel='conda-forge') from e

self.generate_completion = lambda prompt, num_tokens: openai.Completion.create(
engine=self.model_name,
prompt=prompt,
max_tokens=num_tokens,
max_tokens=1,
logprobs=5,
temperature=0.0)

def process_result(self, completion): # pyright: ignore
def process_result(self, completion: Optional[dict]):
if completion is None:
raise ValueError("Couldn't generate model output")
if len(completion.choices[0].logprobs.top_logprobs[0]) > 0:

assert isinstance(completion, dict)
if len(completion['choices'][0]['logprobs']['top_logprobs']) > 0:
tensor = self.tokenizer.construct_logit_tensor(
dict(completion.choices[0].logprobs.top_logprobs[0]))
dict(completion['choices'][0]['logprobs']['top_logprobs'][0]))
return tensor
else:
# the model sometimes stops early even though we are still requesting tokens!
Expand Down
119 changes: 57 additions & 62 deletions tests/models/inference_api_wrapper/test_inference_api_eval_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,69 +34,60 @@ def load_icl_config():
})


class MockTopLogProb:

def __init__(self, expected_token: str) -> None:
setattr(self, 'top_logprobs', [{expected_token: 0}])


class MockLogprob:

def __init__(self, expected_token: str) -> None:
setattr(self, 'logprobs', MockTopLogProb(expected_token))


class MockCompletion:

def __init__(self, expected_token: str) -> None:
setattr(self, 'choices', [MockLogprob(expected_token)])


class MockContent:

def __init__(self, expected_token: str) -> None:
setattr(self, 'content', expected_token)


class MockMessage:

def __init__(self, expected_token: str) -> None:
setattr(self, 'message', MockContent(expected_token))


class MockChatCompletion:

def __init__(self, expected_token: str) -> None:
setattr(self, 'choices', [MockMessage(expected_token)])


def mock_create(**kwargs: Dict[str, str]):
prompt = kwargs['prompt']
if prompt == 'AMERICAN HISTORY: On May 29, 1765 Patrick Henrys Stamp Act protest was interrupted with this one word\nAnswer:': # pyright: ignore[reportUnnecessaryComparison]
return MockCompletion(' Tre')

return {
'choices': [{
'logprobs': {
'top_logprobs': [{
' Tre': 0,
}],
},
}],
}
elif prompt == 'AMERICAN HISTORY: On May 29, 1765 Patrick Henrys Stamp Act protest was interrupted with this one word\nAnswer: Tre': # pyright: ignore[reportUnnecessaryComparison]
return MockCompletion('ason')

return {
'choices': [{
'logprobs': {
'top_logprobs': [{
'ason': 0,
}],
},
}],
}
elif prompt == 'AMERICAN HISTORY: On May 29, 1765 Patrick Henrys Stamp Act protest was interrupted with this one word\nAnswer: Treason': # pyright: ignore[reportUnnecessaryComparison]
return MockCompletion('!')

return {
'choices': [{
'logprobs': {
'top_logprobs': [{
'!': 0,
}],
},
}],
}
else:
# dummy token to make sure the model is incorrect on any other prompt
return MockCompletion(' ')
return {
'choices': [{
'logprobs': {
'top_logprobs': [{
' ': 0,
}],
},
}],
}


def test_openai_api_eval_wrapper(tmp_path: str):
_ = pytest.importorskip('openai')

model_name = 'davinci'
tokenizer = TiktokenTokenizerWrapper(model_name=model_name,
pad_token='<|endoftext|>')
model = OpenAICausalLMEvalWrapper(model_cfg={'version': model_name},
tokenizer=tokenizer)
with patch.object(model, 'client') as mock:
mock.completions.create = mock_create

with patch('openai.Completion') as mock:
mock.create = mock_create
model_name = 'davinci'
tokenizer = TiktokenTokenizerWrapper(model_name=model_name,
pad_token='<|endoftext|>')
model = OpenAICausalLMEvalWrapper(model_cfg={'version': model_name},
tokenizer=tokenizer)
task_cfg = load_icl_config()
evaluators, _ = build_icl_evaluators(task_cfg.icl_tasks,
tokenizer,
Expand All @@ -118,16 +109,20 @@ def test_openai_api_eval_wrapper(tmp_path: str):

def test_chat_api_eval_wrapper(tmp_path: str):
_ = pytest.importorskip('openai')

model_name = 'gpt-3.5-turbo'
tokenizer = TiktokenTokenizerWrapper(model_name=model_name,
pad_token='<|endoftext|>')
chatmodel = OpenAIChatAPIEvalWrapper(model_cfg={'version': model_name},
tokenizer=tokenizer)
with patch.object(chatmodel, 'client') as mock:
mock.chat.completions.create.return_value = MockChatCompletion(
'Treason!')

with patch('openai.ChatCompletion') as mock:
mock.create.return_value = {
'choices': [{
'message': {
'role': 'assistant',
'content': 'Treason!'
},
}],
}
model_name = 'gpt-3.5-turbo'
tokenizer = TiktokenTokenizerWrapper(model_name=model_name,
pad_token='<|endoftext|>')
chatmodel = OpenAIChatAPIEvalWrapper(model_cfg={'version': model_name},
tokenizer=tokenizer)
task_cfg = load_icl_config()
evaluators, _ = build_icl_evaluators(task_cfg.icl_tasks,
tokenizer,
Expand Down
39 changes: 39 additions & 0 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,45 @@ def test_forward_with_padding(attention_impl: str, pos_emb_config: dict,
batched_output[1, :],
atol=1e-6 if attention_impl == 'torch' else 1e-8)

try:
from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip
except:
unpad_input, pad_input = None, None

if unpad_input is not None and pad_input is not None:
# Checking numerical precision with pad_token ffn
for block in mpt.transformer.blocks:
# Flip the padding usage in the model
block.use_pad_tok_in_ffn = not block.use_pad_tok_in_ffn

right_padding_output_pad_flipped = mpt(
right_padding_input_ids,
attention_mask=right_padding_attention_mask).logits
middle_padding_output_pad_flipped = mpt(
middle_padding_input_ids,
attention_mask=middle_padding_attention_mask).logits
left_padding_output_pad_flipped = mpt(
left_padding_input_ids,
attention_mask=left_padding_attention_mask).logits

pad_vs_unpad_rtol = 1e-5
pad_vs_unpad_atol = 1e-6
assert torch.allclose(right_padding_output[0, :3],
right_padding_output_pad_flipped[0, :3],
rtol=pad_vs_unpad_rtol,
atol=pad_vs_unpad_atol)

assert torch.allclose(middle_padding_output[0, [0, 1, 5]],
middle_padding_output_pad_flipped[0,
[0, 1, 5]],
rtol=pad_vs_unpad_rtol,
atol=pad_vs_unpad_atol)

assert torch.allclose(left_padding_output[0, 3:],
left_padding_output_pad_flipped[0, 3:],
rtol=pad_vs_unpad_rtol,
atol=pad_vs_unpad_atol)


@pytest.mark.parametrize('attention_impl', ['torch', 'triton'])
def test_advanced_mask_building(attention_impl: str):
Expand Down

0 comments on commit aeda62c

Please sign in to comment.