From 714a758c20a242ba0271b78c9fe316c12e4fb954 Mon Sep 17 00:00:00 2001 From: qbc Date: Thu, 7 Mar 2024 12:11:27 +0800 Subject: [PATCH] [feature] as_studio (#44) --- examples/conversation/conversation.py | 68 ++-- examples/werewolf/werewolf.py | 201 ++++++------ .../werewolf/{utils.py => werewolf_utils.py} | 0 setup.py | 10 +- src/agentscope/agents/user_agent.py | 5 +- src/agentscope/models/__init__.py | 3 +- src/agentscope/utils/logging_utils.py | 61 ++++ src/agentscope/web/README.md | 34 +- src/agentscope/web/studio/studio.py | 297 ++++++++++++++++++ src/agentscope/web/studio/utils.py | 186 +++++++++++ tests/model_test.py | 13 +- 11 files changed, 747 insertions(+), 131 deletions(-) rename examples/werewolf/{utils.py => werewolf_utils.py} (100%) create mode 100644 src/agentscope/web/studio/studio.py create mode 100644 src/agentscope/web/studio/utils.py diff --git a/examples/conversation/conversation.py b/examples/conversation/conversation.py index ff926ae76..8eb51af56 100644 --- a/examples/conversation/conversation.py +++ b/examples/conversation/conversation.py @@ -5,36 +5,44 @@ from agentscope.agents.user_agent import UserAgent from agentscope.pipelines.functional import sequentialpipeline -agentscope.init( - model_configs=[ - { - "model_type": "openai", - "config_name": "gpt-3.5-turbo", - "model_name": "gpt-3.5-turbo", - "api_key": "xxx", # Load from env if not provided - "organization": "xxx", # Load from env if not provided - "generate_args": { - "temperature": 0.5, + +def main() -> None: + """A conversation demo""" + + agentscope.init( + model_configs=[ + { + "model_type": "openai", + "config_name": "gpt-3.5-turbo", + "model_name": "gpt-3.5-turbo", + "api_key": "xxx", # Load from env if not provided + "organization": "xxx", # Load from env if not provided + "generate_args": { + "temperature": 0.5, + }, + }, + { + "model_type": "post_api_chat", + "config_name": "my_post_api", + "api_url": "https://xxx", + "headers": {}, }, - }, - { - "model_type": "post_api_chat", - "config_name": "my_post_api", - "api_url": "https://xxx", - "headers": {}, - }, - ], -) + ], + ) + + # Init two agents + dialog_agent = DialogAgent( + name="Assistant", + sys_prompt="You're a helpful assistant.", + model_config_name="gpt-3.5-turbo", # replace by your model config name + ) + user_agent = UserAgent() + + # start the conversation between user and assistant + x = None + while x is None or x.content != "exit": + x = sequentialpipeline([dialog_agent, user_agent], x) -# Init two agents -dialog_agent = DialogAgent( - name="Assistant", - sys_prompt="You're a helpful assistant.", - model_config_name="gpt-3.5-turbo", # replace by your model config name -) -user_agent = UserAgent() -# start the conversation between user and assistant -x = None -while x is None or x.content != "exit": - x = sequentialpipeline([dialog_agent, user_agent], x) +if __name__ == "__main__": + main() diff --git a/examples/werewolf/werewolf.py b/examples/werewolf/werewolf.py index 96af8a0b2..b7a2b2f94 100644 --- a/examples/werewolf/werewolf.py +++ b/examples/werewolf/werewolf.py @@ -3,7 +3,7 @@ from functools import partial from prompt import Prompts -from utils import ( +from werewolf_utils import ( check_winning, update_alive_players, majority_vote, @@ -15,99 +15,122 @@ from agentscope.pipelines.functional import sequentialpipeline import agentscope -# default settings -HostMsg = partial(Msg, name="Moderator", echo=True) -healing, poison = True, True -MAX_WEREWOLF_DISCUSSION_ROUND = 3 -MAX_GAME_ROUND = 6 -# read model and agent configs, and initialize agents automatically -survivors = agentscope.init( - model_configs="./configs/model_configs.json", - agent_configs="./configs/agent_configs.json", -) -roles = ["werewolf", "werewolf", "villager", "villager", "seer", "witch"] -wolves, witch, seer = survivors[:2], survivors[-1], survivors[-2] - -# start the game -for i in range(1, MAX_GAME_ROUND + 1): - # night phase, werewolves discuss - hint = HostMsg(content=Prompts.to_wolves.format(n2s(wolves))) - with msghub(wolves, announcement=hint) as hub: - for _ in range(MAX_WEREWOLF_DISCUSSION_ROUND): - x = sequentialpipeline(wolves) - if x.get("agreement", False): - break - # werewolves vote - hint = HostMsg(content=Prompts.to_wolves_vote) - votes = [extract_name_and_id(wolf(hint).content)[0] for wolf in wolves] - # broadcast the result to werewolves - dead_player = [majority_vote(votes)] - hub.broadcast( - HostMsg(content=Prompts.to_wolves_res.format(dead_player[0])), - ) +# pylint: disable=too-many-statements +def main() -> None: + """werewolf game""" + # default settings + HostMsg = partial(Msg, name="Moderator", echo=True) + healing, poison = True, True + MAX_WEREWOLF_DISCUSSION_ROUND = 3 + MAX_GAME_ROUND = 6 + # read model and agent configs, and initialize agents automatically + survivors = agentscope.init( + model_configs="./configs/model_configs.json", + agent_configs="./configs/agent_configs.json", + ) + roles = ["werewolf", "werewolf", "villager", "villager", "seer", "witch"] + wolves, witch, seer = survivors[:2], survivors[-1], survivors[-2] + + # start the game + for _ in range(1, MAX_GAME_ROUND + 1): + # night phase, werewolves discuss + hint = HostMsg(content=Prompts.to_wolves.format(n2s(wolves))) + with msghub(wolves, announcement=hint) as hub: + for _ in range(MAX_WEREWOLF_DISCUSSION_ROUND): + x = sequentialpipeline(wolves) + if x.get("agreement", False): + break - # witch - healing_used_tonight = False - if witch in survivors: - if healing: + # werewolves vote + hint = HostMsg(content=Prompts.to_wolves_vote) + votes = [ + extract_name_and_id(wolf(hint).content)[0] for wolf in wolves + ] + # broadcast the result to werewolves + dead_player = [majority_vote(votes)] + hub.broadcast( + HostMsg(content=Prompts.to_wolves_res.format(dead_player[0])), + ) + + # witch + healing_used_tonight = False + if witch in survivors: + if healing: + hint = HostMsg( + content=Prompts.to_witch_resurrect.format_map( + { + "witch_name": witch.name, + "dead_name": dead_player[0], + }, + ), + ) + if witch(hint).get("resurrect", False): + healing_used_tonight = True + dead_player.pop() + healing = False + + if poison and not healing_used_tonight: + x = witch(HostMsg(content=Prompts.to_witch_poison)) + if x.get("eliminate", False): + dead_player.append(extract_name_and_id(x.content)[0]) + poison = False + + # seer + if seer in survivors: hint = HostMsg( - content=Prompts.to_witch_resurrect.format_map( - {"witch_name": witch.name, "dead_name": dead_player[0]}, - ), + content=Prompts.to_seer.format(seer.name, n2s(survivors)), ) - if witch(hint).get("resurrect", False): - healing_used_tonight = True - dead_player.pop() - healing = False - - if poison and not healing_used_tonight: - x = witch(HostMsg(content=Prompts.to_witch_poison)) - if x.get("eliminate", False): - dead_player.append(extract_name_and_id(x.content)[0]) - poison = False - - # seer - if seer in survivors: - hint = HostMsg( - content=Prompts.to_seer.format(seer.name, n2s(survivors)), - ) - x = seer(hint) - - player, idx = extract_name_and_id(x.content) - role = "werewolf" if roles[idx] == "werewolf" else "villager" - hint = HostMsg(content=Prompts.to_seer_result.format(player, role)) - seer.observe(hint) - - survivors, wolves = update_alive_players(survivors, wolves, dead_player) - if check_winning(survivors, wolves, "Moderator"): - break - - # daytime discussion - content = ( - Prompts.to_all_danger.format(n2s(dead_player)) - if dead_player - else Prompts.to_all_peace - ) - hints = [ - HostMsg(content=content), - HostMsg(content=Prompts.to_all_discuss.format(n2s(survivors))), - ] - with msghub(survivors, announcement=hints) as hub: - # discuss - x = sequentialpipeline(survivors) - - # vote - hint = HostMsg(content=Prompts.to_all_vote.format(n2s(survivors))) - votes = [extract_name_and_id(_(hint).content)[0] for _ in survivors] - vote_res = majority_vote(votes) - # broadcast the result to all players - result = HostMsg(content=Prompts.to_all_res.format(vote_res)) - hub.broadcast(result) - - survivors, wolves = update_alive_players(survivors, wolves, vote_res) + x = seer(hint) + player, idx = extract_name_and_id(x.content) + role = "werewolf" if roles[idx] == "werewolf" else "villager" + hint = HostMsg(content=Prompts.to_seer_result.format(player, role)) + seer.observe(hint) + + survivors, wolves = update_alive_players( + survivors, + wolves, + dead_player, + ) if check_winning(survivors, wolves, "Moderator"): break - hub.broadcast(HostMsg(content=Prompts.to_all_continue)) + # daytime discussion + content = ( + Prompts.to_all_danger.format(n2s(dead_player)) + if dead_player + else Prompts.to_all_peace + ) + hints = [ + HostMsg(content=content), + HostMsg(content=Prompts.to_all_discuss.format(n2s(survivors))), + ] + with msghub(survivors, announcement=hints) as hub: + # discuss + x = sequentialpipeline(survivors) + + # vote + hint = HostMsg(content=Prompts.to_all_vote.format(n2s(survivors))) + votes = [ + extract_name_and_id(_(hint).content)[0] for _ in survivors + ] + vote_res = majority_vote(votes) + # broadcast the result to all players + result = HostMsg(content=Prompts.to_all_res.format(vote_res)) + hub.broadcast(result) + + survivors, wolves = update_alive_players( + survivors, + wolves, + vote_res, + ) + + if check_winning(survivors, wolves, "Moderator"): + break + + hub.broadcast(HostMsg(content=Prompts.to_all_continue)) + + +if __name__ == "__main__": + main() diff --git a/examples/werewolf/utils.py b/examples/werewolf/werewolf_utils.py similarity index 100% rename from examples/werewolf/utils.py rename to examples/werewolf/werewolf_utils.py diff --git a/setup.py b/setup.py index 7f3a74ec9..2a9e6a99f 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,8 @@ test_requires = ["pytest", "pytest-cov", "pre-commit"] +gradio_requires = ["gradio==4.19.1", "modelscope_studio==0.0.5"] + # released requires minimal_requires = [ "loguru", @@ -47,7 +49,7 @@ "Flask==3.0.0", "Flask-Cors==4.0.0", "Flask-SocketIO==5.3.6", - "dashscope", + "dashscope==1.14.1", ] distribute_requires = minimal_requires + rpc_requires @@ -60,6 +62,7 @@ + service_requires + doc_requires + test_requires + + gradio_requires ) with open("README.md", "r", encoding="UTF-8") as fh: @@ -93,4 +96,9 @@ "Operating System :: OS Independent", ], python_requires=">=3.9", + entry_points={ + "console_scripts": [ + "as_studio=agentscope.web.studio.studio:run_app", + ], + }, ) diff --git a/src/agentscope/agents/user_agent.py b/src/agentscope/agents/user_agent.py index ac51ddab5..334c42ed9 100644 --- a/src/agentscope/agents/user_agent.py +++ b/src/agentscope/agents/user_agent.py @@ -6,6 +6,7 @@ from agentscope.agents import AgentBase from agentscope.message import Msg +from agentscope.web.studio.utils import user_input class UserAgent(AgentBase): @@ -62,7 +63,7 @@ def reply( # TODO: To avoid order confusion, because `input` print much quicker # than logger.chat time.sleep(0.5) - content = input(f"{self.name}: ") + content = user_input() kwargs = {} if required_keys is not None: @@ -85,8 +86,6 @@ def reply( **kwargs, # type: ignore[arg-type] ) - self.speak(msg) - # Add to memory self.memory.add(msg) diff --git a/src/agentscope/models/__init__.py b/src/agentscope/models/__init__.py index e7f6e7b5d..1cce37c76 100644 --- a/src/agentscope/models/__init__.py +++ b/src/agentscope/models/__init__.py @@ -135,9 +135,10 @@ def read_model_configs( # check if name is unique for cfg in format_configs: if cfg.config_name in _MODEL_CONFIGS: - raise ValueError( + logger.warning( f"config_name [{cfg.config_name}] already exists.", ) + continue _MODEL_CONFIGS[cfg.config_name] = cfg # print the loaded model configs diff --git a/src/agentscope/utils/logging_utils.py b/src/agentscope/utils/logging_utils.py index 295ce1c0b..7817b89d5 100644 --- a/src/agentscope/utils/logging_utils.py +++ b/src/agentscope/utils/logging_utils.py @@ -3,10 +3,17 @@ import json import os import sys +import threading from typing import Optional, Literal, Union, Any from loguru import logger +from agentscope.web.studio.utils import ( + generate_image_from_name, + send_msg, + get_reset_msg, +) + LOG_LEVEL = Literal[ "TRACE", "DEBUG", @@ -115,12 +122,66 @@ def _chat(message: Union[str, dict], *args: Any, **kwargs: Any) -> None: "\n".join(print_str).replace("{", "{{").replace("}", "}}") ) logger.log(LEVEL_CHAT_LOG, print_str, *args, **kwargs) + + thread_name = threading.current_thread().name + if thread_name != "MainThread": + log_gradio(message, thread_name, **kwargs) return message = str(message).replace("{", "{{").replace("}", "}}") logger.log(LEVEL_CHAT_LOG, message, *args, **kwargs) +def log_gradio(message: dict, thread_name: str, **kwargs: Any) -> None: + """Send chat message to gradio. + + Args: + message (`dict`): + The message to be logged. It should have "name"(or "role") and + "content" keys, and the message will be logged as ": + ". + thread_name (`str`): + The name of the thread. + """ + if thread_name != "MainThread": + get_reset_msg(uid=thread_name) + name = message.get("name", "default") or message.get("role", "default") + avatar = kwargs.get("avatar", None) or generate_image_from_name( + message["name"], + ) + + msg = message["content"] + flushing = True + if "url" in message: + flushing = False + for i in range(len(message["url"])): + msg += "\n" + f"""""" + if "audio_path" in message: + flushing = False + for i in range(len(message["audio_path"])): + msg += ( + "\n" + + f"""""" + ) + if "video_path" in message: + flushing = False + for i in range(len(message["video_path"])): + msg += ( + "\n" + + f"""""" + ) + + send_msg( + msg, + role=name, + uid=thread_name, + flushing=flushing, + avatar=avatar, + ) + + def _level_format(record: dict) -> str: """Format the log record.""" if record["level"].name == LEVEL_CHAT_LOG: diff --git a/src/agentscope/web/README.md b/src/agentscope/web/README.md index 96defa987..1a2685012 100644 --- a/src/agentscope/web/README.md +++ b/src/agentscope/web/README.md @@ -1,9 +1,11 @@ -# AgentScope Web UI +# Web UI + +## AgentScope Web UI A user interface for AgentScope, which is a tool for monitoring and analyzing the communication of agents in a multi-agent application. -## Quick Start +### Quick Start To start a web UI, you can run the following python code: ```python @@ -29,7 +31,7 @@ agentscope.init( # ... ) ``` -## A Running Example +### A Running Example The home page of web UI, which lists all available projects and runs in the given saving path. @@ -38,4 +40,28 @@ given saving path. By clicking a running instance, we can observe more details. -![The running details](https://img.alicdn.com/imgextra/i2/O1CN01AZtsf31MIHm4FmjjO_!!6000000001411-0-tps-3104-1849.jpg) \ No newline at end of file +![The running details](https://img.alicdn.com/imgextra/i2/O1CN01AZtsf31MIHm4FmjjO_!!6000000001411-0-tps-3104-1849.jpg) + + +## AgentScope Studio + +A running-time interface for AgentScope, which is a tool for monitoring +the communication of agents in a multi-agent application. + +### How to Use +To start a studio, you can run the following python code: + +```python +as_studio path/to/your/script.py +``` +Remark: in `path/to/your/script.py`, there should be a `main` function. + +### An Example + +Run the following code in the root directory of this project after you setup the configs in `examples/conversation/conversation.py`: +```python +as_studio examples/conversation/conversation.py +``` +The following interface will be launched at `localhost:7860`. + +![](https://gw.alicdn.com/imgextra/i3/O1CN01X673v81WaHV1oCxEN_!!6000000002804-0-tps-2992-1498.jpg) diff --git a/src/agentscope/web/studio/studio.py b/src/agentscope/web/studio/studio.py new file mode 100644 index 000000000..08281d57e --- /dev/null +++ b/src/agentscope/web/studio/studio.py @@ -0,0 +1,297 @@ +# -*- coding: utf-8 -*- +"""run web ui""" +import argparse +import os +import sys +import threading +import time +from collections import defaultdict +from typing import Optional, Callable +import traceback +import gradio as gr +import modelscope_studio as mgr + +from agentscope.web.studio.utils import ( + send_player_input, + get_chat_msg, + SYS_MSG_PREFIX, + ResetException, + check_uuid, + send_msg, + generate_image_from_name, + audio2text, + send_reset_msg, +) + +MAX_NUM_DISPLAY_MSG = 20 +FAIL_COUNT_DOWN = 30 + + +def init_uid_list() -> list: + """Initialize an empty list for storing user IDs.""" + return [] + + +glb_history_dict = defaultdict(init_uid_list) +glb_signed_user = [] + + +def reset_glb_var(uid: str) -> None: + """Reset global variables for a given user ID.""" + global glb_history_dict + glb_history_dict[uid] = init_uid_list() + + +def get_chat(uid: str) -> list[list]: + """Retrieve chat messages for a given user ID.""" + uid = check_uuid(uid) + global glb_history_dict + line = get_chat_msg(uid=uid) + # TODO: Optimize the display effect, currently there is a problem of + # output display jumping + if line: + glb_history_dict[uid] += [line] + dial_msg = [] + for line in glb_history_dict[uid]: + _, msg = line + if isinstance(msg, dict): + dial_msg.append(line) + else: + # User chat, format: (msg, None) + dial_msg.append(line) + return dial_msg[-MAX_NUM_DISPLAY_MSG:] + + +def send_audio(audio_term: str, uid: str) -> None: + """Convert audio input to text and send as a chat message.""" + uid = check_uuid(uid) + content = audio2text(audio_path=audio_term) + send_player_input(content, uid=uid) + msg = f"""{content} + """ + send_msg(msg, is_player=True, role="Me", uid=uid, avatar=None) + + +def send_image(image_term: str, uid: str) -> None: + """Send an image as a chat message.""" + uid = check_uuid(uid) + send_player_input(image_term, uid=uid) + + msg = f"""""" + avatar = generate_image_from_name("Me") + send_msg(msg, is_player=True, role="Me", uid=uid, avatar=avatar) + + +def send_message(msg: str, uid: str) -> str: + """Send a generic message to the player.""" + uid = check_uuid(uid) + send_player_input(msg, uid=uid) + avatar = generate_image_from_name("Me") + send_msg(msg, is_player=True, role="Me", uid=uid, avatar=avatar) + return "" + + +def fn_choice(data: gr.EventData, uid: str) -> None: + """Handle a selection event from the chatbot interface.""" + uid = check_uuid(uid) + # pylint: disable=protected-access + send_player_input(data._data["value"], uid=uid) + + +def import_function_from_path( + module_path: str, + function_name: str, + module_name: Optional[str] = None, +) -> Callable: + """Import a function from the given module path.""" + import importlib.util + + script_dir = os.path.dirname(os.path.abspath(module_path)) + + # Temporarily add a script directory to sys.path + original_sys_path = sys.path[:] + sys.path.insert(0, script_dir) + + try: + # If a module name is not provided, you can use the filename ( + # without extension) as the module name + if module_name is None: + module_name = os.path.splitext(os.path.basename(module_path))[0] + # Creating module specifications and loading modules + spec = importlib.util.spec_from_file_location( + module_name, + module_path, + ) + if spec is not None: + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + # Getting a function from a module + function = getattr(module, function_name) + else: + raise ImportError( + f"Could not find module spec for {module_name} at" + f" {module_path}", + ) + except AttributeError as exc: + raise AttributeError( + f"The module '{module_name}' does not have a function named '" + f"{function_name}'. Please put your code in the main function, " + f"read README.md for details.", + ) from exc + finally: + # Restore the original sys.path + sys.path = original_sys_path + + return function + + +# pylint: disable=too-many-statements +def run_app() -> None: + """Entry point for the web UI application.""" + parser = argparse.ArgumentParser() + parser.add_argument("script", type=str, help="Script file to run") + args = parser.parse_args() + + # Make sure script_path is an absolute path + script_path = os.path.abspath(args.script) + + # Get the directory where the script is located + script_dir = os.path.dirname(script_path) + # Save the current working directory + # Change the current working directory to the directory where + os.chdir(script_dir) + + def start_game() -> None: + """Start the main game loop.""" + uid = threading.currentThread().name + main = import_function_from_path(script_path, "main") + + while True: + try: + main() + except ResetException: + print(f"Reset Successfully:{uid} ") + except Exception as e: + trace_info = "".join( + traceback.TracebackException.from_exception(e).format(), + ) + for i in range(FAIL_COUNT_DOWN, 0, -1): + send_msg( + f"{SYS_MSG_PREFIX} error {trace_info}, reboot " + f"in {i} seconds", + uid=uid, + ) + time.sleep(1) + reset_glb_var(uid) + + def check_for_new_session(uid: str) -> None: + """ + Check for a new user session and start a game thread if necessary. + """ + uid = check_uuid(uid) + if uid not in glb_signed_user: + glb_signed_user.append(uid) + print("==========Signed User==========") + print(f"Total number of users: {len(glb_signed_user)}") + run_thread = threading.Thread( + target=start_game, + name=uid, + ) + run_thread.start() + + with gr.Blocks() as demo: + warning_html_code = """ +
+

If you want to start over, please click the + reset + button and refresh the page

+
+ """ + gr.HTML(warning_html_code) + uuid = gr.Textbox(label="modelscope_uuid", visible=False) + + with gr.Row(): + chatbot = mgr.Chatbot( + label="Dialog", + show_label=False, + bubble_full_width=False, + visible=True, + ) + + with gr.Column(): + user_chat_input = gr.Textbox( + label="user_chat_input", + placeholder="Say something here", + show_label=False, + ) + send_button = gr.Button(value="📣Send") + with gr.Row(): + audio = gr.Accordion("Audio input", open=False) + with audio: + audio_term = gr.Audio( + visible=True, + type="filepath", + format="wav", + ) + submit_audio_button = gr.Button(value="Send Audio") + image = gr.Accordion("Image input", open=False) + with image: + image_term = gr.Image( + visible=True, + height=300, + interactive=True, + type="filepath", + ) + submit_image_button = gr.Button(value="Send Image") + with gr.Column(): + reset_button = gr.Button(value="Reset") + + # submit message + send_button.click( + send_message, + [user_chat_input, uuid], + user_chat_input, + ) + user_chat_input.submit( + send_message, + [user_chat_input, uuid], + user_chat_input, + ) + + submit_audio_button.click( + send_audio, + inputs=[audio_term, uuid], + outputs=[audio_term], + ) + + submit_image_button.click( + send_image, + inputs=[image_term, uuid], + outputs=[image_term], + ) + + reset_button.click(send_reset_msg, inputs=[uuid]) + + chatbot.custom(fn=fn_choice, inputs=[uuid]) + + demo.load( + check_for_new_session, + inputs=[uuid], + every=0.5, + ) + + demo.load( + get_chat, + inputs=[uuid], + outputs=[chatbot], + every=0.5, + ) + demo.queue() + demo.launch() + + +if __name__ == "__main__": + run_app() diff --git a/src/agentscope/web/studio/utils.py b/src/agentscope/web/studio/utils.py new file mode 100644 index 000000000..fc86048fc --- /dev/null +++ b/src/agentscope/web/studio/utils.py @@ -0,0 +1,186 @@ +# -*- coding: utf-8 -*- +"""web ui utils""" +import os +import threading +from typing import Optional +import hashlib +from multiprocessing import Queue +from collections import defaultdict + +from PIL import Image +import gradio as gr + +from dashscope.audio.asr import RecognitionCallback, Recognition + +SYS_MSG_PREFIX = "【SYSTEM】" + + +def init_uid_queues() -> dict: + """Initializes and returns a dictionary of user-specific queues.""" + return { + "glb_queue_chat_msg": Queue(), + "glb_queue_user_input": Queue(), + "glb_queue_reset_msg": Queue(), + } + + +glb_uid_dict = defaultdict(init_uid_queues) + + +def send_msg( + msg: str, + is_player: bool = False, + role: Optional[str] = None, + uid: Optional[str] = None, + flushing: bool = False, + avatar: Optional[str] = None, + msg_id: Optional[str] = None, +) -> None: + """Sends a message to the web UI.""" + global glb_uid_dict + glb_queue_chat_msg = glb_uid_dict[uid]["glb_queue_chat_msg"] + if is_player: + glb_queue_chat_msg.put( + [ + { + "text": msg, + "name": role, + "flushing": flushing, + "avatar": avatar, + }, + None, + ], + ) + else: + glb_queue_chat_msg.put( + [ + None, + { + "text": msg, + "name": role, + "flushing": flushing, + "avatar": avatar, + "id": msg_id, + }, + ], + ) + + +def get_chat_msg(uid: Optional[str] = None) -> list: + """Retrieves the next chat message from the queue, if available.""" + global glb_uid_dict + glb_queue_chat_msg = glb_uid_dict[uid]["glb_queue_chat_msg"] + if not glb_queue_chat_msg.empty(): + line = glb_queue_chat_msg.get(block=False) + if line is not None: + return line + return [] + + +def send_player_input(msg: str, uid: Optional[str] = None) -> None: + """Sends player input to the web UI.""" + global glb_uid_dict + glb_queue_user_input = glb_uid_dict[uid]["glb_queue_user_input"] + glb_queue_user_input.put([None, msg]) + + +def get_player_input( + uid: Optional[str] = None, +) -> list[str]: + """Gets player input from the web UI or command line.""" + global glb_uid_dict + glb_queue_user_input = glb_uid_dict[uid]["glb_queue_user_input"] + content = glb_queue_user_input.get(block=True)[1] + if content == "**Reset**": + glb_uid_dict[uid] = init_uid_queues() + raise ResetException + return content + + +def send_reset_msg(uid: Optional[str] = None) -> None: + """Sends a reset message to the web UI.""" + uid = check_uuid(uid) + global glb_uid_dict + glb_queue_reset_msg = glb_uid_dict[uid]["glb_queue_reset_msg"] + glb_queue_reset_msg.put([None, "**Reset**"]) + + +def get_reset_msg(uid: Optional[str] = None) -> None: + """Retrieves a reset message from the queue, if available.""" + global glb_uid_dict + glb_queue_reset_msg = glb_uid_dict[uid]["glb_queue_reset_msg"] + if not glb_queue_reset_msg.empty(): + content = glb_queue_reset_msg.get(block=True)[1] + if content == "**Reset**": + glb_uid_dict[uid] = init_uid_queues() + raise ResetException + + +class ResetException(Exception): + """Custom exception to signal a reset action in the application.""" + + +def check_uuid(uid: Optional[str]) -> str: + """Checks whether a UUID is provided or generates a default one.""" + if not uid or uid == "": + if os.getenv("MODELSCOPE_ENVIRONMENT") == "studio": + raise gr.Error("Please login first") + uid = "local_user" + return uid + + +def generate_image_from_name(name: str) -> str: + """Generates an image based on the hash of the given name.""" + from agentscope.file_manager import file_manager + + # Using hashlib to generate a hash of the name + hash_func = hashlib.md5() + hash_func.update(name.encode("utf-8")) + hash_value = hash_func.hexdigest() + + # Extract the first 6 characters of the hash value as the hexadecimal + # representation of the color + # generate a color value between #000000 and #ffffff + color_hex = "#" + hash_value[:6] + color_rgb = Image.new("RGB", (1, 1), color_hex).getpixel((0, 0)) + + image_filepath = os.path.join(file_manager.dir_root, f"{name}_image.png") + + # Check if the image already exists + if os.path.exists(image_filepath): + return image_filepath + + # If the image does not exist, generate and save it + width, height = 200, 200 + image = Image.new("RGB", (width, height), color_rgb) + + image.save(image_filepath) + + return image_filepath + + +def audio2text(audio_path: str) -> str: + """Converts audio file at the given path to text using ASR.""" + # dashscope.api_key = "" + callback = RecognitionCallback() + rec = Recognition( + model="paraformer-realtime-v1", + format="wav", + sample_rate=16000, + callback=callback, + ) + + result = rec.call(audio_path) + return " ".join([s["text"] for s in result["output"]["sentence"]]) + + +def user_input() -> str: + """get user input""" + thread_name = threading.current_thread().name + if thread_name == "MainThread": + content = input("User input: ") + else: + content = get_player_input( + uid=threading.current_thread().name, + ) + return content diff --git a/tests/model_test.py b/tests/model_test.py index 54eeb2c00..0409b2513 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -5,6 +5,8 @@ from typing import Any import unittest +from unittest.mock import patch, MagicMock + from agentscope.models import ( ModelResponse, @@ -46,7 +48,8 @@ def test_model_registry(self) -> None: PostAPIModelWrapperBase, ) - def test_load_model_configs(self) -> None: + @patch("loguru.logger.warning") + def test_load_model_configs(self, mock_logging: MagicMock) -> None: """Test to load model configs""" configs = [ { @@ -83,8 +86,12 @@ def test_load_model_configs(self) -> None: self.assertEqual(model.config_name, "gpt-4") self.assertRaises(ValueError, load_model_by_config_name, "my_post_api") - # automatically detect model with the same id - self.assertRaises(ValueError, read_model_configs, configs[0]) + # load model with the same id + read_model_configs(configs=configs[0], clear_existing=False) + mock_logging.assert_called_once_with( + "config_name [gpt-4] already exists.", + ) + read_model_configs( configs={ "model_type": "TestModelWrapperSimple",