From f84bee805ee18f9ed3b2f0fc0c52767d303ef109 Mon Sep 17 00:00:00 2001 From: Jack-Q Date: Thu, 11 Jan 2024 15:26:26 +0800 Subject: [PATCH 01/15] add stream json parser --- taskweaver/utils/json_parser.py | 330 ++++++++++++++++++++++++++++++++ 1 file changed, 330 insertions(+) create mode 100644 taskweaver/utils/json_parser.py diff --git a/taskweaver/utils/json_parser.py b/taskweaver/utils/json_parser.py new file mode 100644 index 00000000..8561f974 --- /dev/null +++ b/taskweaver/utils/json_parser.py @@ -0,0 +1,330 @@ +import itertools +from typing import Any, Iterable, List, Literal, NamedTuple, Optional, Tuple + +ParserEventType = Literal[ + "start_map", + "end_map", + "start_array", + "end_array", + "map_key", + "null", + "boolean", + # use number for both integer and double + # "integer", + # "double", + "number", + "string", + "ws", +] + + +class ParserEvent(NamedTuple): + prefix: str + event: ParserEventType + value: Any + value_str: str + is_end: bool + + +ParserStateType = Literal[ + "object", + "object_key", + "object_value", + "array", + "array_value", + "number", + "string", + "literal", +] + + +def reduce_events(events: Iterable[ParserEvent]) -> Iterable[ParserEvent]: + reduced: List[ParserEvent] = [] + cur: Optional[ParserEvent] = None + for ev in events: + if cur is None: + cur = ev + continue + if ev.event == cur.event: + cur = ParserEvent( + ev.prefix, + cur.event, + ev.value, + cur.value_str + ev.value_str, + ev.is_end, + ) + else: + reduced.append(cur) + cur = ev + if cur is not None: + reduced.append(cur) + return reduced + + +def is_ws(ch: str): + return ch == " " or ch == "\t" or ch == "\n" or ch == "\r" + + +def parse_json_stream(token_stream: Iterable[str]) -> Iterable[ParserEvent]: + buf: str = "" + is_end: bool = False + prefix_stack: List[str] = [] + state_stack: List[Tuple[ParserStateType, Any]] = [] + ev_queue: List[ParserEvent] = [] + + def add_event(ev: ParserEventType, value: Any, value_str: str, is_end: bool): + ev_queue.append( + ParserEvent( + "".join(prefix_stack), + ev, + value, + value_str, + is_end, + ), + ) + + def parse_ws(ch: str) -> bool: + if not is_ws(ch): + return False + add_event("ws", None, ch, False) + return True + + def parse_str_begin(ch: str, is_obj_key: bool = False) -> bool: + if ch == '"': + add_event("object_key" if is_obj_key else "string", "", "", False) # type: ignore + state_stack.append(("string", (False, "", "", is_obj_key))) + return True + return False + + def parse_value_begin(ch: str) -> bool: + if parse_ws(ch) or parse_str_begin(ch): + return True + if ch == "{": + add_event("start_map", None, ch, True) + state_stack.append(("object", None)) + return True + if ch == "[": + add_event("start_array", None, ch, True) + state_stack.append(("array", (0, False, False))) + return True + if ch in ["t", "f", "n"]: + literal_state: Tuple[str, ParserEventType, str, Any] + if ch == "t": + literal_state = (ch, "boolean", "true", True) + elif ch == "f": + literal_state = (ch, "boolean", "false", False) + else: + literal_state = (ch, "null", "null", None) + add_event(literal_state[1], None, ch, False) + state_stack.append(("literal", literal_state)) + return True + if ch == "-" or ch.isdigit(): + add_event("number", ch, ch, False) + state_stack.append(("number", (ch, False, False, False))) + return True + return False + + def parse_obj_begin(ch: str) -> bool: + if ch == "}": + add_event("end_map", None, ch, True) + state_stack.pop() + return True + if parse_ws(ch): + return True + if parse_str_begin(ch, True): + return True + return False + + def parse_obj_value(ch: str, cur_state_ext: Tuple[str, bool, bool]) -> bool: + key, value_to_begin, value_to_end = cur_state_ext + if parse_ws(ch): + return True + if value_to_end: + prefix_stack.pop() + state_stack.pop() + if ch == ",": + return True + if ch == "}": + print(f"state check: {state_stack}") + add_event("end_map", None, ch, True) + state_stack.pop() # pop the object begin state + return True + raise Exception(f"invalid value after value of key {key}: {ch}") + if value_to_begin: + state_stack[-1] = ("object_value", (key, False, True)) + if parse_value_begin(ch): + return True + raise Exception(f"invalid value for key {key}: {ch}") + if ch == ":": + state_stack[-1] = ("object_value", (key, True, False)) + return True + return False + + def parse_array_begin(ch: str, cur_state_ext: Tuple[int, bool, bool]) -> bool: + idx, value_begins, require_value = cur_state_ext + if parse_ws(ch): + return True + if value_begins: + prefix_stack.pop() + if ch == ",": + state_stack[-1] = ("array", (idx + 1, False, True)) + return True + if ch == "]": + add_event("end_array", None, ch, True) + state_stack.pop() + return True + else: + if not require_value and ch == "]": + add_event("end_array", None, ch, True) + state_stack.pop() + return True + state_stack[-1] = ("array", (idx, True, False)) + prefix_stack.append(f"[{idx}]") + if parse_value_begin(ch): + return True + raise Exception(f"invalid value for index {idx}: {ch}") + return False + + def parse_str_value(ch: str, cur_state_ext: Tuple[bool, str, str, bool]) -> bool: + in_escape, escape_buf, value_buf, is_obj_key = cur_state_ext + ev: ParserEventType = "object_key" if is_obj_key else "string" # type: ignore + if in_escape and escape_buf.startswith("u"): + if ch in "0123456789abcdefABCDEF": + escape_buf += ch + else: + raise Exception(f"invalid unicode escape sequence: \\{escape_buf}{ch}") + if len(escape_buf) == 5: + new_ch = chr(int(escape_buf[1:], 16)) + value_buf += new_ch + add_event(ev, None, new_ch, False) + state_stack[-1] = ("string", (False, "", value_buf, is_obj_key)) + else: + state_stack[-1] = ("string", (True, escape_buf, value_buf, is_obj_key)) + return True + if in_escape: + assert escape_buf == "" + if ch == "u": + state_stack[-1] = ("string", (True, ch, value_buf, is_obj_key)) + return True + new_ch = "" + if ch == "n": + new_ch = "\n" + elif ch == "/": + new_ch = "/" + elif ch == "\\": + new_ch = "\\" + elif ch == "r": + new_ch = "\r" + elif ch == "r": + new_ch = "\r" + elif ch == "t": + new_ch = "\t" + elif ch == "b": + new_ch = "\b" + elif ch == "f": + new_ch = "\f" + elif ch == '"': + new_ch = '"' + else: + raise Exception(f"invalid escape sequence: \\{ch}") + value_buf += new_ch + add_event(ev, None, new_ch, False) + state_stack[-1] = ("string", (False, "", value_buf, is_obj_key)) + return True + if ch == '"': + add_event(ev, value_buf, "", True) + state_stack.pop() + if is_obj_key: + prefix_stack.append(value_buf) + state_stack.append(("object_value", (value_buf, False, False))) + return True + if ch == "\\": + state_stack[-1] = ("string", (True, "", value_buf, is_obj_key)) + return True + value_buf += ch + add_event(ev, None, ch, False) + state_stack[-1] = ("string", (False, "", value_buf, is_obj_key)) + return True + + def parse_literal_value( + ch: str, + cur_state_ext: Tuple[str, ParserEventType, str, Any], + ) -> bool: + buf, ev, literal, value = cur_state_ext + buf += ch + if buf == literal: + add_event(ev, value, buf, True) + state_stack.pop() + return True + if literal.startswith(buf): + add_event(ev, None, ch, False) + state_stack[-1] = ("literal", (buf, ev, literal, value)) + return True + raise Exception(f"invalid literal in parsing when expecting {literal}: {buf}") + + def parse_number(ch: str, cur_state_ext: Tuple[str, bool, bool, bool]): + buf, in_exp, in_frac, in_exp_sign = cur_state_ext + if ch.isdigit() or ch == "." or ch == "e" or ch == "E" or ch == "+" or ch == "-": + buf += ch + add_event("number", None, ch, False) + state_stack[-1] = ("number", (buf, in_exp, in_frac, in_exp_sign)) + return True + num_val = float(buf) + add_event("number", num_val, "", True) + state_stack.pop() + return False + + def clear_ev_queue(): + result = ev_queue.copy() + result = reduce_events(result) + ev_queue.clear() + return result + + # while True: + for chunk in itertools.chain(token_stream, [None]): + # chunk = next(token_stream) + if chunk is None: + is_end = True + else: + buf += chunk + while True: + if len(buf) == 0 and not is_end: + break + cur_state, cur_state_ext = state_stack[-1] if len(state_stack) > 0 else (None, None) + ch = "" if buf == "" else buf[0] + buf = buf if buf == "" else buf[1:] + r = False + if cur_state is None: + r = parse_value_begin(ch) + elif cur_state == "object": + r = parse_obj_begin(ch) + elif cur_state == "string": + assert cur_state_ext is not None + r = parse_str_value(ch, cur_state_ext) + elif cur_state == "object_value": + assert cur_state_ext is not None + r = parse_obj_value(ch, cur_state_ext) + elif cur_state == "array": + assert cur_state_ext is not None + r = parse_array_begin(ch, cur_state_ext) + elif cur_state == "literal": + assert cur_state_ext is not None + r = parse_literal_value(ch, cur_state_ext) + elif cur_state == "number": + assert cur_state_ext is not None + r = parse_number(ch, cur_state_ext) + if not r: + # number needs to peek next token to determine if it's finished + # restore token to buffer when finishes + buf = ch + buf + r = True + continue + else: + raise Exception(f"not implemented handling for {cur_state}: {ch}") + if not r and not is_end: + raise Exception( + f"failed to parse {cur_state}: {ch} \n State: {state_stack} Prefix: {prefix_stack}", + ) + if is_end: + break + yield from clear_ev_queue() From 3181f595af1e865405ef4cf935a4fea9ffd6f507 Mon Sep 17 00:00:00 2001 From: Jack-Q Date: Thu, 11 Jan 2024 15:50:26 +0800 Subject: [PATCH 02/15] trigger ws event properly --- taskweaver/utils/json_parser.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/taskweaver/utils/json_parser.py b/taskweaver/utils/json_parser.py index 8561f974..d4aa739c 100644 --- a/taskweaver/utils/json_parser.py +++ b/taskweaver/utils/json_parser.py @@ -35,6 +35,7 @@ class ParserEvent(NamedTuple): "number", "string", "literal", + "ws", ] @@ -84,8 +85,14 @@ def add_event(ev: ParserEventType, value: Any, value_str: str, is_end: bool): ) def parse_ws(ch: str) -> bool: + is_in_ws = state_stack[-1][0] == "ws" if len(state_stack) > 0 else False + if not is_ws(ch): + if is_in_ws: + add_event("ws", None, "", True) + state_stack.pop() return False + state_stack.append(("ws", None)) add_event("ws", None, ch, False) return True @@ -145,7 +152,6 @@ def parse_obj_value(ch: str, cur_state_ext: Tuple[str, bool, bool]) -> bool: if ch == ",": return True if ch == "}": - print(f"state check: {state_stack}") add_event("end_map", None, ch, True) state_stack.pop() # pop the object begin state return True @@ -274,7 +280,7 @@ def parse_number(ch: str, cur_state_ext: Tuple[str, bool, bool, bool]): state_stack.pop() return False - def clear_ev_queue(): + def process_ev_queue(): result = ev_queue.copy() result = reduce_events(result) ev_queue.clear() @@ -319,6 +325,14 @@ def clear_ev_queue(): buf = ch + buf r = True continue + elif cur_state == "ws": + r = parse_ws(ch) + if not r: + # ws also need to peek next token to determine the end + # restore token to buffer when finishes + buf = ch + buf + r = True + continue else: raise Exception(f"not implemented handling for {cur_state}: {ch}") if not r and not is_end: @@ -327,4 +341,4 @@ def clear_ev_queue(): ) if is_end: break - yield from clear_ev_queue() + yield from process_ev_queue() From 91bfafe90fb3f67a19c064ad5607beab44cc394f Mon Sep 17 00:00:00 2001 From: Jack-Q Date: Thu, 11 Jan 2024 18:11:11 +0800 Subject: [PATCH 03/15] add option to skip ws --- taskweaver/utils/json_parser.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/taskweaver/utils/json_parser.py b/taskweaver/utils/json_parser.py index d4aa739c..2d9bf58b 100644 --- a/taskweaver/utils/json_parser.py +++ b/taskweaver/utils/json_parser.py @@ -39,10 +39,12 @@ class ParserEvent(NamedTuple): ] -def reduce_events(events: Iterable[ParserEvent]) -> Iterable[ParserEvent]: +def reduce_events(events: Iterable[ParserEvent], skip_ws: bool = True) -> Iterable[ParserEvent]: reduced: List[ParserEvent] = [] cur: Optional[ParserEvent] = None for ev in events: + if skip_ws and ev.event == "ws": + continue if cur is None: cur = ev continue @@ -66,7 +68,7 @@ def is_ws(ch: str): return ch == " " or ch == "\t" or ch == "\n" or ch == "\r" -def parse_json_stream(token_stream: Iterable[str]) -> Iterable[ParserEvent]: +def parse_json_stream(token_stream: Iterable[str], skip_ws: bool = False) -> Iterable[ParserEvent]: buf: str = "" is_end: bool = False prefix_stack: List[str] = [] @@ -92,7 +94,8 @@ def parse_ws(ch: str) -> bool: add_event("ws", None, "", True) state_stack.pop() return False - state_stack.append(("ws", None)) + if not is_in_ws: + state_stack.append(("ws", None)) add_event("ws", None, ch, False) return True @@ -282,11 +285,10 @@ def parse_number(ch: str, cur_state_ext: Tuple[str, bool, bool, bool]): def process_ev_queue(): result = ev_queue.copy() - result = reduce_events(result) + result = reduce_events(result, skip_ws=skip_ws) ev_queue.clear() return result - # while True: for chunk in itertools.chain(token_stream, [None]): # chunk = next(token_stream) if chunk is None: From c1383bc25b29108a1e97d4916d92e90a8b99b4e4 Mon Sep 17 00:00:00 2001 From: Jack-Q Date: Thu, 11 Jan 2024 18:30:09 +0800 Subject: [PATCH 04/15] parser root element check --- taskweaver/utils/json_parser.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/taskweaver/utils/json_parser.py b/taskweaver/utils/json_parser.py index 2d9bf58b..5dd92457 100644 --- a/taskweaver/utils/json_parser.py +++ b/taskweaver/utils/json_parser.py @@ -27,11 +27,10 @@ class ParserEvent(NamedTuple): ParserStateType = Literal[ + "root", "object", - "object_key", "object_value", "array", - "array_value", "number", "string", "literal", @@ -72,7 +71,7 @@ def parse_json_stream(token_stream: Iterable[str], skip_ws: bool = False) -> Ite buf: str = "" is_end: bool = False prefix_stack: List[str] = [] - state_stack: List[Tuple[ParserStateType, Any]] = [] + state_stack: List[Tuple[ParserStateType, Any]] = [("root", (False, False))] ev_queue: List[ParserEvent] = [] def add_event(ev: ParserEventType, value: Any, value_str: str, is_end: bool): @@ -101,7 +100,7 @@ def parse_ws(ch: str) -> bool: def parse_str_begin(ch: str, is_obj_key: bool = False) -> bool: if ch == '"': - add_event("object_key" if is_obj_key else "string", "", "", False) # type: ignore + add_event("map_key" if is_obj_key else "string", "", "", False) state_stack.append(("string", (False, "", "", is_obj_key))) return True return False @@ -196,7 +195,7 @@ def parse_array_begin(ch: str, cur_state_ext: Tuple[int, bool, bool]) -> bool: def parse_str_value(ch: str, cur_state_ext: Tuple[bool, str, str, bool]) -> bool: in_escape, escape_buf, value_buf, is_obj_key = cur_state_ext - ev: ParserEventType = "object_key" if is_obj_key else "string" # type: ignore + ev: ParserEventType = "map_key" if is_obj_key else "string" if in_escape and escape_buf.startswith("u"): if ch in "0123456789abcdefABCDEF": escape_buf += ch @@ -283,6 +282,18 @@ def parse_number(ch: str, cur_state_ext: Tuple[str, bool, bool, bool]): state_stack.pop() return False + def parse_root(ch: str, cur_state_ext: Tuple[bool, bool]): + has_root_elem, is_end = cur_state_ext + if parse_ws(ch): + return True + if ch == "": + state_stack[-1] = ("root", (has_root_elem, True)) + return True + if has_root_elem: + raise Exception(f"invalid token after root element: {ch}") + state_stack[-1] = ("root", (True, is_end)) + return parse_value_begin(ch) + def process_ev_queue(): result = ev_queue.copy() result = reduce_events(result, skip_ws=skip_ws) @@ -290,7 +301,6 @@ def process_ev_queue(): return result for chunk in itertools.chain(token_stream, [None]): - # chunk = next(token_stream) if chunk is None: is_end = True else: @@ -298,12 +308,12 @@ def process_ev_queue(): while True: if len(buf) == 0 and not is_end: break - cur_state, cur_state_ext = state_stack[-1] if len(state_stack) > 0 else (None, None) + cur_state, cur_state_ext = state_stack[-1] ch = "" if buf == "" else buf[0] buf = buf if buf == "" else buf[1:] r = False - if cur_state is None: - r = parse_value_begin(ch) + if cur_state == "root": + r = parse_root(ch, cur_state_ext) elif cur_state == "object": r = parse_obj_begin(ch) elif cur_state == "string": From 7f7a4ea76c9c6318da1afa320dd886593add0d29 Mon Sep 17 00:00:00 2001 From: Jack-Q Date: Thu, 11 Jan 2024 18:52:46 +0800 Subject: [PATCH 05/15] format prefix style and provide ijson compat format --- taskweaver/utils/json_parser.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/taskweaver/utils/json_parser.py b/taskweaver/utils/json_parser.py index 5dd92457..242d3267 100644 --- a/taskweaver/utils/json_parser.py +++ b/taskweaver/utils/json_parser.py @@ -38,7 +38,10 @@ class ParserEvent(NamedTuple): ] -def reduce_events(events: Iterable[ParserEvent], skip_ws: bool = True) -> Iterable[ParserEvent]: +def reduce_events( + events: Iterable[ParserEvent], + skip_ws: bool = True, +) -> Iterable[ParserEvent]: reduced: List[ParserEvent] = [] cur: Optional[ParserEvent] = None for ev in events: @@ -67,17 +70,25 @@ def is_ws(ch: str): return ch == " " or ch == "\t" or ch == "\n" or ch == "\r" -def parse_json_stream(token_stream: Iterable[str], skip_ws: bool = False) -> Iterable[ParserEvent]: +def parse_json_stream( + token_stream: Iterable[str], + skip_ws: bool = False, + ijson_prefix: bool = False, +) -> Iterable[ParserEvent]: buf: str = "" is_end: bool = False - prefix_stack: List[str] = [] + prefix_stack: List[Tuple[bool, str]] = [] state_stack: List[Tuple[ParserStateType, Any]] = [("root", (False, False))] ev_queue: List[ParserEvent] = [] def add_event(ev: ParserEventType, value: Any, value_str: str, is_end: bool): + if ijson_prefix: + prefix = ".".join("item" if is_arr else val for is_arr, val in prefix_stack) + else: + prefix = "".join(f"[{val}]" if is_arr else f".{val}" for is_arr, val in prefix_stack) ev_queue.append( ParserEvent( - "".join(prefix_stack), + prefix, ev, value, value_str, @@ -187,7 +198,7 @@ def parse_array_begin(ch: str, cur_state_ext: Tuple[int, bool, bool]) -> bool: state_stack.pop() return True state_stack[-1] = ("array", (idx, True, False)) - prefix_stack.append(f"[{idx}]") + prefix_stack.append((True, str(idx))) if parse_value_begin(ch): return True raise Exception(f"invalid value for index {idx}: {ch}") @@ -243,7 +254,7 @@ def parse_str_value(ch: str, cur_state_ext: Tuple[bool, str, str, bool]) -> bool add_event(ev, value_buf, "", True) state_stack.pop() if is_obj_key: - prefix_stack.append(value_buf) + prefix_stack.append((False, value_buf)) state_stack.append(("object_value", (value_buf, False, False))) return True if ch == "\\": From 09fa3b09066a293f139913805356fd582d61583c Mon Sep 17 00:00:00 2001 From: Jack-Q Date: Thu, 11 Jan 2024 20:36:25 +0800 Subject: [PATCH 06/15] allow skip extra content after first root element --- taskweaver/utils/json_parser.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/taskweaver/utils/json_parser.py b/taskweaver/utils/json_parser.py index 242d3267..9fca512b 100644 --- a/taskweaver/utils/json_parser.py +++ b/taskweaver/utils/json_parser.py @@ -15,6 +15,7 @@ "number", "string", "ws", + "skip", ] @@ -74,6 +75,7 @@ def parse_json_stream( token_stream: Iterable[str], skip_ws: bool = False, ijson_prefix: bool = False, + skip_after_root: bool = False, ) -> Iterable[ParserEvent]: buf: str = "" is_end: bool = False @@ -294,16 +296,27 @@ def parse_number(ch: str, cur_state_ext: Tuple[str, bool, bool, bool]): return False def parse_root(ch: str, cur_state_ext: Tuple[bool, bool]): - has_root_elem, is_end = cur_state_ext + has_root_elem, has_skip_cnt = cur_state_ext + + if has_skip_cnt and skip_after_root: + add_event("skip", None, ch, ch == "") + return True + if parse_ws(ch): return True if ch == "": - state_stack[-1] = ("root", (has_root_elem, True)) return True if has_root_elem: + if skip_after_root: + # detected content after first root element, skip if configured + state_stack[-1] = ("root", (True, True)) + add_event("skip", None, ch, False) + return True raise Exception(f"invalid token after root element: {ch}") - state_stack[-1] = ("root", (True, is_end)) - return parse_value_begin(ch) + else: + # first root element begins + state_stack[-1] = ("root", (True, has_skip_cnt)) + return parse_value_begin(ch) def process_ev_queue(): result = ev_queue.copy() From e903e920ebed224a1ddf3bed90f47a2fcd326fae Mon Sep 17 00:00:00 2001 From: Jack-Q Date: Thu, 11 Jan 2024 21:41:31 +0800 Subject: [PATCH 07/15] collect and reconstruct obj --- taskweaver/utils/json_parser.py | 44 +++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/taskweaver/utils/json_parser.py b/taskweaver/utils/json_parser.py index 9fca512b..94ad9ccf 100644 --- a/taskweaver/utils/json_parser.py +++ b/taskweaver/utils/json_parser.py @@ -378,3 +378,47 @@ def process_ev_queue(): if is_end: break yield from process_ev_queue() + + +def parse_json(token_stream: Iterable[str], skip_after_root: bool = False) -> Any: + root_array: List[Any] = [] + obj_stack: List[Tuple[Literal["object", "array", "key"], Any]] = [ + ("array", root_array), + ] + + def add_value(val: Any): + cur_obj_t, cur_obj_v = obj_stack[-1] + if cur_obj_t == "array": + assert type(cur_obj_v) is list + cur_obj_v.append(val) # type: ignore + elif cur_obj_t == "key": + obj_stack.pop() + assert obj_stack[-1][0] == "object", f"unexpected stack state when adding key {obj_stack}" + obj_stack[-1][1][cur_obj_v] = val + else: + assert False, "object value need to have key" + + for ev in parse_json_stream(token_stream, skip_after_root=skip_after_root): + if not ev.is_end: + continue + evt = ev.event + val = ev.value + + if evt == "start_map": + obj_stack.append(("object", {})) + elif evt == "start_array": + obj_stack.append(("array", [])) + elif evt == "map_key": + obj_stack.append(("key", val)) + elif evt == "ws" or evt == "skip": + pass + elif evt == "end_map" or evt == "end_array": + obj_val = obj_stack.pop()[1] + add_value(obj_val) + elif evt == "boolean" or evt == "null" or evt == "number" or evt == "string": + add_value(val) + else: + assert f"unsupported parser event {evt}" + assert len(obj_stack) == 1 + assert len(root_array) == 1 + return root_array[0] From d380105c34c5f859dbcb70206a6759ca3c9f45d2 Mon Sep 17 00:00:00 2001 From: Jack-Q Date: Thu, 11 Jan 2024 22:28:48 +0800 Subject: [PATCH 08/15] add simple test for parser --- taskweaver/utils/json_parser.py | 4 +- tests/unit_tests/test_json_parser.py | 74 ++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 1 deletion(-) create mode 100644 tests/unit_tests/test_json_parser.py diff --git a/taskweaver/utils/json_parser.py b/taskweaver/utils/json_parser.py index 94ad9ccf..c08d1697 100644 --- a/taskweaver/utils/json_parser.py +++ b/taskweaver/utils/json_parser.py @@ -284,13 +284,15 @@ def parse_literal_value( raise Exception(f"invalid literal in parsing when expecting {literal}: {buf}") def parse_number(ch: str, cur_state_ext: Tuple[str, bool, bool, bool]): + # TODO: support rigir buf, in_exp, in_frac, in_exp_sign = cur_state_ext if ch.isdigit() or ch == "." or ch == "e" or ch == "E" or ch == "+" or ch == "-": buf += ch add_event("number", None, ch, False) state_stack[-1] = ("number", (buf, in_exp, in_frac, in_exp_sign)) return True - num_val = float(buf) + is_float_mode = "." in buf or "e" in buf or "E" in buf + num_val = float(buf) if is_float_mode else int(buf) add_event("number", num_val, "", True) state_stack.pop() return False diff --git a/tests/unit_tests/test_json_parser.py b/tests/unit_tests/test_json_parser.py new file mode 100644 index 00000000..1b34891e --- /dev/null +++ b/tests/unit_tests/test_json_parser.py @@ -0,0 +1,74 @@ +import json +from typing import Any, List + +import pytest + +from taskweaver.utils import json_parser + +obj_cases: List[Any] = [ + ["hello", "world"], + "any_str", + { + "test_key": { + "str_array": ["hello", "world", "test"], + "another_key": {}, + "empty_array": [[[]], []], + }, + }, + [True, False, None], + [1, 2, 3], + 123.345, + {"val": 123.345}, + [ + { + "a": {}, + "b": {}, + "c": {}, + "d": {}, + "e": {}, + }, + ], + { + "test_key": { + "str_array": ["hello", "world", "test"], + "test another key": [ + "hello", + "world", + 1, + 2.0, + True, + False, + None, + { + "test yet another key": "test value", + "test yet key 2": '\r\n\u1234\ffdfd\tfdfv\b"', + }, + ], + True: False, + }, + }, +] + + +@pytest.mark.parametrize("obj", obj_cases) +def test_json_parser(obj: Any): + dumped_str = json.dumps(obj) + obj = json_parser.parse_json(json.dumps(obj)) + dumped_str2 = json.dumps(obj) + assert dumped_str == dumped_str2 + + +str_cases: List[str] = [ + ' { "a": [ true, false, null ] } ', + " \r \n \t [ \r \n \t true, false, null \r \n \t ] \r \n \t ", + ' \r \n \t "hello world" \r \n \t ', +] + + +@pytest.mark.parametrize("str_case", str_cases) +def test_json_parser_str(str_case: str): + obj = json.loads(str_case) + dumped_str = json.dumps(obj) + obj = json_parser.parse_json(str_case) + dumped_str2 = json.dumps(obj) + assert dumped_str == dumped_str2 From 958cb264d5ac6d0405b752cd95243752aab9389d Mon Sep 17 00:00:00 2001 From: Jack-Q Date: Fri, 12 Jan 2024 14:18:22 +0800 Subject: [PATCH 09/15] add incomplete stream in translator --- taskweaver/role/translator.py | 55 ++++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/taskweaver/role/translator.py b/taskweaver/role/translator.py index f1e99ab0..6afb23b5 100644 --- a/taskweaver/role/translator.py +++ b/taskweaver/role/translator.py @@ -1,7 +1,7 @@ import io import json from json import JSONDecodeError -from typing import Any, Callable, Dict, Iterable, Iterator, List, Literal, Optional, Union +from typing import Any, Callable, Dict, Iterable, Iterator, List, Literal, Optional, Tuple, Union import ijson from injector import inject @@ -52,26 +52,32 @@ def stream_filter(s: Iterable[ChatMessageType]) -> Iterator[str]: yield c["content"] self.logger.info(f"LLM output: {llm_output}") - for d in self.parse_llm_output_stream(stream_filter(llm_output)): - type_str = d["type"] + value_buf: str = "" + for type_str, value, is_end in self.parse_llm_output_stream(stream_filter(llm_output)): + value_buf += value type: Optional[AttachmentType] = None - value = d["content"] if type_str == "message": - post_proxy.update_message(value) + post_proxy.update_message(value_buf, is_end=is_end) + value_buf = "" elif type_str == "send_to": - assert value in [ - "User", - "Planner", - "CodeInterpreter", - ], f"Invalid send_to value: {value}" - post_proxy.update_send_to(value) # type: ignore + if is_end: + assert value in [ + "User", + "Planner", + "CodeInterpreter", + ], f"Invalid send_to value: {value}" + post_proxy.update_send_to(value) # type: ignore + else: + # collect the whole content before updating post + pass else: try: type = AttachmentType(type_str) - post_proxy.update_attachment(value, type) + post_proxy.update_attachment(value_buf, type, is_end=is_end) + value_buf = "" except Exception as e: self.logger.warning( - f"Failed to parse attachment: {d} due to {str(e)}", + f"Failed to parse attachment: {type_str}-{value_buf} due to {str(e)}", ) continue parsed_type = ( @@ -84,7 +90,9 @@ def stream_filter(s: Iterable[ChatMessageType]) -> Iterator[str]: else None ) assert parsed_type is not None, f"Invalid type: {type_str}" - if early_stop is not None and early_stop(parsed_type, value): + + # check whether parsing should be triggered prematurely when each key parsing is finished + if is_end and early_stop is not None and early_stop(parsed_type, value): break if validation_func is not None: @@ -139,7 +147,7 @@ def parse_llm_output(self, llm_output: str) -> List[Dict[str, str]]: def parse_llm_output_stream( self, llm_output: Iterator[str], - ) -> Iterator[Dict[str, str]]: + ) -> Iterator[Tuple[str, str, bool]]: class StringIteratorIO(io.TextIOBase): def __init__(self, iter: Iterator[str]): self._iter = iter @@ -179,21 +187,22 @@ def read(self, n: Optional[int] = None): # use small buffer to get parse result as soon as acquired from LLM parser = ijson.parse(json_data_stream, buf_size=5) - element = {} + cur_type: Optional[str] = None + cur_content: Optional[str] = None try: for prefix, event, value in parser: if prefix == "response.item" and event == "map_key" and value == "type": - element["type"] = None + cur_type = None elif prefix == "response.item.type" and event == "string": - element["type"] = value + cur_type = value elif prefix == "response.item" and event == "map_key" and value == "content": - element["content"] = None + cur_content = None elif prefix == "response.item.content" and event == "string": - element["content"] = value + cur_content = value - if len(element) == 2 and None not in element.values(): - yield element - element = {} + if cur_type is not None and cur_content is not None: + yield cur_type, cur_content, True + cur_type, cur_content = None, None except ijson.JSONError as e: self.logger.warning( f"Failed to parse LLM output stream due to JSONError: {str(e)}", From 3b8065077a7b947251ef6a6b3252ab2c1dd31df7 Mon Sep 17 00:00:00 2001 From: Jack-Q Date: Fri, 12 Jan 2024 15:11:13 +0800 Subject: [PATCH 10/15] raise speciality parsing error and validate content --- taskweaver/utils/json_parser.py | 39 +++++++++++++++------ tests/unit_tests/test_json_parser.py | 51 ++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 10 deletions(-) diff --git a/taskweaver/utils/json_parser.py b/taskweaver/utils/json_parser.py index c08d1697..bac416a2 100644 --- a/taskweaver/utils/json_parser.py +++ b/taskweaver/utils/json_parser.py @@ -19,6 +19,10 @@ ] +class StreamJsonParserError(Exception): + pass + + class ParserEvent(NamedTuple): prefix: str event: ParserEventType @@ -170,12 +174,12 @@ def parse_obj_value(ch: str, cur_state_ext: Tuple[str, bool, bool]) -> bool: add_event("end_map", None, ch, True) state_stack.pop() # pop the object begin state return True - raise Exception(f"invalid value after value of key {key}: {ch}") + raise StreamJsonParserError(f"invalid value after value of key {key}: {ch}") if value_to_begin: state_stack[-1] = ("object_value", (key, False, True)) if parse_value_begin(ch): return True - raise Exception(f"invalid value for key {key}: {ch}") + raise StreamJsonParserError(f"invalid value for key {key}: {ch}") if ch == ":": state_stack[-1] = ("object_value", (key, True, False)) return True @@ -203,7 +207,7 @@ def parse_array_begin(ch: str, cur_state_ext: Tuple[int, bool, bool]) -> bool: prefix_stack.append((True, str(idx))) if parse_value_begin(ch): return True - raise Exception(f"invalid value for index {idx}: {ch}") + raise StreamJsonParserError(f"invalid value for index {idx}: {ch}") return False def parse_str_value(ch: str, cur_state_ext: Tuple[bool, str, str, bool]) -> bool: @@ -213,7 +217,7 @@ def parse_str_value(ch: str, cur_state_ext: Tuple[bool, str, str, bool]) -> bool if ch in "0123456789abcdefABCDEF": escape_buf += ch else: - raise Exception(f"invalid unicode escape sequence: \\{escape_buf}{ch}") + raise StreamJsonParserError(f"invalid unicode escape sequence: \\{escape_buf}{ch}") if len(escape_buf) == 5: new_ch = chr(int(escape_buf[1:], 16)) value_buf += new_ch @@ -247,7 +251,7 @@ def parse_str_value(ch: str, cur_state_ext: Tuple[bool, str, str, bool]) -> bool elif ch == '"': new_ch = '"' else: - raise Exception(f"invalid escape sequence: \\{ch}") + raise StreamJsonParserError(f"invalid escape sequence: \\{ch}") value_buf += new_ch add_event(ev, None, new_ch, False) state_stack[-1] = ("string", (False, "", value_buf, is_obj_key)) @@ -281,7 +285,7 @@ def parse_literal_value( add_event(ev, None, ch, False) state_stack[-1] = ("literal", (buf, ev, literal, value)) return True - raise Exception(f"invalid literal in parsing when expecting {literal}: {buf}") + raise StreamJsonParserError(f"invalid literal in parsing when expecting {literal}: {buf}") def parse_number(ch: str, cur_state_ext: Tuple[str, bool, bool, bool]): # TODO: support rigir @@ -292,7 +296,10 @@ def parse_number(ch: str, cur_state_ext: Tuple[str, bool, bool, bool]): state_stack[-1] = ("number", (buf, in_exp, in_frac, in_exp_sign)) return True is_float_mode = "." in buf or "e" in buf or "E" in buf - num_val = float(buf) if is_float_mode else int(buf) + try: + num_val = float(buf) if is_float_mode else int(buf) + except ValueError: + raise StreamJsonParserError(f"invalid number literal {buf}") add_event("number", num_val, "", True) state_stack.pop() return False @@ -314,7 +321,7 @@ def parse_root(ch: str, cur_state_ext: Tuple[bool, bool]): state_stack[-1] = ("root", (True, True)) add_event("skip", None, ch, False) return True - raise Exception(f"invalid token after root element: {ch}") + raise StreamJsonParserError(f"invalid token after root element: {ch}") else: # first root element begins state_stack[-1] = ("root", (True, has_skip_cnt)) @@ -372,15 +379,27 @@ def process_ev_queue(): r = True continue else: - raise Exception(f"not implemented handling for {cur_state}: {ch}") + raise StreamJsonParserError(f"not implemented handling for {cur_state}: {ch}") if not r and not is_end: - raise Exception( + raise StreamJsonParserError( f"failed to parse {cur_state}: {ch} \n State: {state_stack} Prefix: {prefix_stack}", ) if is_end: break yield from process_ev_queue() + # post parsing checks + assert len(state_stack) > 0 + + final_root_type, final_root_state = state_stack[0] + assert final_root_type == "root" + + if not final_root_state[0]: + raise StreamJsonParserError("empty string with no element found") + + if len(state_stack) > 1: + raise StreamJsonParserError("incomplete JSON str ends prematurely") + def parse_json(token_stream: Iterable[str], skip_after_root: bool = False) -> Any: root_array: List[Any] = [] diff --git a/tests/unit_tests/test_json_parser.py b/tests/unit_tests/test_json_parser.py index 1b34891e..156ed165 100644 --- a/tests/unit_tests/test_json_parser.py +++ b/tests/unit_tests/test_json_parser.py @@ -53,6 +53,24 @@ @pytest.mark.parametrize("obj", obj_cases) def test_json_parser(obj: Any): dumped_str = json.dumps(obj) + + # expect error with the JSON is incomplete + for i in range(len(dumped_str) - 1): + cur_incomplete_seg = dumped_str[:i] + + try: + float(cur_incomplete_seg) + # skip incomplete number that is valid JSON as well + continue + except ValueError: + pass + + with pytest.raises(json_parser.StreamJsonParserError): + json_parser.parse_json(cur_incomplete_seg) + + # proper parsing exception should raise before this + raise Exception("Failed to parse incomplete JSON: " + cur_incomplete_seg) + obj = json_parser.parse_json(json.dumps(obj)) dumped_str2 = json.dumps(obj) assert dumped_str == dumped_str2 @@ -72,3 +90,36 @@ def test_json_parser_str(str_case: str): obj = json_parser.parse_json(str_case) dumped_str2 = json.dumps(obj) assert dumped_str == dumped_str2 + + +bad_cases: List[str] = [ + " - ", + "'abc'", + "\\a", + "{} {}", + "[[[]}]", + '""""', + "{'abc': 'def'}", + "[[[[{{{{0}}}}]]]]", + " ", + "", + "((((()))))", + '"\\"', + # incomplete json + '"abc', + "[1,2,3", + '{"abc', + '{"abc":', + '{"abc":}', + '{"abc":1', + "123,456,789", + "undefined", + "None", + "{true: false}", +] + + +@pytest.mark.parametrize("bad_case", bad_cases) +def test_json_parser_bad(bad_case: str): + with pytest.raises(json_parser.StreamJsonParserError): + json_parser.parse_json(bad_case) From f17ba02b5100eebe1966d7a42c1236ca63fc428a Mon Sep 17 00:00:00 2001 From: Jack-Q Date: Fri, 12 Jan 2024 15:58:05 +0800 Subject: [PATCH 11/15] add v2 parser into translator --- taskweaver/role/translator.py | 80 ++++++++++++++++++++++++++++- tests/unit_tests/test_translator.py | 53 +++++++++++++++---- 2 files changed, 123 insertions(+), 10 deletions(-) diff --git a/taskweaver/role/translator.py b/taskweaver/role/translator.py index 6afb23b5..c6974796 100644 --- a/taskweaver/role/translator.py +++ b/taskweaver/role/translator.py @@ -11,6 +11,7 @@ from taskweaver.memory import Attachment, Post from taskweaver.memory.attachment import AttachmentType from taskweaver.module.event_emitter import PostEventProxy, SessionEventEmitter +from taskweaver.utils import json_parser class PostTranslator: @@ -34,6 +35,7 @@ def raw_text_to_post( post_proxy: PostEventProxy, early_stop: Optional[Callable[[Union[AttachmentType, Literal["message", "send_to"]], str], bool]] = None, validation_func: Optional[Callable[[Post], None]] = None, + use_v2_parser: bool = False, ) -> None: """ Convert the raw text output of LLM to a Post object. @@ -53,7 +55,13 @@ def stream_filter(s: Iterable[ChatMessageType]) -> Iterator[str]: self.logger.info(f"LLM output: {llm_output}") value_buf: str = "" - for type_str, value, is_end in self.parse_llm_output_stream(stream_filter(llm_output)): + filtered_stream = stream_filter(llm_output) + parser_stream = ( + self.parse_llm_output_stream_v2(filtered_stream) + if use_v2_parser + else self.parse_llm_output_stream(filtered_stream) + ) + for type_str, value, is_end in parser_stream: value_buf += value type: Optional[AttachmentType] = None if type_str == "message": @@ -207,3 +215,73 @@ def read(self, n: Optional[int] = None): self.logger.warning( f"Failed to parse LLM output stream due to JSONError: {str(e)}", ) + + def parse_llm_output_stream_v2( + self, + llm_output: Iterator[str], + ) -> Iterator[Tuple[str, str, bool]]: + parser = json_parser.parse_json_stream(llm_output, skip_after_root=True) + root_element_prefix = ".response" + + list_begin, list_end = False, False + item_idx = 0 + + cur_content_sent: bool = False + cur_content_sent_end: bool = False + cur_type: Optional[str] = None + cur_content: Optional[str] = None + + try: + for ev in parser: + if ev.prefix == root_element_prefix: + if ev.event == "start_array": + list_begin = True + if ev.event == "end_array": + list_end = True + + if not list_begin or list_end: + continue + + cur_item_prefix = f"{root_element_prefix}[{item_idx}]" + if ev.prefix == cur_item_prefix: + if ev.event == "start_map": + cur_content_sent, cur_content_sent_end = False, False + cur_type, cur_content = None, None + if ev.event == "end_map": + if cur_type is None or cur_content is None: + raise Exception( + f"Incomplete generate kv pair in index {item_idx}. " + f"type: {cur_type} content {cur_content}", + ) + + if cur_content_sent and not cur_content_sent_end: + # possible incomplete string, trigger end prematurely + yield cur_type, "", True + + if not cur_content_sent: + yield cur_type, cur_content, True + + cur_content_sent, cur_content_sent_end = False, False + cur_type, cur_content = None, None + item_idx += 1 + + if ev.prefix == cur_item_prefix + ".type": + if ev.event == "string" and ev.is_end: + cur_type = ev.value + + if ev.prefix == cur_item_prefix + ".content": + if ev.event == "string": + if cur_type is not None: + cur_content_sent = True + yield cur_type, ev.value_str, ev.is_end + + assert not cur_content_sent_end, "Invalid state: already sent is_end marker" + if ev.is_end: + cur_content_sent_end = True + if ev.is_end: + cur_content = ev.value + + except json_parser.StreamJsonParserError as e: + self.logger.warning( + f"Failed to parse LLM output stream due to JSONError: {str(e)}", + ) diff --git a/tests/unit_tests/test_translator.py b/tests/unit_tests/test_translator.py index 3fb5962d..fd841437 100644 --- a/tests/unit_tests/test_translator.py +++ b/tests/unit_tests/test_translator.py @@ -1,6 +1,7 @@ from random import randint from typing import Iterator +import pytest from injector import Injector from taskweaver.llm.util import format_chat_message @@ -47,8 +48,13 @@ def response_str() -> Iterator[str]: attachment_list = list(attachments) assert len(attachment_list) == 8 + attachments = translator.parse_llm_output_stream_v2(response_str()) + attachment_list = list(a for a in attachments if a[2]) # only count is_end is true + assert len(attachment_list) == 8 + -def test_parse_llm(): +@pytest.mark.parametrize("use_v2_parser", [True, False]) +def test_parse_llm(use_v2_parser: bool): def early_stop(type: AttachmentType, text: str) -> bool: if type in [AttachmentType.python, AttachmentType.sample, AttachmentType.text]: return True @@ -62,6 +68,7 @@ def early_stop(type: AttachmentType, text: str) -> bool: llm_output=[format_chat_message("assistant", response_str1)], post_proxy=post_proxy, early_stop=early_stop, + use_v2_parser=use_v2_parser, ) response = post_proxy.end() assert response.message == "" @@ -78,6 +85,7 @@ def early_stop(type: AttachmentType, text: str) -> bool: translator.raw_text_to_post( llm_output=[format_chat_message("assistant", response_str1)], post_proxy=post_proxy, + use_v2_parser=use_v2_parser, ) response = post_proxy.end() @@ -89,23 +97,50 @@ def early_stop(type: AttachmentType, text: str) -> bool: def test_post_to_raw_text(): - post = Post.create(message="This is the message", send_from="CodeInterpreter", send_to="Planner") + post = Post.create( + message="This is the message", + send_from="CodeInterpreter", + send_to="Planner", + ) - prompt = translator.post_to_raw_text(post=post, if_format_message=True, if_format_send_to=True) + prompt = translator.post_to_raw_text( + post=post, + if_format_message=True, + if_format_send_to=True, + ) assert prompt == ( '{"response": [{"type": "send_to", "content": "Planner"}, {"type": "message", ' '"content": "This is the message"}]}' ) - prompt = translator.post_to_raw_text(post=post, if_format_message=False, if_format_send_to=False) + prompt = translator.post_to_raw_text( + post=post, + if_format_message=False, + if_format_send_to=False, + ) assert prompt == '{"response": []}' - post.add_attachment(Attachment.create(type="thought", content="This is the thought")) - post.add_attachment(Attachment.create(type="python", content="print('This is the code')")) + post.add_attachment( + Attachment.create(type="thought", content="This is the thought"), + ) + post.add_attachment( + Attachment.create(type="python", content="print('This is the code')"), + ) post.add_attachment(Attachment.create(type="text", content="This is the text")) - post.add_attachment(Attachment.create(type="sample", content="print('This is the sample code')")) + post.add_attachment( + Attachment.create(type="sample", content="print('This is the sample code')"), + ) post.add_attachment(Attachment.create(type="execution_status", content="SUCCESS")) - post.add_attachment(Attachment.create(type="execution_result", content="This is the execution result")) + post.add_attachment( + Attachment.create( + type="execution_result", + content="This is the execution result", + ), + ) - prompt = translator.post_to_raw_text(post=post, if_format_message=True, if_format_send_to=True) + prompt = translator.post_to_raw_text( + post=post, + if_format_message=True, + if_format_send_to=True, + ) assert prompt == response_str1 From 7592330bdf182291d1506a4e1d25f68246103632 Mon Sep 17 00:00:00 2001 From: Jack-Q Date: Fri, 12 Jan 2024 17:41:10 +0800 Subject: [PATCH 12/15] fix event emitter with streamed content --- taskweaver/module/event_emitter.py | 6 +++++- taskweaver/planner/planner.py | 8 ++++++-- taskweaver/role/translator.py | 19 +++++++++++++++---- 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/taskweaver/module/event_emitter.py b/taskweaver/module/event_emitter.py index 20a4be41..70b55127 100644 --- a/taskweaver/module/event_emitter.py +++ b/taskweaver/module/event_emitter.py @@ -169,6 +169,10 @@ def update_attachment( if id is not None: attachment = self.post.attachment_list[-1] assert id == attachment.id + if type is not None: + assert type == attachment.type + attachment.content += message + attachment.extra = extra else: assert type is not None, "type is required when creating new attachment" attachment = Attachment.create( @@ -184,7 +188,7 @@ def update_attachment( { "type": type, "extra": extra, - "id": id, + "id": attachment.id, "is_end": is_end, }, ) diff --git a/taskweaver/planner/planner.py b/taskweaver/planner/planner.py index 2e8b434f..dd0b6b16 100644 --- a/taskweaver/planner/planner.py +++ b/taskweaver/planner/planner.py @@ -225,8 +225,12 @@ def check_post_validity(post: Post): assert post.send_to is not None, "send_to field is None" assert post.send_to != "Planner", "send_to field should not be Planner" assert post.message is not None, "message field is None" - assert post.attachment_list[0].type == AttachmentType.init_plan, "attachment type is not init_plan" - assert post.attachment_list[1].type == AttachmentType.plan, "attachment type is not plan" + assert ( + post.attachment_list[0].type == AttachmentType.init_plan + ), f"attachment type {post.attachment_list[0].type} is not init_plan" + assert ( + post.attachment_list[1].type == AttachmentType.plan + ), f"attachment type {post.attachment_list[1].type} is not plan" assert ( post.attachment_list[2].type == AttachmentType.current_plan_step ), "attachment type is not current_plan_step" diff --git a/taskweaver/role/translator.py b/taskweaver/role/translator.py index c6974796..1a2f8ea5 100644 --- a/taskweaver/role/translator.py +++ b/taskweaver/role/translator.py @@ -35,7 +35,7 @@ def raw_text_to_post( post_proxy: PostEventProxy, early_stop: Optional[Callable[[Union[AttachmentType, Literal["message", "send_to"]], str], bool]] = None, validation_func: Optional[Callable[[Post], None]] = None, - use_v2_parser: bool = False, + use_v2_parser: bool = True, ) -> None: """ Convert the raw text output of LLM to a Post object. @@ -61,6 +61,7 @@ def stream_filter(s: Iterable[ChatMessageType]) -> Iterator[str]: if use_v2_parser else self.parse_llm_output_stream(filtered_stream) ) + cur_attachment: Optional[Attachment] = None for type_str, value, is_end in parser_stream: value_buf += value type: Optional[AttachmentType] = None @@ -69,20 +70,30 @@ def stream_filter(s: Iterable[ChatMessageType]) -> Iterator[str]: value_buf = "" elif type_str == "send_to": if is_end: - assert value in [ + assert value_buf in [ "User", "Planner", "CodeInterpreter", ], f"Invalid send_to value: {value}" - post_proxy.update_send_to(value) # type: ignore + post_proxy.update_send_to(value_buf) # type: ignore + value_buf = "" else: # collect the whole content before updating post pass else: try: type = AttachmentType(type_str) - post_proxy.update_attachment(value_buf, type, is_end=is_end) + if cur_attachment is not None: + assert type == cur_attachment.type + cur_attachment = post_proxy.update_attachment( + value_buf, + type, + id=(cur_attachment.id if cur_attachment is not None else None), + is_end=is_end, + ) value_buf = "" + if is_end: + cur_attachment = None except Exception as e: self.logger.warning( f"Failed to parse attachment: {type_str}-{value_buf} due to {str(e)}", From 140688b8ddba2fe3438181872512311950a993cd Mon Sep 17 00:00:00 2001 From: Jack-Q Date: Fri, 12 Jan 2024 17:42:59 +0800 Subject: [PATCH 13/15] show streaming content on animation status --- taskweaver/chat/console/chat.py | 44 ++++++++++++++++++++++++++++++--- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/taskweaver/chat/console/chat.py b/taskweaver/chat/console/chat.py index 591e1997..e31bfafa 100644 --- a/taskweaver/chat/console/chat.py +++ b/taskweaver/chat/console/chat.py @@ -278,6 +278,27 @@ def get_ani_frame(frame: int = 0): ani_frame = " " * frame_inx + "<=💡=>" + " " * (10 - frame_inx) return ani_frame + def format_status_message(limit: int): + incomplete_suffix = "..." + incomplete_suffix_len = len(incomplete_suffix) + if len(cur_message_buffer) == 0: + if len(status_msg) > limit - 1: + return f" {status_msg[(limit - incomplete_suffix_len - 1):]}{incomplete_suffix}" + return " " + status_msg + + cur_key_display = style_line("[") + style_key(cur_key) + style_line("]") + cur_key_len = len(cur_key) + 2 # with extra bracket + cur_message_buffer_norm = cur_message_buffer.replace("\n", " ").replace( + "\r", + " ", + ) + + if len(cur_message_buffer_norm) < limit - cur_key_len - 1: + return f"{cur_key_display} {cur_message_buffer_norm}" + + status_msg_len = limit - cur_key_len - incomplete_suffix_len + return f"{cur_key_display}{incomplete_suffix}{cur_message_buffer_norm[-status_msg_len:]}" + last_time = 0 while True: clear_line() @@ -330,6 +351,7 @@ def get_ani_frame(frame: int = 0): key=cur_key, ), ) + cur_message_buffer = "" elif action == "round_error": error_message(opt) elif action == "status_update": @@ -340,14 +362,28 @@ def get_ani_frame(frame: int = 0): if self.exit_event.is_set(): break + cur_message_prefix: str = " TaskWeaver " + cur_ani_frame = get_ani_frame(counter) + cur_message_display_len = ( + terminal_column + - len(cur_message_prefix) + - 2 # separator for cur message prefix + - len(role) + - 2 # bracket for role + - len(cur_ani_frame) + - 2 # extra size for emoji in ani + ) + + cur_message_display = format_status_message(cur_message_display_len) + click.secho( - click.style(" TaskWeaver ", fg="white", bg="yellow") + click.style(cur_message_prefix, fg="white", bg="yellow") + click.style("▶ ", fg="yellow") + style_line("[") + style_role(role) - + style_line("] ") - + style_msg(status_msg) - + style_msg(get_ani_frame(counter)) + + style_line("]") + + style_msg(cur_message_display) + + style_msg(cur_ani_frame) + "\r", # f">>> [{style_role(role)}] {status_msg} {get_ani_frame(counter)}\r", nl=False, From b7e45b079deb1f0ac4364877935c7b584ffcb3cb Mon Sep 17 00:00:00 2001 From: Jack-Q Date: Fri, 12 Jan 2024 17:45:38 +0800 Subject: [PATCH 14/15] adjust mock size --- taskweaver/llm/mock.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taskweaver/llm/mock.py b/taskweaver/llm/mock.py index e598d8d7..1331cf41 100644 --- a/taskweaver/llm/mock.py +++ b/taskweaver/llm/mock.py @@ -327,7 +327,7 @@ def _get_from_playback_completion( content = cached_value["content"] cur_pos = 0 while cur_pos < len(content): - chunk_size = random.randint(3, 20) + chunk_size = random.randint(2, 8) next_pos = min(cur_pos + chunk_size, len(content)) yield format_chat_message(role, content[cur_pos:next_pos]) cur_pos = next_pos From 28c2a43f500a3fdc38cb8f395c3fa0c282cac0c9 Mon Sep 17 00:00:00 2001 From: Jack-Q Date: Sat, 13 Jan 2024 02:54:34 +0800 Subject: [PATCH 15/15] add smoother to evenly distribute LLM generations --- .../code_generator/code_generator.py | 1 + taskweaver/llm/__init__.py | 158 ++++++++++++++++-- taskweaver/llm/mock.py | 2 +- taskweaver/planner/planner.py | 1 + 4 files changed, 150 insertions(+), 12 deletions(-) diff --git a/taskweaver/code_interpreter/code_generator/code_generator.py b/taskweaver/code_interpreter/code_generator/code_generator.py index f563b1ca..a7458fbe 100644 --- a/taskweaver/code_interpreter/code_generator/code_generator.py +++ b/taskweaver/code_interpreter/code_generator/code_generator.py @@ -320,6 +320,7 @@ def early_stop(_type: AttachmentType, value: str) -> bool: llm_output=self.llm_api.chat_completion_stream( prompt, use_backup_engine=use_back_up_engine, + use_smoother=True, ), post_proxy=post_proxy, early_stop=early_stop, diff --git a/taskweaver/llm/__init__.py b/taskweaver/llm/__init__.py index 479873ee..92728508 100644 --- a/taskweaver/llm/__init__.py +++ b/taskweaver/llm/__init__.py @@ -1,4 +1,4 @@ -from typing import Any, Generator, List, Optional, Type +from typing import Any, Callable, Generator, List, Optional, Type from injector import Injector, inject @@ -112,18 +112,154 @@ def chat_completion_stream( max_tokens: Optional[int] = None, top_p: Optional[float] = None, stop: Optional[List[str]] = None, + use_smoother: bool = True, **kwargs: Any, ) -> Generator[ChatMessageType, None, None]: - return self.completion_service.chat_completion( - messages, - use_backup_engine, - stream, - temperature, - max_tokens, - top_p, - stop, - **kwargs, - ) + def get_generator() -> Generator[ChatMessageType, None, None]: + return self.completion_service.chat_completion( + messages, + use_backup_engine, + stream, + temperature, + max_tokens, + top_p, + stop, + **kwargs, + ) + + if use_smoother: + return self._stream_smoother(get_generator) + return get_generator() + + def _stream_smoother( + self, + stream_init: Callable[[], Generator[ChatMessageType, None, None]], + ) -> Generator[ChatMessageType, None, None]: + import random + import threading + import time + + min_sleep_interval = 0.1 + min_chunk_size = 2 + min_update_interval = 1 / 30 # 30Hz + + recv_start = time.time() + buffer_message: Optional[ChatMessageType] = None + buffer_content: str = "" + finished = False + + update_lock = threading.Lock() + update_cond = threading.Condition() + cur_base_speed: float = 10.0 + + def speed_normalize(speed: float): + return min(max(speed, 5), 600) + + def base_stream_puller(): + nonlocal buffer_message, buffer_content, finished, cur_base_speed + stream = stream_init() + + for msg in stream: + if msg["content"] == "": + continue + + with update_lock: + buffer_message = msg + buffer_content += msg["content"] + cur_time = time.time() + + new_speed = min(2e3, len(buffer_content) / (cur_time - recv_start)) + weight = min(1.0, len(buffer_content) / 80) + cur_base_speed = new_speed * weight + cur_base_speed * (1 - weight) + + with update_cond: + update_cond.notify() + + with update_lock: + finished = True + + thread = threading.Thread(target=base_stream_puller) + thread.start() + + sent_content: str = "" + sent_start: float = time.time() + next_update_time = time.time() + cur_update_speed = cur_base_speed + + while True: + if finished and len(buffer_content) - len(sent_content) < min_chunk_size * 5: + if buffer_message is not None and len(sent_content) < len( + buffer_content, + ): + new_pack = buffer_content[len(sent_content) :] + sent_content += new_pack + yield format_chat_message( + role=buffer_message["role"], + message=new_pack, + name=buffer_message["name"] if "name" in buffer_message else None, + ) + break + + if time.time() < next_update_time: + with update_cond: + update_cond.wait( + min(min_sleep_interval, next_update_time - time.time()), + ) + continue + + with update_lock: + cur_buf_message = buffer_message + total_len = len(buffer_content) + sent_len = len(sent_content) + rem_len = total_len - sent_len + + if cur_buf_message is None or len(buffer_content) - len(sent_content) < min_chunk_size: + # wait for more buffer + with update_cond: + update_cond.wait(min_sleep_interval) + continue + + if sent_start == 0.0: + # first chunk time + sent_start = time.time() + + cur_base_speed_norm = speed_normalize(cur_base_speed) + cur_actual_speed_norm = speed_normalize( + sent_len / (time.time() - (sent_start if not finished else recv_start)), + ) + target_speed = cur_base_speed_norm + (cur_base_speed_norm - cur_actual_speed_norm) * 0.25 + cur_update_speed = speed_normalize(0.5 * cur_update_speed + target_speed * 0.5) + + if cur_update_speed > min_chunk_size / min_update_interval: + chunk_time_target = min_update_interval + new_pack_size_target = chunk_time_target * cur_update_speed + else: + new_pack_size_target = min_chunk_size + chunk_time_target = new_pack_size_target / cur_update_speed + + rand_min = max( + min(rem_len, min_chunk_size), + int(0.8 * new_pack_size_target), + ) + rand_max = min(rem_len, int(1.2 * new_pack_size_target)) + new_pack_size = random.randint(rand_min, rand_max) if rand_max - rand_min > 1 else rand_min + + chunk_time = chunk_time_target / new_pack_size_target * new_pack_size + + new_pack = buffer_content[sent_len : (sent_len + new_pack_size)] + sent_content += new_pack + + yield format_chat_message( + role=cur_buf_message["role"], + message=new_pack, + name=cur_buf_message["name"] if "name" in cur_buf_message else None, + ) + + next_update_time = time.time() + chunk_time + with update_cond: + update_cond.wait(min(min_sleep_interval, chunk_time)) + + thread.join() def get_embedding(self, string: str) -> List[float]: return self.embedding_service.get_embeddings([string])[0] diff --git a/taskweaver/llm/mock.py b/taskweaver/llm/mock.py index 1331cf41..dde0a89b 100644 --- a/taskweaver/llm/mock.py +++ b/taskweaver/llm/mock.py @@ -329,6 +329,6 @@ def _get_from_playback_completion( while cur_pos < len(content): chunk_size = random.randint(2, 8) next_pos = min(cur_pos + chunk_size, len(content)) + time.sleep(self.config.playback_delay) # init delay yield format_chat_message(role, content[cur_pos:next_pos]) cur_pos = next_pos - time.sleep(self.config.playback_delay) diff --git a/taskweaver/planner/planner.py b/taskweaver/planner/planner.py index 5ab6666c..1b38957e 100644 --- a/taskweaver/planner/planner.py +++ b/taskweaver/planner/planner.py @@ -245,6 +245,7 @@ def check_post_validity(post: Post): llm_stream = self.llm_api.chat_completion_stream( chat_history, use_backup_engine=use_back_up_engine, + use_smoother=True, ) llm_output: List[str] = []