Skip to content

Commit

Permalink
Merge pull request #17 from wch/openai-1_0
Browse files Browse the repository at this point in the history
Fix compatibility with openai>=1.0.0
  • Loading branch information
wch authored Apr 11, 2024
2 parents e906390 + b2bc562 commit 88b7fe8
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 38 deletions.
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"editor.formatOnSave": true,
"editor.tabSize": 4,
"editor.codeActionsOnSave": {
"source.organizeImports": true
"source.organizeImports": "explicit"
},
"editor.rulers": [88]
},
Expand Down
51 changes: 26 additions & 25 deletions chatstream/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import functools
import inspect
import json
import os
import sys
import time
from typing import (
Expand All @@ -29,9 +30,9 @@
cast,
)

import shiny.experimental as x
import tiktoken
from htmltools import HTMLDependency
from openai import AsyncOpenAI
from shiny import Inputs, Outputs, Session, module, reactive, render, ui

from .openai_types import (
Expand All @@ -41,17 +42,16 @@
openai_model_context_limits,
)

if "pyodide" in sys.modules:
from . import openai_pyodide as openai
else:
import openai

if sys.version_info < (3, 10):
from typing_extensions import ParamSpec, TypeGuard
else:
from typing import ParamSpec, TypeGuard


client = AsyncOpenAI(
api_key=os.environ["OPENAI_API_KEY"], # this is also the default, it can be omitted
)

DEFAULT_MODEL: OpenAiModel = "gpt-3.5-turbo"
DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant."
DEFAULT_TEMPERATURE = 0.7
Expand Down Expand Up @@ -175,12 +175,14 @@ def __init__(
text_input_placeholder: str | Callable[[], str] | None = None,
button_label: str | Callable[[], str] = "Ask",
throttle: float | Callable[[], float] = DEFAULT_THROTTLE,
query_preprocessor: Callable[[str], str]
| Callable[[str], Awaitable[str]]
| None = None,
answer_preprocessor: Callable[[str], ui.TagChild]
| Callable[[str], Awaitable[ui.TagChild]]
| None = None,
query_preprocessor: (
Callable[[str], str] | Callable[[str], Awaitable[str]] | None
) = None,
answer_preprocessor: (
Callable[[str], ui.TagChild]
| Callable[[str], Awaitable[ui.TagChild]]
| None
) = None,
debug: bool = False,
):
self.input = input
Expand Down Expand Up @@ -222,15 +224,15 @@ def __init__(
tuple[ChatCompletionStreaming, ...]
] = reactive.Value(tuple())

self.streaming_chat_string_pieces: reactive.Value[
tuple[str, ...]
] = reactive.Value(tuple())
self.streaming_chat_string_pieces: reactive.Value[tuple[str, ...]] = (
reactive.Value(tuple())
)

self._ask_trigger = reactive.Value(0)

self.session_messages: reactive.Value[
tuple[ChatMessageEnriched, ...]
] = reactive.Value(tuple())
self.session_messages: reactive.Value[tuple[ChatMessageEnriched, ...]] = (
reactive.Value(tuple())
)

self.hide_query_ui: reactive.Value[bool] = reactive.Value(False)

Expand All @@ -255,13 +257,13 @@ async def finalize_streaming_result():
current_batch = self.streaming_chat_messages_batch()

for message in current_batch:
if "content" in message["choices"][0]["delta"]:
if message.choices[0].delta.content:
self.streaming_chat_string_pieces.set(
self.streaming_chat_string_pieces()
+ (message["choices"][0]["delta"]["content"],)
+ (message.choices[0].delta.content,)
)

