Skip to content

Commit

Permalink
Merge pull request #92 from liuooo/support_streaming_options
Browse files Browse the repository at this point in the history
Support "include_usage" streaming options
  • Loading branch information
liuooo authored Nov 23, 2024
2 parents 1b2cdfe + 29e9fce commit 44eeaf4
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 0 deletions.
5 changes: 5 additions & 0 deletions app/core/runner/llm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def run(
tools: List = None,
tool_choice="auto",
stream=False,
stream_options=None,
extra_body=None,
temperature=None,
top_p=None,
Expand All @@ -38,6 +39,10 @@ def run(
if "n" in model_params:
raise ValueError("n is not allowed in model_params")
chat_params.update(model_params)
if stream_options:
if isinstance(stream_options, dict):
if "include_usage" in stream_options:
chat_params["stream_options"] = {"include_usage": bool(stream_options["include_usage"])}
if temperature:
chat_params["temperature"] = temperature
if top_p:
Expand Down
4 changes: 4 additions & 0 deletions app/core/runner/llm_callback_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def handle_llm_response(
for chunk in response_stream:
logging.debug(chunk)

if chunk.usage:
self.event_handler.pub_message_usage(chunk)
continue

if not chunk.choices:
continue

Expand Down
18 changes: 18 additions & 0 deletions app/core/runner/pub_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,24 @@ def pub_message_in_progress(self, message):
events.ThreadMessageInProgress(data=_data_adjust_message(message), event="thread.message.in_progress")
)

def pub_message_usage(self, chunk):
"""
目前 stream 未有 usage 相关 event,借用 thread.message.in_progress 进行传输,待官方更新
"""
data = {
"id": chunk.id,
"content": [],
"created_at": 0,
"object": "thread.message",
"role": "assistant",
"status": "in_progress",
"thread_id": "",
"metadata": {"usage": chunk.usage.json()}
}
self.pub_event(
events.ThreadMessageInProgress(data=data, event="thread.message.in_progress")
)

def pub_message_completed(self, message):
self.pub_event(
events.ThreadMessageCompleted(data=_data_adjust_message(message), event="thread.message.completed")
Expand Down
1 change: 1 addition & 0 deletions app/core/runner/thread_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __run_step(
tools=[tool.openai_function for tool in tools],
tool_choice="auto" if len(run_steps) < self.max_step else "none",
stream=True,
stream_options=run.stream_options,
extra_body=run.extra_body,
temperature=run.temperature,
top_p=run.top_p,
Expand Down
2 changes: 2 additions & 0 deletions app/models/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class RunBase(BaseModel):
failed_at: Optional[datetime] = Field(default=None)
additional_instructions: Optional[str] = Field(default=None, max_length=32768, sa_column=Column(TEXT))
extra_body: Optional[dict] = Field(default={}, sa_column=Column(JSON))
stream_options: Optional[dict] = Field(default=None, sa_column=Column(JSON))
incomplete_details: Optional[str] = Field(default=None) # 未完成详情
max_completion_tokens: Optional[int] = Field(default=None) # 最大完成长度
max_prompt_tokens: Optional[int] = Field(default=None) # 最大提示长度
Expand All @@ -74,6 +75,7 @@ class RunCreate(BaseModel):
tools: Optional[list] = []
extra_body: Optional[dict[str, Union[dict[str, Union[Authentication, Any]], Any]]] = {}
stream: Optional[bool] = False
stream_options: Optional[dict] = Field(default=None, sa_column=Column(JSON))
additional_messages: Optional[list[MessageCreate]] = Field(default=[], sa_column=Column(JSON)) # 消息列表
max_completion_tokens: Optional[int] = None # 最大完成长度
max_prompt_tokens: Optional[int] = Field(default=None) # 最大提示长度
Expand Down
10 changes: 10 additions & 0 deletions examples/run_assistant_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import logging

from openai import AssistantEventHandler
from openai.types.beta import AssistantStreamEvent
from openai.types.beta.assistant_stream_event import ThreadMessageInProgress
from openai.types.beta.threads.message import Message
from openai.types.beta.threads.runs import ToolCall, ToolCallDelta

Expand Down Expand Up @@ -47,6 +49,11 @@ def on_text_delta(self, delta, snapshot) -> None:
def on_text_done(self, text) -> None:
logging.info("text done: %s\n", text)

def on_event(self, event: AssistantStreamEvent) -> None:
if isinstance(event, ThreadMessageInProgress):
logging.info("event: %s\n", event)


if __name__ == "__main__":
assistant = client.beta.assistants.create(
name="Assistant Demo",
Expand All @@ -70,5 +77,8 @@ def on_text_done(self, text) -> None:
thread_id=thread.id,
assistant_id=assistant.id,
event_handler=event_handler,
extra_body={
"stream_options": {"include_usage": True}
}
) as stream:
stream.until_done()
31 changes: 31 additions & 0 deletions migrations/versions/2024-11-15-18-30_5b2b73d0fdf6.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""empty message
Revision ID: 5b2b73d0fdf6
Revises: b217fafdb5f0
Create Date: 2024-11-15 18:30:43.391344
"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
import sqlmodel


# revision identifiers, used by Alembic.
revision: str = '5b2b73d0fdf6'
down_revision: Union[str, None] = 'b217fafdb5f0'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('run', sa.Column('stream_options', sa.JSON(), nullable=True))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('run', 'stream_options')
# ### end Alembic commands ###

0 comments on commit 44eeaf4

Please sign in to comment.