Skip to content

Commit

Permalink
Solve pydantic errors after update
Browse files Browse the repository at this point in the history
  • Loading branch information
kaancayli committed Nov 12, 2024
1 parent 5d81124 commit ec52f25
Show file tree
Hide file tree
Showing 10 changed files with 52 additions and 51 deletions.
20 changes: 11 additions & 9 deletions app/llm/external/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Union

from ...llm.external.model import LanguageModel
from ...llm.external.openai_completion import (
DirectOpenAICompletionModel,
Expand All @@ -10,12 +12,12 @@
)
from ...llm.external.ollama import OllamaModel

type AnyLLM = (
DirectOpenAICompletionModel
| AzureOpenAICompletionModel
| DirectOpenAIChatModel
| AzureOpenAIChatModel
| DirectOpenAIEmbeddingModel
| AzureOpenAIEmbeddingModel
| OllamaModel
)
AnyLLM = Union[
DirectOpenAICompletionModel,
AzureOpenAICompletionModel,
DirectOpenAIChatModel,
AzureOpenAIChatModel,
DirectOpenAIEmbeddingModel,
AzureOpenAIEmbeddingModel,
OllamaModel,
]
5 changes: 1 addition & 4 deletions app/llm/external/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from langchain_core.tools import BaseTool
from openai.types.chat import ChatCompletionMessage
from pydantic import BaseModel
from pydantic.v1 import BaseModel as LegacyBaseModel

