-
Notifications
You must be signed in to change notification settings - Fork 225
/
rlAgent.h
177 lines (144 loc) · 4.46 KB
/
rlAgent.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
/*
* http://github.com/dusty-nv/jetson-reinforcement
*/
#ifndef __REINFORCEMENT_LEARNING_AGENT_H_
#define __REINFORCEMENT_LEARNING_AGENT_H_
#include <stdio.h>
#include <stdint.h>
#include <string>
#include "aiAgent.h"
#include "pyTensor.h"
/**
* Default name of the Python module to load
*/
#define DEFAULT_RL_MODULE "DQN"
/**
* Default name of the Python function from the user's module
* which infers the next action from the current state.
* The expected function is of the form `def next_action(state):`
* where state is a pyTorch tensor containing the environment,
* and the function returns the predicted action.
*/
#define DEFAULT_NEXT_ACTION "next_action"
/**
* Default name of the Python function from the user's module
* which recieves rewards and performs training.
* The expected reward function is of the form
* `def next_reward(state, reward, new_episode):`, where the
* function returns the predicted action and accepts the reward.
*/
#define DEFAULT_NEXT_REWARD "next_reward"
/**
* Default name of the Python function for loading model checkpoints
*/
#define DEFAULT_LOAD_MODEL "load_model"
/**
* Default name of the Python function for saving model checkpoints
*/
#define DEFAULT_SAVE_MODEL "save_model"
/**
* Base class for deep reinforcement learning agent,
* using Python & pyTorch underneath with C FFI.
*/
class rlAgent : public aiAgent
{
public:
/**
* Create a new instance of a module for training an agent.
*/
static rlAgent* Create( uint32_t numInputs, uint32_t numActions,
const char* module=DEFAULT_RL_MODULE,
const char* nextAction=DEFAULT_NEXT_ACTION,
const char* nextReward=DEFAULT_NEXT_REWARD,
const char* loadModel=DEFAULT_LOAD_MODEL,
const char* saveModel=DEFAULT_SAVE_MODEL );
/**
* Create a new instance of a module for training an agent.
*/
static rlAgent* Create( uint32_t width, uint32_t height,
uint32_t channels, uint32_t numActions,
const char* module=DEFAULT_RL_MODULE,
const char* nextAction=DEFAULT_NEXT_ACTION,
const char* nextReward=DEFAULT_NEXT_REWARD,
const char* loadModel=DEFAULT_LOAD_MODEL,
const char* saveModel=DEFAULT_SAVE_MODEL );
/**
* Destructor
*/
virtual ~rlAgent();
/**
* From the input state, predict the next action
*/
virtual bool NextAction( Tensor* state, int* action );
/**
* Issue the next reward and training iteration
*/
virtual bool NextReward( float reward, bool end_episode );
/**
* Load model checkpoint
*/
virtual bool LoadCheckpoint( const char* filename );
/**
* Save model checkpoint
*/
virtual bool SaveCheckpoint( const char* filename );
/**
* Globally load Python scripting interpreter.
* LoadInterpreter is automatically called before tensors or scripts are run.
* It can optionally be called by the user at the beginning of their program to
* load Python at that time. It has internal protections to only be called once.
*/
static bool LoadInterpreter();
/**
* Load Python script module
*/
bool LoadModule( const char* module );
/**
* Load Python script module (with arguments)
*/
bool LoadModule( const char* module, int argc, char** argv );
/**
* GetType
*/
virtual TypeID GetType() const { return TYPE_RL; }
/**
* TypeID
*/
const TypeID TYPE_RL = TYPE_AI | (1 << 1);
protected:
rlAgent();
virtual bool Init( uint32_t width, uint32_t height, uint32_t channels,
uint32_t numActions, const char* module,
const char* nextAction, const char* nextReward,
const char* loadModel, const char* saveModel,
const char* optimizer="RMSprop", float learning_rate=0.001,
uint32_t replay_mem=10000, uint32_t batch_size=64, float gamma=0.9,
float epsilon_start=0.9, float epsilon_end=0.05, float epsilon_decay=200,
bool use_lstm=true, int lstm_size=256, bool allow_random=true, bool debug_mode=false);
#ifdef USE_LUA
lua_State* L; /**< Lua/Torch7 operating environment */
THCState* THC; /**< cutorch state */
#endif
//bool mNewEpisode;
uint32_t mInputWidth;
uint32_t mInputHeight;
uint32_t mNumInputs;
uint32_t mNumActions;
Tensor* mRewardTensor;
Tensor* mActionTensor;
enum
{
ACTION_FUNCTION = 0,
REWARD_FUNCTION,
LOAD_FUNCTION,
SAVE_FUNCTION,
NUM_FUNCTIONS
};
std::string mModuleName;
void* mModuleObj;
void* mFunction[NUM_FUNCTIONS];
void* mFunctionArgs[NUM_FUNCTIONS];
std::string mFunctionName[NUM_FUNCTIONS];
static bool scriptingLoaded;
};
#endif