Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use VecDeque<OwnedFd> and Vec<OwnedFd> for in_fds/out_fds #666

Merged
merged 2 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 19 additions & 39 deletions wayland-backend/src/rs/socket.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
//! Wayland socket manipulation

use std::collections::VecDeque;
use std::io::{ErrorKind, IoSlice, IoSliceMut, Result as IoResult};
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, IntoRawFd, OwnedFd, RawFd};
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, OwnedFd, RawFd};
use std::os::unix::net::UnixStream;
use std::slice;

Expand Down Expand Up @@ -37,7 +38,7 @@ impl Socket {
/// The `fds` slice should not be longer than `MAX_FDS_OUT`, and the `bytes`
/// slice should not be longer than `MAX_BYTES_OUT` otherwise the receiving
/// end may lose some data.
pub fn send_msg(&self, bytes: &[u8], fds: &[RawFd]) -> IoResult<usize> {
pub fn send_msg(&self, bytes: &[u8], fds: &[OwnedFd]) -> IoResult<usize> {
let flags = SendFlags::DONTWAIT | SendFlags::NOSIGNAL;

if !fds.is_empty() {
Expand All @@ -64,7 +65,7 @@ impl Socket {
/// The `buffer` slice should be at least `MAX_BYTES_OUT` long and the `fds`
/// slice `MAX_FDS_OUT` long, otherwise some data of the received message may
/// be lost.
pub fn rcv_msg(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> IoResult<(usize, usize)> {
pub fn rcv_msg(&self, buffer: &mut [u8], fds: &mut VecDeque<OwnedFd>) -> IoResult<usize> {
let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmRights(MAX_FDS_OUT))];
let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
let mut iov = [IoSliceMut::new(buffer)];
Expand All @@ -75,19 +76,15 @@ impl Socket {
RecvFlags::DONTWAIT | RecvFlags::CMSG_CLOEXEC,
)?;

let mut fd_count = 0;
let received_fds = cmsg_buffer
.drain()
.filter_map(|cmsg| match cmsg {
RecvAncillaryMessage::ScmRights(fds) => Some(fds),
_ => None,
})
.flatten();
for (fd, place) in received_fds.zip(fds.iter_mut()) {
fd_count += 1;
*place = fd.into_raw_fd();
}
Ok((msg.bytes, fd_count))
fds.extend(received_fds);
Ok(msg.bytes)
}
}

Expand Down Expand Up @@ -119,9 +116,9 @@ impl AsRawFd for Socket {
pub struct BufferedSocket {
socket: Socket,
in_data: Buffer<u32>,
in_fds: Buffer<RawFd>,
in_fds: VecDeque<OwnedFd>,
out_data: Buffer<u32>,
out_fds: Buffer<RawFd>,
out_fds: Vec<OwnedFd>,
}

impl BufferedSocket {
Expand All @@ -130,9 +127,9 @@ impl BufferedSocket {
Self {
socket,
in_data: Buffer::new(2 * MAX_BYTES_OUT / 4), // Incoming buffers are twice as big in order to be
in_fds: Buffer::new(2 * MAX_FDS_OUT), // able to store leftover data if needed
in_fds: VecDeque::new(), // able to store leftover data if needed
out_data: Buffer::new(MAX_BYTES_OUT / 4),
out_fds: Buffer::new(MAX_FDS_OUT),
out_fds: Vec::new(),
}
}

Expand All @@ -146,13 +143,7 @@ impl BufferedSocket {
let bytes = unsafe {
::std::slice::from_raw_parts(words.as_ptr() as *const u8, words.len() * 4)
};
let fds = self.out_fds.get_contents();
let written = self.socket.send_msg(bytes, fds)?;
for &fd in fds {
// once the fds are sent, we can close them
unsafe { rustix::io::close(fd) };
}
written
self.socket.send_msg(bytes, &self.out_fds)?
};
self.out_data.offset(written / 4);
self.out_data.move_to_front();
Expand All @@ -168,14 +159,9 @@ impl BufferedSocket {
// if false is returned, it means there is not enough space
// in the buffer
fn attempt_write_message(&mut self, msg: &Message<u32, RawFd>) -> IoResult<bool> {
match write_to_buffers(
msg,
self.out_data.get_writable_storage(),
self.out_fds.get_writable_storage(),
) {
Ok((bytes_out, fds_out)) => {
match write_to_buffers(msg, self.out_data.get_writable_storage(), &mut self.out_fds) {
Ok(bytes_out) => {
self.out_data.advance(bytes_out);
self.out_fds.advance(fds_out);
Ok(true)
}
Err(MessageWriteError::BufferTooSmall) => Ok(false),
Expand Down Expand Up @@ -212,23 +198,20 @@ impl BufferedSocket {
pub fn fill_incoming_buffers(&mut self) -> IoResult<()> {
// reorganize the buffers
self.in_data.move_to_front();
self.in_fds.move_to_front();
// receive a message
let (in_bytes, in_fds) = {
let in_bytes = {
let words = self.in_data.get_writable_storage();
let bytes = unsafe {
::std::slice::from_raw_parts_mut(words.as_ptr() as *mut u8, words.len() * 4)
};
let fds = self.in_fds.get_writable_storage();
self.socket.rcv_msg(bytes, fds)?
self.socket.rcv_msg(bytes, &mut self.in_fds)?
};
if in_bytes == 0 {
// the other end of the socket was closed
return Err(rustix::io::Errno::PIPE.into());
}
// advance the storage
self.in_data.advance(in_bytes / 4 + usize::from(in_bytes % 4 > 0));
self.in_fds.advance(in_fds);
Ok(())
}

Expand All @@ -244,19 +227,16 @@ impl BufferedSocket {
where
F: FnMut(u32, u16) -> Option<&'static [ArgumentType]>,
{
let (msg, read_data, read_fd) = {
let (msg, read_data) = {
let data = self.in_data.get_contents();
let fds = self.in_fds.get_contents();
if data.len() < 2 {
return Err(MessageParseError::MissingData);
}
let object_id = data[0];
let opcode = (data[1] & 0x0000_FFFF) as u16;
if let Some(sig) = signature(object_id, opcode) {
match parse_message(data, sig, fds) {
Ok((msg, rest_data, rest_fds)) => {
(msg, data.len() - rest_data.len(), fds.len() - rest_fds.len())
}
match parse_message(data, sig, &mut self.in_fds) {
Ok((msg, rest_data)) => (msg, data.len() - rest_data.len()),
Err(e) => return Err(e),
}
} else {
Expand All @@ -266,7 +246,6 @@ impl BufferedSocket {
};

self.in_data.offset(read_data);
self.in_fds.offset(read_fd);

Ok(msg)
}
Expand Down Expand Up @@ -313,6 +292,7 @@ impl<T: Copy + Default> Buffer<T> {
///
/// This only sets the counter of occupied space back to zero,
/// allowing previous content to be overwritten.
#[allow(unused)]
fn clear(&mut self) {
self.occupied = 0;
self.offset = 0;
Expand Down
53 changes: 23 additions & 30 deletions wayland-backend/src/rs/wire.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
//! Types and routines used to manipulate arguments from the wire format

use std::collections::VecDeque;
use std::ffi::CStr;
use std::os::unix::io::RawFd;
use std::os::unix::io::{BorrowedFd, OwnedFd};
use std::os::unix::io::{FromRawFd, RawFd};
use std::ptr;
use std::{ffi::CStr, os::unix::prelude::AsRawFd};

use crate::protocol::{Argument, ArgumentType, Message};

Expand Down Expand Up @@ -72,10 +73,9 @@
pub fn write_to_buffers(
msg: &Message<u32, RawFd>,
payload: &mut [u32],
mut fds: &mut [RawFd],
) -> Result<(usize, usize), MessageWriteError> {
fds: &mut Vec<OwnedFd>,
) -> Result<usize, MessageWriteError> {
let orig_payload_len = payload.len();
let orig_fds_len = fds.len();
// Helper function to write a u32 or a RawFd to its buffer
fn write_buf<T>(u: T, payload: &mut [T]) -> Result<&mut [T], MessageWriteError> {
if let Some((head, tail)) = payload.split_first_mut() {
Expand Down Expand Up @@ -113,8 +113,6 @@

let (header, mut payload) = payload.split_at_mut(2);

let mut pending_fds = Vec::new();

// write the contents in the buffer
for arg in &msg.args {
// Just to make the borrow checker happy
Expand All @@ -135,28 +133,19 @@
payload = write_array_to_payload(a, old_payload)?;
}
Argument::Fd(fd) => {
let old_fds = fds;
let dup_fd = unsafe { BorrowedFd::borrow_raw(fd) }
.try_clone_to_owned()
.map_err(MessageWriteError::DupFdFailed)?;
let raw_dup_fd = dup_fd.as_raw_fd();
pending_fds.push(dup_fd);
fds = write_buf(raw_dup_fd, old_fds)?;
fds.push(dup_fd);
payload = old_payload;
}
}
}

// if we reached here, the whole message was written successfully, we can drop the pending_fds they
// don't need to be closed, sendmsg will take ownership of them
for fd in pending_fds {
std::mem::forget(fd);
}

let wrote_size = (free_size - payload.len()) * 4;
header[0] = msg.sender_id;
header[1] = ((wrote_size as u32) << 16) | u32::from(msg.opcode);
Ok((orig_payload_len - payload.len(), orig_fds_len - fds.len()))
Ok(orig_payload_len - payload.len())
}

/// Attempts to parse a single wayland message with the given signature.
Expand All @@ -167,11 +156,11 @@
///
/// Errors if the message is malformed.
#[allow(clippy::type_complexity)]
pub fn parse_message<'a, 'b>(
pub fn parse_message<'a>(
raw: &'a [u32],
signature: &[ArgumentType],
fds: &'b [RawFd],
) -> Result<(Message<u32, OwnedFd>, &'a [u32], &'b [RawFd]), MessageParseError> {
fds: &mut VecDeque<OwnedFd>,
) -> Result<(Message<u32, OwnedFd>, &'a [u32]), MessageParseError> {
// helper function to read arrays
fn read_array_from_payload(
array_len: usize,
Expand Down Expand Up @@ -203,18 +192,21 @@
return Err(MessageParseError::MissingData);
}

let fd_len = signature.iter().filter(|x| matches!(x, ArgumentType::Fd)).count();
if fd_len > fds.len() {
return Err(MessageParseError::MissingFD);

Check warning on line 197 in wayland-backend/src/rs/wire.rs

View check run for this annotation

Codecov / codecov/patch

wayland-backend/src/rs/wire.rs#L197

Added line #L197 was not covered by tests
}

let (mut payload, rest) = raw.split_at(len);
payload = &payload[2..];
let mut fds = fds;

let arguments = signature
.iter()
.map(|argtype| {
if let ArgumentType::Fd = *argtype {
// don't consume input but fd
if let Some((&front, tail)) = fds.split_first() {
fds = tail;
Ok(Argument::Fd(unsafe { OwnedFd::from_raw_fd(front) }))
if let Some(front) = fds.pop_front() {
Ok(Argument::Fd(front))
} else {
Err(MessageParseError::MissingFD)
}
Expand Down Expand Up @@ -255,7 +247,7 @@
.collect::<Result<SmallVec<_>, MessageParseError>>()?;

let msg = Message { sender_id, opcode, args: arguments };
Ok((msg, rest, fds))
Ok((msg, rest))
}

#[cfg(test)]
Expand All @@ -268,7 +260,7 @@
#[test]
fn into_from_raw_cycle() {
let mut bytes_buffer = vec![0; 1024];
let mut fd_buffer = [0; 10];
let mut fd_buffer = Vec::new();

let msg = Message {
sender_id: 42,
Expand All @@ -284,9 +276,10 @@
],
};
// write the message to the buffers
write_to_buffers(&msg, &mut bytes_buffer[..], &mut fd_buffer[..]).unwrap();
write_to_buffers(&msg, &mut bytes_buffer[..], &mut fd_buffer).unwrap();
// read them back
let (rebuilt, _, _) = parse_message(
let mut fd_buffer = VecDeque::from(fd_buffer);
let (rebuilt, _) = parse_message(
&bytes_buffer[..],
&[
ArgumentType::Uint,
Expand All @@ -297,7 +290,7 @@
ArgumentType::NewId,
ArgumentType::Int,
],
&fd_buffer[..],
&mut fd_buffer,
)
.unwrap();
assert_eq!(rebuilt.map_fd(IntoRawFd::into_raw_fd), msg);
Expand Down
Loading