diff --git a/.vscode/settings.json b/.vscode/settings.json index 3c676c7..a8b6131 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -10,7 +10,7 @@ "editor.formatOnSave": true, "editor.tabSize": 4, "editor.codeActionsOnSave": { - "source.organizeImports": true + "source.organizeImports": "explicit" }, "editor.rulers": [88] }, diff --git a/chatstream/__init__.py b/chatstream/__init__.py index 90aa33c..5fb694c 100644 --- a/chatstream/__init__.py +++ b/chatstream/__init__.py @@ -13,6 +13,7 @@ import functools import inspect import json +import os import sys import time from typing import ( @@ -29,9 +30,9 @@ cast, ) -import shiny.experimental as x import tiktoken from htmltools import HTMLDependency +from openai import AsyncOpenAI from shiny import Inputs, Outputs, Session, module, reactive, render, ui from .openai_types import ( @@ -41,17 +42,16 @@ openai_model_context_limits, ) -if "pyodide" in sys.modules: - from . import openai_pyodide as openai -else: - import openai - if sys.version_info < (3, 10): from typing_extensions import ParamSpec, TypeGuard else: from typing import ParamSpec, TypeGuard +client = AsyncOpenAI( + api_key=os.environ["OPENAI_API_KEY"], # this is also the default, it can be omitted +) + DEFAULT_MODEL: OpenAiModel = "gpt-3.5-turbo" DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant." DEFAULT_TEMPERATURE = 0.7 @@ -175,12 +175,14 @@ def __init__( text_input_placeholder: str | Callable[[], str] | None = None, button_label: str | Callable[[], str] = "Ask", throttle: float | Callable[[], float] = DEFAULT_THROTTLE, - query_preprocessor: Callable[[str], str] - | Callable[[str], Awaitable[str]] - | None = None, - answer_preprocessor: Callable[[str], ui.TagChild] - | Callable[[str], Awaitable[ui.TagChild]] - | None = None, + query_preprocessor: ( + Callable[[str], str] | Callable[[str], Awaitable[str]] | None + ) = None, + answer_preprocessor: ( + Callable[[str], ui.TagChild] + | Callable[[str], Awaitable[ui.TagChild]] + | None + ) = None, debug: bool = False, ): self.input = input @@ -222,15 +224,15 @@ def __init__( tuple[ChatCompletionStreaming, ...] ] = reactive.Value(tuple()) - self.streaming_chat_string_pieces: reactive.Value[ - tuple[str, ...] - ] = reactive.Value(tuple()) + self.streaming_chat_string_pieces: reactive.Value[tuple[str, ...]] = ( + reactive.Value(tuple()) + ) self._ask_trigger = reactive.Value(0) - self.session_messages: reactive.Value[ - tuple[ChatMessageEnriched, ...] - ] = reactive.Value(tuple()) + self.session_messages: reactive.Value[tuple[ChatMessageEnriched, ...]] = ( + reactive.Value(tuple()) + ) self.hide_query_ui: reactive.Value[bool] = reactive.Value(False) @@ -255,13 +257,13 @@ async def finalize_streaming_result(): current_batch = self.streaming_chat_messages_batch() for message in current_batch: - if "content" in message["choices"][0]["delta"]: + if message.choices[0].delta.content: self.streaming_chat_string_pieces.set( self.streaming_chat_string_pieces() - + (message["choices"][0]["delta"]["content"],) + + (message.choices[0].delta.content,) ) - finish_reason = message["choices"][0]["finish_reason"] + finish_reason = message.choices[0].finish_reason if finish_reason in ["stop", "length"]: # If we got here, we know that streaming_chat_string is not None. current_message_str = "".join(self.streaming_chat_string_pieces()) @@ -353,10 +355,9 @@ async def perform_query(): # this Task (which would block other computation to happen, like running # reactive stuff). messages: StreamResult[ChatCompletionStreaming] = stream_to_reactive( - openai.ChatCompletion.acreate( # pyright: ignore[reportUnknownMemberType, reportGeneralTypeIssues] + client.chat.completions.create( # pyright: ignore model=self.model(), - api_key=self.api_key(), - messages=outgoing_messages_normalized, + messages=outgoing_messages_normalized, # pyright: ignore stream=True, temperature=self.temperature(), **extra_kwargs, @@ -418,7 +419,7 @@ def query_ui(): return ui.div() return ui.div( - x.ui.input_text_area( + ui.input_text_area( "query", None, # value="2+2", diff --git a/chatstream/openai_types.py b/chatstream/openai_types.py index 303decd..34938fe 100644 --- a/chatstream/openai_types.py +++ b/chatstream/openai_types.py @@ -31,7 +31,7 @@ openai_models: list[OpenAiModel] = list(openai_model_context_limits) -class Usage(TypedDict): +class Usage: completion_tokens: int # Note: this doesn't seem to be present in all cases. prompt_tokens: int total_tokens: int @@ -42,11 +42,11 @@ class ChatMessage(TypedDict): role: Literal["system", "user", "assistant"] -class ChoiceDelta(TypedDict): +class ChoiceDelta: content: str -class ChoiceBase(TypedDict): +class ChoiceBase: finish_reason: Literal["stop", "length"] | None index: int @@ -59,13 +59,13 @@ class ChoiceStreaming(ChoiceBase): delta: ChoiceDelta -class ChatCompletionBase(TypedDict): +class ChatCompletionBase: id: str created: int model: str -class ChatCompletionNonStreaming(TypedDict): +class ChatCompletionNonStreaming: object: Literal["chat.completion"] choices: list[ChoiceNonStreaming] usage: Usage diff --git a/examples/doc_query/app.py b/examples/doc_query/app.py index f636e53..9ec2e3f 100644 --- a/examples/doc_query/app.py +++ b/examples/doc_query/app.py @@ -8,8 +8,8 @@ from pathlib import Path from typing import Generator, Sequence, cast -import chromadb # pyright: ignore[reportMissingTypeStubs] -import chromadb.api # pyright: ignore[reportMissingTypeStubs] +import chromadb +import chromadb.api import pypdf import shiny.experimental as x import tiktoken @@ -114,7 +114,7 @@ def add_context_to_query(query: str) -> str: # because the extra stuff just won't be used. n_documents = math.ceil(max_context_tokens / APPROX_DOCUMENT_SIZE) - results = collection.query( + results = collection.query( # pyright: ignore[reportUnknownMemberType] query_texts=[query], n_results=min(collection.count(), n_documents), ) @@ -258,7 +258,7 @@ def extract_text_from_pdf(pdf_path: str | Path) -> str: async def add_file_content_to_db( - collection: chromadb.api.Collection, + collection: chromadb.api.Collection, # pyright: ignore[reportPrivateImportUsage] file: str | Path, label: str, debug: bool = False, @@ -283,7 +283,7 @@ async def add_file_content_to_db( for i in range(len(text_chunks)): p.set(value=i, message="Adding text to database...") await asyncio.sleep(0) - collection.add( + collection.add( # pyright: ignore[reportUnknownMemberType] documents=text_chunks[i], metadatas={"filename": label, "page": str(i)}, ids=f"{label}-{i}", diff --git a/pyrightconfig.json b/pyrightconfig.json index de262a2..9121e1c 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -1,5 +1,5 @@ { - "ignore": ["venv", "typings"], + "ignore": ["venv", "typings", "docs"], "typeCheckingMode": "strict", "reportUnusedFunction": "none", "reportPrivateUsage": "none", diff --git a/setup.cfg b/setup.cfg index 2102337..6562891 100644 --- a/setup.cfg +++ b/setup.cfg @@ -28,7 +28,7 @@ setup_requires = setuptools install_requires = shiny>=0.3.3 - openai;platform_system!="Emscripten" + openai>=1.13.3;platform_system!="Emscripten" tiktoken tests_require = pytest>=3