forked from valohai/qlearning-simple
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gambler.py
45 lines (38 loc) · 2 KB
/
gambler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from enums import *
import random
class Gambler:
def __init__(self, learning_rate=0.1, discount=0.95, exploration_rate=1.0, iterations=10000):
self.q_table = [[0,0,0,0,0], [0,0,0,0,0]] # Spreadsheet (Q-table) for rewards accounting
self.learning_rate = learning_rate # How much we appreciate new q-value over current
self.discount = discount # How much we appreciate future reward over current
self.exploration_rate = 1.0 # Initial exploration rate
self.exploration_delta = 1.0 / iterations # Shift from exploration to explotation
def get_next_action(self, state):
if random.random() > self.exploration_rate: # Explore (gamble) or exploit (greedy)
return self.greedy_action(state)
else:
return self.random_action()
def greedy_action(self, state):
# Is FORWARD reward is bigger?
if self.q_table[FORWARD][state] > self.q_table[BACKWARD][state]:
return FORWARD
# Is BACKWARD reward is bigger?
elif self.q_table[BACKWARD][state] > self.q_table[FORWARD][state]:
return BACKWARD
# Rewards are equal, take random action
return FORWARD if random.random() < 0.5 else BACKWARD
def random_action(self):
return FORWARD if random.random() < 0.5 else BACKWARD
def update(self, old_state, new_state, action, reward):
# Old Q-table value
old_value = self.q_table[action][old_state]
# What would be our best next action?
future_action = self.greedy_action(new_state)
# What is reward for the best next action?
future_reward = self.q_table[future_action][new_state]
# Main Q-table updating algorithm
new_value = old_value + self.learning_rate * (reward + self.discount * future_reward - old_value)
self.q_table[action][old_state] = new_value
# Finally shift our exploration_rate toward zero (less gambling)
if self.exploration_rate > 0:
self.exploration_rate -= self.exploration_delta