diff --git a/README.md b/README.md index ae65c7a..49bbc10 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,12 @@ conversation_list = chatbot.get_conversation_list() # Get the available models (not hardcore) models = chatbot.get_available_llm_models() +# Get image link. +# Work only for model "CohereForAI/c4ai-command-r-plus" +chat_result = chatbot.chat("Draw a cat.") +print(chat_result.get_final_text()) +print(chat_result.get_image_link()) + # Switch model with given index chatbot.switch_llm(0) # Switch to the first model chatbot.switch_llm(1) # Switch to the second model diff --git a/src/hugchat/hugchat.py b/src/hugchat/hugchat.py index d1d8016..0cacfa7 100644 --- a/src/hugchat/hugchat.py +++ b/src/hugchat/hugchat.py @@ -711,6 +711,7 @@ def _stream_query( resp.encoding = 'utf-8' if resp.status_code != 200: + retry_count -= 1 if retry_count <= 0: raise exceptions.ChatError( @@ -725,6 +726,12 @@ def _stream_query( if obj.__contains__("type"): _type = obj["type"] + if _type == "file": + _sha = obj["sha"] + _image_link = f"{self.hf_base_url}/chat/conversation/{conversation}/output/{_sha}" + yield {"type": _type, "image_link": _image_link} + continue + if _type == "finalAnswer": final_answer = obj break_flag = True diff --git a/src/hugchat/message.py b/src/hugchat/message.py index b26caf0..b4e962c 100644 --- a/src/hugchat/message.py +++ b/src/hugchat/message.py @@ -7,6 +7,7 @@ RESPONSE_TYPE_STREAM = "stream" RESPONSE_TYPE_WEB = "webSearch" RESPONSE_TYPE_STATUS = "status" +RESPONSE_TYPE_IMAGE = "file" MSGTYPE_ERROR = "error" MSGSTATUS_PENDING = 0 @@ -73,6 +74,7 @@ def __init__( self.g = g self._stream_yield_all = _stream_yield_all self.web_search = web_search + self.image_link = None @property def text(self) -> str: @@ -127,6 +129,9 @@ def __next__(self) -> dict: self.web_search_done = True elif t == RESPONSE_TYPE_STATUS: pass + elif t == RESPONSE_TYPE_IMAGE: + if not a.__contains__("sha"): + self.image_link = a["image_link"] else: if "Model is overloaded" in str(a): self.error = ModelOverloadedError( @@ -185,6 +190,13 @@ def get_search_sources(self) -> list: """ return self.web_search_sources + def get_image_link(self) -> str: + """ + :Return: + - self.image_link + """ + return self.image_link + def search_enabled(self) -> bool: """ :Return: