diff --git a/src/hugchat/hugchat.py b/src/hugchat/hugchat.py index 0cacfa7..97ad905 100644 --- a/src/hugchat/hugchat.py +++ b/src/hugchat/hugchat.py @@ -831,6 +831,7 @@ def chat( ), _stream_yield_all=_stream_yield_all, web_search=web_search, + conversation=conversation ) return msg diff --git a/src/hugchat/message.py b/src/hugchat/message.py index b4e962c..d762942 100644 --- a/src/hugchat/message.py +++ b/src/hugchat/message.py @@ -1,18 +1,28 @@ -from typing import Generator, Union +from typing import Generator, Union, List +from .types.tool import Tool +from .types.file import File +from .types.message import Conversation from .exceptions import ChatError, ModelOverloadedError import json -RESPONSE_TYPE_FINAL = "finalAnswer" -RESPONSE_TYPE_STREAM = "stream" -RESPONSE_TYPE_WEB = "webSearch" -RESPONSE_TYPE_STATUS = "status" -RESPONSE_TYPE_IMAGE = "file" -MSGTYPE_ERROR = "error" -MSGSTATUS_PENDING = 0 -MSGSTATUS_RESOLVED = 1 -MSGSTATUS_REJECTED = 2 +class ResponseTypes: + FINAL = "finalAnswer" + STREAM = "stream" + TOOL = "tool" + FILE = "file" + WEB = "webSearch" + STATUS = "status" + + +class MessageStatus: + PENDING = 0 + RESOLVED = 1 + REJECTED = 2 + + +MSGTYPE_ERROR = "error" class WebSearchSource: @@ -35,7 +45,7 @@ class Message(Generator): - web_search_sources: list[WebSearchSource] = list() - text: str = "" - web_search_done: bool = not web_search - - msg_status: int = MSGSTATUS_PENDING + - msg_status: int = MessageStatus.PENDING - error: Union[Exception, None] = None A wrapper of `Generator` that receives and process the response @@ -49,20 +59,26 @@ class Message(Generator): for res in msg: ... # process else: - if msg.done() == MSGSTATUS_REJECTED: + if msg.done() == MessageStatus.REJECTED: raise msg.error # or simply use: final = msg.wait_until_done() """ - g: Generator _stream_yield_all: bool = False - web_search: bool = False - web_search_sources: list = [] _result_text: str = "" + + gen: Generator + + web_search: bool = False + web_search_sources: List[WebSearchSource] = [] web_search_done: bool = not web_search - msg_status: int = MSGSTATUS_PENDING + + tools_used: List[Tool] = [] + files_created: List[File] = [] + + msg_status: int = MessageStatus.PENDING error: Union[Exception, None] = None def __init__( @@ -70,17 +86,18 @@ def __init__( g: Generator, _stream_yield_all: bool = False, web_search: bool = False, + conversation: Conversation = None ) -> None: - self.g = g + self.gen = g self._stream_yield_all = _stream_yield_all self.web_search = web_search - self.image_link = None + self.conversation = conversation @property def text(self) -> str: self._result_text = self.wait_until_done() return self._result_text - + @text.setter def text(self, v: str) -> None: self._result_text = v @@ -93,64 +110,84 @@ def _filterResponse(self, obj: dict): raise ChatError(f"No `type` and `message` returned: {obj}") def __next__(self) -> dict: - if self.msg_status == MSGSTATUS_RESOLVED: + if self.msg_status == MessageStatus.RESOLVED: raise StopIteration - elif self.msg_status == MSGSTATUS_REJECTED: + + elif self.msg_status == MessageStatus.REJECTED: if self.error is not None: raise self.error else: - raise Exception( - "Message stauts is `Rejected` but no error found") + raise Exception("Message status is `Rejected` but no error found") try: - a: dict = next(self.g) - self._filterResponse(a) - t: str = a["type"] + data: dict = next(self.gen) + self._filterResponse(data) + data_type: str = data["type"] message_type: str = "" - if t == RESPONSE_TYPE_FINAL: - self._result_text = a["text"] - self.msg_status = MSGSTATUS_RESOLVED - elif t == RESPONSE_TYPE_WEB: + + # set _result_text if this is the final iteration of the chat message + if data_type == ResponseTypes.FINAL: + self._result_text = data["text"] + self.msg_status = MessageStatus.RESOLVED + + # Handle web response type + elif data_type == ResponseTypes.WEB: # gracefully pass unparseable webpages - if message_type != MSGTYPE_ERROR and a.__contains__("sources"): + if message_type != MSGTYPE_ERROR and data.__contains__("sources"): self.web_search_sources.clear() - sources = a["sources"] + sources = data["sources"] for source in sources: wss = WebSearchSource() wss.title = source["title"] wss.link = source["link"] self.web_search_sources.append(wss) - elif "messageType" in a: - message_type: str = a["messageType"] + + # Handle what is done when a tool completes + elif data_type == ResponseTypes.TOOL: + if data["subtype"] == "result": + tool = Tool(data["uuid"], data["result"]) + self.tools_used.append(tool) + + # Handle what is done when a file is created + elif data_type == ResponseTypes.FILE: + file = File(data["sha"], data["name"], data["mime"], self.conversation) + self.files_created.append(file) + + # replace null characters with an empty string + elif data_type == ResponseTypes.STREAM: + data["token"] = data["token"].replace('\u0000', '') + + elif "messageType" in data: + message_type: str = data["messageType"] if message_type == MSGTYPE_ERROR: - self.error = ChatError(a["message"]) - self.msg_status = MSGSTATUS_REJECTED - if t == RESPONSE_TYPE_STREAM: + self.error = ChatError(data["message"]) + self.msg_status = MessageStatus.REJECTED + + if data_type == ResponseTypes.STREAM: self.web_search_done = True - elif t == RESPONSE_TYPE_STATUS: + + elif data_type == ResponseTypes.STATUS: pass - elif t == RESPONSE_TYPE_IMAGE: - if not a.__contains__("sha"): - self.image_link = a["image_link"] + else: - if "Model is overloaded" in str(a): + if "Model is overloaded" in str(data): self.error = ModelOverloadedError( "Model is overloaded, please try again later or switch to another model." ) - self.msg_status = MSGSTATUS_REJECTED - elif a.__contains__(MSGTYPE_ERROR): - self.error = ChatError(a[MSGTYPE_ERROR]) - self.msg_status = MSGSTATUS_REJECTED + self.msg_status = MessageStatus.REJECTED + elif data.__contains__(MSGTYPE_ERROR): + self.error = ChatError(data[MSGTYPE_ERROR]) + self.msg_status = MessageStatus.REJECTED else: - self.error = ChatError(f"Unknown json response: {a}") + self.error = ChatError(f"Unknown json response: {data}") # If _stream_yield_all is True, yield all responses from the server. - if self._stream_yield_all or t == RESPONSE_TYPE_STREAM: - return a + if self._stream_yield_all or data_type == ResponseTypes.STREAM: + return data else: return self.__next__() except StopIteration: - if self.msg_status == MSGSTATUS_PENDING: + if self.msg_status == MessageStatus.PENDING: self.error = ChatError( "Stream of responses has abruptly ended (final answer has not been received)." ) @@ -159,7 +196,7 @@ def __next__(self) -> dict: except Exception as e: # print("meet error: ", str(e)) self.error = e - self.msg_status = MSGSTATUS_REJECTED + self.msg_status = MessageStatus.REJECTED raise self.error def __iter__(self): @@ -171,10 +208,10 @@ def throw( __val=None, __tb=None, ): - return self.g.throw(__typ, __val, __tb) + return self.gen.throw(__typ, __val, __tb) def send(self, __value): - return self.g.send(__value) + return self.gen.send(__value) def get_final_text(self) -> str: """ @@ -190,12 +227,19 @@ def get_search_sources(self) -> list: """ return self.web_search_sources - def get_image_link(self) -> str: + def get_tools_used(self) -> list: + """ + :Return: + - self.tools_used + """ + return self.tools_used + + def get_files_created(self) -> list: """ :Return: - - self.image_link + - self.files_created """ - return self.image_link + return self.files_created def search_enabled(self) -> bool: """ @@ -213,10 +257,13 @@ def wait_until_done(self) -> str: """ while not self.is_done(): self.__next__() - if self.is_done() == MSGSTATUS_RESOLVED: + + if self.is_done() == MessageStatus.RESOLVED: return self._result_text + elif self.error is not None: raise self.error + else: raise Exception("Rejected but no error captured!") @@ -226,9 +273,9 @@ def is_done(self): - self.msg_status 3 status: - - MSGSTATUS_PENDING = 0 # running - - MSGSTATUS_RESOLVED = 1 # done with no error(maybe?) - - MSGSTATUS_REJECTED = 2 # error raised + - MessageStatus.PENDING = 0 # running + - MessageStatus.RESOLVED = 1 # done with no error(maybe?) + - MessageStatus.REJECTED = 2 # error raised """ return self.msg_status @@ -247,11 +294,13 @@ def __str__(self): def __getitem__(self, key: str) -> str: print("_getitem_") self.wait_until_done() - print("done") + if key == "text": return self.text + elif key == "web_search": return self.web_search + elif key == "web_search_sources": return self.web_search_sources @@ -267,7 +316,3 @@ def __iadd__(self, other: str) -> str: self.wait_until_done() self.text += other return self.text - - -if __name__ == "__main__": - pass diff --git a/src/hugchat/types/file.py b/src/hugchat/types/file.py new file mode 100644 index 0000000..cfe9644 --- /dev/null +++ b/src/hugchat/types/file.py @@ -0,0 +1,33 @@ +from .message import Conversation + + +class File: + ''' + Class used to represent files created by the model + ''' + + def __init__(self, sha: str, name: str, mime: str, conversation: Conversation): + self.sha = sha + self.name = name + self.mime = mime + + self.conversation = conversation + self.url = self.get_url() + + def get_url(self) -> str: + """ + Gets the url for the given file + """ + + return f"https://huggingface.co/chat/conversation/{self.conversation.id}/output/{self.sha}" + + def download_file(self, chatBot) -> bytes: + """ + Downloads the given file + """ + + r = chatBot.session.get(self.url) + return r.content + + def __str__(self) -> str: + return f"File(url={self.url}, sha={self.sha}, name={self.name}, mime={self.mime})" diff --git a/src/hugchat/types/tool.py b/src/hugchat/types/tool.py new file mode 100644 index 0000000..4354286 --- /dev/null +++ b/src/hugchat/types/tool.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass + + +@dataclass +class Tool: + ''' + Class used to represent tools used by the model + ''' + + uuid: str + result: str + + def __str__(self) -> str: + return f"Tool(uuid={self.uuid}, result={self.result})" diff --git a/src/integration_test.py b/src/integration_test.py index 7a6339d..dc60b36 100644 --- a/src/integration_test.py +++ b/src/integration_test.py @@ -7,7 +7,6 @@ import pytest -from .hugchat.message import MSGSTATUS_RESOLVED, Message, MSGTYPE_ERROR, RESPONSE_TYPE_WEB from .hugchat import hugchat, cli from .hugchat.login import Login import sys diff --git a/src/unit_test.py b/src/unit_test.py index 81fb0e4..271a746 100644 --- a/src/unit_test.py +++ b/src/unit_test.py @@ -5,7 +5,7 @@ # import os import logging -from .hugchat.message import MSGSTATUS_RESOLVED, RESPONSE_TYPE_FINAL, Message, MSGTYPE_ERROR, RESPONSE_TYPE_WEB +from .hugchat.message import MessageStatus, ResponseTypes, Message, MSGTYPE_ERROR import sys logging.basicConfig(level=logging.DEBUG) @@ -39,7 +39,7 @@ class Test(object): def test_web_search_failed_results(self): response_list = [{ "type": - RESPONSE_TYPE_WEB, + ResponseTypes.WEB, "messageType": MSGTYPE_ERROR, "message": @@ -50,9 +50,9 @@ def test_web_search_failed_results(self): ] }, { "type": - RESPONSE_TYPE_WEB, + ResponseTypes.WEB, "messageType": - RESPONSE_TYPE_WEB, + ResponseTypes.WEB, "sources": [{ "title": "1", "link": "2", @@ -60,7 +60,7 @@ def test_web_search_failed_results(self): }] }, { "type": - RESPONSE_TYPE_WEB, + ResponseTypes.WEB, "messageType": MSGTYPE_ERROR, "message": @@ -70,7 +70,7 @@ def test_web_search_failed_results(self): "https://www.accuweather.com/en/gb/london/ec4a-2/weather-forecast/328328" ] }, { - "type": RESPONSE_TYPE_FINAL, + "type": ResponseTypes.FINAL, "messageType": "answer", "text": "Funny joke" }]