Skip to content

Commit

Permalink
Use VecDeque<OwnedFd> and Vec<OwnedFd> for in_fds/out_fds
Browse files Browse the repository at this point in the history
This avoids some awkward and potentially error-prone code, and should
perform no worse than how the `Buffer` type was used.
  • Loading branch information
ids1024 committed Nov 1, 2023
1 parent edd0f60 commit 2a505ed
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 69 deletions.
58 changes: 19 additions & 39 deletions wayland-backend/src/rs/socket.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
//! Wayland socket manipulation
use std::collections::VecDeque;
use std::io::{ErrorKind, IoSlice, IoSliceMut, Result as IoResult};
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, IntoRawFd, OwnedFd, RawFd};
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
use std::os::unix::net::UnixStream;
use std::slice;

Expand Down Expand Up @@ -37,7 +38,7 @@ impl Socket {
/// The `fds` slice should not be longer than `MAX_FDS_OUT`, and the `bytes`
/// slice should not be longer than `MAX_BYTES_OUT` otherwise the receiving
/// end may lose some data.
pub fn send_msg(&self, bytes: &[u8], fds: &[RawFd]) -> IoResult<usize> {
pub fn send_msg(&self, bytes: &[u8], fds: &[OwnedFd]) -> IoResult<usize> {

Check warning on line 41 in wayland-backend/src/rs/socket.rs

View check run for this annotation

Codecov / codecov/patch

wayland-backend/src/rs/socket.rs#L41

Added line #L41 was not covered by tests
let flags = SendFlags::DONTWAIT | SendFlags::NOSIGNAL;
let iov = [IoSlice::new(bytes)];

Expand Down Expand Up @@ -65,7 +66,7 @@ impl Socket {
/// The `buffer` slice should be at least `MAX_BYTES_OUT` long and the `fds`
/// slice `MAX_FDS_OUT` long, otherwise some data of the received message may
/// be lost.
pub fn rcv_msg(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> IoResult<(usize, usize)> {
pub fn rcv_msg(&self, buffer: &mut [u8], fds: &mut VecDeque<OwnedFd>) -> IoResult<usize> {

Check warning on line 69 in wayland-backend/src/rs/socket.rs

View check run for this annotation

Codecov / codecov/patch

wayland-backend/src/rs/socket.rs#L69

Added line #L69 was not covered by tests
let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmRights(MAX_FDS_OUT))];
let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
let mut iov = [IoSliceMut::new(buffer)];
Expand All @@ -76,19 +77,15 @@ impl Socket {
RecvFlags::DONTWAIT | RecvFlags::CMSG_CLOEXEC,
)?;

let mut fd_count = 0;
let received_fds = cmsg_buffer
.drain()
.filter_map(|cmsg| match cmsg {
RecvAncillaryMessage::ScmRights(fds) => Some(fds),
_ => None,
})
.flatten();
for (fd, place) in received_fds.zip(fds.iter_mut()) {
fd_count += 1;
*place = fd.into_raw_fd();
}
Ok((msg.bytes, fd_count))
fds.extend(received_fds);
Ok(msg.bytes)

Check warning on line 88 in wayland-backend/src/rs/socket.rs

View check run for this annotation

Codecov / codecov/patch

wayland-backend/src/rs/socket.rs#L87-L88

Added lines #L87 - L88 were not covered by tests
}
}

Expand Down Expand Up @@ -120,9 +117,9 @@ impl AsRawFd for Socket {
pub struct BufferedSocket {
socket: Socket,
in_data: Buffer<u32>,
in_fds: Buffer<RawFd>,
in_fds: VecDeque<OwnedFd>,
out_data: Buffer<u32>,
out_fds: Buffer<RawFd>,
out_fds: Vec<OwnedFd>,
}

impl BufferedSocket {
Expand All @@ -131,9 +128,9 @@ impl BufferedSocket {
Self {
socket,
in_data: Buffer::new(2 * MAX_BYTES_OUT / 4), // Incoming buffers are twice as big in order to be
in_fds: Buffer::new(2 * MAX_FDS_OUT), // able to store leftover data if needed
in_fds: VecDeque::new(), // able to store leftover data if needed

Check warning on line 131 in wayland-backend/src/rs/socket.rs

View check run for this annotation

Codecov / codecov/patch

wayland-backend/src/rs/socket.rs#L131

Added line #L131 was not covered by tests
out_data: Buffer::new(MAX_BYTES_OUT / 4),
out_fds: Buffer::new(MAX_FDS_OUT),
out_fds: Vec::new(),

Check warning on line 133 in wayland-backend/src/rs/socket.rs

View check run for this annotation

Codecov / codecov/patch

wayland-backend/src/rs/socket.rs#L133

Added line #L133 was not covered by tests
}
}

Expand All @@ -147,13 +144,7 @@ impl BufferedSocket {
let bytes = unsafe {
::std::slice::from_raw_parts(words.as_ptr() as *const u8, words.len() * 4)
};
let fds = self.out_fds.get_contents();
let written = self.socket.send_msg(bytes, fds)?;
for &fd in fds {
// once the fds are sent, we can close them
unsafe { rustix::io::close(fd) };
}
written
self.socket.send_msg(bytes, &self.out_fds)?

Check warning on line 147 in wayland-backend/src/rs/socket.rs

View check run for this annotation

Codecov / codecov/patch

wayland-backend/src/rs/socket.rs#L147

Added line #L147 was not covered by tests
};
self.out_data.offset(written / 4);
self.out_data.move_to_front();
Expand All @@ -169,14 +160,9 @@ impl BufferedSocket {
// if false is returned, it means there is not enough space
// in the buffer
fn attempt_write_message(&mut self, msg: &Message<u32, RawFd>) -> IoResult<bool> {
match write_to_buffers(
msg,
self.out_data.get_writable_storage(),
self.out_fds.get_writable_storage(),
) {
Ok((bytes_out, fds_out)) => {
match write_to_buffers(msg, self.out_data.get_writable_storage(), &mut self.out_fds) {
Ok(bytes_out) => {

Check warning on line 164 in wayland-backend/src/rs/socket.rs

View check run for this annotation

Codecov / codecov/patch

wayland-backend/src/rs/socket.rs#L163-L164

Added lines #L163 - L164 were not covered by tests
self.out_data.advance(bytes_out);
self.out_fds.advance(fds_out);
Ok(true)
}
Err(MessageWriteError::BufferTooSmall) => Ok(false),
Expand Down Expand Up @@ -213,23 +199,20 @@ impl BufferedSocket {
pub fn fill_incoming_buffers(&mut self) -> IoResult<()> {
// reorganize the buffers
self.in_data.move_to_front();
self.in_fds.move_to_front();
// receive a message
let (in_bytes, in_fds) = {
let in_bytes = {

Check warning on line 203 in wayland-backend/src/rs/socket.rs

View check run for this annotation

Codecov / codecov/patch

wayland-backend/src/rs/socket.rs#L203

Added line #L203 was not covered by tests
let words = self.in_data.get_writable_storage();
let bytes = unsafe {
::std::slice::from_raw_parts_mut(words.as_ptr() as *mut u8, words.len() * 4)
};
let fds = self.in_fds.get_writable_storage();
self.socket.rcv_msg(bytes, fds)?
self.socket.rcv_msg(bytes, &mut self.in_fds)?

Check warning on line 208 in wayland-backend/src/rs/socket.rs

View check run for this annotation

Codecov / codecov/patch

wayland-backend/src/rs/socket.rs#L208

Added line #L208 was not covered by tests
};
if in_bytes == 0 {
// the other end of the socket was closed
return Err(rustix::io::Errno::PIPE.into());
}
// advance the storage
self.in_data.advance(in_bytes / 4 + usize::from(in_bytes % 4 > 0));
self.in_fds.advance(in_fds);
Ok(())
}

Expand All @@ -245,19 +228,16 @@ impl BufferedSocket {
where
F: FnMut(u32, u16) -> Option<&'static [ArgumentType]>,
{
let (msg, read_data, read_fd) = {
let (msg, read_data) = {

Check warning on line 231 in wayland-backend/src/rs/socket.rs

View check run for this annotation

Codecov / codecov/patch

wayland-backend/src/rs/socket.rs#L231

Added line #L231 was not covered by tests
let data = self.in_data.get_contents();
let fds = self.in_fds.get_contents();
if data.len() < 2 {
return Err(MessageParseError::MissingData);
}
let object_id = data[0];
let opcode = (data[1] & 0x0000_FFFF) as u16;
if let Some(sig) = signature(object_id, opcode) {
match parse_message(data, sig, fds) {
Ok((msg, rest_data, rest_fds)) => {
(msg, data.len() - rest_data.len(), fds.len() - rest_fds.len())
}
match parse_message(data, sig, &mut self.in_fds) {
Ok((msg, rest_data)) => (msg, data.len() - rest_data.len()),

Check warning on line 240 in wayland-backend/src/rs/socket.rs

View check run for this annotation

Codecov / codecov/patch

wayland-backend/src/rs/socket.rs#L239-L240

Added lines #L239 - L240 were not covered by tests
Err(e) => return Err(e),
}
} else {
Expand All @@ -267,7 +247,6 @@ impl BufferedSocket {
};

self.in_data.offset(read_data);
self.in_fds.offset(read_fd);

Ok(msg)
}
Expand Down Expand Up @@ -314,6 +293,7 @@ impl<T: Copy + Default> Buffer<T> {
///
/// This only sets the counter of occupied space back to zero,
/// allowing previous content to be overwritten.
#[allow(unused)]
fn clear(&mut self) {
self.occupied = 0;
self.offset = 0;
Expand Down
48 changes: 18 additions & 30 deletions wayland-backend/src/rs/wire.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
//! Types and routines used to manipulate arguments from the wire format
use std::collections::VecDeque;
use std::ffi::CStr;
use std::os::unix::io::RawFd;
use std::os::unix::io::{BorrowedFd, OwnedFd};
use std::os::unix::io::{FromRawFd, RawFd};
use std::ptr;
use std::{ffi::CStr, os::unix::prelude::AsRawFd};

use crate::protocol::{Argument, ArgumentType, Message};

Expand Down Expand Up @@ -72,10 +73,9 @@ impl std::fmt::Display for MessageParseError {
pub fn write_to_buffers(
msg: &Message<u32, RawFd>,
payload: &mut [u32],
mut fds: &mut [RawFd],
) -> Result<(usize, usize), MessageWriteError> {
fds: &mut Vec<OwnedFd>,
) -> Result<usize, MessageWriteError> {

Check warning on line 77 in wayland-backend/src/rs/wire.rs

View check run for this annotation

Codecov / codecov/patch

wayland-backend/src/rs/wire.rs#L76-L77

Added lines #L76 - L77 were not covered by tests
let orig_payload_len = payload.len();
let orig_fds_len = fds.len();
// Helper function to write a u32 or a RawFd to its buffer
fn write_buf<T>(u: T, payload: &mut [T]) -> Result<&mut [T], MessageWriteError> {
if let Some((head, tail)) = payload.split_first_mut() {
Expand Down Expand Up @@ -113,8 +113,6 @@ pub fn write_to_buffers(

let (header, mut payload) = payload.split_at_mut(2);

let mut pending_fds = Vec::new();

// write the contents in the buffer
for arg in &msg.args {
// Just to make the borrow checker happy
Expand All @@ -135,28 +133,19 @@ pub fn write_to_buffers(
payload = write_array_to_payload(a, old_payload)?;
}
Argument::Fd(fd) => {
let old_fds = fds;
let dup_fd = unsafe { BorrowedFd::borrow_raw(fd) }
.try_clone_to_owned()
.map_err(MessageWriteError::DupFdFailed)?;
let raw_dup_fd = dup_fd.as_raw_fd();
pending_fds.push(dup_fd);
fds = write_buf(raw_dup_fd, old_fds)?;
fds.push(dup_fd);

Check warning on line 139 in wayland-backend/src/rs/wire.rs

View check run for this annotation

Codecov / codecov/patch

wayland-backend/src/rs/wire.rs#L139

Added line #L139 was not covered by tests
payload = old_payload;
}
}
}

// if we reached here, the whole message was written successfully, we can drop the pending_fds they
// don't need to be closed, sendmsg will take ownership of them
for fd in pending_fds {
std::mem::forget(fd);
}

let wrote_size = (free_size - payload.len()) * 4;
header[0] = msg.sender_id;
header[1] = ((wrote_size as u32) << 16) | u32::from(msg.opcode);
Ok((orig_payload_len - payload.len(), orig_fds_len - fds.len()))
Ok(orig_payload_len - payload.len())

Check warning on line 148 in wayland-backend/src/rs/wire.rs

View check run for this annotation

Codecov / codecov/patch

wayland-backend/src/rs/wire.rs#L148

Added line #L148 was not covered by tests
}

/// Attempts to parse a single wayland message with the given signature.
Expand All @@ -167,11 +156,11 @@ pub fn write_to_buffers(
///
/// Errors if the message is malformed.
#[allow(clippy::type_complexity)]
pub fn parse_message<'a, 'b>(
pub fn parse_message<'a>(

Check warning on line 159 in wayland-backend/src/rs/wire.rs

View check run for this annotation

Codecov / codecov/patch

wayland-backend/src/rs/wire.rs#L159

Added line #L159 was not covered by tests
raw: &'a [u32],
signature: &[ArgumentType],
fds: &'b [RawFd],
) -> Result<(Message<u32, OwnedFd>, &'a [u32], &'b [RawFd]), MessageParseError> {
fds: &mut VecDeque<OwnedFd>,
) -> Result<(Message<u32, OwnedFd>, &'a [u32]), MessageParseError> {

Check warning on line 163 in wayland-backend/src/rs/wire.rs

View check run for this annotation

Codecov / codecov/patch

wayland-backend/src/rs/wire.rs#L162-L163

Added lines #L162 - L163 were not covered by tests
// helper function to read arrays
fn read_array_from_payload(
array_len: usize,
Expand Down Expand Up @@ -205,16 +194,14 @@ pub fn parse_message<'a, 'b>(

let (mut payload, rest) = raw.split_at(len);
payload = &payload[2..];
let mut fds = fds;

let arguments = signature
.iter()
.map(|argtype| {
if let ArgumentType::Fd = *argtype {
// don't consume input but fd
if let Some((&front, tail)) = fds.split_first() {
fds = tail;
Ok(Argument::Fd(unsafe { OwnedFd::from_raw_fd(front) }))
if let Some(front) = fds.pop_front() {
Ok(Argument::Fd(front))

Check warning on line 204 in wayland-backend/src/rs/wire.rs

View check run for this annotation

Codecov / codecov/patch

wayland-backend/src/rs/wire.rs#L203-L204

Added lines #L203 - L204 were not covered by tests
} else {
Err(MessageParseError::MissingFD)
}
Expand Down Expand Up @@ -255,7 +242,7 @@ pub fn parse_message<'a, 'b>(
.collect::<Result<SmallVec<_>, MessageParseError>>()?;

let msg = Message { sender_id, opcode, args: arguments };
Ok((msg, rest, fds))
Ok((msg, rest))

Check warning on line 245 in wayland-backend/src/rs/wire.rs

View check run for this annotation

Codecov / codecov/patch

wayland-backend/src/rs/wire.rs#L245

Added line #L245 was not covered by tests
}

#[cfg(test)]
Expand All @@ -268,7 +255,7 @@ mod tests {
#[test]
fn into_from_raw_cycle() {
let mut bytes_buffer = vec![0; 1024];
let mut fd_buffer = [0; 10];
let mut fd_buffer = Vec::new();

let msg = Message {
sender_id: 42,
Expand All @@ -284,9 +271,10 @@ mod tests {
],
};
// write the message to the buffers
write_to_buffers(&msg, &mut bytes_buffer[..], &mut fd_buffer[..]).unwrap();
write_to_buffers(&msg, &mut bytes_buffer[..], &mut fd_buffer).unwrap();
// read them back
let (rebuilt, _, _) = parse_message(
let mut fd_buffer = VecDeque::from(fd_buffer);
let (rebuilt, _) = parse_message(
&bytes_buffer[..],
&[
ArgumentType::Uint,
Expand All @@ -297,7 +285,7 @@ mod tests {
ArgumentType::NewId,
ArgumentType::Int,
],
&fd_buffer[..],
&mut fd_buffer,
)
.unwrap();
assert_eq!(rebuilt.map_fd(IntoRawFd::into_raw_fd), msg);
Expand Down

0 comments on commit 2a505ed

Please sign in to comment.