Skip to content

Commit

Permalink
Merge pull request #232 from Whitelisted1/master
Browse files Browse the repository at this point in the history
Implemented tool capabilities
  • Loading branch information
Soulter authored Jun 7, 2024
2 parents 325776e + b4ac2e8 commit 3191bee
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 72 deletions.
1 change: 1 addition & 0 deletions src/hugchat/hugchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,7 @@ def chat(
),
_stream_yield_all=_stream_yield_all,
web_search=web_search,
conversation=conversation
)
return msg

Expand Down
175 changes: 110 additions & 65 deletions src/hugchat/message.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,28 @@
from typing import Generator, Union
from typing import Generator, Union, List

from .types.tool import Tool
from .types.file import File
from .types.message import Conversation
from .exceptions import ChatError, ModelOverloadedError
import json

RESPONSE_TYPE_FINAL = "finalAnswer"
RESPONSE_TYPE_STREAM = "stream"
RESPONSE_TYPE_WEB = "webSearch"
RESPONSE_TYPE_STATUS = "status"
RESPONSE_TYPE_IMAGE = "file"
MSGTYPE_ERROR = "error"

MSGSTATUS_PENDING = 0
MSGSTATUS_RESOLVED = 1
MSGSTATUS_REJECTED = 2
class ResponseTypes:
FINAL = "finalAnswer"
STREAM = "stream"
TOOL = "tool"
FILE = "file"
WEB = "webSearch"
STATUS = "status"


class MessageStatus:
PENDING = 0
RESOLVED = 1
REJECTED = 2


MSGTYPE_ERROR = "error"


class WebSearchSource:
Expand All @@ -35,7 +45,7 @@ class Message(Generator):
- web_search_sources: list[WebSearchSource] = list()
- text: str = ""
- web_search_done: bool = not web_search
- msg_status: int = MSGSTATUS_PENDING
- msg_status: int = MessageStatus.PENDING
- error: Union[Exception, None] = None
A wrapper of `Generator` that receives and process the response
Expand All @@ -49,38 +59,45 @@ class Message(Generator):
for res in msg:
... # process
else:
if msg.done() == MSGSTATUS_REJECTED:
if msg.done() == MessageStatus.REJECTED:
raise msg.error
# or simply use:
final = msg.wait_until_done()
"""

g: Generator
_stream_yield_all: bool = False
web_search: bool = False
web_search_sources: list = []
_result_text: str = ""

gen: Generator

web_search: bool = False
web_search_sources: List[WebSearchSource] = []
web_search_done: bool = not web_search
msg_status: int = MSGSTATUS_PENDING

tools_used: List[Tool] = []
files_created: List[File] = []

msg_status: int = MessageStatus.PENDING
error: Union[Exception, None] = None

def __init__(
self,
g: Generator,
_stream_yield_all: bool = False,
web_search: bool = False,
conversation: Conversation = None
) -> None:
self.g = g
self.gen = g
self._stream_yield_all = _stream_yield_all
self.web_search = web_search
self.image_link = None
self.conversation = conversation

@property
def text(self) -> str:
self._result_text = self.wait_until_done()
return self._result_text

@text.setter
def text(self, v: str) -> None:
self._result_text = v
Expand All @@ -93,64 +110,84 @@ def _filterResponse(self, obj: dict):
raise ChatError(f"No `type` and `message` returned: {obj}")

def __next__(self) -> dict:
if self.msg_status == MSGSTATUS_RESOLVED:
if self.msg_status == MessageStatus.RESOLVED:
raise StopIteration
elif self.msg_status == MSGSTATUS_REJECTED:

elif self.msg_status == MessageStatus.REJECTED:
if self.error is not None:
raise self.error
else:
raise Exception(
"Message stauts is `Rejected` but no error found")
raise Exception("Message status is `Rejected` but no error found")

try:
a: dict = next(self.g)
self._filterResponse(a)
t: str = a["type"]
data: dict = next(self.gen)
self._filterResponse(data)
data_type: str = data["type"]
message_type: str = ""
if t == RESPONSE_TYPE_FINAL:
self._result_text = a["text"]
self.msg_status = MSGSTATUS_RESOLVED
elif t == RESPONSE_TYPE_WEB:

# set _result_text if this is the final iteration of the chat message
if data_type == ResponseTypes.FINAL:
self._result_text = data["text"]
self.msg_status = MessageStatus.RESOLVED

# Handle web response type
elif data_type == ResponseTypes.WEB:
# gracefully pass unparseable webpages
if message_type != MSGTYPE_ERROR and a.__contains__("sources"):
if message_type != MSGTYPE_ERROR and data.__contains__("sources"):
self.web_search_sources.clear()
sources = a["sources"]
sources = data["sources"]
for source in sources:
wss = WebSearchSource()
wss.title = source["title"]
wss.link = source["link"]
self.web_search_sources.append(wss)
elif "messageType" in a:
message_type: str = a["messageType"]

# Handle what is done when a tool completes
elif data_type == ResponseTypes.TOOL:
if data["subtype"] == "result":
tool = Tool(data["uuid"], data["result"])
self.tools_used.append(tool)

# Handle what is done when a file is created
elif data_type == ResponseTypes.FILE:
file = File(data["sha"], data["name"], data["mime"], self.conversation)
self.files_created.append(file)

# replace null characters with an empty string
elif data_type == ResponseTypes.STREAM:
data["token"] = data["token"].replace('\u0000', '')

elif "messageType" in data:
message_type: str = data["messageType"]
if message_type == MSGTYPE_ERROR:
self.error = ChatError(a["message"])
self.msg_status = MSGSTATUS_REJECTED
if t == RESPONSE_TYPE_STREAM:
self.error = ChatError(data["message"])
self.msg_status = MessageStatus.REJECTED

if data_type == ResponseTypes.STREAM:
self.web_search_done = True
elif t == RESPONSE_TYPE_STATUS:

elif data_type == ResponseTypes.STATUS:
pass
elif t == RESPONSE_TYPE_IMAGE:
if not a.__contains__("sha"):
self.image_link = a["image_link"]

else:
if "Model is overloaded" in str(a):
if "Model is overloaded" in str(data):
self.error = ModelOverloadedError(
"Model is overloaded, please try again later or switch to another model."
)
self.msg_status = MSGSTATUS_REJECTED
elif a.__contains__(MSGTYPE_ERROR):
self.error = ChatError(a[MSGTYPE_ERROR])
self.msg_status = MSGSTATUS_REJECTED
self.msg_status = MessageStatus.REJECTED
elif data.__contains__(MSGTYPE_ERROR):
self.error = ChatError(data[MSGTYPE_ERROR])
self.msg_status = MessageStatus.REJECTED
else:
self.error = ChatError(f"Unknown json response: {a}")
self.error = ChatError(f"Unknown json response: {data}")

# If _stream_yield_all is True, yield all responses from the server.
if self._stream_yield_all or t == RESPONSE_TYPE_STREAM:
return a
if self._stream_yield_all or data_type == ResponseTypes.STREAM:
return data
else:
return self.__next__()
except StopIteration:
if self.msg_status == MSGSTATUS_PENDING:
if self.msg_status == MessageStatus.PENDING:
self.error = ChatError(
"Stream of responses has abruptly ended (final answer has not been received)."
)
Expand All @@ -159,7 +196,7 @@ def __next__(self) -> dict:
except Exception as e:
# print("meet error: ", str(e))
self.error = e
self.msg_status = MSGSTATUS_REJECTED
self.msg_status = MessageStatus.REJECTED
raise self.error

def __iter__(self):
Expand All @@ -171,10 +208,10 @@ def throw(
__val=None,
__tb=None,
):
return self.g.throw(__typ, __val, __tb)
return self.gen.throw(__typ, __val, __tb)

def send(self, __value):
return self.g.send(__value)
return self.gen.send(__value)

def get_final_text(self) -> str:
"""
Expand All @@ -190,12 +227,19 @@ def get_search_sources(self) -> list:
"""
return self.web_search_sources

def get_image_link(self) -> str:
def get_tools_used(self) -> list:
"""
:Return:
- self.tools_used
"""
return self.tools_used

def get_files_created(self) -> list:
"""
:Return:
- self.image_link
- self.files_created
"""
return self.image_link
return self.files_created

def search_enabled(self) -> bool:
"""
Expand All @@ -213,10 +257,13 @@ def wait_until_done(self) -> str:
"""
while not self.is_done():
self.__next__()
if self.is_done() == MSGSTATUS_RESOLVED:

if self.is_done() == MessageStatus.RESOLVED:
return self._result_text

elif self.error is not None:
raise self.error

else:
raise Exception("Rejected but no error captured!")

Expand All @@ -226,9 +273,9 @@ def is_done(self):
- self.msg_status
3 status:
- MSGSTATUS_PENDING = 0 # running
- MSGSTATUS_RESOLVED = 1 # done with no error(maybe?)
- MSGSTATUS_REJECTED = 2 # error raised
- MessageStatus.PENDING = 0 # running
- MessageStatus.RESOLVED = 1 # done with no error(maybe?)
- MessageStatus.REJECTED = 2 # error raised
"""
return self.msg_status

