Skip to content

Commit

Permalink
add mock as decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
allenanswerzq committed Sep 27, 2024
1 parent 19126f0 commit 10c3e2a
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/ell/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


from ell.lmp.simple import simple
from ell.lmp.simple import mock
from ell.lmp.tool import tool
from ell.lmp.complex import complex
from ell.types.message import system, user, assistant, Message, ContentBlock
Expand Down
1 change: 1 addition & 0 deletions src/ell/lmp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from ell.lmp.simple import simple
from ell.lmp.simple import mock
from ell.lmp.complex import complex
11 changes: 10 additions & 1 deletion src/ell/lmp/simple.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
from functools import wraps
from typing import Any, Optional
from typing import Any, Optional, Callable

from ell.lmp.complex import complex
from ell.providers.mock import MockAIClient


def mock(model: str, client: Optional[Any] = None, exempt_from_tracking=False, mock_func:Callable[..., Any]=None, **api_params):
"""Mock decortoar should accept everything passed to simple"""
if mock_func:
api_params['mock_func'] = mock_func

return simple(model, client=MockAIClient(), exempt_from_tracking=exempt_from_tracking, **api_params)


def simple(model: str, client: Optional[Any] = None, exempt_from_tracking=False, **api_params):
Expand Down
1 change: 1 addition & 0 deletions src/ell/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ell.providers.openai
import ell.providers.groq
import ell.providers.anthropic
import ell.providers.mock
# import ell.providers.mistral
# import ell.providers.cohere
# import ell.providers.gemini
Expand Down
62 changes: 62 additions & 0 deletions src/ell/providers/mock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import random
import string

from typing import Optional, Dict, Any, List, Type, Tuple
from ell.provider import Provider, EllCallParams, Metadata
from ell.types import Message
from ell.types.message import LMP
from ell.configurator import config, register_provider
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast


class MockAIClient:

def __init__(self, **kwargs):
self.api_key = "mock"

def chat_completions_create(self, **kwargs):
return None


class MockAIProvider(Provider):
dangerous_disable_validation = True

def provider_call_function(
self, client: MockAIClient, api_call_params: Optional[Dict[str, Any]] = None
) -> Callable[..., Any]:
return client.chat_completions_create

def translate_to_provider(self, ell_call: EllCallParams) -> Dict[str, Any]:
return ell_call.api_params.copy()

def default_mock_func(self) -> Tuple[List[Message], Metadata]:
results = []
random_str = "".join(
random.choices(
string.ascii_letters + string.digits, k=random.randint(1, 40)
)
)
results.append(
Message(
role=("user"),
content="mock_" + random_str,
)
)
return results, Metadata

def translate_from_provider(
self,
_provider_response: Any,
_ell_call: EllCallParams,
provider_call_params: Dict[str, Any],
_origin_id: Optional[str] = None,
_logger: Optional[Callable[..., None]] = None,
) -> Tuple[List[Message], Metadata]:
if "mock_func" in provider_call_params:
mock_func = provider_call_params["mock_func"]
return [Message(role=("user"), content=mock_func())], Metadata
else:
return self.default_mock_func()


register_provider(MockAIProvider(), MockAIClient)

0 comments on commit 10c3e2a

Please sign in to comment.