finish_reason = message["choices"][0]["finish_reason"]
finish_reason = message.choices[0].finish_reason
if finish_reason in ["stop", "length"]:
# If we got here, we know that streaming_chat_string is not None.
current_message_str = "".join(self.streaming_chat_string_pieces())
Expand Down Expand Up @@ -353,10 +355,9 @@ async def perform_query():
# this Task (which would block other computation to happen, like running
# reactive stuff).
messages: StreamResult[ChatCompletionStreaming] = stream_to_reactive(
openai.ChatCompletion.acreate( # pyright: ignore[reportUnknownMemberType, reportGeneralTypeIssues]
client.chat.completions.create( # pyright: ignore
model=self.model(),
api_key=self.api_key(),
messages=outgoing_messages_normalized,
messages=outgoing_messages_normalized, # pyright: ignore
stream=True,
temperature=self.temperature(),
**extra_kwargs,
Expand Down Expand Up @@ -418,7 +419,7 @@ def query_ui():
return ui.div()

return ui.div(
x.ui.input_text_area(
ui.input_text_area(
"query",
None,
# value="2+2",
Expand Down
10 changes: 5 additions & 5 deletions chatstream/openai_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
openai_models: list[OpenAiModel] = list(openai_model_context_limits)


class Usage(TypedDict):
class Usage:
completion_tokens: int # Note: this doesn't seem to be present in all cases.
prompt_tokens: int
total_tokens: int
Expand All @@ -42,11 +42,11 @@ class ChatMessage(TypedDict):
role: Literal["system", "user", "assistant"]


class ChoiceDelta(TypedDict):
class ChoiceDelta:
content: str


class ChoiceBase(TypedDict):
class ChoiceBase:
finish_reason: Literal["stop", "length"] | None
index: int

Expand All @@ -59,13 +59,13 @@ class ChoiceStreaming(ChoiceBase):
delta: ChoiceDelta


class ChatCompletionBase(TypedDict):
class ChatCompletionBase:
id: str
created: int
model: str


class ChatCompletionNonStreaming(TypedDict):
class ChatCompletionNonStreaming:
object: Literal["chat.completion"]
choices: list[ChoiceNonStreaming]
usage: Usage
Expand Down
10 changes: 5 additions & 5 deletions examples/doc_query/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from pathlib import Path
from typing import Generator, Sequence, cast

import chromadb # pyright: ignore[reportMissingTypeStubs]
import chromadb.api # pyright: ignore[reportMissingTypeStubs]
import chromadb
import chromadb.api
import pypdf
import shiny.experimental as x
import tiktoken
Expand Down Expand Up @@ -114,7 +114,7 @@ def add_context_to_query(query: str) -> str:
# because the extra stuff just won't be used.
n_documents = math.ceil(max_context_tokens / APPROX_DOCUMENT_SIZE)

results = collection.query(
results = collection.query( # pyright: ignore[reportUnknownMemberType]
query_texts=[query],
n_results=min(collection.count(), n_documents),
)
Expand Down Expand Up @@ -258,7 +258,7 @@ def extract_text_from_pdf(pdf_path: str | Path) -> str:


async def add_file_content_to_db(
collection: chromadb.api.Collection,
collection: chromadb.api.Collection, # pyright: ignore[reportPrivateImportUsage]
file: str | Path,
label: str,
debug: bool = False,
Expand All @@ -283,7 +283,7 @@ async def add_file_content_to_db(
for i in range(len(text_chunks)):
p.set(value=i, message="Adding text to database...")
await asyncio.sleep(0)
collection.add(
collection.add( # pyright: ignore[reportUnknownMemberType]
documents=text_chunks[i],
metadatas={"filename": label, "page": str(i)},
ids=f"{label}-{i}",
Expand Down
2 changes: 1 addition & 1 deletion pyrightconfig.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"ignore": ["venv", "typings"],
"ignore": ["venv", "typings", "docs"],
"typeCheckingMode": "strict",
"reportUnusedFunction": "none",
"reportPrivateUsage": "none",
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ setup_requires =
setuptools
install_requires =
shiny>=0.3.3
openai;platform_system!="Emscripten"
openai>=1.13.3;platform_system!="Emscripten"
tiktoken
tests_require =
pytest>=3
Expand Down

0 comments on commit 88b7fe8

Please sign in to comment.