Expand All @@ -247,11 +294,13 @@ def __str__(self):
def __getitem__(self, key: str) -> str:
print("_getitem_")
self.wait_until_done()
print("done")

if key == "text":
return self.text

elif key == "web_search":
return self.web_search

elif key == "web_search_sources":
return self.web_search_sources

Expand All @@ -267,7 +316,3 @@ def __iadd__(self, other: str) -> str:
self.wait_until_done()
self.text += other
return self.text


if __name__ == "__main__":
pass
33 changes: 33 additions & 0 deletions src/hugchat/types/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from .message import Conversation


class File:
'''
Class used to represent files created by the model
'''

def __init__(self, sha: str, name: str, mime: str, conversation: Conversation):
self.sha = sha
self.name = name
self.mime = mime

self.conversation = conversation
self.url = self.get_url()

def get_url(self) -> str:
"""
Gets the url for the given file
"""

return f"https://huggingface.co/chat/conversation/{self.conversation.id}/output/{self.sha}"

def download_file(self, chatBot) -> bytes:
"""
Downloads the given file
"""

r = chatBot.session.get(self.url)
return r.content

def __str__(self) -> str:
return f"File(url={self.url}, sha={self.sha}, name={self.name}, mime={self.mime})"
14 changes: 14 additions & 0 deletions src/hugchat/types/tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from dataclasses import dataclass


@dataclass
class Tool:
'''
Class used to represent tools used by the model
'''

uuid: str
result: str

def __str__(self) -> str:
return f"Tool(uuid={self.uuid}, result={self.result})"
1 change: 0 additions & 1 deletion src/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import pytest

from .hugchat.message import MSGSTATUS_RESOLVED, Message, MSGTYPE_ERROR, RESPONSE_TYPE_WEB
from .hugchat import hugchat, cli
from .hugchat.login import Login
import sys
Expand Down
Loading

0 comments on commit 3191bee

Please sign in to comment.