From 0096e0e384853bb1e2c61205bb584835d875ff9c Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Tue, 30 Apr 2024 21:06:19 +0100 Subject: [PATCH] fix: non-jittable code --- gymnax/environments/minatar/asterix.py | 2 +- gymnax/environments/minatar/freeway.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/gymnax/environments/minatar/asterix.py b/gymnax/environments/minatar/asterix.py index d42a786..099779b 100755 --- a/gymnax/environments/minatar/asterix.py +++ b/gymnax/environments/minatar/asterix.py @@ -332,7 +332,7 @@ def step_entities( return ( state.replace(entities=entities, move_timer=move_timer), reward, - bool(done > 0), + jnp.bool_(done > 0), ) diff --git a/gymnax/environments/minatar/freeway.py b/gymnax/environments/minatar/freeway.py index 192410e..2527fc6 100755 --- a/gymnax/environments/minatar/freeway.py +++ b/gymnax/environments/minatar/freeway.py @@ -199,7 +199,7 @@ def step_agent( win_cond = pos == 0 reward = win_cond * 1.0 pos = jax.lax.select(win_cond, 9, pos) - return state.replace(pos=pos, move_timer=move_timer), reward, win_cond.item() + return state.replace(pos=pos, move_timer=move_timer), reward, win_cond def step_cars(state: EnvState) -> EnvState: