Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add fine grained JSON stream parser #119

Merged
merged 16 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions taskweaver/chat/console/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,27 @@ def get_ani_frame(frame: int = 0):
ani_frame = " " * frame_inx + "<=💡=>" + " " * (10 - frame_inx)
return ani_frame

def format_status_message(limit: int):
incomplete_suffix = "..."
incomplete_suffix_len = len(incomplete_suffix)
if len(cur_message_buffer) == 0:
if len(status_msg) > limit - 1:
return f" {status_msg[(limit - incomplete_suffix_len - 1):]}{incomplete_suffix}"
return " " + status_msg

cur_key_display = style_line("[") + style_key(cur_key) + style_line("]")
cur_key_len = len(cur_key) + 2 # with extra bracket
cur_message_buffer_norm = cur_message_buffer.replace("\n", " ").replace(
"\r",
" ",
)

if len(cur_message_buffer_norm) < limit - cur_key_len - 1:
return f"{cur_key_display} {cur_message_buffer_norm}"

status_msg_len = limit - cur_key_len - incomplete_suffix_len
return f"{cur_key_display}{incomplete_suffix}{cur_message_buffer_norm[-status_msg_len:]}"

last_time = 0
while True:
clear_line()
Expand Down Expand Up @@ -330,6 +351,7 @@ def get_ani_frame(frame: int = 0):
key=cur_key,
),
)
cur_message_buffer = ""
elif action == "round_error":
error_message(opt)
elif action == "status_update":
Expand All @@ -340,14 +362,28 @@ def get_ani_frame(frame: int = 0):
if self.exit_event.is_set():
break

cur_message_prefix: str = " TaskWeaver "
cur_ani_frame = get_ani_frame(counter)
cur_message_display_len = (
terminal_column
- len(cur_message_prefix)
- 2 # separator for cur message prefix
- len(role)
- 2 # bracket for role
- len(cur_ani_frame)
- 2 # extra size for emoji in ani
)

cur_message_display = format_status_message(cur_message_display_len)

