diff --git a/dqn_zoo/iqn/agent.py b/dqn_zoo/iqn/agent.py index e1dc997..80ea8fb 100644 --- a/dqn_zoo/iqn/agent.py +++ b/dqn_zoo/iqn/agent.py @@ -76,7 +76,9 @@ def select_action(rng_key, network_params, s_t): tau_t = _sample_tau(tau_key, (1, tau_samples)) q_t = network.apply( network_params, apply_key, IqnInputs(s_t[None, ...], tau_t) - ).q_values[0] + ).q_values[ + 0 + ] # pytype: disable=wrong-arg-types # jax-ndarray a_t = distrax.EpsilonGreedy(q_t, exploration_epsilon).sample( seed=policy_key ) @@ -188,15 +190,15 @@ def loss_fn(online_params, target_params, transitions, rng_key): _, *apply_keys = jax.random.split(rng_key, 4) dist_q_tm1 = network.apply( online_params, apply_keys[0], IqnInputs(transitions.s_tm1, tau_tm1) - ).q_dist + ).q_dist # pytype: disable=wrong-arg-types # jax-ndarray dist_q_t_selector = network.apply( target_params, apply_keys[1], IqnInputs(transitions.s_t, tau_t_selector), - ).q_dist + ).q_dist # pytype: disable=wrong-arg-types # jax-ndarray dist_q_target_t = network.apply( target_params, apply_keys[2], IqnInputs(transitions.s_t, tau_t) - ).q_dist + ).q_dist # pytype: disable=wrong-arg-types # jax-ndarray losses = _batch_quantile_q_learning( dist_q_tm1, tau_tm1, @@ -229,7 +231,9 @@ def select_action(rng_key, network_params, s_t, exploration_epsilon): tau_t = _sample_tau(sample_key, (1, tau_samples_policy)) q_t = network.apply( network_params, apply_key, IqnInputs(s_t[None, ...], tau_t) - ).q_values[0] + ).q_values[ + 0 + ] # pytype: disable=wrong-arg-types # jax-ndarray a_t = distrax.EpsilonGreedy(q_t, exploration_epsilon).sample( seed=policy_key )