diff --git a/candle-examples/examples/reinforcement-learning/gym_env.rs b/candle-examples/examples/reinforcement-learning/gym_env.rs index 8868c1884d..a2b6652f87 100644 --- a/candle-examples/examples/reinforcement-learning/gym_env.rs +++ b/candle-examples/examples/reinforcement-learning/gym_env.rs @@ -42,7 +42,7 @@ impl GymEnv { /// Creates a new session of the specified OpenAI Gym environment. pub fn new(name: &str) -> Result { Python::with_gil(|py| { - let gym = py.import("gymnasium")?; + let gym = py.import_bound("gymnasium")?; let make = gym.getattr("make")?; let env = make.call1((name,))?; let action_space = env.getattr("action_space")?; @@ -66,10 +66,10 @@ impl GymEnv { /// Resets the environment, returning the observation tensor. pub fn reset(&self, seed: u64) -> Result { let state: Vec = Python::with_gil(|py| { - let kwargs = PyDict::new(py); + let kwargs = PyDict::new_bound(py); kwargs.set_item("seed", seed)?; - let state = self.env.call_method(py, "reset", (), Some(kwargs))?; - state.as_ref(py).get_item(0)?.extract() + let state = self.env.call_method_bound(py, "reset", (), Some(&kwargs))?; + state.bind(py).get_item(0)?.extract() }) .map_err(w)?; Tensor::new(state, &Device::Cpu) @@ -81,8 +81,10 @@ impl GymEnv { action: A, ) -> Result> { let (state, reward, terminated, truncated) = Python::with_gil(|py| { - let step = self.env.call_method(py, "step", (action.clone(),), None)?; - let step = step.as_ref(py); + let step = self + .env + .call_method_bound(py, "step", (action.clone(),), None)?; + let step = step.bind(py); let state: Vec = step.get_item(0)?.extract()?; let reward: f64 = step.get_item(1)?.extract()?; let terminated: bool = step.get_item(2)?.extract()?; diff --git a/candle-examples/examples/reinforcement-learning/vec_gym_env.rs b/candle-examples/examples/reinforcement-learning/vec_gym_env.rs index 8f8f30bd6b..e382ad76da 100644 --- a/candle-examples/examples/reinforcement-learning/vec_gym_env.rs +++ b/candle-examples/examples/reinforcement-learning/vec_gym_env.rs @@ -24,13 +24,13 @@ fn w(res: PyErr) -> candle::Error { impl VecGymEnv { pub fn new(name: &str, img_dir: Option<&str>, nprocesses: usize) -> Result { Python::with_gil(|py| { - let sys = py.import("sys")?; + let sys = py.import_bound("sys")?; let path = sys.getattr("path")?; let _ = path.call_method1( "append", ("candle-examples/examples/reinforcement-learning",), )?; - let gym = py.import("atari_wrappers")?; + let gym = py.import_bound("atari_wrappers")?; let make = gym.getattr("make")?; let env = make.call1((name, img_dir, nprocesses))?; let action_space = env.getattr("action_space")?; @@ -60,10 +60,10 @@ impl VecGymEnv { pub fn step(&self, action: Vec) -> Result { let (obs, reward, is_done) = Python::with_gil(|py| { - let step = self.env.call_method(py, "step", (action,), None)?; - let step = step.as_ref(py); + let step = self.env.call_method_bound(py, "step", (action,), None)?; + let step = step.bind(py); let obs = step.get_item(0)?.call_method("flatten", (), None)?; - let obs_buffer = pyo3::buffer::PyBuffer::get(obs)?; + let obs_buffer = pyo3::buffer::PyBuffer::get_bound(&obs)?; let obs: Vec = obs_buffer.to_vec(py)?; let reward: Vec = step.get_item(1)?.extract()?; let is_done: Vec = step.get_item(2)?.extract()?;