From fa40ab2611da7a40cd1fb3910ecef6d2beac9652 Mon Sep 17 00:00:00 2001 From: Whitelisted <77711834+Whitelisted1@users.noreply.github.com> Date: Tue, 4 Jun 2024 16:56:34 -0600 Subject: [PATCH 1/9] Implement "tool" and "file" capabilities Still not completely finished, currently a work in progress --- src/hugchat/hugchat.py | 1 + src/hugchat/message.py | 33 +++++++++++++++++++++++++++++++-- src/hugchat/types/file.py | 24 ++++++++++++++++++++++++ src/hugchat/types/tool.py | 14 ++++++++++++++ 4 files changed, 70 insertions(+), 2 deletions(-) create mode 100644 src/hugchat/types/file.py create mode 100644 src/hugchat/types/tool.py diff --git a/src/hugchat/hugchat.py b/src/hugchat/hugchat.py index d1d8016..41eaa43 100644 --- a/src/hugchat/hugchat.py +++ b/src/hugchat/hugchat.py @@ -824,6 +824,7 @@ def chat( ), _stream_yield_all=_stream_yield_all, web_search=web_search, + conversation=conversation ) return msg diff --git a/src/hugchat/message.py b/src/hugchat/message.py index b26caf0..4921e64 100644 --- a/src/hugchat/message.py +++ b/src/hugchat/message.py @@ -1,10 +1,15 @@ from typing import Generator, Union +from .types.tool import Tool +from .types.file import File +from .types.message import Conversation from .exceptions import ChatError, ModelOverloadedError import json RESPONSE_TYPE_FINAL = "finalAnswer" RESPONSE_TYPE_STREAM = "stream" +RESPONSE_TYPE_TOOL = "tool" # with subtypes "call" and "result" +RESPONSE_TYPE_FILE = "file" RESPONSE_TYPE_WEB = "webSearch" RESPONSE_TYPE_STATUS = "status" MSGTYPE_ERROR = "error" @@ -59,6 +64,8 @@ class Message(Generator): _stream_yield_all: bool = False web_search: bool = False web_search_sources: list = [] + tools_used: list = [] + files_created: list = [] _result_text: str = "" web_search_done: bool = not web_search msg_status: int = MSGSTATUS_PENDING @@ -69,10 +76,12 @@ def __init__( g: Generator, _stream_yield_all: bool = False, web_search: bool = False, + conversation: Conversation = None ) -> None: self.g = g self._stream_yield_all = _stream_yield_all self.web_search = web_search + self.conversation = conversation @property def text(self) -> str: @@ -97,8 +106,7 @@ def __next__(self) -> dict: if self.error is not None: raise self.error else: - raise Exception( - "Message stauts is `Rejected` but no error found") + raise Exception("Message status is `Rejected` but no error found") try: a: dict = next(self.g) @@ -118,6 +126,13 @@ def __next__(self) -> dict: wss.title = source["title"] wss.link = source["link"] self.web_search_sources.append(wss) + elif t == RESPONSE_TYPE_TOOL: + if a["subtype"] == "result": + tool = Tool(a["uuid"], a["result"]) + self.tools_used.append(tool) + elif t == RESPONSE_TYPE_FILE: + file = File(a["sha"], a["name"], a["mime"], self.conversation) + self.files_created.append(file) elif "messageType" in a: message_type: str = a["messageType"] if message_type == MSGTYPE_ERROR: @@ -185,6 +200,20 @@ def get_search_sources(self) -> list: """ return self.web_search_sources + def get_tools_used(self) -> list: + """ + :Return: + - self.tools_used + """ + return self.tools_used + + def get_files_created(self) -> list: + """ + :Return: + - self.files_created + """ + return self.files_created + def search_enabled(self) -> bool: """ :Return: diff --git a/src/hugchat/types/file.py b/src/hugchat/types/file.py new file mode 100644 index 0000000..2d28d61 --- /dev/null +++ b/src/hugchat/types/file.py @@ -0,0 +1,24 @@ +from .message import Conversation + + +class File: + ''' + Class used to represent files created by the model + ''' + + def __init__(self, sha: str, name: str, mime: str, conversation: Conversation): + self.sha = sha + self.name = name + self.mime = mime + + self.conversation = conversation + self.url = self.get_url() + + def get_url(self) -> str: + print(self.conversation) + print(dir(self.conversation)) + print(self.conversation.id) + return f"https://huggingface.co/chat/conversation/{self.conversation.id}/output/{self.sha}" + + def __str__(self) -> str: + return f"File(url={self.url}, sha={self.sha}, name={self.name}, mime={self.mime})" diff --git a/src/hugchat/types/tool.py b/src/hugchat/types/tool.py new file mode 100644 index 0000000..4354286 --- /dev/null +++ b/src/hugchat/types/tool.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass + + +@dataclass +class Tool: + ''' + Class used to represent tools used by the model + ''' + + uuid: str + result: str + + def __str__(self) -> str: + return f"Tool(uuid={self.uuid}, result={self.result})" From 59bfbe617cf166d02d73eeb324dcde2d154c9e2c Mon Sep 17 00:00:00 2001 From: Whitelisted <77711834+Whitelisted1@users.noreply.github.com> Date: Tue, 4 Jun 2024 16:57:07 -0600 Subject: [PATCH 2/9] Remove debugging print statements --- src/hugchat/types/file.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/hugchat/types/file.py b/src/hugchat/types/file.py index 2d28d61..9ff4e4c 100644 --- a/src/hugchat/types/file.py +++ b/src/hugchat/types/file.py @@ -15,9 +15,6 @@ def __init__(self, sha: str, name: str, mime: str, conversation: Conversation): self.url = self.get_url() def get_url(self) -> str: - print(self.conversation) - print(dir(self.conversation)) - print(self.conversation.id) return f"https://huggingface.co/chat/conversation/{self.conversation.id}/output/{self.sha}" def __str__(self) -> str: From 686c92cf1b3c4f992613f70dace4cd10092365a6 Mon Sep 17 00:00:00 2001 From: Whitelisted <77711834+Whitelisted1@users.noreply.github.com> Date: Tue, 4 Jun 2024 17:15:12 -0600 Subject: [PATCH 3/9] Implement download_file feature Example usage: ```python f = open(f"out.{file.mime.split('/')[1]}", "bw") f.write(file.download_file(chatbot)) f.close() ``` --- src/hugchat/types/file.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/hugchat/types/file.py b/src/hugchat/types/file.py index 9ff4e4c..78b98f8 100644 --- a/src/hugchat/types/file.py +++ b/src/hugchat/types/file.py @@ -15,7 +15,20 @@ def __init__(self, sha: str, name: str, mime: str, conversation: Conversation): self.url = self.get_url() def get_url(self) -> str: + """ + Gets the url for the given file + """ + return f"https://huggingface.co/chat/conversation/{self.conversation.id}/output/{self.sha}" + def download_file(self, chatBot: "ChatBot") -> bytes: + """ + Downloads the given file + """ + + r = chatBot.session.get(self.url) + print(f'"{self.url}"') + return r.content + def __str__(self) -> str: return f"File(url={self.url}, sha={self.sha}, name={self.name}, mime={self.mime})" From 6a904f0da245f1dea30cb1d1a0526ff7e52d4eaf Mon Sep 17 00:00:00 2001 From: Whitelisted <77711834+Whitelisted1@users.noreply.github.com> Date: Tue, 4 Jun 2024 17:27:23 -0600 Subject: [PATCH 4/9] Many changes in message.py * Refactored RESPONSE_TYPE_* to class ResponseTypes.* * Refactored MSGSTATUS_* to class MessageStatus.* * Renamed some internal variables in __next__ function to be clearer * Add some short inline comments --- src/hugchat/message.py | 152 ++++++++++++++++++++++++----------------- 1 file changed, 88 insertions(+), 64 deletions(-) diff --git a/src/hugchat/message.py b/src/hugchat/message.py index 4921e64..c034c2d 100644 --- a/src/hugchat/message.py +++ b/src/hugchat/message.py @@ -1,4 +1,4 @@ -from typing import Generator, Union +from typing import Generator, Union, List from .types.tool import Tool from .types.file import File @@ -6,17 +6,23 @@ from .exceptions import ChatError, ModelOverloadedError import json -RESPONSE_TYPE_FINAL = "finalAnswer" -RESPONSE_TYPE_STREAM = "stream" -RESPONSE_TYPE_TOOL = "tool" # with subtypes "call" and "result" -RESPONSE_TYPE_FILE = "file" -RESPONSE_TYPE_WEB = "webSearch" -RESPONSE_TYPE_STATUS = "status" -MSGTYPE_ERROR = "error" -MSGSTATUS_PENDING = 0 -MSGSTATUS_RESOLVED = 1 -MSGSTATUS_REJECTED = 2 +class ResponseTypes: + FINAL = "finalAnswer" + STREAM = "stream" + TOOL = "tool" + FILE = "file" + WEB = "webSearch" + STATUS = "status" + + +class MessageStatus: + PENDING = 0 + RESOLVED = 1 + REJECTED = 2 + + +MSGTYPE_ERROR = "error" class WebSearchSource: @@ -39,7 +45,7 @@ class Message(Generator): - web_search_sources: list[WebSearchSource] = list() - text: str = "" - web_search_done: bool = not web_search - - msg_status: int = MSGSTATUS_PENDING + - msg_status: int = MessageStatus.PENDING - error: Union[Exception, None] = None A wrapper of `Generator` that receives and process the response @@ -53,22 +59,26 @@ class Message(Generator): for res in msg: ... # process else: - if msg.done() == MSGSTATUS_REJECTED: + if msg.done() == MessageStatus.REJECTED: raise msg.error # or simply use: final = msg.wait_until_done() """ - g: Generator _stream_yield_all: bool = False - web_search: bool = False - web_search_sources: list = [] - tools_used: list = [] - files_created: list = [] _result_text: str = "" + + gen: Generator + + web_search: bool = False + web_search_sources: List[WebSearchSource] = [] web_search_done: bool = not web_search - msg_status: int = MSGSTATUS_PENDING + + tools_used: List[Tool] = [] + files_created: List[File] = [] + + msg_status: int = MessageStatus.PENDING error: Union[Exception, None] = None def __init__( @@ -78,7 +88,7 @@ def __init__( web_search: bool = False, conversation: Conversation = None ) -> None: - self.g = g + self.gen = g self._stream_yield_all = _stream_yield_all self.web_search = web_search self.conversation = conversation @@ -87,7 +97,7 @@ def __init__( def text(self) -> str: self._result_text = self.wait_until_done() return self._result_text - + @text.setter def text(self, v: str) -> None: self._result_text = v @@ -100,67 +110,80 @@ def _filterResponse(self, obj: dict): raise ChatError(f"No `type` and `message` returned: {obj}") def __next__(self) -> dict: - if self.msg_status == MSGSTATUS_RESOLVED: + if self.msg_status == MessageStatus.RESOLVED: raise StopIteration - elif self.msg_status == MSGSTATUS_REJECTED: + + elif self.msg_status == MessageStatus.REJECTED: if self.error is not None: raise self.error else: raise Exception("Message status is `Rejected` but no error found") try: - a: dict = next(self.g) - self._filterResponse(a) - t: str = a["type"] + data: dict = next(self.gen) + self._filterResponse(data) + data_type: str = data["type"] message_type: str = "" - if t == RESPONSE_TYPE_FINAL: - self._result_text = a["text"] - self.msg_status = MSGSTATUS_RESOLVED - elif t == RESPONSE_TYPE_WEB: + + # set _result_text if this is the final iteration of the chat message + if data_type == ResponseTypes.FINAL: + self._result_text = data["text"] + self.msg_status = MessageStatus.RESOLVED + + # Handle web response type + elif data_type == ResponseTypes.WEB: # gracefully pass unparseable webpages - if message_type != MSGTYPE_ERROR and a.__contains__("sources"): + if message_type != MSGTYPE_ERROR and data.__contains__("sources"): self.web_search_sources.clear() - sources = a["sources"] + sources = data["sources"] for source in sources: wss = WebSearchSource() wss.title = source["title"] wss.link = source["link"] self.web_search_sources.append(wss) - elif t == RESPONSE_TYPE_TOOL: - if a["subtype"] == "result": - tool = Tool(a["uuid"], a["result"]) + + # Handle what is done when a tool completes + elif data_type == ResponseTypes.TOOL: + if data["subtype"] == "result": + tool = Tool(data["uuid"], data["result"]) self.tools_used.append(tool) - elif t == RESPONSE_TYPE_FILE: - file = File(a["sha"], a["name"], a["mime"], self.conversation) + + # Handle what is done when a file is created + elif data_type == ResponseTypes.FILE: + file = File(data["sha"], data["name"], data["mime"], self.conversation) self.files_created.append(file) - elif "messageType" in a: - message_type: str = a["messageType"] + + elif "messageType" in data: + message_type: str = data["messageType"] if message_type == MSGTYPE_ERROR: - self.error = ChatError(a["message"]) - self.msg_status = MSGSTATUS_REJECTED - if t == RESPONSE_TYPE_STREAM: + self.error = ChatError(data["message"]) + self.msg_status = MessageStatus.REJECTED + + if data_type == ResponseTypes.STREAM: self.web_search_done = True - elif t == RESPONSE_TYPE_STATUS: + + elif data_type == ResponseTypes.STATUS: pass + else: - if "Model is overloaded" in str(a): + if "Model is overloaded" in str(data): self.error = ModelOverloadedError( "Model is overloaded, please try again later or switch to another model." ) - self.msg_status = MSGSTATUS_REJECTED - elif a.__contains__(MSGTYPE_ERROR): - self.error = ChatError(a[MSGTYPE_ERROR]) - self.msg_status = MSGSTATUS_REJECTED + self.msg_status = MessageStatus.REJECTED + elif data.__contains__(MSGTYPE_ERROR): + self.error = ChatError(data[MSGTYPE_ERROR]) + self.msg_status = MessageStatus.REJECTED else: - self.error = ChatError(f"Unknown json response: {a}") + self.error = ChatError(f"Unknown json response: {data}") # If _stream_yield_all is True, yield all responses from the server. - if self._stream_yield_all or t == RESPONSE_TYPE_STREAM: - return a + if self._stream_yield_all or data_type == ResponseTypes.STREAM: + return data else: return self.__next__() except StopIteration: - if self.msg_status == MSGSTATUS_PENDING: + if self.msg_status == MessageStatus.PENDING: self.error = ChatError( "Stream of responses has abruptly ended (final answer has not been received)." ) @@ -169,7 +192,7 @@ def __next__(self) -> dict: except Exception as e: # print("meet error: ", str(e)) self.error = e - self.msg_status = MSGSTATUS_REJECTED + self.msg_status = MessageStatus.REJECTED raise self.error def __iter__(self): @@ -181,10 +204,10 @@ def throw( __val=None, __tb=None, ): - return self.g.throw(__typ, __val, __tb) + return self.gen.throw(__typ, __val, __tb) def send(self, __value): - return self.g.send(__value) + return self.gen.send(__value) def get_final_text(self) -> str: """ @@ -230,10 +253,13 @@ def wait_until_done(self) -> str: """ while not self.is_done(): self.__next__() - if self.is_done() == MSGSTATUS_RESOLVED: + + if self.is_done() == MessageStatus.RESOLVED: return self._result_text + elif self.error is not None: raise self.error + else: raise Exception("Rejected but no error captured!") @@ -243,9 +269,9 @@ def is_done(self): - self.msg_status 3 status: - - MSGSTATUS_PENDING = 0 # running - - MSGSTATUS_RESOLVED = 1 # done with no error(maybe?) - - MSGSTATUS_REJECTED = 2 # error raised + - MessageStatus.PENDING = 0 # running + - MessageStatus.RESOLVED = 1 # done with no error(maybe?) + - MessageStatus.REJECTED = 2 # error raised """ return self.msg_status @@ -264,11 +290,13 @@ def __str__(self): def __getitem__(self, key: str) -> str: print("_getitem_") self.wait_until_done() - print("done") + if key == "text": return self.text + elif key == "web_search": return self.web_search + elif key == "web_search_sources": return self.web_search_sources @@ -284,7 +312,3 @@ def __iadd__(self, other: str) -> str: self.wait_until_done() self.text += other return self.text - - -if __name__ == "__main__": - pass From 95737aff080e41cab82f377fe53b934b8783b786 Mon Sep 17 00:00:00 2001 From: Whitelisted <77711834+Whitelisted1@users.noreply.github.com> Date: Tue, 4 Jun 2024 17:32:52 -0600 Subject: [PATCH 5/9] Remove debugging print statement --- src/hugchat/types/file.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/hugchat/types/file.py b/src/hugchat/types/file.py index 78b98f8..01335d6 100644 --- a/src/hugchat/types/file.py +++ b/src/hugchat/types/file.py @@ -27,7 +27,6 @@ def download_file(self, chatBot: "ChatBot") -> bytes: """ r = chatBot.session.get(self.url) - print(f'"{self.url}"') return r.content def __str__(self) -> str: From d89c5d1993d526a5f1b7a2947e12c48341ed9d39 Mon Sep 17 00:00:00 2001 From: Whitelisted <77711834+Whitelisted1@users.noreply.github.com> Date: Tue, 4 Jun 2024 17:44:38 -0600 Subject: [PATCH 6/9] Update unit_test.py --- src/unit_test.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/unit_test.py b/src/unit_test.py index 81fb0e4..271a746 100644 --- a/src/unit_test.py +++ b/src/unit_test.py @@ -5,7 +5,7 @@ # import os import logging -from .hugchat.message import MSGSTATUS_RESOLVED, RESPONSE_TYPE_FINAL, Message, MSGTYPE_ERROR, RESPONSE_TYPE_WEB +from .hugchat.message import MessageStatus, ResponseTypes, Message, MSGTYPE_ERROR import sys logging.basicConfig(level=logging.DEBUG) @@ -39,7 +39,7 @@ class Test(object): def test_web_search_failed_results(self): response_list = [{ "type": - RESPONSE_TYPE_WEB, + ResponseTypes.WEB, "messageType": MSGTYPE_ERROR, "message": @@ -50,9 +50,9 @@ def test_web_search_failed_results(self): ] }, { "type": - RESPONSE_TYPE_WEB, + ResponseTypes.WEB, "messageType": - RESPONSE_TYPE_WEB, + ResponseTypes.WEB, "sources": [{ "title": "1", "link": "2", @@ -60,7 +60,7 @@ def test_web_search_failed_results(self): }] }, { "type": - RESPONSE_TYPE_WEB, + ResponseTypes.WEB, "messageType": MSGTYPE_ERROR, "message": @@ -70,7 +70,7 @@ def test_web_search_failed_results(self): "https://www.accuweather.com/en/gb/london/ec4a-2/weather-forecast/328328" ] }, { - "type": RESPONSE_TYPE_FINAL, + "type": ResponseTypes.FINAL, "messageType": "answer", "text": "Funny joke" }] From 154aaa8b909fbbce3889a8bd194ad6da646091f6 Mon Sep 17 00:00:00 2001 From: Whitelisted <77711834+Whitelisted1@users.noreply.github.com> Date: Tue, 4 Jun 2024 17:45:36 -0600 Subject: [PATCH 7/9] Remove "ChatBot" type hinting --- src/hugchat/types/file.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hugchat/types/file.py b/src/hugchat/types/file.py index 01335d6..cfe9644 100644 --- a/src/hugchat/types/file.py +++ b/src/hugchat/types/file.py @@ -21,7 +21,7 @@ def get_url(self) -> str: return f"https://huggingface.co/chat/conversation/{self.conversation.id}/output/{self.sha}" - def download_file(self, chatBot: "ChatBot") -> bytes: + def download_file(self, chatBot) -> bytes: """ Downloads the given file """ From 460e773e7871fd43137678a6667320d3f2413161 Mon Sep 17 00:00:00 2001 From: Whitelisted <77711834+Whitelisted1@users.noreply.github.com> Date: Tue, 4 Jun 2024 17:46:58 -0600 Subject: [PATCH 8/9] Update integration_test.py --- src/integration_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/integration_test.py b/src/integration_test.py index 7a6339d..dc60b36 100644 --- a/src/integration_test.py +++ b/src/integration_test.py @@ -7,7 +7,6 @@ import pytest -from .hugchat.message import MSGSTATUS_RESOLVED, Message, MSGTYPE_ERROR, RESPONSE_TYPE_WEB from .hugchat import hugchat, cli from .hugchat.login import Login import sys From 1b090b426762de6e23b1f9fbd2aa077b5da64f42 Mon Sep 17 00:00:00 2001 From: Whitelisted <77711834+Whitelisted1@users.noreply.github.com> Date: Tue, 4 Jun 2024 18:28:56 -0600 Subject: [PATCH 9/9] Fix null characters appended to token stream More of a temporary solution, as we do not know what is causing this --- src/hugchat/message.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/hugchat/message.py b/src/hugchat/message.py index c034c2d..d762942 100644 --- a/src/hugchat/message.py +++ b/src/hugchat/message.py @@ -153,6 +153,10 @@ def __next__(self) -> dict: file = File(data["sha"], data["name"], data["mime"], self.conversation) self.files_created.append(file) + # replace null characters with an empty string + elif data_type == ResponseTypes.STREAM: + data["token"] = data["token"].replace('\u0000', '') + elif "messageType" in data: message_type: str = data["messageType"] if message_type == MSGTYPE_ERROR: