-
Notifications
You must be signed in to change notification settings - Fork 0
/
ray_rllib源码解析.txt
245 lines (206 loc) · 8.36 KB
/
ray_rllib源码解析.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
#############################################################################################################Train
class RolloutWorker(ParallelIteratorWorker):
class WorkerSet()
#############################################################################################################Train
class Trainable:
def default_resource_request():默认资源配置
def train_buffered():
while 时间范围内 or 没有result:
result = self.train()
results.append(result)
if done SHOULD_CHECKPOINT RESULT_DUPLICATE max_buffer_length:break
return results
def train():
result = self.step()#trainer_cls.step()
return result
def step():===>需要被重写
def save_checkpoint(dir):===>需要被重写 保存模型
def load_checkpoint(checkpoint):===>需要被重写 加载模型
def setup():===>需要被重写 自定义初始化
<class 'ray.rllib.agents.callbacks.DefaultCallbacks'>
class DefaultCallbacks():
def on_episode_start():
def on_episode_step():
def on_episode_end():
def on_postprocess_trajectory():
def on_sample_end():
def on_learn_on_batch():
def on_train_result():
class Trainer(Trainable):
self._env_id = self._register_if_needed(env or config.get("env"))#env_name
@override
def default_resource_request()
重写资源配置
@override
def train():
重写train 加入重试
for _ in range(1+3):#最多重试四次 Trainable.train()
try:
result = Trainable.train(self)->调用一次父级train->self.step 注意:Trainable|Trainer 无 step
except:
raise Exception
else:
break
return result
@override
def setup():
#创建环境
self.env_creator = functools.partial(gym_env_creator, env_descriptor=env)#return gym.make(env)
#合并config
self.raw_user_config = config
self.config = Trainer.merge_trainer_configs(self._default_config,config)
#校验config
self._validate_config(self.config, trainer_obj_or_none=self)
#自定义中间回调
self.callbacks()#<class 'ray.rllib.agents.callbacks.DefaultCallbacks'>,
@override
def cleanup():
@override
def save_checkpoint():
@override
def load_checkpoint():
@DeveloperAPI
def _make_workers():
return WorkerSet(env_creator=env_creator,validate_env=validate_env,policy_class=policy_class, trainer_config=config,num_workers=num_workers,logdir=self.logdir)
def evaluate():
self.evaluation_workers.remote_workers()
def _sync_weights_to_workers:#同步参数到worker
weights = ray.put(self.workers.local_worker().save())
worker_set.foreach_worker(lambda w: w.restore(ray.get(weights)))
def compute_single_action()
pp = self.workers.local_worker().preprocessors[policy_id]
result = self.get_policy(policy_id).compute_single_action()
def compute_actions()
actions, states, infos = policy.compute_actions()
############################################################################################################# Loss
class QLoss:
def __init__():
q_target = r+gamma*n_step*mask
loss=q_select - q_target
self.td_error = q_t_selected - q_t_selected_target.detach()
self.loss = torch.mean(importance_weights.float() * huber_loss(self.td_error))#Huber Loss MSE+MAE
def build_q_losses(Policy,model,_,train_batch):
q_t, q_logits_t, q_probs_t, _ #Q-network evaluation.
q_tp1, q_logits_tp1, q_probs_tp1, _ #Target Q-network evaluation.
q_loss = QLoss(q_t_selected, q_logits_t_selected, q_tp1_best,q_probs_tp1_best,
train_batch[PRIO_WEIGHTS],
train_batch[SampleBatch.REWARDS],
train_batch[SampleBatch.DONES].float(), config["gamma"],
config["n_step"], config["num_atoms"], config["v_min"],
config["v_max"])
return q_loss.loss
#############################################################################################################Policy
class Policy():
def __init__()
self.observation_space
self.action_space
self.action_space_struct
self.global_timestep
self.exploration = self._create_exploration()
def _create_exploration():
config["Exploration"]
def compute_actions():
def compute_single_action(obs,state,prev_action,prev_reward):
out = self.compute_actions_from_input_dict()
def compute_actions_from_input_dict(input=SampleBatch:dict):
return compute_actions
def compute_log_likelihoods():
def postprocess_trajectory():return sample_batch
def learn_on_batch():
grads, grad_info = self.compute_gradients(samples)
self.apply_gradients(grads)
return grad_info
def compute_gradients():
def apply_gradients():
def get_weights():
def set_weights():
def get_exploration_info():return self.get_exploration_state()
def get_num_samples_loaded_into_buffer():
def learn_on_loaded_batch()
class ModelV2:
def __init__(obs_space,action_space,num_outputs,model_config,name,framework):
def forward(input_dict,state,seq_lens):
def value_function():
def custom_loss():
def from_batch():
class TorchPolicy():
def __init__(model: ModelV2):
super.__init__():
self.dist_class = action_distribution_class#动作分布类-》TorchDistributionWrapper logp entropy kl sample sampled_action_logp
self.model_gpu_towers.append(model_copy.to(self.devices[i]))
self.model = self.model_gpu_towers[0]
self.loss = loss = build_q_losses
#处理gpu + worker
def compute_actions():
return self._compute_action_helper():
def _compute_action_helper():
self.exploration.before_compute_actions()
dist_class = self.dist_class
dist_inputs, state_out = self.model(input_dict, state_batches,seq_lens)
class DQNTFPolicy():
pass
def build_policy_class(build_q_losses.optimizer_fn=adam_optimizer):
policy_cls(TorchPolicy)
model_cls = TorchModelV2()->ModelV2
def build_tf_policy():
policy_cls(DynamicTFPolicy)
############################################################################################################# 经验池
def LocalReplayBuffer():
def add_batch
def replay
def update_priorities
def ParallelRollouts():
workers.sync_weights()
rollouts = from_actors(workers.remote_workers())->ParallelIterator
def execution_plan(Trainer,WorkerSet,config,):
local_replay_buffer = LocalReplayBuffer # Generate and store
rollouts = ParallelRollouts(workers, mode="bulk_sync")
replay_op = Replay() # Read and train
train_op = Concurrently([store_op, replay_op], mode="round_robin", output_indexes=[1])
return StandardMetricsReporting(train_op, workers, config)
############################################################################################################# utils
def get_policy_class():if config["framework"] == "torch": return DQNTorchPolicy
#############################################################################################################入口
SimpleQTrainer = build_trainer(
name="SimpleQTrainer",
default_policy=SimpleQTFPolicy,
get_policy_class=get_policy_class,-> config['torch']->SimpleQTorchPolicy
execution_plan=execution_plan,
default_config=DEFAULT_CONFIG,#rllib/agent/dqn/simple_q/DEFAULT_CONFIG
)
GenericOffPolicyTrainer = SimpleQTrainer.with_updates(
name="GenericOffPolicyTrainer",
# No Policy preference.
default_policy=None,
get_policy_class=None,
# Use DQN's config and exec. plan as base for
# all other OffPolicy algos.
default_config=DEFAULT_CONFIG,##rllib/agent/dqn/DEFAULT_CONFIG
validate_config=validate_config,
execution_plan=execution_plan
)
DQNTrainer = GenericOffPolicyTrainer.with_updates(
name="DQN",
default_policy=DQNTFPolicy,->build_tf_policy
get_policy_class=get_policy_class,->DQNTorchPolicy->build_policy_class
default_config=DEFAULT_CONFIG,#rllib/agent/dqn/DEFAULT_CONFIG
)
def build_trainer(name,default_policy,get_policy_class,execution_plan,default_config):
base = add_mixins(Trainer, mixins) #mixins = [] base->Trainer
class trainer_cls(Trainer):
_name = name
_default_config = default_config or COMMON_CONFIG -> DEFAULT_CONFIG
_policy_class = default_policy | get_policy_class(config)
def __init__():
self.workers = self._make_workers()
def setup():
super().setup(config)
def _init(config,env_creator):
self._policy_class = default_policy->SimpleQTorchPolicy|SimpleQTFPolicy 第一次初始化->build_policy_class
self.workers = self._make_workers()
def step():Trainable|Trainer 的step在此处调用
step_results = next(self.train_exec_impl)->execution_plan(self.workers, config)
if __name__ == "__main__":
agent = dqn.DQNTrainer(_config, env=StockTradingEnv)
for epoch in range(0,10000000):
result = agent.train()