Skip to content

Commit

Permalink
Updated chat sessions to be able to take the no_stream option (#91)
Browse files Browse the repository at this point in the history
Following #89, in order to run o1 models you need to supply the options:
```bash
gpt --no_stream --temperature=1 --model=o1-preview
```
as o1 is not built for more than one temperature and cannot stream.

Whilst `--no_stream` works for `--prompt` and `--execute`, it does not
work for chat mode. This PR implements this by adding it as an attribute
read from the command line arguments to `CLIChatSession`
  • Loading branch information
williamjameshandley authored Nov 17, 2024
1 parent e448f07 commit b4f1f22
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
6 changes: 3 additions & 3 deletions gptcli/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def run_non_interactive(args, assistant):


class CLIChatSession(ChatSession):
def __init__(self, assistant: Assistant, markdown: bool, show_price: bool):
def __init__(self, assistant: Assistant, markdown: bool, show_price: bool, stream: bool):
listeners = [
CLIChatListener(markdown),
LoggingChatListener(),
Expand All @@ -240,13 +240,13 @@ def __init__(self, assistant: Assistant, markdown: bool, show_price: bool):
listeners.append(PriceChatListener(assistant))

listener = CompositeChatListener(listeners)
super().__init__(assistant, listener)
super().__init__(assistant, listener, stream)


def run_interactive(args, assistant):
logger.info("Starting a new chat session. Assistant config: %s", assistant.config)
session = CLIChatSession(
assistant=assistant, markdown=args.markdown, show_price=args.show_price
assistant=assistant, markdown=args.markdown, show_price=args.show_price, stream=not args.no_stream
)
history_filename = os.path.expanduser("~/.config/gpt-cli/history")
os.makedirs(os.path.dirname(history_filename), exist_ok=True)
Expand Down
4 changes: 3 additions & 1 deletion gptcli/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,13 @@ def __init__(
self,
assistant: Assistant,
listener: ChatListener,
stream: bool = True,
):
self.assistant = assistant
self.messages: List[Message] = assistant.init_messages()
self.user_prompts: List[Tuple[Message, ModelOverrides]] = []
self.listener = listener
self.stream = stream

def _clear(self):
self.messages = self.assistant.init_messages()
Expand All @@ -112,7 +114,7 @@ def _respond(self, overrides: ModelOverrides) -> bool:
usage: Optional[UsageEvent] = None
try:
completion_iter = self.assistant.complete_chat(
self.messages, override_params=overrides
self.messages, override_params=overrides, stream=self.stream
)

with self.listener.response_streamer() as stream:
Expand Down
16 changes: 8 additions & 8 deletions tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_simple_input():
assistant_message = {"role": "assistant", "content": expected_response}

assistant_mock.complete_chat.assert_called_once_with(
[system_message, user_message], override_params={}
[system_message, user_message], override_params={}, stream=True,
)
listener_mock.on_chat_message.assert_has_calls(
[mock.call(user_message), mock.call(assistant_message)]
Expand All @@ -66,7 +66,7 @@ def test_clear():

assistant_mock.complete_chat.assert_called_once_with(
[system_message, {"role": "user", "content": "user_message"}],
override_params={},
override_params={}, stream=True,
)
listener_mock.on_chat_message.assert_has_calls(
[
Expand All @@ -93,7 +93,7 @@ def test_clear():

assistant_mock.complete_chat.assert_called_once_with(
[system_message, {"role": "user", "content": "user_message_1"}],
override_params={},
override_params={}, stream=True,
)
listener_mock.on_chat_message.assert_has_calls(
[
Expand Down Expand Up @@ -128,7 +128,7 @@ def test_rerun():

assistant_mock.complete_chat.assert_called_once_with(
[system_message, {"role": "user", "content": "user_message"}],
override_params={},
override_params={}, stream=True,
)
listener_mock.on_chat_message.assert_has_calls(
[
Expand All @@ -150,7 +150,7 @@ def test_rerun():

assistant_mock.complete_chat.assert_called_once_with(
[system_message, {"role": "user", "content": "user_message"}],
override_params={},
override_params={}, stream=True,
)
listener_mock.on_chat_message.assert_has_calls(
[
Expand All @@ -175,7 +175,7 @@ def test_args():
assistant_message = {"role": "assistant", "content": expected_response}

assistant_mock.complete_chat.assert_called_once_with(
[system_message, user_message], override_params={"arg1": "value1"}
[system_message, user_message], override_params={"arg1": "value1"}, stream=True,
)
listener_mock.on_chat_message.assert_has_calls(
[mock.call(user_message), mock.call(assistant_message)]
Expand All @@ -191,7 +191,7 @@ def test_args():
assert should_continue

assistant_mock.complete_chat.assert_called_once_with(
[system_message, user_message], override_params={"arg1": "value1"}
[system_message, user_message], override_params={"arg1": "value1"}, stream=True,
)
listener_mock.on_chat_message.assert_has_calls([mock.call(assistant_message)])

Expand Down Expand Up @@ -250,7 +250,7 @@ def test_openai_error():
assert should_continue

assistant_mock.complete_chat.assert_called_once_with(
[system_message, user_message], override_params={}
[system_message, user_message], override_params={}, stream=True,
)
listener_mock.on_chat_message.assert_has_calls(
[
Expand Down

0 comments on commit b4f1f22

Please sign in to comment.