Skip to content

Commit

Permalink
Fix type annotation in IQN agent revealed by jnp.ndarray == jax.Array.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 512042734
  • Loading branch information
jqdm committed Mar 2, 2023
1 parent 21b3d90 commit 289379e
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions dqn_zoo/iqn/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
def _sample_tau(
rng_key: parts.PRNGKey,
shape: Tuple[int, ...],
) -> Tuple[jnp.ndarray]:
) -> jnp.ndarray:
"""Samples tau values uniformly between 0 and 1."""
return jax.random.uniform(rng_key, shape=shape)

Expand Down Expand Up @@ -76,9 +76,7 @@ def select_action(rng_key, network_params, s_t):
tau_t = _sample_tau(tau_key, (1, tau_samples))
q_t = network.apply(
network_params, apply_key, IqnInputs(s_t[None, ...], tau_t)
).q_values[
0
] # pytype: disable=wrong-arg-types # jax-ndarray
).q_values[0]
a_t = distrax.EpsilonGreedy(q_t, exploration_epsilon).sample(
seed=policy_key
)
Expand Down

0 comments on commit 289379e

Please sign in to comment.