forked from Albert-Z-Guo/Deep-Reinforcement-Stock-Trading
-
Notifications
You must be signed in to change notification settings - Fork 1
/
evaluate.py
94 lines (76 loc) · 3.61 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import argparse
import importlib
import logging
import sys
import numpy as np
# np.random.seed(3) # for reproducible Keras operations
from utils import *
parser = argparse.ArgumentParser(description='command line options')
parser.add_argument('--model_to_load', action="store", dest="model_to_load", default='DQN_ep10', help="model name")
parser.add_argument('--stock_name', action="store", dest="stock_name", default='^GSPC_2018', help="stock name")
parser.add_argument('--initial_balance', action="store", dest="initial_balance", default=50000, type=int, help='initial balance')
inputs = parser.parse_args()
model_to_load = inputs.model_to_load
model_name = model_to_load.split('_')[0]
stock_name = inputs.stock_name
initial_balance = inputs.initial_balance
display = True
window_size = 10
action_dict = {0: 'Hold', 1: 'Hold', 2: 'Sell'}
# select evaluation model
model = importlib.import_module('agents.{}'.format(model_name))
def hold():
logger.info('Hold')
def buy(t):
agent.balance -= stock_prices[t]
agent.inventory.append(stock_prices[t])
agent.buy_dates.append(t)
logger.info('Buy: ${:.2f}'.format(stock_prices[t]))
def sell(t):
agent.balance += stock_prices[t]
bought_price = agent.inventory.pop(0)
profit = stock_prices[t] - bought_price
global reward
reward = profit
agent.sell_dates.append(t)
logger.info('Sell: ${:.2f} | Profit: ${:.2f}'.format(stock_prices[t], profit))
# configure logger
logger = logging.getLogger()
handler = logging.FileHandler('logs/{}_evaluation_{}.log'.format(model_name, stock_name), mode='w')
handler.setFormatter(logging.Formatter(fmt='[%(asctime)s.%(msecs)03d %(filename)s:%(lineno)3s] %(message)s', datefmt='%m/%d/%Y %H:%M:%S'))
logger.addHandler(handler)
logger.setLevel(logging.INFO)
portfolio_return = 0
while portfolio_return == 0: # a hack to avoid stationary case
agent = model.Agent(state_dim=13, balance=initial_balance, is_eval=True, model_name=model_to_load)
stock_prices = stock_close_prices(stock_name)
trading_period = len(stock_prices) - 1
state = generate_combined_state(0, window_size, stock_prices, agent.balance, len(agent.inventory))
for t in range(1, trading_period + 1):
if model_name == 'DDPG':
actions = agent.act(state, t)
action = np.argmax(actions)
else:
actions = agent.model.predict(state)[0]
action = agent.act(state)
# print('actions:', actions)
# print('chosen action:', action)
next_state = generate_combined_state(t, window_size, stock_prices, agent.balance, len(agent.inventory))
previous_portfolio_value = len(agent.inventory) * stock_prices[t] + agent.balance
# execute position
logger.info('Step: {}'.format(t))
if action != np.argmax(actions): logger.info("\t\t'{}' is an exploration.".format(action_dict[action]))
if action == 0: hold() # hold
if action == 1 and agent.balance > stock_prices[t]: buy(t) # buy
if action == 2 and len(agent.inventory) > 0: sell(t) # sell
current_portfolio_value = len(agent.inventory) * stock_prices[t] + agent.balance
agent.return_rates.append((current_portfolio_value - previous_portfolio_value) / previous_portfolio_value)
agent.portfolio_values.append(current_portfolio_value)
state = next_state
done = True if t == trading_period else False
if done:
portfolio_return = evaluate_portfolio_performance(agent, logger)
if display:
# plot_portfolio_transaction_history(stock_name, agent)
# plot_portfolio_performance_comparison(stock_name, agent)
plot_all(stock_name, agent)