Skip to content

Commit

Permalink
Re-merge
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed Nov 4, 2024
1 parent b438e6a commit a0dd3bf
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 19 deletions.
66 changes: 47 additions & 19 deletions src/transformers/agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,8 +748,6 @@ def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs):
Args:
task (`str`): The task to perform
stream (`bool`, *optional*, defaults to `False`): Whether to stream the logs of the agent's interactions.
reset (`bool`, *optional*, defaults to `True`): Whether to reset the agent's state before running it.
Example:
```py
Expand All @@ -766,7 +764,15 @@ def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs):
self.initialize_for_run()
else:
self.logs.append({"task": task})
if stream:
return self.stream_run(task)
else:
return self.direct_run(task)

def stream_run(self, task: str):
"""
Runs the agent in streaming mode, yielding steps as they are executed: should be launched only in the `run` method.
"""
final_answer = None
iteration = 0
while final_answer is None and iteration < self.max_iterations:
Expand All @@ -779,25 +785,47 @@ def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs):
self.logs[-1]["error"] = e
finally:
iteration += 1
if stream:
yield self.logs[-1]

if iteration == self.max_iterations:
if final_answer is None:
error_message = "Reached max iterations."
final_step_log = {"error": AgentMaxIterationsError(error_message)}
self.logs.append(final_step_log)
self.logger.error(error_message, exc_info=1)
final_answer = self.provide_final_answer(task)
final_step_log["final_answer"] = final_answer
else:
final_step_log = self.logs[-1]

if stream:
yield self.logs[-1]

if final_answer is None and iteration == self.max_iterations:
error_message = "Reached max iterations."
final_step_log = {"error": AgentMaxIterationsError(error_message)}
self.logs.append(final_step_log)
self.logger.error(error_message, exc_info=1)
final_answer = self.provide_final_answer(task)
final_step_log["final_answer"] = final_answer
yield final_step_log
else:
return final_answer

yield final_answer

def direct_run(self, task: str):
"""
Runs the agent in direct mode, returning outputs only at the end: should be launched only in the `run` method.
"""
final_answer = None
iteration = 0
while final_answer is None and iteration < self.max_iterations:
try:
if self.planning_interval is not None and iteration % self.planning_interval == 0:
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
step_logs = self.step()
if "final_answer" in step_logs:
final_answer = step_logs["final_answer"]
except AgentError as e:
self.logger.error(e, exc_info=1)
self.logs[-1]["error"] = e
finally:
iteration += 1

if final_answer is None and iteration == self.max_iterations:
error_message = "Reached max iterations."
final_step_log = {"error": AgentMaxIterationsError(error_message)}
self.logs.append(final_step_log)
self.logger.error(error_message, exc_info=1)
final_answer = self.provide_final_answer(task)
final_step_log["final_answer"] = final_answer

return final_answer

def planning_step(self, task, is_first_step: bool = False, iteration: int = None):
"""
Expand Down
5 changes: 5 additions & 0 deletions tests/agents/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,8 @@ def test_agent_description_gets_correctly_inserted_in_system_prompt(self):
assert "You can also give requests to team members." not in agent.system_prompt
assert "<<managed_agents_descriptions>>" not in agent.system_prompt
assert "You can also give requests to team members." in manager_agent.system_prompt


if __name__ == "__main__":
tester = AgentTests()
tester.test_react_fails_max_iterations()

0 comments on commit a0dd3bf

Please sign in to comment.