Skip to content

Commit

Permalink
Replace deco with standard multiprocessing pool
Browse files Browse the repository at this point in the history
  • Loading branch information
notadamking committed Jul 6, 2019
1 parent afacd02 commit 9d4ad1c
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 71 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,17 @@ python ./cli.py --help
or simply run the project with default options:

```bash
python ./cli.py optimize-train-test
python ./cli.py optimize
```

If you have a standard set of configs you want to run the trader against, you can specify a config file to load configuration from. Rename config/config.ini.dist to config/config.ini and run

```bash
python ./cli.py --from-config config/config.ini optimize-train-test
python ./cli.py --from-config config/config.ini optimize
```

```bash
python ./cli.py optimize-train-test
python ./cli.py optimize
```

### Testing with vagrant
Expand All @@ -92,7 +92,7 @@ Note: With vagrant you cannot take full advantage of your GPU, so is mainly for
If you want to run everything within a docker container, then just use:

```bash
./run-with-docker (cpu|gpu) (yes|no) optimize-train-test
./run-with-docker (cpu|gpu) (yes|no) optimize
```

- cpu - start the container using CPU requirements
Expand All @@ -101,7 +101,7 @@ If you want to run everything within a docker container, then just use:
Note: in case using yes as second argument, use

```bash
python ./ cli.py --params-db-path "postgres://rl_trader:rl_trader@localhost" optimize-train-test
python ./ cli.py --params-db-path "postgres://rl_trader:rl_trader@localhost" optimize
```

The database and it's data are pesisted under `data/postgres` locally.
Expand Down
32 changes: 19 additions & 13 deletions cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np

from deco import concurrent
import multiprocessing

from lib.RLTrader import RLTrader
from lib.cli.RLTraderCLI import RLTraderCLI
Expand All @@ -12,27 +11,34 @@
args = trader_cli.get_args()


@concurrent(processes=args.parallel_jobs)
def run_concurrent_optimize(trader: RLTrader, args):
trader.optimize(args.trials, args.trials, args.parallel_jobs)
def run_concurrent_optimize():
trader = RLTrader(**vars(args))
trader.optimize(args.trials)


def concurrent_optimize():
processes = []
for i in range(args.parallel_jobs):
processes.append(multiprocessing.Process(target=run_concurrent_optimize, args=()))

print(processes)

for p in processes:
p.start()

for p in processes:
p.join()


if __name__ == '__main__':
logger = init_logger(__name__, show_debug=args.debug)
trader = RLTrader(**vars(args), logger=logger)

if args.command == 'optimize':
run_concurrent_optimize(trader, args)
concurrent_optimize()
elif args.command == 'train':
trader.train(n_epochs=args.epochs)
elif args.command == 'test':
trader.test(model_epoch=args.model_epoch, should_render=args.no_render)
elif args.command == 'optimize-train-test':
run_concurrent_optimize(trader, args)
trader.train(
n_epochs=args.train_epochs,
test_trained_model=args.no_test,
render_trained_model=args.no_render
)
elif args.command == 'update-static-data':
download_data_async()
6 changes: 2 additions & 4 deletions lib/RLTrader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
from os import path
from typing import Dict

from deco import concurrent
from stable_baselines.common.base_class import BaseRLModel
from stable_baselines.common.policies import BasePolicy, MlpPolicy
from stable_baselines.common.policies import BasePolicy, MlpLnLstmPolicy
from stable_baselines.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines.common import set_global_seeds
from stable_baselines import PPO2
Expand All @@ -31,7 +30,7 @@ class RLTrader:
data_provider = None
study_name = None

def __init__(self, modelClass: BaseRLModel = PPO2, policyClass: BasePolicy = MlpPolicy, exchange_args: Dict = {}, **kwargs):
def __init__(self, modelClass: BaseRLModel = PPO2, policyClass: BasePolicy = MlpLnLstmPolicy, exchange_args: Dict = {}, **kwargs):
self.logger = kwargs.get('logger', init_logger(__name__, show_debug=kwargs.get('show_debug', True)))

self.Model = modelClass
Expand Down Expand Up @@ -162,7 +161,6 @@ def optimize_params(self, trial, n_prune_evals_per_trial: int = 2, n_tests_per_e

return -1 * last_reward

@concurrent
def optimize(self, n_trials: int = 100, n_parallel_jobs: int = 1, *optimize_params):
try:
self.optuna_study.optimize(
Expand Down
6 changes: 0 additions & 6 deletions lib/cli/RLTraderCLI.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,6 @@ def __init__(self):

subparsers = self.parser.add_subparsers(help='Command', dest="command")

opt_train_test_parser = subparsers.add_parser('optimize-train-test', description='Optimize train and test')
opt_train_test_parser.add_argument('--trials', type=int, default=20, help='Number of trials')
opt_train_test_parser.add_argument('--train-epochs', type=int, default=10, help='Train for how many epochs')
opt_train_test_parser.add_argument('--no-render', action='store_false', help='Should render the model')
opt_train_test_parser.add_argument('--no-test', action='store_false', help='Should test the model')

optimize_parser = subparsers.add_parser('optimize', description='Optimize model parameters')
optimize_parser.add_argument('--trials', type=int, default=1, help='Number of trials')

Expand Down
1 change: 0 additions & 1 deletion requirements.base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,4 @@ statsmodels==0.10.0rc2
empyrical
ccxt
psycopg2
deco
configparser
42 changes: 0 additions & 42 deletions update_data.py

This file was deleted.

0 comments on commit 9d4ad1c

Please sign in to comment.