diff --git a/src/transformers/agents/agents.py b/src/transformers/agents/agents.py index e9cf10687c723e..73b7186d25a3c7 100644 --- a/src/transformers/agents/agents.py +++ b/src/transformers/agents/agents.py @@ -748,8 +748,6 @@ def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs): Args: task (`str`): The task to perform - stream (`bool`, *optional*, defaults to `False`): Whether to stream the logs of the agent's interactions. - reset (`bool`, *optional*, defaults to `True`): Whether to reset the agent's state before running it. Example: ```py @@ -766,7 +764,15 @@ def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs): self.initialize_for_run() else: self.logs.append({"task": task}) + if stream: + return self.stream_run(task) + else: + return self.direct_run(task) + def stream_run(self, task: str): + """ + Runs the agent in streaming mode, yielding steps as they are executed: should be launched only in the `run` method. + """ final_answer = None iteration = 0 while final_answer is None and iteration < self.max_iterations: @@ -779,25 +785,47 @@ def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs): self.logs[-1]["error"] = e finally: iteration += 1 - if stream: - yield self.logs[-1] - - if iteration == self.max_iterations: - if final_answer is None: - error_message = "Reached max iterations." - final_step_log = {"error": AgentMaxIterationsError(error_message)} - self.logs.append(final_step_log) - self.logger.error(error_message, exc_info=1) - final_answer = self.provide_final_answer(task) - final_step_log["final_answer"] = final_answer - else: - final_step_log = self.logs[-1] - - if stream: + yield self.logs[-1] + + if final_answer is None and iteration == self.max_iterations: + error_message = "Reached max iterations." + final_step_log = {"error": AgentMaxIterationsError(error_message)} + self.logs.append(final_step_log) + self.logger.error(error_message, exc_info=1) + final_answer = self.provide_final_answer(task) + final_step_log["final_answer"] = final_answer yield final_step_log - else: - return final_answer + yield final_answer + + def direct_run(self, task: str): + """ + Runs the agent in direct mode, returning outputs only at the end: should be launched only in the `run` method. + """ + final_answer = None + iteration = 0 + while final_answer is None and iteration < self.max_iterations: + try: + if self.planning_interval is not None and iteration % self.planning_interval == 0: + self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration) + step_logs = self.step() + if "final_answer" in step_logs: + final_answer = step_logs["final_answer"] + except AgentError as e: + self.logger.error(e, exc_info=1) + self.logs[-1]["error"] = e + finally: + iteration += 1 + + if final_answer is None and iteration == self.max_iterations: + error_message = "Reached max iterations." + final_step_log = {"error": AgentMaxIterationsError(error_message)} + self.logs.append(final_step_log) + self.logger.error(error_message, exc_info=1) + final_answer = self.provide_final_answer(task) + final_step_log["final_answer"] = final_answer + + return final_answer def planning_step(self, task, is_first_step: bool = False, iteration: int = None): """ diff --git a/src/transformers/agents/monitoring.py b/src/transformers/agents/monitoring.py index 93101a3ee1f654..73b84f05fb0eca 100644 --- a/src/transformers/agents/monitoring.py +++ b/src/transformers/agents/monitoring.py @@ -59,9 +59,10 @@ def stream_to_gradio(agent: ReactAgent, task: str, **kwargs): for message in pull_message(step_log): yield message - if isinstance(step_log["final_answer"], AgentText): - yield ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{step_log["final_answer"].to_string()}\n```") + yield ChatMessage( + role="assistant", content=f"**Final answer:**\n```\n{step_log['final_answer'].to_string()}\n```" + ) elif isinstance(step_log["final_answer"], AgentImage): yield ChatMessage( role="assistant", diff --git a/tests/agents/test_monitoring.py b/tests/agents/test_monitoring.py index 2df00f7efd87e5..da33ea5f1a204f 100644 --- a/tests/agents/test_monitoring.py +++ b/tests/agents/test_monitoring.py @@ -14,9 +14,11 @@ # limitations under the License. import unittest -from transformers.agents.monitoring import stream_to_gradio -from transformers.agents.agents import ReactCodeAgent, ReactJsonAgent, AgentError + from transformers.agents.agent_types import AgentImage +from transformers.agents.agents import AgentError, ReactCodeAgent, ReactJsonAgent +from transformers.agents.monitoring import stream_to_gradio + class TestMonitoring(unittest.TestCase): def test_streaming_agent_text_output(self): @@ -76,9 +78,7 @@ def dummy_llm_engine(prompt, **kwargs): outputs = list(stream_to_gradio(agent, task="Test task")) # Check that the error message is yielded - print("OUTPUTTTTS", outputs) self.assertEqual(len(outputs), 3) final_message = outputs[-1] self.assertEqual(final_message.role, "assistant") self.assertIn("Simulated agent error", final_message.content) -