Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: improve Ollama provider implementation #593

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 201 additions & 78 deletions agentops/llms/providers/ollama.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,30 @@
import inspect
import sys
from typing import Optional
import json
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from dataclasses import dataclass
import asyncio
import httpx

from agentops.event import LLMEvent
from agentops.event import LLMEvent, ErrorEvent
from agentops.session import Session
from agentops.helpers import get_ISO_time, check_call_stack_for_agent_id
from .instrumented_provider import InstrumentedProvider
from agentops.singleton import singleton


@dataclass
class Choice:
message: dict = None
delta: dict = None
finish_reason: str = None
index: int = 0


@dataclass
class ChatResponse:
model: str
choices: list[Choice]


original_func = {}


Expand All @@ -16,111 +33,217 @@ class OllamaProvider(InstrumentedProvider):
original_create = None
original_create_async = None

def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None) -> dict:
llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs)
if session is not None:
llm_event.session_id = session.session_id

def handle_stream_chunk(chunk: dict):
message = chunk.get("message", {"role": None, "content": ""})

if chunk.get("done"):
llm_event.end_timestamp = get_ISO_time()
llm_event.model = f'ollama/{chunk.get("model")}'
llm_event.returns = chunk
llm_event.returns["message"] = llm_event.completion
llm_event.prompt = kwargs["messages"]
llm_event.agent_id = check_call_stack_for_agent_id()
self._safe_record(session, llm_event)

if llm_event.completion is None:
llm_event.completion = {
"role": message.get("role"),
"content": message.get("content", ""),
"tool_calls": None,
"function_call": None,
}
else:
llm_event.completion["content"] += message.get("content", "")

if inspect.isgenerator(response):

def generator():
for chunk in response:
handle_stream_chunk(chunk)
yield chunk

return generator()

