diff --git a/src/ell/providers/groq.py b/src/ell/providers/groq.py index 1a824685..52d9a004 100644 --- a/src/ell/providers/groq.py +++ b/src/ell/providers/groq.py @@ -4,8 +4,11 @@ import json import contextvars +from typing import List from pydantic import BaseModel +import ell +from ell.provider import EllCallParams from ell.providers.openai import OpenAIProvider from ell.configurator import register_provider from ell.types.message import ContentBlock, Message @@ -18,22 +21,22 @@ import groq class GroqProvider(OpenAIProvider): dangerous_disable_validation = True - def translate_to_provider(self, *args, **kwargs): - params = super().translate_to_provider(*args, **kwargs) - params.pop('stream_options', None) - params['messages'] = messages_to_groq_message_format(params['messages']) - + def translate_to_provider(self, ell_call: EllCallParams): # assert 'response_format' not in params, 'Groq does not support response_format.' # Store the response_format model between to_provider and from_provider - response_format = params.get('response_format') + response_format = ell_call.api_params.get('response_format') store_response_format.set(response_format) if isinstance(response_format, type) and issubclass(response_format, BaseModel): - # Groq beta JSON response does not support streaming or stop tokens - params.pop('stream', None) - params.pop('stop', None) - params['response_format'] = {'type': 'json_object'} # Groq suggests explain how to respond with JSON in system prompt - params['messages'] = add_json_schema_to_system_prompt(response_format, params['messages']) + ell_call.messages = add_json_schema_to_system_prompt(response_format, ell_call.messages) + # Groq beta JSON response does not support streaming or stop tokens + ell_call.api_params.pop('stream', None) + ell_call.api_params.pop('stop', None) + ell_call.api_params['response_format'] = {'type': 'json_object'} + + params = super().translate_to_provider(ell_call) + params.pop('stream_options', None) + params['messages'] = messages_to_groq_message_format(params['messages']) return params @@ -56,15 +59,15 @@ def translate_from_provider(self, *args, **kwargs): except ImportError: pass -def add_json_schema_to_system_prompt(response_format: BaseModel, messages): +def add_json_schema_to_system_prompt(response_format: BaseModel, messages: List[Message]) -> List[Message]: json_prompt = f'\n\nYou must respond with a JSON object compliant with following JSON schema:\n{json.dumps(response_format.model_json_schema(), indent=4)}' - system_prompt = next(filter(lambda m: m['role'] == 'system', messages), None) + system_prompt = next(filter(lambda m: m.role == 'system', messages), None) if system_prompt is None: - messages = [{'role': 'system', 'content': json_prompt}] + messages + messages = [ell.system(content=json_prompt)] + messages else: - system_prompt['content'] += json_prompt + system_prompt.content.append(ContentBlock(text=json_prompt)) return messages @@ -74,8 +77,8 @@ def messages_to_groq_message_format(messages): # XXX: Issue #289: groq.BadRequestError: Error code: 400 - {'error': {'message': "'messages.1' : for 'role:assistant' the following must be satisfied[('messages.1.content' : value must be a string)]", 'type': 'invalid_request_error'}} new_messages = [] for message in messages: - if message['role'] == 'assistant': - # Assistant messages must be strings + # Assistant messages must be strings or tool calls + if message['role'] == 'assistant' and 'tool_calls' not in message: # If content is a list, only one string element is allowed if isinstance(message['content'], str): new_messages.append({'role': 'assistant', 'content': message['content']})