-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvalue_function.py
executable file
·41 lines (31 loc) · 1.22 KB
/
value_function.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from tabular_policy import TabularPolicy
class ValueFunction():
def update(self, state, value):
abstract
def merge(self, value_table):
abstract
def get_value(self, state):
abstract
""" Return the Q-value of action in state """
def get_q_value(self, mdp, state, action):
q_value = 0.0
for (new_state, probability) in mdp.get_transitions(state, action):
reward = mdp.get_reward(state, action, new_state)
q_value += probability * (
reward
+ (mdp.get_discount_factor() * self.get_value(new_state))
)
return q_value
""" Return a policy from this value function """
def extract_policy(self, mdp):
policy = TabularPolicy()
for state in mdp.get_states():
max_q = float("-inf")
for action in mdp.get_actions(state):
q_value = self.get_q_value(mdp, state, action)
# If this is the maximum Q-value so far,
# set the policy for this state
if q_value > max_q:
policy.update(state, action)
max_q = q_value
return policy