-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_best.py
85 lines (78 loc) · 3.24 KB
/
train_best.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import pandas as pd
import numpy as np
import pickle
import optuna
from train_fqi import read_dataset, generate_dataset, get_cli_args, prepare_dataset
from erl_config import build_env
from trade_simulator import TradeSimulator
from trlib.algorithms.reinforcement.fqi import FQI
from trlib.policies.qfunction import ZeroQ
from trlib.policies.valuebased import EpsilonGreedy
from sklearn.ensemble import ExtraTreesRegressor
def load_policy(model_path):
policy = pickle.load(open(model_path, "rb"))
return policy
args = get_cli_args()
n_windows = 8
start_day = 7
n_train_days = 1
n_validation_days = 1
base_dir = "sqlite:///."
base_out_dir = '.'
for window in range(n_windows):
study_path = base_dir + f"/trial_{window}_window/optuna_study.db"
out_dir = base_out_dir + f"/trial_{window}_window/"
loaded_study = optuna.load_study(study_name=None, storage=study_path)
start_day_train = start_day + window
end_day_train = start_day_train + n_train_days - 1
sample_days_train = [start_day_train, start_day_train + n_train_days - 1]
policies = ['random_policy', 'long_only_policy', 'short_only_policy', 'flat_only_policy']
dfs, dfs_unread = read_dataset(sample_days_train, policies=policies)
if len(dfs_unread) > 0:
dfs_train = generate_dataset(days_to_sample=sample_days_train,
max_steps=args.max_steps, episodes=args.train_episodes, policies=dfs_unread)
dfs += dfs_train
if len(dfs) > 0:
dfs = pd.concat(dfs)
else:
raise ValueError("No dataset!!")
state_actions, rewards, next_states, absorbing = prepare_dataset(dfs)
actions_values = [0, 1, 2]
np.random.seed()
seed = np.random.randint(100000)
max_steps = args.max_steps
env_args = {
"env_name": "TradeSimulator-v0",
"num_envs": 1,
"max_step": max_steps,
"state_dim": 8 + 2, # factor_dim + (position, holding)
"action_dim": 3, # long, 0, short
"if_discrete": True,
"max_position": 1,
"slippage": 7e-7,
"num_sims": 1,
"step_gap": 1,
"env_class": TradeSimulator,
"eval": True,
"days": [end_day_train + 1, end_day_train + n_validation_days]
}
eval_env = build_env(TradeSimulator, env_args, -1)
pi = EpsilonGreedy(actions_values, ZeroQ(), epsilon=0)
max_iterations = loaded_study.best_params['iterations']
n_estimators = loaded_study.best_params['n_estimators']
max_depth = loaded_study.best_params['max_depth']
min_split = loaded_study.best_params['min_samples_split']
algorithm = FQI(mdp=eval_env, policy=pi, actions=actions_values, batch_size=5, max_iterations=max_iterations,
regressor_type=ExtraTreesRegressor, random_state=seed, n_estimators=n_estimators, n_jobs=-1,
max_depth=max_depth, min_samples_split=min_split)
for i in range(max_iterations):
iteration = i + 1
algorithm._iter(
state_actions.to_numpy(dtype=np.float32),
rewards.to_numpy(dtype=np.float32),
next_states.to_numpy(dtype=np.float32),
absorbing,
)
model_name = out_dir + f'Policy_iter{iteration}.pkl'
with open(model_name, 'wb+') as f:
pickle.dump(algorithm._policy, f)