Skip to content

Commit

Permalink
fix config.get_provider_for for provider inherit issue
Browse files Browse the repository at this point in the history
  • Loading branch information
lenage committed Oct 9, 2024
1 parent 4509ceb commit 8fbe643
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 11 deletions.
34 changes: 24 additions & 10 deletions src/ell/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ class _Model:
name: str
default_client: Optional[Union[openai.Client, Any]] = None
#XXX: Deprecation in 0.1.0
#XXX: We will depreciate this when streaming is implemented.
#XXX: We will depreciate this when streaming is implemented.
# Currently we stream by default for the verbose renderer,
# but in the future we will not support streaming by default
# but in the future we will not support streaming by default
# and stream=True must be passed which will then make API providers the
# single source of truth for whether or not a model supports an api parameter.
# This makes our implementation extremely light, only requiring us to provide
Expand All @@ -44,9 +44,9 @@ def __init__(self, **data):
self._lock = threading.Lock()
self._local = threading.local()


def register_model(
self,
self,
name: str,
default_client: Optional[Union[openai.Client, Any]] = None,
supports_streaming: Optional[bool] = None
Expand Down Expand Up @@ -74,12 +74,12 @@ def model_registry_override(self, overrides: Dict[str, _Model]):
"""
if not hasattr(self._local, 'stack'):
self._local.stack = []

with self._lock:
current_registry = self._local.stack[-1] if self._local.stack else self.registry
new_registry = current_registry.copy()
new_registry.update(overrides)

self._local.stack.append(new_registry)
try:
yield
Expand Down Expand Up @@ -133,11 +133,25 @@ def get_provider_for(self, client: Union[Type[Any], Any]) -> Optional[Provider]:
"""

client_type = type(client) if not isinstance(client, type) else client
for provider_type, provider in self.providers.items():
if issubclass(client_type, provider_type) or client_type == provider_type:
return provider
# First, try to find an exact match
if client_type in self.providers:
return self.providers[client_type]

# If no exact match, look for the most specific subclass
matching_providers = [
(provider_type, provider)
for provider_type, provider in self.providers.items()
if issubclass(client_type, provider_type)
]

if matching_providers:
# Sort by inheritance depth (most derived class first)
matching_providers.sort(key=lambda x: len(x[0].mro()), reverse=True)
return matching_providers[0][1]

return None


# Single* instance
# XXX: Make a singleton
config = Config()
Expand Down Expand Up @@ -187,7 +201,7 @@ def init(
def get_store() -> Union[Store, None]:
return config.store

# Will be deprecated at 0.1.0
# Will be deprecated at 0.1.0

# You can add more helper functions here if needed
def register_provider(provider: Provider, client_type: Type[Any]) -> None:
Expand Down
3 changes: 2 additions & 1 deletion src/ell/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

logger = logging.getLogger(__name__)


def register(client: openai.Client):
"""
Register OpenAI models with the provided client.
Expand Down Expand Up @@ -92,4 +93,4 @@ def register(client: openai.Client):
pass

register(default_client)
config.default_client = default_client
config.default_client = default_client

0 comments on commit 8fbe643

Please sign in to comment.