From 0062cf2585723c1f1fdc220850946de1e10dc31b Mon Sep 17 00:00:00 2001 From: "panxuchen.pxc" Date: Tue, 20 Feb 2024 15:10:04 +0800 Subject: [PATCH 1/2] fix parse_func and fault_handler --- README.md | 4 ++-- docs/sphinx_doc/source/tutorial/201-agent.md | 4 ++-- docs/sphinx_doc/source/tutorial/204-service.md | 4 ++-- src/agentscope/agents/dict_dialog_agent.py | 17 ++++++++++++++--- src/agentscope/models/model.py | 9 +++++++++ 5 files changed, 29 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 3f412fb53..687d7b530 100644 --- a/README.md +++ b/README.md @@ -253,8 +253,8 @@ from agentscope.agents import AgentBase class MyAgent(AgentBase): def reply(self, x): - # Do something here, e.g. calling your model - response = self.model(x) + # Do something here, e.g. calling your model and get the raw field as your agent's response + response = self.model(x).raw return response ``` diff --git a/docs/sphinx_doc/source/tutorial/201-agent.md b/docs/sphinx_doc/source/tutorial/201-agent.md index 7885ba28d..f4d513e53 100644 --- a/docs/sphinx_doc/source/tutorial/201-agent.md +++ b/docs/sphinx_doc/source/tutorial/201-agent.md @@ -90,10 +90,10 @@ def reply(self, x: dict = None) -> dict: prompt = self.engine.join(self.sys_prompt, self.memory.get_memory()) # Invoke the language model with the prepared prompt - response = self.model(prompt, parse_func=json.loads, fault_handler=lambda x: {"speak": x}) + response = self.model(prompt).text # Format the response and create a message object - msg = Msg(self.name, response.get("speak", None) or response, **response) + msg = Msg(self.name, response) # Record the message to memory and return it self.memory.add(msg) diff --git a/docs/sphinx_doc/source/tutorial/204-service.md b/docs/sphinx_doc/source/tutorial/204-service.md index 448627a28..9703171a7 100644 --- a/docs/sphinx_doc/source/tutorial/204-service.md +++ b/docs/sphinx_doc/source/tutorial/204-service.md @@ -118,12 +118,12 @@ class YourAgent(AgentBase): prompt += params_prompt # Get the model response - model_response = self.model(prompt) + model_response = self.model(prompt).text # Parse the model response and call the create_file function # Additional extraction functions might be necessary try: - kwargs = json.loads(model_response.content) + kwargs = json.loads(model_response) create_file(**kwargs) except: # Error handling diff --git a/src/agentscope/agents/dict_dialog_agent.py b/src/agentscope/agents/dict_dialog_agent.py index 9e9315892..c4d725044 100644 --- a/src/agentscope/agents/dict_dialog_agent.py +++ b/src/agentscope/agents/dict_dialog_agent.py @@ -7,10 +7,21 @@ from ..message import Msg from .agent import AgentBase +from ..models.model import ModelResponse from ..prompt import PromptEngine from ..prompt import PromptType +def parse_dict(response: ModelResponse) -> ModelResponse: + """Parse function for DictDialogAgent""" + return ModelResponse(raw=json.loads(response.text)) + + +def default_response(response: ModelResponse) -> ModelResponse: + """The default response of fault_handler""" + return ModelResponse(raw={"speak": response.text}) + + class DictDialogAgent(AgentBase): """An agent that generates response in a dict format, where user can specify the required fields in the response via prompt, e.g. @@ -40,8 +51,8 @@ def __init__( model_config_name: str = None, use_memory: bool = True, memory_config: Optional[dict] = None, - parse_func: Optional[Callable[..., Any]] = json.loads, - fault_handler: Optional[Callable[..., Any]] = lambda x: {"speak": x}, + parse_func: Optional[Callable[..., Any]] = parse_dict, + fault_handler: Optional[Callable[..., Any]] = default_response, max_retries: Optional[int] = 3, prompt_type: Optional[PromptType] = PromptType.LIST, ) -> None: @@ -129,7 +140,7 @@ def reply(self, x: dict = None) -> dict: parse_func=self.parse_func, fault_handler=self.fault_handler, max_retries=self.max_retries, - ).text + ).raw # logging raw messages in debug mode logger.debug(json.dumps(response, indent=4)) diff --git a/src/agentscope/models/model.py b/src/agentscope/models/model.py index 02624066b..2496120aa 100644 --- a/src/agentscope/models/model.py +++ b/src/agentscope/models/model.py @@ -107,6 +107,15 @@ def raw(self) -> dict: """Raw dictionary field.""" return self._raw + def __str__(self) -> str: + serialized_fields = { + "text": self.text, + "embedding": self.embedding, + "image_urls": self.image_urls, + "raw": self.raw, + } + return json.dumps(serialized_fields, indent=4) + def _response_parse_decorator( model_call: Callable, From 1cb112e4c3242f4f70a94e5064080016bb51f665 Mon Sep 17 00:00:00 2001 From: "panxuchen.pxc" Date: Tue, 20 Feb 2024 15:17:20 +0800 Subject: [PATCH 2/2] update docstring --- src/agentscope/agents/dict_dialog_agent.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/agentscope/agents/dict_dialog_agent.py b/src/agentscope/agents/dict_dialog_agent.py index c4d725044..be376105a 100644 --- a/src/agentscope/agents/dict_dialog_agent.py +++ b/src/agentscope/agents/dict_dialog_agent.py @@ -71,11 +71,13 @@ def __init__( Whether the agent has memory. memory_config (`Optional[dict]`, defaults to `None`): The config of memory. - parse_func (`Optional[Callable[..., Any]]`, defaults to `None`): + parse_func (`Optional[Callable[..., Any]]`, + defaults to `parse_dict`): The function used to parse the model output, e.g. `json.loads`, which is used to extract json from the output. - fault_handler (`Optional[Callable[..., Any]]`, defaults to `None`): + fault_handler (`Optional[Callable[..., Any]]`, + defaults to `default_response`): The function used to handle the fault when parse_func fails to parse the model output. max_retries (`Optional[int]`, defaults to `None`):