diff --git a/gymnax/environments/misc/rooms.py b/gymnax/environments/misc/rooms.py index fbb3f13..8aece95 100644 --- a/gymnax/environments/misc/rooms.py +++ b/gymnax/environments/misc/rooms.py @@ -111,7 +111,7 @@ def step_env( # Caculate the disocunted return discounted_reward = normal_reward - 0.9 *( state.time / params.max_steps_in_episode) - reward = jax.lax.cond(params.discounted_reward, lambda: discounted_reward, lambda: int(normal_reward)) + reward = jax.lax.cond(params.discounted_reward, lambda: discounted_reward, lambda: jnp.int32(normal_reward)) # Update state dict and evaluate termination conditions state = EnvState(new_pos, state.goal, state.time + 1) done = self.is_terminal(state, params)