Skip to content

Commit

Permalink
Also adapt the RL example.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Apr 1, 2024
1 parent 64e0534 commit ad9b80e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
14 changes: 8 additions & 6 deletions candle-examples/examples/reinforcement-learning/gym_env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl GymEnv {
/// Creates a new session of the specified OpenAI Gym environment.
pub fn new(name: &str) -> Result<GymEnv> {
Python::with_gil(|py| {
let gym = py.import("gymnasium")?;
let gym = py.import_bound("gymnasium")?;
let make = gym.getattr("make")?;
let env = make.call1((name,))?;
let action_space = env.getattr("action_space")?;
Expand All @@ -66,10 +66,10 @@ impl GymEnv {
/// Resets the environment, returning the observation tensor.
pub fn reset(&self, seed: u64) -> Result<Tensor> {
let state: Vec<f32> = Python::with_gil(|py| {
let kwargs = PyDict::new(py);
let kwargs = PyDict::new_bound(py);
kwargs.set_item("seed", seed)?;
let state = self.env.call_method(py, "reset", (), Some(kwargs))?;
state.as_ref(py).get_item(0)?.extract()
let state = self.env.call_method_bound(py, "reset", (), Some(&kwargs))?;
state.bind(py).get_item(0)?.extract()
})
.map_err(w)?;
Tensor::new(state, &Device::Cpu)
Expand All @@ -81,8 +81,10 @@ impl GymEnv {
action: A,
) -> Result<Step<A>> {
let (state, reward, terminated, truncated) = Python::with_gil(|py| {
let step = self.env.call_method(py, "step", (action.clone(),), None)?;
let step = step.as_ref(py);
let step = self
.env
.call_method_bound(py, "step", (action.clone(),), None)?;
let step = step.bind(py);
let state: Vec<f32> = step.get_item(0)?.extract()?;
let reward: f64 = step.get_item(1)?.extract()?;
let terminated: bool = step.get_item(2)?.extract()?;
Expand Down
10 changes: 5 additions & 5 deletions candle-examples/examples/reinforcement-learning/vec_gym_env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ fn w(res: PyErr) -> candle::Error {
impl VecGymEnv {
pub fn new(name: &str, img_dir: Option<&str>, nprocesses: usize) -> Result<VecGymEnv> {
Python::with_gil(|py| {
let sys = py.import("sys")?;
let sys = py.import_bound("sys")?;
let path = sys.getattr("path")?;
let _ = path.call_method1(
"append",
("candle-examples/examples/reinforcement-learning",),
)?;
let gym = py.import("atari_wrappers")?;
let gym = py.import_bound("atari_wrappers")?;
let make = gym.getattr("make")?;
let env = make.call1((name, img_dir, nprocesses))?;
let action_space = env.getattr("action_space")?;
Expand Down Expand Up @@ -60,10 +60,10 @@ impl VecGymEnv {

pub fn step(&self, action: Vec<usize>) -> Result<Step> {
let (obs, reward, is_done) = Python::with_gil(|py| {
let step = self.env.call_method(py, "step", (action,), None)?;
let step = step.as_ref(py);
let step = self.env.call_method_bound(py, "step", (action,), None)?;
let step = step.bind(py);
let obs = step.get_item(0)?.call_method("flatten", (), None)?;
let obs_buffer = pyo3::buffer::PyBuffer::get(obs)?;
let obs_buffer = pyo3::buffer::PyBuffer::get_bound(&obs)?;
let obs: Vec<u8> = obs_buffer.to_vec(py)?;
let reward: Vec<f32> = step.get_item(1)?.extract()?;
let is_done: Vec<f32> = step.get_item(2)?.extract()?;
Expand Down

0 comments on commit ad9b80e

Please sign in to comment.