diff --git a/dqn_zoo/iqn/agent.py b/dqn_zoo/iqn/agent.py index 80ea8fb..1ce858a 100644 --- a/dqn_zoo/iqn/agent.py +++ b/dqn_zoo/iqn/agent.py @@ -45,7 +45,7 @@ def _sample_tau( rng_key: parts.PRNGKey, shape: Tuple[int, ...], -) -> Tuple[jnp.ndarray]: +) -> jnp.ndarray: """Samples tau values uniformly between 0 and 1.""" return jax.random.uniform(rng_key, shape=shape) @@ -76,9 +76,7 @@ def select_action(rng_key, network_params, s_t): tau_t = _sample_tau(tau_key, (1, tau_samples)) q_t = network.apply( network_params, apply_key, IqnInputs(s_t[None, ...], tau_t) - ).q_values[ - 0 - ] # pytype: disable=wrong-arg-types # jax-ndarray + ).q_values[0] a_t = distrax.EpsilonGreedy(q_t, exploration_epsilon).sample( seed=policy_key )