llm_event.end_timestamp = get_ISO_time()
llm_event.model = f'ollama/{response["model"]}'
llm_event.returns = response
llm_event.agent_id = check_call_stack_for_agent_id()
llm_event.prompt = kwargs["messages"]
llm_event.completion = {
"role": response["message"].get("role"),
"content": response["message"].get("content", ""),
"tool_calls": None,
"function_call": None,
def handle_response(self, response_data, request_data, init_timestamp, session=None):
"""Handle the response from the Ollama API."""
end_timestamp = get_ISO_time()
model = request_data.get("model", "unknown")

# Extract error if present
error = None
if isinstance(response_data, dict) and "error" in response_data:
error = response_data["error"]

# Create event data
event_data = {
"model": model, # Use the raw model name from request
"params": request_data,
"returns": response_data, # Include full response data
"init_timestamp": init_timestamp,
"end_timestamp": end_timestamp,
"prompt": request_data.get("messages", []),
"prompt_tokens": None, # Ollama doesn't provide token counts
"completion_tokens": None,
"cost": None, # Ollama is free/local
}
self._safe_record(session, llm_event)
return response

if error:
event_data["completion"] = error
else:
# Extract completion from response
if isinstance(response_data, dict):
message = response_data.get("message", {})
if isinstance(message, dict):
content = message.get("content", "")
event_data["completion"] = content

# Create and emit LLM event
if session:
event = LLMEvent(**event_data)
session.record(event)

return event_data

def override(self):
"""Override Ollama methods with instrumented versions."""
self._override_chat_client()
self._override_chat()
self._override_chat_async_client()

def undo_override(self):
if original_func is not None and original_func != {}:
import ollama

ollama.chat = original_func["ollama.chat"]
ollama.Client.chat = original_func["ollama.Client.chat"]
ollama.AsyncClient.chat = original_func["ollama.AsyncClient.chat"]
import ollama

def __init__(self, client):
super().__init__(client)
if hasattr(self, "_original_chat"):
ollama.chat = self._original_chat
if hasattr(self, "_original_client_chat"):
ollama.Client.chat = self._original_client_chat
if hasattr(self, "_original_async_chat"):
ollama.AsyncClient.chat = self._original_async_chat

def __init__(self, http_client=None, client=None, model=None):
"""Initialize the Ollama provider."""
super().__init__(client=client)
self.base_url = "http://localhost:11434" # Ollama runs locally by default
self.timeout = 60.0 # Default timeout in seconds
self.model = model # Store default model

# Initialize HTTP client if not provided
if http_client is None:
self.http_client = httpx.AsyncClient(timeout=self.timeout)
else:
self.http_client = http_client

# Store original methods for restoration
self._original_chat = None
self._original_chat_client = None
self._original_chat_async_client = None

def _override_chat(self):
import ollama

original_func["ollama.chat"] = ollama.chat
self._original_chat = ollama.chat

def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()
result = original_func["ollama.chat"](*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=kwargs.get("session", None))
session = kwargs.pop("session", None)
result = self._original_chat(*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)

# Override the original method with the patched one
ollama.chat = patched_function

Comment on lines 109 to 121

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Bug Fix:

Ensure Proper Session Management in Patched Function
The recent change in the _override_chat method involves popping the session from kwargs, which can lead to potential issues if the session is needed later in the function or elsewhere in the system. This could result in unexpected behavior or bugs, especially if other parts of the code rely on the presence of session in kwargs.

Recommendations:

  • Review the necessity of popping the session: Ensure that removing the session from kwargs does not affect other parts of the function or system.
  • Consider alternative approaches: If the session is required elsewhere, consider passing it explicitly or using a different mechanism to manage it.
  • Test thoroughly: Ensure that the changes do not break existing functionality by running comprehensive tests.

By addressing these points, you can maintain the integrity of the function and prevent potential issues related to session management.

🔧 Suggested Code Diff:

 def _override_chat(self):
     import ollama
 
     self._original_chat = ollama.chat
 
     def patched_function(*args, **kwargs):
         init_timestamp = get_ISO_time()
-        session = kwargs.pop("session", None)
+        session = kwargs.get("session", None)
         result = self._original_chat(*args, **kwargs)
         return self.handle_response(result, kwargs, init_timestamp, session=session)
 
     ollama.chat = patched_function
📝 Committable Code Suggestion

‼️ Ensure you review the code suggestion before committing it to the branch. Make sure it replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
def _override_chat(self):
import ollama
original_func["ollama.chat"] = ollama.chat
self._original_chat = ollama.chat
def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()
result = original_func["ollama.chat"](*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=kwargs.get("session", None))
session = kwargs.pop("session", None)
result = self._original_chat(*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)
# Override the original method with the patched one
ollama.chat = patched_function
def _override_chat(self):
import ollama
self._original_chat = ollama.chat
def patched_function(*args, **kwargs):
init_timestamp = get_ISO_time()
session = kwargs.get("session", None) # Use get instead of pop to avoid removing session
result = self._original_chat(*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)
ollama.chat = patched_function
📜 Guidelines

• Use meaningful variable and function names following specific naming conventions
• Use exceptions for error handling, but avoid assert statements for critical logic


Comment on lines 109 to 121

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential Issue:

Ensure Safe Overriding of Core Functionality
The current change introduces a patched version of the ollama.chat function, which can lead to logical errors if not handled correctly. Overriding core functions can have significant implications, especially if the new implementation does not fully replicate the original behavior or introduces side effects.

Recommendations:

  • Replicate Original Behavior: Ensure that the patched function covers all scenarios handled by the original ollama.chat function. This includes edge cases and error handling.
  • Comprehensive Testing: Develop a suite of tests to verify the behavior of the patched function across various scenarios. This will help catch any discrepancies early.
  • Consider Dependency Injection: Instead of directly overriding the function, consider using dependency injection. This approach can help isolate changes and reduce the risk of unintended side effects.

By following these steps, you can mitigate the risks associated with overriding core functionality and ensure the system remains stable and reliable. 🛠️

🔧 Suggested Code Diff:

+    def _override_chat(self):
+        import ollama
+
+        self._original_chat = ollama.chat
+
+        def patched_function(*args, **kwargs):
+            init_timestamp = get_ISO_time()
+            session = kwargs.pop("session", None)
+            result = self._original_chat(*args, **kwargs)
+            return self.handle_response(result, kwargs, init_timestamp, session=session)
+
+        ollama.chat = patched_function
📝 Committable Code Suggestion

‼️ Ensure you review the code suggestion before committing it to the branch. Make sure it replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
def _override_chat(self):
import ollama
original_func["ollama.chat"] = ollama.chat
self._original_chat = ollama.chat
def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()
result = original_func["ollama.chat"](*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=kwargs.get("session", None))
session = kwargs.pop("session", None)
result = self._original_chat(*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)
# Override the original method with the patched one
ollama.chat = patched_function
def _override_chat(self):
import ollama
self._original_chat = ollama.chat
def patched_function(*args, **kwargs):
init_timestamp = get_ISO_time()
session = kwargs.pop("session", None)
try:
result = self._original_chat(*args, **kwargs)
except Exception as e:
# Handle exceptions that may occur during the chat operation
self.log_error(f"Error in chat operation: {str(e)}")
raise
return self.handle_response(result, kwargs, init_timestamp, session=session)
ollama.chat = patched_function
# Additional unit tests should be written to ensure the patched function behaves as expected
# across various scenarios, including edge cases and error conditions.
📜 Guidelines

• Write unit tests for your code
• Use meaningful variable and function names following specific naming conventions


def _override_chat_client(self):
from ollama import Client

original_func["ollama.Client.chat"] = Client.chat
self._original_client_chat = Client.chat

def patched_function(*args, **kwargs):
# Call the original function with its original arguments
def patched_function(self_client, *args, **kwargs):
init_timestamp = get_ISO_time()
result = original_func["ollama.Client.chat"](*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=kwargs.get("session", None))
session = kwargs.pop("session", None)
result = self._original_client_chat(self_client, *args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)

# Override the original method with the patched one
Client.chat = patched_function

Comment on lines 122 to 134

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential Issue:

Ensure Robustness in Overriding Core Methods
The current change involves overriding the Client.chat method with a patched function. This approach can introduce unexpected behavior if the patched function does not fully replicate the original method's behavior. This is particularly critical as it can affect all parts of the system relying on this method, potentially leading to incorrect session handling or response processing.

Recommendations:

  • Replicate Original Behavior: Ensure that the patched function replicates all necessary behavior of the original Client.chat method. This includes handling all parameters and return values correctly.
  • Comprehensive Testing: Add comprehensive tests to verify that the new implementation handles all expected scenarios correctly, including edge cases. This will help in identifying any discrepancies introduced by the override.
  • Consider Alternative Approaches: Consider using a more controlled approach to modify behavior, such as subclassing or using dependency injection. This can minimize the risk of unintended side effects and improve maintainability.
  • Documentation: Clearly document the changes and the rationale behind overriding the method to aid future developers in understanding the modifications.

By following these steps, you can ensure that the system remains robust and reliable. 🛠️

📜 Guidelines

• Write unit tests for your code
• Use meaningful variable and function names following specific naming conventions


def _override_chat_async_client(self):
from ollama import AsyncClient

original_func = {}
original_func["ollama.AsyncClient.chat"] = AsyncClient.chat
self._original_async_chat = AsyncClient.chat

async def patched_function(*args, **kwargs):
# Call the original function with its original arguments
async def patched_function(self_client, *args, **kwargs):
init_timestamp = get_ISO_time()
result = await original_func["ollama.AsyncClient.chat"](*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=kwargs.get("session", None))
session = kwargs.pop("session", None)
result = await self._original_async_chat(self_client, *args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)

# Override the original method with the patched one
AsyncClient.chat = patched_function

Comment on lines 135 to 147

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Bug Fix:

Ensure Correct Overriding of Async Chat Function
The recent changes to the _override_chat_async_client method involve altering how the original async chat function is overridden. This change is critical as it can introduce logical errors if not handled correctly.

Key Points:

  • Session Handling: Ensure that the session parameter is consistently managed. The new implementation uses kwargs.pop("session", None), which is a good practice to avoid passing unexpected parameters to the original function.
  • Functionality Replication: Verify that the new method replicates the original functionality accurately. The patched function should behave identically to the original in terms of input and output, except for the additional instrumentation.

Recommendations:

  • Testing: Conduct thorough testing to ensure that the patched function behaves as expected in all scenarios, especially focusing on session management and response handling.
  • Documentation: Update any relevant documentation to reflect the changes in the method of overriding.

By following these steps, you can ensure that the changes do not introduce any regressions or unexpected behavior. 🛠️

🔧 Suggested Code Diff:

 def _override_chat_async_client(self):
     from ollama import AsyncClient

-    original_func = {}
-    original_func["ollama.AsyncClient.chat"] = AsyncClient.chat
+    self._original_async_chat = AsyncClient.chat

-    async def patched_function(*args, **kwargs):
-        # Call the original function with its original arguments
+    async def patched_function(self_client, *args, **kwargs):
         init_timestamp = get_ISO_time()
-        result = await original_func["ollama.AsyncClient.chat"](*args, **kwargs)
-        return self.handle_response(result, kwargs, init_timestamp, session=kwargs.get("session", None))
+        session = kwargs.pop("session", None)
+        result = await self._original_async_chat(self_client, *args, **kwargs)
+        return self.handle_response(result, kwargs, init_timestamp, session=session)

     # Override the original method with the patched one
     AsyncClient.chat = patched_function
📝 Committable Code Suggestion

‼️ Ensure you review the code suggestion before committing it to the branch. Make sure it replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
def _override_chat_async_client(self):
from ollama import AsyncClient
original_func = {}
original_func["ollama.AsyncClient.chat"] = AsyncClient.chat
self._original_async_chat = AsyncClient.chat
async def patched_function(*args, **kwargs):
# Call the original function with its original arguments
async def patched_function(self_client, *args, **kwargs):
init_timestamp = get_ISO_time()
result = await original_func["ollama.AsyncClient.chat"](*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=kwargs.get("session", None))
session = kwargs.pop("session", None)
result = await self._original_async_chat(self_client, *args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)
# Override the original method with the patched one
AsyncClient.chat = patched_function
def _override_chat_async_client(self):
from ollama import AsyncClient
self._original_async_chat = AsyncClient.chat
async def patched_function(self_client, *args, **kwargs):
init_timestamp = get_ISO_time()
session = kwargs.pop("session", None)
result = await self._original_async_chat(self_client, *args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)
# Override the original method with the patched one
AsyncClient.chat = patched_function
📜 Guidelines

• Use type annotations to improve code clarity and maintainability
• Follow PEP 8 style guide for consistent code formatting


async def chat_completion(
self,
messages: List[Dict[str, str]],
model: Optional[str] = None,
stream: bool = False,
session: Optional[Session] = None,
**kwargs,
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
"""Send a chat completion request to the Ollama API."""
model = model or self.model
init_timestamp = get_ISO_time()

request_data = {
"model": model,
"messages": messages,
"stream": stream,
**kwargs,
}

try:
response = await self.http_client.post(
f"{self.base_url}/api/chat",
json=request_data,
)
response_data = await response.json()

# Check for error response
if "error" in response_data:
error_message = response_data["error"]
# Format error message consistently for model not found errors
if "not found" in error_message.lower():
error_message = f'model "{model}" not found'

# Record error event
if session:
error_event = ErrorEvent(details=error_message, error_type="ModelError")
session.record(error_event)
raise Exception(error_message)

if stream:
return self.stream_generator(response, request_data, init_timestamp, session)

# Record event for non-streaming response
self.handle_response(response_data, request_data, init_timestamp, session)

return ChatResponse(model=model, choices=[Choice(message=response_data["message"], finish_reason="stop")])

except Exception as e:
error_msg = str(e)
# Format error message consistently for model not found errors
if "not found" in error_msg.lower() and 'model "' not in error_msg:
error_msg = f'model "{model}" not found'

# Create error event
error_event = ErrorEvent(details=error_msg, error_type="ModelError")
if session:
session.record(error_event)
raise Exception(error_msg)

async def stream_generator(
self,
response: Any,
request_data: dict,
init_timestamp: str,
session: Optional[Session] = None,
) -> AsyncGenerator[ChatResponse, None]:
"""Generate streaming responses from the Ollama API."""
try:
current_content = ""
async for line in response.aiter_lines():
if not line:
continue

chunk = json.loads(line)
content = chunk.get("message", {}).get("content", "")
current_content += content

if chunk.get("done", False):
# Record the final event with complete response
event_data = {
"model": request_data.get("model", "unknown"), # Use raw model name
"params": request_data,
"returns": chunk,
"prompt": request_data.get("messages", []),
"completion": current_content,
"prompt_tokens": None,
"completion_tokens": None,
"cost": None,
}
if session:
session.record(LLMEvent(**event_data))

yield ChatResponse(
model=request_data.get("model", "unknown"), # Add model parameter
choices=[Choice(message={"role": "assistant", "content": content}, finish_reason=None)],
)
except Exception as e:
# Create error event with correct model information
error_event = ErrorEvent(details=str(e), error_type="ModelError")
session.record(error_event)
raise
1 change: 1 addition & 0 deletions tests/integration/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Integration tests for AgentOps providers and frameworks."""
46 changes: 46 additions & 0 deletions tests/integration/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import pytest
import asyncio
import json
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4

from agentops.session import Session
from agentops.client import Client
from agentops.event import LLMEvent


class BaseProviderTest:
"""Base class for provider tests."""

async def async_setup_method(self, method):
"""Set up test method."""
# Initialize mock client and session
self.mock_req = AsyncMock()
client = Client()
client.configure(api_key="test-key")
config = client._config
config.http_client = self.mock_req # Set mock on config instead of session
self.session = Session(session_id=uuid4(), config=config)

async def teardown_method(self, method):
"""Clean up after test."""
if hasattr(self, "provider"):
self.provider.undo_override()

async def async_verify_events(self, session, expected_count=1):
"""Verify events were recorded."""
await asyncio.sleep(0.1) # Allow time for async event processing
create_events_requests = [req for req in self.mock_req.request_history if req.url.endswith("/v2/create_events")]
assert len(create_events_requests) >= 1, "No events were recorded"
request_body = json.loads(create_events_requests[-1].body.decode("utf-8"))
assert "session_id" in request_body, "Session ID not found in request"

async def async_verify_llm_event(self, mock_req, model=None):
"""Verify LLM event was recorded."""
await asyncio.sleep(0.1) # Allow time for async event processing
create_events_requests = [req for req in mock_req.request_history if req.url.endswith("/v2/create_events")]
assert len(create_events_requests) >= 1, "No events were recorded"
request_body = json.loads(create_events_requests[-1].body.decode("utf-8"))
assert "event_type" in request_body and request_body["event_type"] == "llms", "LLM event not found"
if model:
assert "model" in request_body and request_body["model"] == model, f"Model {model} not found in event"
Loading
Loading