forked from All-Hands-AI/OpenHands
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathenv.py
129 lines (105 loc) Β· 4.65 KB
/
env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import re
import traceback
from datatypes import ParseError, StepOutput, TaskState
from tasks.base import Task
from openhands.controller.state.state import State
class SimplifiedEnv:
INVALID_INPUT_MESSAGE = (
"I don't understand your input. \n"
'If you want to execute code, please use <execute_ipython> YOUR_CODE_HERE </execute_ipython>.\n'
'If you want to give me an answer, please use <solution> YOUR_SOLUTION_HERE </solution>.\n'
'For example: The answer to the question is <solution> 42 </solution>. \n'
)
def __init__(self, agent_state: State, task: Task, task_config: dict[str, int]):
self.agent_state = agent_state
self.task = task
agent_action_count = {
'propose_solution': 0,
'use_tool': 0,
'invalid_action': 0,
}
# check if agent_state has attribute turn_info set
if hasattr(self.agent_state, 'propose_solution_count'):
agent_action_count['propose_solution'] = (
self.agent_state.propose_solution_count
)
self.task_state = TaskState(agent_action_count=agent_action_count)
self.task_config = task_config
def step(self, lm_message: str):
observation = self.handle_propose_solution(lm_message)
self.check_max_iteration()
turn_info = (
self.task_config['max_iterations'] - self.agent_state.iteration,
self.task_config['max_propose_solution']
- self.task_state.agent_action_count['propose_solution'],
)
output = StepOutput(
observation=observation,
success=self.task_state.success,
turn_info=turn_info,
)
self.agent_state.propose_solution_count = self.task_state.agent_action_count[
'propose_solution'
]
self.log_output(output)
return self.task_state
def handle_propose_solution(self, lm_message) -> str | None:
"""Propose answer to check the task success.
It might set self.state.finished = True if the task is successful.
"""
self.task_state.agent_action_count['propose_solution'] += 1
try:
parsed = self.parse_propose_solution(lm_message)
task_success = self.check_task_success(parsed['answer'])
if task_success:
self.task_state.finished = True
self.task_state.success = True
self.task_state.terminate_reason = 'task_success'
# NOTE: should not return the function now, because we need to log the output
# Set state.finished = True will terminate the episode
except ParseError:
return SimplifiedEnv.INVALID_INPUT_MESSAGE
except Exception:
error_traceback = traceback.format_exc()
return f'{error_traceback}'
def parse_propose_solution(self, lm_message: str) -> dict:
"""Define the parsing logic."""
lm_output = '\n' + lm_message + '\n'
answer = '\n'.join(
[
i.strip()
for i in re.findall(r'<solution>(.*?)</solution>', lm_output, re.DOTALL)
]
)
if answer == '':
raise ParseError('No answer found.')
return {'answer': answer}
def log_output(self, output: StepOutput) -> None:
if self.task_state.finished:
return
content = output.to_str()
self.task_state.latest_output = output.to_dict()
self.task_state.latest_output['content'] = content
def check_task_success(self, answer: str) -> bool:
# log_message.info(f"STUDENT ANSWER: [{answer}]")
# log_message.info(f"REFERENCE ANSWER: [{self.task.reference}]")
return self.task.success(answer)
def check_max_iteration(self):
"""Check if the agent has reached the max iteration limit.
It might set self.state.finished = True if the agent has reached the max iteration limit.
"""
if self.task_state.finished:
# ignore if the episode is already finished (e.g., task success)
return
if (
# propose solution > max output solution
self.task_state.agent_action_count['propose_solution']
>= self.task_config['max_propose_solution']
):
self.task_state.finished = True
self.task_state.success = False
self.task_state.terminate_reason = 'max_propose_steps'
elif self.agent_state.iteration >= self.task_config['max_iterations']:
self.task_state.finished = True
self.task_state.success = False
self.task_state.terminate_reason = 'max_iterations'