diff --git a/src/utils.rs b/src/utils.rs index 8ec1f975..98d9e29e 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -12,10 +12,12 @@ use nix::sys::time::{TimeVal, TimeValLike}; use nix::unistd::read; use std::io::Write; use std::mem::size_of; +use std::os::unix::io::IntoRawFd; use std::os::unix::io::RawFd; use std::thread::sleep; -use std::time::{Duration, Instant, SystemTime}; +use std::time::{Duration, SystemTime}; use vmm_sys_util::epoll::{ControlOperation, Epoll, EpollEvent, EventSet}; +use vmm_sys_util::timerfd::TimerFd; use crate::common::{NitroCliErrorEnum, NitroCliFailure, NitroCliResult}; use crate::new_nitro_cli_failure; @@ -146,81 +148,110 @@ impl Console { disconnect_timeout_sec: Option, ) -> NitroCliResult<()> { // Initialize variables - let (epoll, mut events, mut epoll_timeout_us, mut start_epoll_time) = ( - Epoll::new().map_err(|e| { + let epoll = Epoll::new().map_err(|e| { + new_nitro_cli_failure!( + &format!("Failed to create epoll: {:?}", e), + NitroCliErrorEnum::EpollError + ) + })?; + + // Add console fd to epoll + epoll + .ctl( + ControlOperation::Add, + self.fd, + EpollEvent::new(EventSet::IN, self.fd as u64), + ) + .map_err(|e| { + new_nitro_cli_failure!( + &format!("Failed to add fd to epoll: {:?}", e), + NitroCliErrorEnum::EpollError + ) + })?; + + // If the function call provides a disconnect timeout, create a timerfd, + // arm it and then add it to epoll + if let Some(disconnect_timeout) = disconnect_timeout_sec { + // Create timerfd + let mut timerfd = TimerFd::new().map_err(|e| { new_nitro_cli_failure!( - &format!("Failed to create epoll: {:?}", e), + &format!("Failed to initialize timerfd: {:?}", e), NitroCliErrorEnum::EpollError ) - })?, - [EpollEvent::default(); 1], - 0, - Instant::now(), - ); - - if disconnect_timeout_sec.is_some() { - let epoll_event = EpollEvent::new(EventSet::IN, self.fd as u64); + })?; + + // Arm timerfd with disconnect_timeout seconds + timerfd + .reset(Duration::from_secs(disconnect_timeout), None) + .map_err(|e| { + new_nitro_cli_failure!( + &format!("Failed to arm timerfd: {:?}", e), + NitroCliErrorEnum::EpollError + ) + })?; + + // Add timerfd fd to epoll + let timerfd_fd = timerfd.into_raw_fd(); epoll - .ctl(ControlOperation::Add, self.fd, epoll_event) + .ctl( + ControlOperation::Add, + timerfd_fd, + EpollEvent::new(EventSet::IN, timerfd_fd as u64), + ) .map_err(|e| { new_nitro_cli_failure!( &format!("Failed to add fd to epoll: {:?}", e), NitroCliErrorEnum::EpollError ) })?; - - // Disconnect timeout in microseconds - epoll_timeout_us = (disconnect_timeout_sec.unwrap_or(0) * 1000 * 1000) as i128; } - loop { - if disconnect_timeout_sec.is_some() { - start_epoll_time = Instant::now(); - - // Use epoll_wait to exit the blocking state when the fd is ready to be read - // or when the disconnect time has passed - let num_events = epoll - .wait((epoll_timeout_us / 1000) as i32, &mut events) - .map_err(|e| { - new_nitro_cli_failure!( - &format!("Failed to wait on epoll: {:?}", e), - NitroCliErrorEnum::EpollError - ) - })?; - - // If the timeout expires, no event happend and the console disconnects - if num_events == 0 { - break; - } - } + // Allow only one epoll event to happen at a given time + let mut events = [EpollEvent::default(); 1]; - let mut buffer = [0u8; BUFFER_SIZE]; - let size = read(self.fd, &mut buffer).map_err(|e| { + loop { + // Wait for kernel notification that one of the fds is available + let num_events = epoll.wait(-1, &mut events).map_err(|e| { new_nitro_cli_failure!( - &format!("Failed to read data from the console: {:?}", e), - NitroCliErrorEnum::EnclaveConsoleReadError + &format!("Failed to wait epoll: {:?}", e), + NitroCliErrorEnum::EpollError ) })?; - if size == 0 { - break; - } - - if size > 0 { - output.write(&buffer[..size]).map_err(|e| { - new_nitro_cli_failure!( - &format!( - "Failed to write data from the console to the given stream: {:?}", - e - ), - NitroCliErrorEnum::EnclaveConsoleWriteOutputError - ) - })?; - } + // Check if any event triggered, because an interrupt could unblock the wait + // without any of the requested events to occur + if num_events == 1 { + match events[0].fd() { + // Check if console fd triggered + fd if fd == self.fd => { + let mut buffer = [0u8; BUFFER_SIZE]; + let size = read(self.fd, &mut buffer).map_err(|e| { + new_nitro_cli_failure!( + &format!("Failed to read data from the console: {:?}", e), + NitroCliErrorEnum::EnclaveConsoleReadError + ) + })?; + + if size == 0 { + break; + } - // Account for the read/write/epoll_wait elapsed time - if disconnect_timeout_sec.is_some() { - epoll_timeout_us -= start_epoll_time.elapsed().as_micros() as i128; + if size > 0 { + output.write(&buffer[..size]).map_err(|e| { + new_nitro_cli_failure!( + &format!( + "Failed to write data from the \ + console to the given stream: {:?}", + e + ), + NitroCliErrorEnum::EnclaveConsoleWriteOutputError + ) + })?; + } + } + // Check if timerfd triggered + _ => break, + } } }