click.secho(
click.style(" TaskWeaver ", fg="white", bg="yellow")
click.style(cur_message_prefix, fg="white", bg="yellow")
+ click.style("▶ ", fg="yellow")
+ style_line("[")
+ style_role(role)
+ style_line("] ")
+ style_msg(status_msg)
+ style_msg(get_ani_frame(counter))
+ style_line("]")
+ style_msg(cur_message_display)
+ style_msg(cur_ani_frame)
+ "\r",
# f">>> [{style_role(role)}] {status_msg} {get_ani_frame(counter)}\r",
nl=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ def early_stop(_type: AttachmentType, value: str) -> bool:
llm_output=self.llm_api.chat_completion_stream(
prompt,
use_backup_engine=use_back_up_engine,
use_smoother=True,
),
post_proxy=post_proxy,
early_stop=early_stop,
Expand Down
158 changes: 147 additions & 11 deletions taskweaver/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Generator, List, Optional, Type
from typing import Any, Callable, Generator, List, Optional, Type

from injector import Injector, inject

Expand Down Expand Up @@ -112,18 +112,154 @@ def chat_completion_stream(
max_tokens: Optional[int] = None,
top_p: Optional[float] = None,
stop: Optional[List[str]] = None,
use_smoother: bool = True,
**kwargs: Any,
) -> Generator[ChatMessageType, None, None]:
return self.completion_service.chat_completion(
messages,
use_backup_engine,
stream,
temperature,
max_tokens,
top_p,
stop,
**kwargs,
)
def get_generator() -> Generator[ChatMessageType, None, None]:
return self.completion_service.chat_completion(
messages,
use_backup_engine,
stream,
temperature,
max_tokens,
top_p,
stop,
**kwargs,
)

if use_smoother:
return self._stream_smoother(get_generator)
return get_generator()

def _stream_smoother(
self,
stream_init: Callable[[], Generator[ChatMessageType, None, None]],
) -> Generator[ChatMessageType, None, None]:
import random
import threading
import time

min_sleep_interval = 0.1
min_chunk_size = 2
min_update_interval = 1 / 30 # 30Hz

recv_start = time.time()
buffer_message: Optional[ChatMessageType] = None
buffer_content: str = ""
finished = False

update_lock = threading.Lock()
update_cond = threading.Condition()
cur_base_speed: float = 10.0

def speed_normalize(speed: float):
return min(max(speed, 5), 600)

def base_stream_puller():
nonlocal buffer_message, buffer_content, finished, cur_base_speed
stream = stream_init()

for msg in stream:
if msg["content"] == "":
continue

with update_lock:
buffer_message = msg
buffer_content += msg["content"]
cur_time = time.time()

new_speed = min(2e3, len(buffer_content) / (cur_time - recv_start))
weight = min(1.0, len(buffer_content) / 80)
cur_base_speed = new_speed * weight + cur_base_speed * (1 - weight)

with update_cond:
update_cond.notify()

with update_lock:
finished = True

thread = threading.Thread(target=base_stream_puller)
thread.start()

sent_content: str = ""
sent_start: float = time.time()
next_update_time = time.time()
cur_update_speed = cur_base_speed

while True:
if finished and len(buffer_content) - len(sent_content) < min_chunk_size * 5:
if buffer_message is not None and len(sent_content) < len(
buffer_content,
):
new_pack = buffer_content[len(sent_content) :]
sent_content += new_pack
yield format_chat_message(
role=buffer_message["role"],
message=new_pack,
name=buffer_message["name"] if "name" in buffer_message else None,
)
break

if time.time() < next_update_time:
with update_cond:
update_cond.wait(
min(min_sleep_interval, next_update_time - time.time()),
)
continue

with update_lock:
cur_buf_message = buffer_message
total_len = len(buffer_content)
sent_len = len(sent_content)
rem_len = total_len - sent_len

if cur_buf_message is None or len(buffer_content) - len(sent_content) < min_chunk_size:
# wait for more buffer
with update_cond:
update_cond.wait(min_sleep_interval)
continue

if sent_start == 0.0:
# first chunk time
sent_start = time.time()

cur_base_speed_norm = speed_normalize(cur_base_speed)
cur_actual_speed_norm = speed_normalize(
sent_len / (time.time() - (sent_start if not finished else recv_start)),
)
target_speed = cur_base_speed_norm + (cur_base_speed_norm - cur_actual_speed_norm) * 0.25
cur_update_speed = speed_normalize(0.5 * cur_update_speed + target_speed * 0.5)

if cur_update_speed > min_chunk_size / min_update_interval:
chunk_time_target = min_update_interval
new_pack_size_target = chunk_time_target * cur_update_speed
else:
new_pack_size_target = min_chunk_size
chunk_time_target = new_pack_size_target / cur_update_speed

rand_min = max(
min(rem_len, min_chunk_size),
int(0.8 * new_pack_size_target),
)
rand_max = min(rem_len, int(1.2 * new_pack_size_target))
new_pack_size = random.randint(rand_min, rand_max) if rand_max - rand_min > 1 else rand_min

chunk_time = chunk_time_target / new_pack_size_target * new_pack_size

new_pack = buffer_content[sent_len : (sent_len + new_pack_size)]
sent_content += new_pack

yield format_chat_message(
role=cur_buf_message["role"],
message=new_pack,
name=cur_buf_message["name"] if "name" in cur_buf_message else None,
)

next_update_time = time.time() + chunk_time
with update_cond:
update_cond.wait(min(min_sleep_interval, chunk_time))

thread.join()

def get_embedding(self, string: str) -> List[float]:
return self.embedding_service.get_embeddings([string])[0]
Expand Down
4 changes: 2 additions & 2 deletions taskweaver/llm/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,8 @@ def _get_from_playback_completion(
content = cached_value["content"]
cur_pos = 0
while cur_pos < len(content):
chunk_size = random.randint(3, 20)
chunk_size = random.randint(2, 8)
next_pos = min(cur_pos + chunk_size, len(content))
time.sleep(self.config.playback_delay) # init delay
yield format_chat_message(role, content[cur_pos:next_pos])
cur_pos = next_pos
time.sleep(self.config.playback_delay)
6 changes: 5 additions & 1 deletion taskweaver/module/event_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def update_attachment(
if id is not None:
attachment = self.post.attachment_list[-1]
assert id == attachment.id
if type is not None:
assert type == attachment.type
attachment.content += message
attachment.extra = extra
else:
assert type is not None, "type is required when creating new attachment"
attachment = Attachment.create(
Expand All @@ -184,7 +188,7 @@ def update_attachment(
{
"type": type,
"extra": extra,
"id": id,
"id": attachment.id,
"is_end": is_end,
},
)
Expand Down
9 changes: 7 additions & 2 deletions taskweaver/planner/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,12 @@ def check_post_validity(post: Post):
assert post.send_to is not None, "send_to field is None"
assert post.send_to != "Planner", "send_to field should not be Planner"
assert post.message is not None, "message field is None"
assert post.attachment_list[0].type == AttachmentType.init_plan, "attachment type is not init_plan"
assert post.attachment_list[1].type == AttachmentType.plan, "attachment type is not plan"
assert (
post.attachment_list[0].type == AttachmentType.init_plan
), f"attachment type {post.attachment_list[0].type} is not init_plan"
assert (
post.attachment_list[1].type == AttachmentType.plan
), f"attachment type {post.attachment_list[1].type} is not plan"
assert (
post.attachment_list[2].type == AttachmentType.current_plan_step
), "attachment type is not current_plan_step"
Expand All @@ -241,6 +245,7 @@ def check_post_validity(post: Post):
llm_stream = self.llm_api.chat_completion_stream(
chat_history,
use_backup_engine=use_back_up_engine,
use_smoother=True,
)

llm_output: List[str] = []
Expand Down
Loading