Skip to content

Commit

Permalink
Fix or ignore some pytype errors related to jnp.ndarray == jax.Array.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 511317999
  • Loading branch information
Dqn Zoo Team authored and jqdm committed Mar 2, 2023
1 parent c1d1125 commit 21b3d90
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions dqn_zoo/iqn/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def select_action(rng_key, network_params, s_t):
tau_t = _sample_tau(tau_key, (1, tau_samples))
q_t = network.apply(
network_params, apply_key, IqnInputs(s_t[None, ...], tau_t)
).q_values[0]
).q_values[
0
] # pytype: disable=wrong-arg-types # jax-ndarray
a_t = distrax.EpsilonGreedy(q_t, exploration_epsilon).sample(
seed=policy_key
)
Expand Down Expand Up @@ -188,15 +190,15 @@ def loss_fn(online_params, target_params, transitions, rng_key):
_, *apply_keys = jax.random.split(rng_key, 4)
dist_q_tm1 = network.apply(
online_params, apply_keys[0], IqnInputs(transitions.s_tm1, tau_tm1)
).q_dist
).q_dist # pytype: disable=wrong-arg-types # jax-ndarray
dist_q_t_selector = network.apply(
target_params,
apply_keys[1],
IqnInputs(transitions.s_t, tau_t_selector),
).q_dist
).q_dist # pytype: disable=wrong-arg-types # jax-ndarray
dist_q_target_t = network.apply(
target_params, apply_keys[2], IqnInputs(transitions.s_t, tau_t)
).q_dist
).q_dist # pytype: disable=wrong-arg-types # jax-ndarray
losses = _batch_quantile_q_learning(
dist_q_tm1,
tau_tm1,
Expand Down Expand Up @@ -229,7 +231,9 @@ def select_action(rng_key, network_params, s_t, exploration_epsilon):
tau_t = _sample_tau(sample_key, (1, tau_samples_policy))
q_t = network.apply(
network_params, apply_key, IqnInputs(s_t[None, ...], tau_t)
).q_values[0]
).q_values[
0
] # pytype: disable=wrong-arg-types # jax-ndarray
a_t = distrax.EpsilonGreedy(q_t, exploration_epsilon).sample(
seed=policy_key
)
Expand Down

0 comments on commit 21b3d90

Please sign in to comment.