Skip to content

Commit

Permalink
Re-merge
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed Nov 11, 2024
1 parent b438e6a commit d46bba3
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 42 deletions.
66 changes: 47 additions & 19 deletions src/transformers/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,8 +748,6 @@ def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs):
Args:
task (`str`): The task to perform
stream (`bool`, *optional*, defaults to `False`): Whether to stream the logs of the agent's interactions.
reset (`bool`, *optional*, defaults to `True`): Whether to reset the agent's state before running it.
Example:
```py
Expand All @@ -766,7 +764,15 @@ def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs):
self.initialize_for_run()
else:
self.logs.append({"task": task})
if stream:
return self.stream_run(task)
else:
return self.direct_run(task)

def stream_run(self, task: str):
"""
Runs the agent in streaming mode, yielding steps as they are executed: should be launched only in the `run` method.
"""
final_answer = None
iteration = 0
while final_answer is None and iteration < self.max_iterations:
Expand All @@ -779,25 +785,47 @@ def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs):
self.logs[-1]["error"] = e
finally:
iteration += 1
if stream:
yield self.logs[-1]

if iteration == self.max_iterations:
if final_answer is None:
error_message = "Reached max iterations."
final_step_log = {"error": AgentMaxIterationsError(error_message)}
self.logs.append(final_step_log)
self.logger.error(error_message, exc_info=1)
final_answer = self.provide_final_answer(task)
final_step_log["final_answer"] = final_answer
else:
final_step_log = self.logs[-1]

if stream:
yield self.logs[-1]

if final_answer is None and iteration == self.max_iterations:
error_message = "Reached max iterations."
final_step_log = {"error": AgentMaxIterationsError(error_message)}
self.logs.append(final_step_log)
self.logger.error(error_message, exc_info=1)
final_answer = self.provide_final_answer(task)
final_step_log["final_answer"] = final_answer
yield final_step_log
else:
return final_answer

yield final_answer

def direct_run(self, task: str):
"""
Runs the agent in direct mode, returning outputs only at the end: should be launched only in the `run` method.
"""
final_answer = None
iteration = 0
while final_answer is None and iteration < self.max_iterations:
try:
if self.planning_interval is not None and iteration % self.planning_interval == 0:
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
step_logs = self.step()
if "final_answer" in step_logs:
final_answer = step_logs["final_answer"]
except AgentError as e:
self.logger.error(e, exc_info=1)
self.logs[-1]["error"] = e
finally:
iteration += 1

if final_answer is None and iteration == self.max_iterations:
error_message = "Reached max iterations."
final_step_log = {"error": AgentMaxIterationsError(error_message)}
self.logs.append(final_step_log)
self.logger.error(error_message, exc_info=1)
final_answer = self.provide_final_answer(task)
final_step_log["final_answer"] = final_answer

return final_answer

def planning_step(self, task, is_first_step: bool = False, iteration: int = None):
"""
Expand Down
37 changes: 25 additions & 12 deletions src/transformers/agents/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,17 @@
from .agents import ReactAgent


def pull_message(step_log: dict):
def pull_message(step_log: dict, test_mode: bool = True):
try:
from gradio import ChatMessage
except ImportError:
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
if test_mode:

class ChatMessage:
role: str
content: dict
else:
raise ImportError("Gradio should be installed in order to launch a gradio demo.")

if step_log.get("rationale"):
yield ChatMessage(role="assistant", content=step_log["rationale"])
Expand All @@ -46,31 +52,38 @@ def pull_message(step_log: dict):
)


def stream_to_gradio(agent: ReactAgent, task: str, **kwargs):
def stream_to_gradio(agent: ReactAgent, task: str, test_mode: bool = False, **kwargs):
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""

try:
from gradio import ChatMessage
except ImportError:
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
if test_mode:

class ChatMessage:
role: str
content: dict
else:
raise ImportError("Gradio should be installed in order to launch a gradio demo.")

for step_log in agent.run(task, stream=True, **kwargs):
if isinstance(step_log, dict):
for message in pull_message(step_log):
for message in pull_message(step_log, test_mode=test_mode):
yield message

final_answer = step_log # Last log is the run's final_answer

if isinstance(step_log["final_answer"], AgentText):
yield ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{step_log["final_answer"].to_string()}\n```")
elif isinstance(step_log["final_answer"], AgentImage):
if isinstance(final_answer, AgentText):
yield ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{final_answer.to_string()}\n```")
elif isinstance(final_answer, AgentImage):
yield ChatMessage(
role="assistant",
content={"path": step_log["final_answer"].to_string(), "mime_type": "image/png"},
content={"path": final_answer.to_string(), "mime_type": "image/png"},
)
elif isinstance(step_log["final_answer"], AgentAudio):
elif isinstance(final_answer, AgentAudio):
yield ChatMessage(
role="assistant",
content={"path": step_log["final_answer"].to_string(), "mime_type": "audio/wav"},
content={"path": final_answer.to_string(), "mime_type": "audio/wav"},
)
else:
yield ChatMessage(role="assistant", content=str(step_log["final_answer"]))
yield ChatMessage(role="assistant", content=str(final_answer))
26 changes: 15 additions & 11 deletions tests/agents/test_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,21 @@
# limitations under the License.

import unittest
from transformers.agents.monitoring import stream_to_gradio
from transformers.agents.agents import ReactCodeAgent, ReactJsonAgent, AgentError

from transformers.agents.agent_types import AgentImage
from transformers.agents.agents import AgentError, ReactCodeAgent, ReactJsonAgent
from transformers.agents.monitoring import stream_to_gradio

class TestMonitoring(unittest.TestCase):

class MonitoringTester(unittest.TestCase):
def test_streaming_agent_text_output(self):
# Create a dummy LLM engine that returns a final answer
def dummy_llm_engine(prompt, **kwargs):
return "final_answer('This is the final answer.')"
return """
Code:
````
final_answer('This is the final answer.')
```"""

agent = ReactCodeAgent(
tools=[],
Expand All @@ -31,7 +37,7 @@ def dummy_llm_engine(prompt, **kwargs):
)

# Use stream_to_gradio to capture the output
outputs = list(stream_to_gradio(agent, task="Test task"))
outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True))

# Check that the final output is a ChatMessage with the expected content
self.assertEqual(len(outputs), 3)
Expand All @@ -51,10 +57,10 @@ def dummy_llm_engine(prompt, **kwargs):
)

# Use stream_to_gradio to capture the output
outputs = list(stream_to_gradio(agent, task="Test task", image=AgentImage(value="path.png")))
outputs = list(stream_to_gradio(agent, task="Test task", image=AgentImage(value="path.png"), test_mode=True))

# Check that the final output is a ChatMessage with the expected content
self.assertEqual(len(outputs), 3)
self.assertEqual(len(outputs), 2)
final_message = outputs[-1]
self.assertEqual(final_message.role, "assistant")
self.assertIsInstance(final_message.content, dict)
Expand All @@ -64,7 +70,7 @@ def dummy_llm_engine(prompt, **kwargs):
def test_streaming_with_agent_error(self):
# Create a dummy LLM engine that raises an error
def dummy_llm_engine(prompt, **kwargs):
raise AgentError("Simulated agent error.")
raise AgentError("Simulated agent error")

agent = ReactCodeAgent(
tools=[],
Expand All @@ -73,12 +79,10 @@ def dummy_llm_engine(prompt, **kwargs):
)

# Use stream_to_gradio to capture the output
outputs = list(stream_to_gradio(agent, task="Test task"))
outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True))

# Check that the error message is yielded
print("OUTPUTTTTS", outputs)
self.assertEqual(len(outputs), 3)
final_message = outputs[-1]
self.assertEqual(final_message.role, "assistant")
self.assertIn("Simulated agent error", final_message.content)

0 comments on commit d46bba3

Please sign in to comment.