Skip to content

Commit

Permalink
Use OwnedFd and don't implement Clone.
Browse files Browse the repository at this point in the history
Cloning a vsock stream results in two copies of the same file
descriptor, which could result in it being closed twice when they are
dropped. Instead, use an OwnedFd, and have try_clone duplicate it.
  • Loading branch information
qwandor committed Nov 27, 2023
1 parent 68dd544 commit f73fa01
Showing 1 changed file with 61 additions and 53 deletions.
114 changes: 61 additions & 53 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,29 @@ use nix::{
sockopt::{ReceiveTimeout, SendTimeout, SocketError},
AddressFamily, GetSockOpt, MsgFlags, SetSockOpt, SockFlag, SockType,
},
unistd::close,
};
use std::fs::File;
use std::io::{Error, ErrorKind, Read, Result, Write};
use std::mem::{self, size_of};
use std::mem::size_of;
use std::net::Shutdown;
use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
use std::time::Duration;
use std::{fs::File, os::fd::OwnedFd};
use std::{
io::{Error, ErrorKind, Read, Result, Write},
os::fd::{AsFd, BorrowedFd},
};

pub use libc::{VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR, VMADDR_CID_LOCAL};
pub use nix::sys::socket::{SockaddrLike, VsockAddr};

fn new_socket() -> Result<RawFd> {
Ok(socket(
fn new_socket() -> Result<OwnedFd> {
let fd = socket(
AddressFamily::Vsock,
SockType::Stream,
SockFlag::SOCK_CLOEXEC,
None,
)?)
)?;
// SAFETY: We just created a new file descriptor, so we can take ownership of it.
unsafe { Ok(OwnedFd::from_raw_fd(fd)) }
}

/// An iterator that infinitely accepts connections on a VsockListener.
Expand All @@ -64,9 +68,9 @@ impl<'a> Iterator for Incoming<'a> {
}

/// A virtio socket server, listening for connections.
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct VsockListener {
socket: RawFd,
socket: OwnedFd,
}

impl VsockListener {
Expand All @@ -81,10 +85,10 @@ impl VsockListener {

let socket = new_socket()?;

bind(socket, addr)?;
bind(socket.as_raw_fd(), addr)?;

// rust stdlib uses a 128 connection backlog
listen(socket, 128)?;
listen(socket.as_raw_fd(), 128)?;

Ok(Self { socket })
}
Expand All @@ -96,12 +100,14 @@ impl VsockListener {

/// The local socket address of the listener.
pub fn local_addr(&self) -> Result<VsockAddr> {
Ok(getsockname(self.socket)?)
Ok(getsockname(self.socket.as_raw_fd())?)
}

/// Create a new independently owned handle to the underlying socket.
pub fn try_clone(&self) -> Result<Self> {
Ok(self.clone())
Ok(Self {
socket: self.socket.try_clone()?,
})
}

/// Accept a new incoming connection from this listener.
Expand All @@ -116,7 +122,7 @@ impl VsockListener {
let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t;
let socket = unsafe {
accept4(
self.socket,
self.socket.as_raw_fd(),
&mut vsock_addr as *mut _ as *mut sockaddr,
&mut vsock_addr_len,
SOCK_CLOEXEC,
Expand All @@ -139,7 +145,7 @@ impl VsockListener {

/// Retrieve the latest error associated with the underlying socket.
pub fn take_error(&self) -> Result<Option<Error>> {
let error = SocketError.get(self.socket)?;
let error = SocketError.get(self.socket.as_raw_fd())?;
Ok(if error == 0 {
None
} else {
Expand All @@ -150,7 +156,7 @@ impl VsockListener {
/// Move this stream in and out of nonblocking mode.
pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
let mut nonblocking: i32 = if nonblocking { 1 } else { 0 };
if unsafe { ioctl(self.socket, FIONBIO, &mut nonblocking) } < 0 {
if unsafe { ioctl(self.socket.as_raw_fd(), FIONBIO, &mut nonblocking) } < 0 {
Err(Error::last_os_error())
} else {
Ok(())
Expand All @@ -160,34 +166,34 @@ impl VsockListener {

impl AsRawFd for VsockListener {
fn as_raw_fd(&self) -> RawFd {
self.socket
self.socket.as_raw_fd()
}
}

impl AsFd for VsockListener {
fn as_fd(&self) -> BorrowedFd {
self.socket.as_fd()
}
}

impl FromRawFd for VsockListener {
unsafe fn from_raw_fd(socket: RawFd) -> Self {
Self { socket }
Self {
socket: OwnedFd::from_raw_fd(socket),
}
}
}

impl IntoRawFd for VsockListener {
fn into_raw_fd(self) -> RawFd {
let fd = self.socket;
mem::forget(self);
fd
}
}

impl Drop for VsockListener {
fn drop(&mut self) {
let _ = close(self.socket);
self.socket.into_raw_fd()
}
}

/// A virtio stream between a local and a remote socket.
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct VsockStream {
socket: RawFd,
socket: OwnedFd,
}

impl VsockStream {
Expand All @@ -200,9 +206,9 @@ impl VsockStream {
));
}

let sock = new_socket()?;
connect(sock, addr)?;
Ok(unsafe { Self::from_raw_fd(sock) })
let socket = new_socket()?;
connect(socket.as_raw_fd(), addr)?;
Ok(Self { socket })
}

/// Open a connection to a remote host with specified cid and port.
Expand All @@ -212,12 +218,12 @@ impl VsockStream {

/// Virtio socket address of the remote peer associated with this connection.
pub fn peer_addr(&self) -> Result<VsockAddr> {
Ok(getpeername(self.socket)?)
Ok(getpeername(self.socket.as_raw_fd())?)
}

/// Virtio socket address of the local address associated with this connection.
pub fn local_addr(&self) -> Result<VsockAddr> {
Ok(getsockname(self.socket)?)
Ok(getsockname(self.socket.as_raw_fd())?)
}

/// Shutdown the read, write, or both halves of this connection.
Expand All @@ -227,29 +233,31 @@ impl VsockStream {
Shutdown::Read => socket::Shutdown::Read,
Shutdown::Both => socket::Shutdown::Both,
};
Ok(shutdown(self.socket, how)?)
Ok(shutdown(self.socket.as_raw_fd(), how)?)
}

/// Create a new independently owned handle to the underlying socket.
pub fn try_clone(&self) -> Result<Self> {
Ok(self.clone())
Ok(Self {
socket: self.socket.try_clone()?,
})
}

/// Set the timeout on read operations.
pub fn set_read_timeout(&self, dur: Option<Duration>) -> Result<()> {
let timeout = Self::timeval_from_duration(dur)?.into();
Ok(ReceiveTimeout.set(self.socket, &timeout)?)
Ok(ReceiveTimeout.set(self.socket.as_raw_fd(), &timeout)?)
}

/// Set the timeout on write operations.
pub fn set_write_timeout(&self, dur: Option<Duration>) -> Result<()> {
let timeout = Self::timeval_from_duration(dur)?.into();
Ok(SendTimeout.set(self.socket, &timeout)?)
Ok(SendTimeout.set(self.socket.as_raw_fd(), &timeout)?)
}

/// Retrieve the latest error associated with the underlying socket.
pub fn take_error(&self) -> Result<Option<Error>> {
let error = SocketError.get(self.socket)?;
let error = SocketError.get(self.socket.as_raw_fd())?;
Ok(if error == 0 {
None
} else {
Expand All @@ -260,7 +268,7 @@ impl VsockStream {
/// Move this stream in and out of nonblocking mode.
pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
let mut nonblocking: i32 = if nonblocking { 1 } else { 0 };
if unsafe { ioctl(self.socket, FIONBIO, &mut nonblocking) } < 0 {
if unsafe { ioctl(self.socket.as_raw_fd(), FIONBIO, &mut nonblocking) } < 0 {
Err(Error::last_os_error())
} else {
Ok(())
Expand Down Expand Up @@ -319,13 +327,13 @@ impl Write for VsockStream {

impl Read for &VsockStream {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
Ok(recv(self.socket, buf, MsgFlags::empty())?)
Ok(recv(self.socket.as_raw_fd(), buf, MsgFlags::empty())?)
}
}

impl Write for &VsockStream {
fn write(&mut self, buf: &[u8]) -> Result<usize> {
Ok(send(self.socket, buf, MsgFlags::MSG_NOSIGNAL)?)
Ok(send(self.socket.as_raw_fd(), buf, MsgFlags::MSG_NOSIGNAL)?)
}

fn flush(&mut self) -> Result<()> {
Expand All @@ -335,27 +343,27 @@ impl Write for &VsockStream {

impl AsRawFd for VsockStream {
fn as_raw_fd(&self) -> RawFd {
self.socket
self.socket.as_raw_fd()
}
}

impl AsFd for VsockStream {
fn as_fd(&self) -> BorrowedFd {
self.socket.as_fd()
}
}

impl FromRawFd for VsockStream {
unsafe fn from_raw_fd(socket: RawFd) -> Self {
Self { socket }
Self {
socket: OwnedFd::from_raw_fd(socket),
}
}
}

impl IntoRawFd for VsockStream {
fn into_raw_fd(self) -> RawFd {
let fd = self.socket;
mem::forget(self);
fd
}
}

impl Drop for VsockStream {
fn drop(&mut self) {
let _ = close(self.socket);
self.socket.into_raw_fd()
}
}

Expand Down

0 comments on commit f73fa01

Please sign in to comment.