Skip to content

Commit

Permalink
Re-merge
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed Nov 4, 2024
1 parent b438e6a commit 99477b8
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 25 deletions.
66 changes: 47 additions & 19 deletions src/transformers/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,8 +748,6 @@ def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs):
Args:
task (`str`): The task to perform
stream (`bool`, *optional*, defaults to `False`): Whether to stream the logs of the agent's interactions.
reset (`bool`, *optional*, defaults to `True`): Whether to reset the agent's state before running it.
Example:
```py
Expand All @@ -766,7 +764,15 @@ def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs):
self.initialize_for_run()
else:
self.logs.append({"task": task})
if stream:
return self.stream_run(task)
else:
return self.direct_run(task)

def stream_run(self, task: str):
"""
Runs the agent in streaming mode, yielding steps as they are executed: should be launched only in the `run` method.
"""
final_answer = None
iteration = 0
while final_answer is None and iteration < self.max_iterations:
Expand All @@ -779,25 +785,47 @@ def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs):
self.logs[-1]["error"] = e
finally:
iteration += 1
if stream:
yield self.logs[-1]

if iteration == self.max_iterations:
if final_answer is None:
error_message = "Reached max iterations."
final_step_log = {"error": AgentMaxIterationsError(error_message)}
self.logs.append(final_step_log)
self.logger.error(error_message, exc_info=1)
final_answer = self.provide_final_answer(task)
final_step_log["final_answer"] = final_answer
else:
final_step_log = self.logs[-1]

if stream:
yield self.logs[-1]

if final_answer is None and iteration == self.max_iterations:
error_message = "Reached max iterations."
final_step_log = {"error": AgentMaxIterationsError(error_message)}
self.logs.append(final_step_log)
self.logger.error(error_message, exc_info=1)
final_answer = self.provide_final_answer(task)
final_step_log["final_answer"] = final_answer
yield final_step_log
else:
return final_answer

yield final_answer

def direct_run(self, task: str):
"""
Runs the agent in direct mode, returning outputs only at the end: should be launched only in the `run` method.
"""
final_answer = None
iteration = 0
while final_answer is None and iteration < self.max_iterations:
try:
if self.planning_interval is not None and iteration % self.planning_interval == 0:
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
step_logs = self.step()
if "final_answer" in step_logs:
final_answer = step_logs["final_answer"]
except AgentError as e:
self.logger.error(e, exc_info=1)
self.logs[-1]["error"] = e
finally:
iteration += 1

if final_answer is None and iteration == self.max_iterations:
error_message = "Reached max iterations."
final_step_log = {"error": AgentMaxIterationsError(error_message)}
self.logs.append(final_step_log)
self.logger.error(error_message, exc_info=1)
final_answer = self.provide_final_answer(task)
final_step_log["final_answer"] = final_answer

return final_answer

def planning_step(self, task, is_first_step: bool = False, iteration: int = None):
"""
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/agents/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ def stream_to_gradio(agent: ReactAgent, task: str, **kwargs):
for message in pull_message(step_log):
yield message


if isinstance(step_log["final_answer"], AgentText):
yield ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{step_log["final_answer"].to_string()}\n```")
yield ChatMessage(
role="assistant", content=f"**Final answer:**\n```\n{step_log['final_answer'].to_string()}\n```"
)
elif isinstance(step_log["final_answer"], AgentImage):
yield ChatMessage(
role="assistant",
Expand Down
8 changes: 4 additions & 4 deletions tests/agents/test_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
# limitations under the License.

import unittest
from transformers.agents.monitoring import stream_to_gradio
from transformers.agents.agents import ReactCodeAgent, ReactJsonAgent, AgentError

from transformers.agents.agent_types import AgentImage
from transformers.agents.agents import AgentError, ReactCodeAgent, ReactJsonAgent
from transformers.agents.monitoring import stream_to_gradio


class TestMonitoring(unittest.TestCase):
def test_streaming_agent_text_output(self):
Expand Down Expand Up @@ -76,9 +78,7 @@ def dummy_llm_engine(prompt, **kwargs):
outputs = list(stream_to_gradio(agent, task="Test task"))

# Check that the error message is yielded
print("OUTPUTTTTS", outputs)
self.assertEqual(len(outputs), 3)
final_message = outputs[-1]
self.assertEqual(final_message.role, "assistant")
self.assertIn("Simulated agent error", final_message.content)

0 comments on commit 99477b8

Please sign in to comment.