from ...common.pyris_message import PyrisMessage
from ...llm import CompletionArguments
Expand Down Expand Up @@ -53,9 +52,7 @@ def chat(
@abstractmethod
def bind_tools(
self,
tools: Sequence[
Union[Dict[str, Any], Type[LegacyBaseModel], Callable, BaseTool]
],
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
):
"""Bind tools"""
raise NotImplementedError(
Expand Down
7 changes: 2 additions & 5 deletions app/llm/external/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
from langchain_core.messages import BaseMessage
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from pydantic import Field
from pydantic import Field, BaseModel

from ollama import Client, Message
from pydantic.v1 import BaseModel as LegacyBaseModel

from ...common.message_converters import map_role_to_str, map_str_to_role
from ...common.pyris_message import PyrisMessage
Expand Down Expand Up @@ -155,9 +154,7 @@ def embed(self, text: str) -> list[float]:
# TODO: Implement tool binding support for Ollama models
def bind_tools(
self,
tools: Sequence[
Union[Dict[str, Any], Type[LegacyBaseModel], Callable, BaseTool]
],
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
) -> Runnable[LanguageModelInput, BaseMessage]:
raise NotImplementedError(
f"The LLM {self.__str__()} does not support binding tools"
Expand Down
9 changes: 3 additions & 6 deletions app/llm/external/openai_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletionMessage, ChatCompletionMessageParam
from openai.types.shared_params import ResponseFormatJSONObject
from pydantic import Field
from pydantic.v1 import BaseModel as LegacyBaseModel
from pydantic import Field, BaseModel

from app.domain.data.text_message_content_dto import TextMessageContentDTO
from ...common.message_converters import map_role_to_str, map_str_to_role
Expand Down Expand Up @@ -202,7 +201,7 @@ class OpenAIChatModel(ChatModel):
model: str
api_key: str
tools: Optional[
Sequence[Union[Dict[str, Any], Type[LegacyBaseModel], Callable, BaseTool]]
Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]]
] = Field(default_factory=list, alias="tools")

def chat(
Expand Down Expand Up @@ -258,9 +257,7 @@ def chat(

def bind_tools(
self,
tools: Sequence[
Union[Dict[str, Any], Type[LegacyBaseModel], Callable, BaseTool]
],
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
):
self.tools = [convert_to_openai_tool(tool) for tool in tools]

Expand Down
3 changes: 2 additions & 1 deletion app/llm/langchain/iris_langchain_chat_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from logging import Logger
from typing import List, Optional, Any, Sequence, Union, Dict, Type, Callable

from langchain_core.callbacks import CallbackManagerForLLMRun
Expand Down Expand Up @@ -28,7 +29,7 @@ class IrisLangchainChatModel(BaseChatModel):
request_handler: RequestHandler
completion_args: CompletionArguments
tokens: TokenUsageDTO = None
logger = logging.getLogger(__name__)
logger: Logger = logging.getLogger(__name__)

def __init__(
self,
Expand Down
5 changes: 3 additions & 2 deletions app/llm/llm_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import Annotated

from pydantic import BaseModel, Field
from pydantic import BaseModel, Discriminator

import yaml

Expand All @@ -15,7 +16,7 @@

# Small workaround to get pydantic discriminators working
class LlmList(BaseModel):
llms: list[AnyLLM] = Field(discriminator="type")
llms: list[Annotated[AnyLLM, Discriminator("type")]]


class LlmManager(metaclass=Singleton):
Expand Down
7 changes: 5 additions & 2 deletions app/llm/request_handler/basic_request_handler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Optional, Sequence, Union, Dict, Any, Type, Callable

from langchain_core.tools import BaseTool
from pydantic.v1 import BaseModel
from pydantic import ConfigDict
from pydantic import BaseModel

from app.common.pyris_message import PyrisMessage
from app.domain.data.image_message_content_dto import ImageMessageContentDTO
Expand All @@ -13,9 +14,11 @@

class BasicRequestHandler(RequestHandler):
model_id: str
llm_manager: LlmManager
llm_manager: LlmManager | None = None
model_config = ConfigDict(arbitrary_types_allowed=True)

def __init__(self, model_id: str):
super().__init__(model_id=model_id, llm_manager=None)
self.model_id = model_id
self.llm_manager = LlmManager()

Expand Down
9 changes: 7 additions & 2 deletions app/llm/request_handler/capability_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import Sequence, Union, Dict, Any, Type, Callable

from langchain_core.tools import BaseTool
from pydantic.v1 import BaseModel
from pydantic import ConfigDict
from pydantic import BaseModel

from app.common.pyris_message import PyrisMessage
from app.llm.capability import RequirementList
Expand All @@ -29,13 +30,17 @@ class CapabilityRequestHandler(RequestHandler):

requirements: RequirementList
selection_mode: CapabilityRequestHandlerSelectionMode
llm_manager: LlmManager
llm_manager: LlmManager | None = None
model_config = ConfigDict(arbitrary_types_allowed=True)

def __init__(
self,
requirements: RequirementList,
selection_mode: CapabilityRequestHandlerSelectionMode = CapabilityRequestHandlerSelectionMode.WORST,
) -> None:
super().__init__(
requirements=requirements, selection_mode=selection_mode, llm_manager=None
)
self.requirements = requirements
self.selection_mode = selection_mode
self.llm_manager = LlmManager()
Expand Down
4 changes: 2 additions & 2 deletions app/llm/request_handler/request_handler_interface.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from abc import ABCMeta, abstractmethod
from typing import Optional, Sequence, Union, Dict, Any, Type, Callable
from langchain_core.tools import BaseTool
from pydantic.v1 import BaseModel
from pydantic import BaseModel

from .. import LanguageModel
from ...common.pyris_message import PyrisMessage
from ...domain.data.image_message_content_dto import ImageMessageContentDTO
from ...llm import CompletionArguments


class RequestHandler(metaclass=ABCMeta):
class RequestHandler(BaseModel, metaclass=ABCMeta):
"""Interface for the request handlers"""

@classmethod
Expand Down
34 changes: 16 additions & 18 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
black==24.8.0
fastapi==0.112.2
black==24.10.0
fastapi==0.115.5
flake8==7.1.1
langchain==0.2.14
ollama==0.3.1
openai==1.42.0
pre-commit==3.8.0
psutil==6.0.0
pydantic==2.8.2
PyMuPDF==1.24.9
langchain==0.3.7
ollama==0.3.3
openai==1.54.4
pre-commit==4.0.1
psutil==6.1.0
pydantic==2.9.2
PyMuPDF==1.24.13
pytz==2024.1
PyYAML==6.0.2
requests==2.32.3
sentry-sdk[starlette,fastapi,openai]==2.13.0
unstructured==0.15.7
uvicorn==0.30.6
weaviate-client==4.7.1
langchain_openai==0.1.19
starlette~=0.37.2
langsmith~=0.1.75
langgraph~=0.1.17
langchain-core~=0.2.41
langchain-text-splitters~=0.2.1
unstructured==0.16.5
uvicorn==0.32.0
weaviate-client==4.9.3
langchain-core~=0.3.17
starlette~=0.41.2
langsmith~=0.1.142
langchain-text-splitters~=0.3.2

0 comments on commit ec52f25

Please sign in to comment.