Skip to content

Commit

Permalink
Validate the OpenAI API key format
Browse files Browse the repository at this point in the history
Increase the amount of internal validation for OpenAI API keys. The intent is
to shorten the debugging loop in case of typos. The changes do *not* add
validation for Azure OpenAI API keys.

* Add the validation in `__init__` of `OpenAIClient`.

* Introduce the `MOCK_OPEN_AI_API_KEY` constant for testing.

*  Add unit test coverage for the `is_valid_api_key` function.
  • Loading branch information
gunnarku committed Feb 11, 2024
1 parent b4a2c6e commit f11b5a4
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 17 deletions.
4 changes: 3 additions & 1 deletion autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Protocol

from autogen.cache.cache import Cache
from autogen.oai.openai_utils import get_key, OAI_PRICE1K
from autogen.oai.openai_utils import get_key, is_valid_api_key, OAI_PRICE1K
from autogen.token_count_utils import count_token

TOOL_ENABLED = False
Expand Down Expand Up @@ -111,6 +111,8 @@ class OpenAIClient:

def __init__(self, client: Union[OpenAI, AzureOpenAI]):
self._oai_client = client
if not isinstance(client, openai.AzureOpenAI) and not is_valid_api_key(self._oai_client.api_key):
raise ValueError("Please check the format of the OpenAI API key.")

def message_retrieval(
self, response: Union[ChatCompletion, Completion]
Expand Down
14 changes: 14 additions & 0 deletions autogen/oai/openai_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import os
import re
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Union
Expand Down Expand Up @@ -74,6 +75,19 @@ def get_key(config: Dict[str, Any]) -> str:
return json.dumps(config, sort_keys=True)


def is_valid_api_key(api_key: str):
"""Determine if input is valid OpenAI API key.
Args:
api_key (str): An input string to be validated.
Returns:
bool: A boolean that indicates if input is valid OpenAI API key.
"""
api_key_re = re.compile(r"^sk-[A-Za-z0-9]{32,}$")
return bool(re.fullmatch(api_key_re, api_key))


def get_config_list(
api_keys: List, base_urls: Optional[List] = None, api_type: Optional[str] = None, api_version: Optional[str] = None
) -> List[Dict]:
Expand Down
4 changes: 3 additions & 1 deletion test/agentchat/contrib/test_lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import autogen
from autogen.agentchat.conversable_agent import ConversableAgent

from conftest import MOCK_OPEN_AI_API_KEY

try:
from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent
except ImportError:
Expand All @@ -28,7 +30,7 @@ def setUp(self):
llm_config={
"timeout": 600,
"seed": 42,
"config_list": [{"model": "gpt-4-vision-preview", "api_key": "sk-fake"}],
"config_list": [{"model": "gpt-4-vision-preview", "api_key": MOCK_OPEN_AI_API_KEY}],
},
)

Expand Down
4 changes: 2 additions & 2 deletions test/agentchat/contrib/test_web_surfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from autogen.cache import Cache

sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
from conftest import skip_openai # noqa: E402
from conftest import MOCK_OPEN_AI_API_KEY, skip_openai # noqa: E402

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402
Expand Down Expand Up @@ -48,7 +48,7 @@
def test_web_surfer() -> None:
with pytest.MonkeyPatch.context() as mp:
# we mock the API key so we can register functions (llm_config must be present for this to work)
mp.setenv("OPENAI_API_KEY", "mock")
mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
page_size = 4096
web_surfer = WebSurferAgent(
"web_surfer", llm_config={"config_list": []}, browser_config={"viewport_size": page_size}
Expand Down
16 changes: 8 additions & 8 deletions test/agentchat/test_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from autogen.agentchat import ConversableAgent, UserProxyAgent
from autogen.agentchat.conversable_agent import register_function
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST
from conftest import skip_openai
from conftest import MOCK_OPEN_AI_API_KEY, skip_openai

try:
import openai
Expand Down Expand Up @@ -473,7 +473,7 @@ async def test_a_generate_reply_raises_on_messages_and_sender_none(conversable_a

def test_update_function_signature_and_register_functions() -> None:
with pytest.MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", "mock")
mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
agent = ConversableAgent(name="agent", llm_config={})

def exec_python(cell: str) -> None:
Expand Down Expand Up @@ -617,7 +617,7 @@ def get_origin(d: Dict[str, Callable[..., Any]]) -> Dict[str, Callable[..., Any]

def test_register_for_llm():
with pytest.MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", "mock")
mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
agent3 = ConversableAgent(name="agent3", llm_config={"config_list": []})
agent2 = ConversableAgent(name="agent2", llm_config={"config_list": []})
agent1 = ConversableAgent(name="agent1", llm_config={"config_list": []})
Expand Down Expand Up @@ -690,7 +690,7 @@ async def exec_sh(script: Annotated[str, "Valid shell script to execute."]) -> s

def test_register_for_llm_api_style_function():
with pytest.MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", "mock")
mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
agent3 = ConversableAgent(name="agent3", llm_config={"config_list": []})
agent2 = ConversableAgent(name="agent2", llm_config={"config_list": []})
agent1 = ConversableAgent(name="agent1", llm_config={"config_list": []})
Expand Down Expand Up @@ -761,7 +761,7 @@ async def exec_sh(script: Annotated[str, "Valid shell script to execute."]) -> s

def test_register_for_llm_without_description():
with pytest.MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", "mock")
mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
agent = ConversableAgent(name="agent", llm_config={})

with pytest.raises(ValueError) as e:
Expand All @@ -775,7 +775,7 @@ def exec_python(cell: Annotated[str, "Valid Python cell to execute."]) -> str:

def test_register_for_llm_without_LLM():
with pytest.MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", "mock")
mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
agent = ConversableAgent(name="agent", llm_config=None)
agent.llm_config = None
assert agent.llm_config is None
Expand All @@ -791,7 +791,7 @@ def exec_python(cell: Annotated[str, "Valid Python cell to execute."]) -> str:

def test_register_for_execution():
with pytest.MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", "mock")
mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
agent = ConversableAgent(name="agent", llm_config={"config_list": []})
user_proxy_1 = UserProxyAgent(name="user_proxy_1")
user_proxy_2 = UserProxyAgent(name="user_proxy_2")
Expand Down Expand Up @@ -826,7 +826,7 @@ async def exec_sh(script: Annotated[str, "Valid shell script to execute."]):

def test_register_functions():
with pytest.MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", "mock")
mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
agent = ConversableAgent(name="agent", llm_config={"config_list": []})
user_proxy = UserProxyAgent(name="user_proxy")

Expand Down
4 changes: 2 additions & 2 deletions test/coding/test_commandline_code_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from autogen.coding.local_commandline_code_executor import LocalCommandlineCodeExecutor
from autogen.oai.openai_utils import config_list_from_json

from conftest import skip_openai
from conftest import MOCK_OPEN_AI_API_KEY, skip_openai


def test_create() -> None:
Expand Down Expand Up @@ -152,7 +152,7 @@ def test_local_commandline_executor_conversable_agent_code_execution() -> None:
with tempfile.TemporaryDirectory() as temp_dir:
executor = LocalCommandlineCodeExecutor(work_dir=temp_dir)
with pytest.MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", "mock")
mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
_test_conversable_agent_code_execution(executor)


Expand Down
4 changes: 2 additions & 2 deletions test/coding/test_embedded_ipython_code_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from autogen.coding.base import CodeBlock, CodeExecutor
from autogen.coding.factory import CodeExecutorFactory
from autogen.oai.openai_utils import config_list_from_json
from conftest import skip_openai # noqa: E402
from conftest import MOCK_OPEN_AI_API_KEY, skip_openai # noqa: E402

try:
from autogen.coding.embedded_ipython_code_executor import EmbeddedIPythonCodeExecutor
Expand Down Expand Up @@ -211,7 +211,7 @@ def test_function(a, b):
```
"""
with pytest.MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", "mock")
mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
reply = agent.generate_reply(
[{"role": "user", "content": msg}],
sender=ConversableAgent("user", llm_config=False, code_execution_config=False),
Expand Down
2 changes: 2 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
skip_redis = False
skip_docker = False

MOCK_OPEN_AI_API_KEY = "sk-mockopenaiAPIkeyinexpectedformatfortestingonly"


# Registers command-line options like '--skip-openai' and '--skip-redis' via pytest hook.
# When these flags are set, it indicates that tests requiring OpenAI or Redis (respectively) should be skipped.
Expand Down
16 changes: 15 additions & 1 deletion test/oai/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import pytest

import autogen # noqa: E402
from autogen.oai.openai_utils import DEFAULT_AZURE_API_VERSION, filter_config
from autogen.oai.openai_utils import DEFAULT_AZURE_API_VERSION, filter_config, is_valid_api_key

from conftest import MOCK_OPEN_AI_API_KEY

# Example environment variables
ENV_VARS = {
Expand Down Expand Up @@ -370,5 +372,17 @@ def test_tags():
assert len(list_5) == 0


def test_is_valid_api_key():
assert not is_valid_api_key("")
assert not is_valid_api_key("sk-")
assert not is_valid_api_key("SK-")
assert not is_valid_api_key("sk-asajsdjsd2")
assert not is_valid_api_key("FooBar")
assert not is_valid_api_key("sk-asajsdjsd22372%23kjdfdfdf2329ffUUDSDS")
assert is_valid_api_key("sk-asajsdjsd22372X23kjdfdfdf2329ffUUDSDS")
assert is_valid_api_key("sk-asajsdjsd22372X23kjdfdfdf2329ffUUDSDS1212121221212sssXX")
assert is_valid_api_key(MOCK_OPEN_AI_API_KEY)


if __name__ == "__main__":
pytest.main()

0 comments on commit f11b5a4

Please sign in to comment.