Skip to content

Commit

Permalink
fix?
Browse files Browse the repository at this point in the history
  • Loading branch information
bmosaicml committed Apr 3, 2024
1 parent 0b5da4b commit c889fd2
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 31 deletions.
8 changes: 4 additions & 4 deletions llmfoundry/models/inference_api_wrapper/fmapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def block_until_ready(self, base_url: str):
f'Endpoint {ping_url} did not become read after {waited_s:,} seconds, exiting'
)

def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer):
is_local = model_cfg.pop('local', False)
api_key = model_cfg.pop('api_key', None)
def __init__(self, om_model_config: DictConfig, tokenizer: AutoTokenizer):
is_local = om_model_config.pop('local', False)
api_key = om_model_config.pop('api_key', None)
if is_local:
base_url = os.environ.get('MOSAICML_MODEL_ENDPOINT',
'http://0.0.0.0:8080/v2')
Expand All @@ -62,7 +62,7 @@ def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer):
'Must specify base_url or use local=True in model_cfg for FMAPIsEvalWrapper'
)

super().__init__(model_cfg, tokenizer, api_key)
super().__init__(om_model_config, tokenizer, api_key)


class FMAPICasualLMEvalWrapper(FMAPIEvalInterface, OpenAICausalLMEvalWrapper):
Expand Down
27 changes: 8 additions & 19 deletions llmfoundry/models/inference_api_wrapper/gemini_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,13 @@
import os
import google.generativeai as google_genai

from typing import Dict, List, Optional, Union
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""Implements a OpenAI chat and causal LM inference API wrappers."""

from typing import List, Optional, Union
from omegaconf import DictConfig
import logging
import os
import random
from time import sleep
from typing import Any, Dict , Optional
from typing import Any, Optional

from composer.core.types import Batch
from openai.types.chat.chat_completion import ChatCompletion
Expand All @@ -41,26 +37,19 @@
class GeminiChatAPIEvalrapper(InferenceAPIEvalWrapper):
"""Databricks Foundational Model API wrapper for causal LM models."""

def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer) -> None:
api_key = model_cfg.pop('api_key', None)
def __init__(self, om_model_config: DictConfig, tokenizer: AutoTokenizer) -> None:
api_key = om_model_config.pop('api_key', None)
if api_key is None:
api_key = os.environ.get('GEMINI_API_KEY')
google_genai.configure(api_key=api_key)
super().__init__(model_cfg, tokenizer)
self.model_cfg = model_cfg
self.model = google_genai.GenerativeModel(model_cfg.get('version', ''))
super().__init__(om_model_config, tokenizer)
self.model_cfg = om_model_config
self.model = google_genai.GenerativeModel(om_model_config.get('version', ''))
ignore = [
google_genai.types.HarmCategory.HARM_CATEGORY_HARASSMENT,
google_genai.types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
google_genai.types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
google_genai.types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
# google_genai.types.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
# google_genai.types.HarmCategory.HARM_CATEGORY_DEROGATORY,
# google_genai.types.HarmCategory.HARM_CATEGORY_TOXICITY,
# google_genai.types.HarmCategory.HARM_CATEGORY_VIOLENCE,
# google_genai.types.HarmCategory.HARM_CATEGORY_SEXUAL,
# google_genai.types.HarmCategory.HARM_CATEGORY_MEDICAL,
# google_genai.types.HarmCategory.HARM_CATEGORY_DANGEROUS,
]
self.safety_settings = {
category: google_genai.types.HarmBlockThreshold.BLOCK_NONE
Expand Down
5 changes: 2 additions & 3 deletions llmfoundry/models/inference_api_wrapper/openai_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,12 @@ def __init__(self, om_model_config: DictConfig, tokenizer: AutoTokenizer, api_ke
raise ValueError(
'No OpenAI API Key found. Ensure it is saved as an environmental variable called OPENAI_API_KEY.'
)

else:
# Using a custom base URL, where the API key may not be required
log.info(
f'Making request to custom base URL: {base_url}{"" if api_key is not None else " (no API key set)"}'
)
# api_key = 'placeholder' # This cannot be None
api_key = 'placeholder' # This cannot be None

self.client = openai.OpenAI(base_url=base_url, api_key=api_key)
if 'version' in om_model_config:
Expand Down Expand Up @@ -128,7 +127,7 @@ class OpenAIChatAPIEvalWrapper(OpenAIEvalInterface):

def __init__(self, om_model_config: DictConfig, tokenizer: AutoTokenizer, api_key: Optional[str] = None) -> None:
super().__init__(om_model_config, tokenizer, api_key)
self.model_cfg = om_model_config
self.om_model_config = om_model_config

def generate_completion(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

@pytest.fixture(scope='module')
def openai_api_key_env_var() -> str:
# os.environ['OPENAI_API_KEY'] = 'dummy'
os.environ['OPENAI_API_KEY'] = 'dummy'
return os.environ['OPENAI_API_KEY']


Expand Down Expand Up @@ -198,9 +198,9 @@ def test_openai_completions_api_eval_wrapper(tmp_path: str,
batch,
result,
metric=model.get_metrics()
['InContextLearningQAAccuracy']) # pyright: ignore
['InContextLearningGenerationExactMatchAccuracy']) # pyright: ignore
acc = model.get_metrics(
)['InContextLearningQAAccuracy'].compute( # pyright: ignore
)['InContextLearningGenerationExactMatchAccuracy'].compute( # pyright: ignore
) # pyright: ignore
assert acc == 0.5
elif evaluator.label == 'human_eval/0-shot':
Expand Down Expand Up @@ -265,9 +265,9 @@ def test_chat_api_eval_wrapper(tmp_path: str, openai_api_key_env_var: str):
batch,
result,
metric=chatmodel.get_metrics()
['InContextLearningQAAccuracy']) # pyright: ignore
['InContextLearningGenerationExactMatchAccuracy']) # pyright: ignore
acc = chatmodel.get_metrics(
)['InContextLearningQAAccuracy'].compute( # pyright: ignore
)['InContextLearningGenerationExactMatchAccuracy'].compute( # pyright: ignore
) # pyright: ignore
assert acc == 0.5
elif evaluator.label == 'human_eval/0-shot':
Expand Down

0 comments on commit c889fd2

Please sign in to comment.