Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support image input in the chat completion request #55

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,098 changes: 952 additions & 1,146 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ packages = [{include = "simple_ai", from = "src"}]
[tool.poetry.dependencies]
python = "^3.9"
fastapi= "^0.94.1"
grpcio= "^1.42.0"
grpcio= "^1.42.0,!=1.65.0"
protobuf= "^4.21.3"
uvicorn= "^0.21.0"
tomli= {version = "^2.0.1", python = "<3.11"}
Expand All @@ -22,7 +22,7 @@ ruff = "^0.0.260"
black = "^23.3.0"
pre-commit = "^3.2.2"
poetry = "^1.4.1"
grpcio-tools = "^1.51.3"
grpcio-tools = "^1.51.3,!=1.65.0"
protobuf = "^4.21.3"

[build-system]
Expand Down
50 changes: 48 additions & 2 deletions src/simple_ai/api/grpc/chat/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,45 @@

import grpc
from google.protobuf.json_format import MessageToDict
from google.protobuf.struct_pb2 import Value, ListValue, Struct

from . import llm_chat_pb2
from . import llm_chat_pb2_grpc


def dict_to_struct(d):
"""Convert a dictionary to Struct."""
fields = {}
for k, v in d.items():
if isinstance(v, dict):
fields[k] = Value(struct_value=dict_to_struct(v))
elif isinstance(v, list):
list_values = []
for item in v:
if isinstance(item, dict):
list_values.append(Value(struct_value=dict_to_struct(item)))
elif isinstance(item, str):
list_values.append(Value(string_value=item))
elif isinstance(item, float) or isinstance(item, int):
list_values.append(Value(number_value=item))
elif isinstance(item, bool):
list_values.append(Value(bool_value=item))
else:
raise ValueError(f"Unsupported type in list: {type(item)} for key: {k}")
fields[k] = Value(list_value=ListValue(values=list_values))
elif isinstance(v, str):
fields[k] = Value(string_value=v)
elif isinstance(v, float) or isinstance(v, int):
fields[k] = Value(number_value=v)
elif isinstance(v, bool):
fields[k] = Value(bool_value=v)
elif v is None:
fields[k] = Value(null_value=0)
else:
raise ValueError(f"Unsupported type: {type(v)} for key: {k}")
return Struct(fields=fields)


def get_chatlog(stub, chatlog):
response = stub.Chat(chatlog)
results = []
Expand Down Expand Up @@ -46,7 +80,13 @@ def run(
logit_bias=str(logit_bias),
)
for role, content in messages:
grpc_chat = llm_chat_pb2.Chat(role=role, content=content)
if isinstance(content, str):
grpc_chat = llm_chat_pb2.Chat(role=role, content=Value(string_value=content))
else:
list_value = ListValue(
values=[Value(struct_value=dict_to_struct(item)) for item in content]
)
grpc_chat = llm_chat_pb2.Chat(role=role, content=Value(list_value=list_value))
grpc_chatlog.messages.append(grpc_chat)
return get_chatlog(stub, grpc_chatlog)

Expand Down Expand Up @@ -88,7 +128,13 @@ def run_stream(
)

for role, content in messages:
grpc_chat = llm_chat_pb2.Chat(role=role, content=content)
if isinstance(content, str):
grpc_chat = llm_chat_pb2.Chat(role=role, content=Value(string_value=content))
else:
list_value = ListValue(
values=[Value(struct_value=dict_to_struct(item)) for item in content]
)
grpc_chat = llm_chat_pb2.Chat(role=role, content=Value(list_value=list_value))
grpc_chatlog.messages.append(grpc_chat)

yield from stream_chatlog(stub, grpc_chatlog)
Expand Down
42 changes: 23 additions & 19 deletions src/simple_ai/api/grpc/chat/llm_chat_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

69 changes: 37 additions & 32 deletions src/simple_ai/api/grpc/chat/llm_chat_pb2.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from google.protobuf import struct_pb2 as _struct_pb2
from google.protobuf.internal import containers as _containers
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
Expand All @@ -11,47 +12,39 @@ from typing import (

DESCRIPTOR: _descriptor.FileDescriptor

class Chat(_message.Message):
__slots__ = ["content", "role"]
CONTENT_FIELD_NUMBER: _ClassVar[int]
ROLE_FIELD_NUMBER: _ClassVar[int]
content: str
role: str
def __init__(self, role: _Optional[str] = ..., content: _Optional[str] = ...) -> None: ...

class ChatLogInput(_message.Message):
__slots__ = [
"frequence_penalty",
"logit_bias",
"max_tokens",
__slots__ = (
"messages",
"n",
"presence_penalty",
"stop",
"stream",
"max_tokens",
"temperature",
"top_p",
]
FREQUENCE_PENALTY_FIELD_NUMBER: _ClassVar[int]
LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int]
MAX_TOKENS_FIELD_NUMBER: _ClassVar[int]
"n",
"stream",
"stop",
"presence_penalty",
"frequence_penalty",
"logit_bias",
)
MESSAGES_FIELD_NUMBER: _ClassVar[int]
N_FIELD_NUMBER: _ClassVar[int]
PRESENCE_PENALTY_FIELD_NUMBER: _ClassVar[int]
STOP_FIELD_NUMBER: _ClassVar[int]
STREAM_FIELD_NUMBER: _ClassVar[int]
MAX_TOKENS_FIELD_NUMBER: _ClassVar[int]
TEMPERATURE_FIELD_NUMBER: _ClassVar[int]
TOP_P_FIELD_NUMBER: _ClassVar[int]
frequence_penalty: float
logit_bias: str
max_tokens: int
N_FIELD_NUMBER: _ClassVar[int]
STREAM_FIELD_NUMBER: _ClassVar[int]
STOP_FIELD_NUMBER: _ClassVar[int]
PRESENCE_PENALTY_FIELD_NUMBER: _ClassVar[int]
FREQUENCE_PENALTY_FIELD_NUMBER: _ClassVar[int]
LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int]
messages: _containers.RepeatedCompositeFieldContainer[Chat]
n: int
presence_penalty: float
stop: str
stream: bool
max_tokens: int
temperature: float
top_p: float
n: int
stream: bool
stop: str
presence_penalty: float
frequence_penalty: float
logit_bias: str
def __init__(
self,
messages: _Optional[_Iterable[_Union[Chat, _Mapping]]] = ...,
Expand All @@ -67,7 +60,19 @@ class ChatLogInput(_message.Message):
) -> None: ...

class ChatLogOutput(_message.Message):
__slots__ = ["messages"]
__slots__ = ("messages",)
MESSAGES_FIELD_NUMBER: _ClassVar[int]
messages: _containers.RepeatedCompositeFieldContainer[Chat]
def __init__(self, messages: _Optional[_Iterable[_Union[Chat, _Mapping]]] = ...) -> None: ...

class Chat(_message.Message):
__slots__ = ("role", "content")
ROLE_FIELD_NUMBER: _ClassVar[int]
CONTENT_FIELD_NUMBER: _ClassVar[int]
role: str
content: _struct_pb2.Value
def __init__(
self,
role: _Optional[str] = ...,
content: _Optional[_Union[_struct_pb2.Value, _Mapping]] = ...,
) -> None: ...
46 changes: 37 additions & 9 deletions src/simple_ai/api/grpc/chat/llm_chat_pb2_grpc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,34 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings

from . import llm_chat_pb2 as llm__chat__pb2

GRPC_GENERATED_VERSION = "1.64.1"
GRPC_VERSION = grpc.__version__
EXPECTED_ERROR_RELEASE = "1.65.0"
SCHEDULED_RELEASE_DATE = "June 25, 2024"
_version_not_supported = False

try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True

if _version_not_supported:
warnings.warn(
f"The grpc package installed is at version {GRPC_VERSION},"
+ f" but the generated code in llm_chat_pb2_grpc.py depends on"
+ f" grpcio>={GRPC_GENERATED_VERSION}."
+ f" Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}"
+ f" or downgrade your generated code using grpcio-tools<={GRPC_VERSION}."
+ f" This warning will become an error in {EXPECTED_ERROR_RELEASE},"
+ f" scheduled for release on {SCHEDULED_RELEASE_DATE}.",
RuntimeWarning,
)


class LanguageModelStub(object):
"""Interface exported by the server."""
Expand All @@ -18,16 +43,19 @@ def __init__(self, channel):
"/languagemodelchat.LanguageModel/Chat",
request_serializer=llm__chat__pb2.ChatLogInput.SerializeToString,
response_deserializer=llm__chat__pb2.ChatLogOutput.FromString,
_registered_method=True
)
self.Stream = channel.unary_stream(
"/languagemodelchat.LanguageModel/Stream",
request_serializer=llm__chat__pb2.ChatLogInput.SerializeToString,
response_deserializer=llm__chat__pb2.ChatLogOutput.FromString,
_registered_method=True
)


class LanguageModelServicer(object):
"""Interface exported by the server."""
"""Interface exported by the server.
"""

def Chat(self, request, context):
"""Simple RPC"""
Expand Down Expand Up @@ -56,18 +84,17 @@ def add_LanguageModelServicer_to_server(servicer, server):
),
}
generic_handler = grpc.method_handlers_generic_handler(
"languagemodelchat.LanguageModel", rpc_method_handlers
)
"languagemodelchat.LanguageModel", rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers("languagemodelchat.LanguageModel", rpc_method_handlers)


# This class is part of an EXPERIMENTAL API.
class LanguageModel(object):
"""Interface exported by the server."""

@staticmethod
def Chat(
request,
def Chat(request,
target,
options=(),
channel_credentials=None,
Expand All @@ -76,7 +103,7 @@ def Chat(
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
metadata=None
):
return grpc.experimental.unary_unary(
request,
Expand All @@ -92,11 +119,11 @@ def Chat(
wait_for_ready,
timeout,
metadata,
_registered_method=True
)

@staticmethod
def Stream(
request,
def Stream(request,
target,
options=(),
channel_credentials=None,
Expand All @@ -105,7 +132,7 @@ def Stream(
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
metadata=None
):
return grpc.experimental.unary_stream(
request,
Expand All @@ -121,4 +148,5 @@ def Stream(
wait_for_ready,
timeout,
metadata,
_registered_method=True
)
Loading