Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port from nix to rustix #658

Merged
merged 1 commit into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions wayland-backend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,13 @@ features = [
"const_new", # 1.51
]

[dependencies.nix]
version = "0.26.0"
default-features = false
[dependencies.rustix]
version = "0.38.17"
features = [
"event",
"fs",
"poll",
"socket",
"uio",
"net",
"process",
]

[build-dependencies]
Expand Down
23 changes: 13 additions & 10 deletions wayland-backend/src/rs/server_impl/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{
ffi::CString,
os::unix::io::OwnedFd,
os::unix::io::{AsFd, BorrowedFd, OwnedFd},
os::unix::{io::RawFd, net::UnixStream},
sync::Arc,
};
Expand Down Expand Up @@ -295,13 +295,10 @@ impl<D> Client<D> {

#[cfg(any(target_os = "linux", target_os = "android"))]
pub(crate) fn get_credentials(&self) -> Credentials {
use std::os::unix::io::AsRawFd;
let creds = nix::sys::socket::getsockopt(
self.socket.as_raw_fd(),
nix::sys::socket::sockopt::PeerCredentials,
)
.expect("getsockopt failed!?");
Credentials { pid: creds.pid(), uid: creds.uid(), gid: creds.gid() }
let creds =
rustix::net::sockopt::get_socket_peercred(&self.socket).expect("getsockopt failed!?");
let pid = rustix::process::Pid::as_raw(Some(creds.pid));
Credentials { pid, uid: creds.uid.as_raw(), gid: creds.gid.as_raw() }
}

#[cfg(not(any(target_os = "linux", target_os = "android")))]
Expand Down Expand Up @@ -336,7 +333,7 @@ impl<D> Client<D> {
&mut self,
) -> std::io::Result<(Message<u32, OwnedFd>, Object<Data<D>>)> {
if self.killed {
return Err(nix::errno::Errno::EPIPE.into());
return Err(rustix::io::Errno::PIPE.into());
}
loop {
let map = &self.map;
Expand All @@ -358,7 +355,7 @@ impl<D> Client<D> {
}
Err(MessageParseError::Malformed) => {
self.kill(DisconnectReason::ConnectionClosed);
return Err(nix::errno::Errno::EPROTO.into());
return Err(rustix::io::Errno::PROTO.into());
}
};

Expand Down Expand Up @@ -659,6 +656,12 @@ impl<D> Client<D> {
}
}

impl<D> AsFd for Client<D> {
fn as_fd(&self) -> BorrowedFd<'_> {
self.socket.as_fd()
}
}

#[derive(Debug)]
pub(crate) struct ClientStore<D: 'static> {
clients: Vec<Option<Client<D>>>,
Expand Down
38 changes: 16 additions & 22 deletions wayland-backend/src/rs/server_impl/common_poll.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{
os::unix::io::{AsRawFd, FromRawFd},
os::unix::io::AsRawFd,
os::unix::io::{BorrowedFd, OwnedFd},
sync::{Arc, Mutex},
};
Expand All @@ -16,15 +16,15 @@ use crate::{
};

#[cfg(any(target_os = "linux", target_os = "android"))]
use nix::sys::epoll::*;
use rustix::event::epoll;

#[cfg(any(
target_os = "dragonfly",
target_os = "freebsd",
target_os = "netbsd",
target_os = "openbsd"
))]
use nix::sys::event::*;
use rustix::event::kqueue::*;
use smallvec::SmallVec;

#[derive(Debug)]
Expand All @@ -35,7 +35,7 @@ pub struct InnerBackend<D: 'static> {
impl<D> InnerBackend<D> {
pub fn new() -> Result<Self, InitError> {
#[cfg(any(target_os = "linux", target_os = "android"))]
let poll_fd = epoll_create1(EpollCreateFlags::EPOLL_CLOEXEC)
let poll_fd = epoll::create(epoll::CreateFlags::CLOEXEC)
.map_err(Into::into)
.map_err(InitError::Io)?;

Expand All @@ -47,9 +47,7 @@ impl<D> InnerBackend<D> {
))]
let poll_fd = kqueue().map_err(Into::into).map_err(InitError::Io)?;

Ok(Self {
state: Arc::new(Mutex::new(State::new(unsafe { OwnedFd::from_raw_fd(poll_fd) }))),
})
Ok(Self { state: Arc::new(Mutex::new(State::new(poll_fd))) })
}

pub fn flush(&self, client: Option<ClientId>) -> std::io::Result<()> {
Expand Down Expand Up @@ -80,18 +78,20 @@ impl<D> InnerBackend<D> {

#[cfg(any(target_os = "linux", target_os = "android"))]
pub fn dispatch_all_clients(&self, data: &mut D) -> std::io::Result<usize> {
use std::os::unix::io::AsFd;

let poll_fd = self.poll_fd();
let mut dispatched = 0;
loop {
let mut events = [EpollEvent::empty(); 32];
let nevents = epoll_wait(poll_fd.as_raw_fd(), &mut events, 0)?;
let mut events = epoll::EventVec::with_capacity(32);
epoll::wait(poll_fd.as_fd(), &mut events, 0)?;

if nevents == 0 {
if events.is_empty() {
break;
}

for event in events.iter().take(nevents) {
let id = InnerClientId::from_u64(event.data());
for event in events.iter() {
let id = InnerClientId::from_u64(event.data.u64());
// remove the cb while we call it, to gracefully handle reentrancy
if let Ok(count) = self.dispatch_events_for(data, id) {
dispatched += count;
Expand All @@ -111,19 +111,13 @@ impl<D> InnerBackend<D> {
target_os = "openbsd"
))]
pub fn dispatch_all_clients(&self, data: &mut D) -> std::io::Result<usize> {
use std::time::Duration;

let poll_fd = self.poll_fd();
let mut dispatched = 0;
loop {
let mut events = [KEvent::new(
0,
EventFilter::EVFILT_READ,
EventFlag::empty(),
FilterFlag::empty(),
0,
0,
); 32];

let nevents = kevent(poll_fd.as_raw_fd(), &[], &mut events, 0)?;
let mut events = Vec::with_capacity(32);
let nevents = unsafe { kevent(&poll_fd, &[], &mut events, Some(Duration::ZERO))? };

if nevents == 0 {
break;
Expand Down
33 changes: 17 additions & 16 deletions wayland-backend/src/rs/server_impl/handle.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use std::{
ffi::CString,
os::unix::io::OwnedFd,
os::unix::{
io::{AsRawFd, RawFd},
net::UnixStream,
},
os::unix::{io::RawFd, net::UnixStream},
sync::{Arc, Mutex, Weak},
};

Expand Down Expand Up @@ -314,15 +311,19 @@ impl<D> ErasedState for State<D> {
stream: UnixStream,
data: Arc<dyn ClientData>,
) -> std::io::Result<InnerClientId> {
let client_fd = stream.as_raw_fd();
let id = self.clients.create_client(stream, data);
let client = self.clients.get_client(id.clone()).unwrap();

// register the client to the internal epoll
#[cfg(any(target_os = "linux", target_os = "android"))]
let ret = {
use nix::sys::epoll::*;
let mut evt = EpollEvent::new(EpollFlags::EPOLLIN, id.as_u64());
epoll_ctl(self.poll_fd.as_raw_fd(), EpollOp::EpollCtlAdd, client_fd, &mut evt)
use rustix::event::epoll;
epoll::add(
&self.poll_fd,
client,
epoll::EventData::new_u64(id.as_u64()),
epoll::EventFlags::IN,
)
};

#[cfg(any(
Expand All @@ -332,17 +333,17 @@ impl<D> ErasedState for State<D> {
target_os = "openbsd"
))]
let ret = {
use nix::sys::event::*;
let evt = KEvent::new(
client_fd as usize,
EventFilter::EVFILT_READ,
EventFlag::EV_ADD | EventFlag::EV_RECEIPT,
FilterFlag::empty(),
0,
use rustix::event::kqueue::*;
use std::os::unix::io::{AsFd, AsRawFd};

let evt = Event::new(
EventFilter::Read(client.as_fd().as_raw_fd()),
EventFlags::ADD | EventFlags::RECEIPT,
id.as_u64() as isize,
);

kevent_ts(self.poll_fd.as_raw_fd(), &[evt], &mut [], None).map(|_| ())
let mut events = Vec::new();
unsafe { kevent(&self.poll_fd, &[evt], &mut events, None).map(|_| ()) }
};

match ret {
Expand Down
66 changes: 39 additions & 27 deletions wayland-backend/src/rs/socket.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
//! Wayland socket manipulation

use std::io::{ErrorKind, IoSlice, IoSliceMut, Result as IoResult};
use std::os::unix::io::{AsFd, BorrowedFd, OwnedFd};
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, IntoRawFd, OwnedFd, RawFd};
use std::os::unix::net::UnixStream;
use std::slice;

use nix::sys::socket;
use rustix::net::{
recvmsg, sendmsg, RecvAncillaryBuffer, RecvAncillaryMessage, RecvFlags, SendAncillaryBuffer,
SendAncillaryMessage, SendFlags,
};

use crate::protocol::{ArgumentType, Message};

Expand Down Expand Up @@ -35,14 +38,19 @@
/// slice should not be longer than `MAX_BYTES_OUT` otherwise the receiving
/// end may lose some data.
pub fn send_msg(&self, bytes: &[u8], fds: &[RawFd]) -> IoResult<usize> {
let flags = socket::MsgFlags::MSG_DONTWAIT | socket::MsgFlags::MSG_NOSIGNAL;
let flags = SendFlags::DONTWAIT | SendFlags::NOSIGNAL;
let iov = [IoSlice::new(bytes)];

if !fds.is_empty() {
let cmsgs = [socket::ControlMessage::ScmRights(fds)];
Ok(socket::sendmsg::<()>(self.stream.as_raw_fd(), &iov, &cmsgs, flags, None)?)
let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmRights(fds.len()))];
let mut cmsg_buffer = SendAncillaryBuffer::new(&mut cmsg_space);
let fds =
unsafe { slice::from_raw_parts(fds.as_ptr() as *const BorrowedFd, fds.len()) };
cmsg_buffer.push(SendAncillaryMessage::ScmRights(fds));
Ok(sendmsg(self, &iov, &mut cmsg_buffer, flags)?)
} else {
Ok(socket::sendmsg::<()>(self.stream.as_raw_fd(), &iov, &[], flags, None)?)
let mut cmsg_buffer = SendAncillaryBuffer::new(&mut []);
Ok(sendmsg(self, &iov, &mut cmsg_buffer, flags)?)
}
}

Expand All @@ -58,25 +66,27 @@
/// slice `MAX_FDS_OUT` long, otherwise some data of the received message may
/// be lost.
pub fn rcv_msg(&self, buffer: &mut [u8], fds: &mut [RawFd]) -> IoResult<(usize, usize)> {
let mut cmsg = nix::cmsg_space!([RawFd; MAX_FDS_OUT]);
let mut cmsg_space = vec![0; rustix::cmsg_space!(ScmRights(MAX_FDS_OUT))];
let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
let mut iov = [IoSliceMut::new(buffer)];
let msg = socket::recvmsg::<()>(
self.stream.as_raw_fd(),
let msg = recvmsg(
&self.stream,
&mut iov[..],
Some(&mut cmsg),
socket::MsgFlags::MSG_DONTWAIT
| socket::MsgFlags::MSG_CMSG_CLOEXEC
| socket::MsgFlags::MSG_NOSIGNAL,
&mut cmsg_buffer,
RecvFlags::DONTWAIT | RecvFlags::CMSG_CLOEXEC,
)?;

let mut fd_count = 0;
let received_fds = msg.cmsgs().flat_map(|cmsg| match cmsg {
socket::ControlMessageOwned::ScmRights(s) => s,
_ => Vec::new(),
});
let received_fds = cmsg_buffer
.drain()
.filter_map(|cmsg| match cmsg {
RecvAncillaryMessage::ScmRights(fds) => Some(fds),
_ => None,

Check warning on line 84 in wayland-backend/src/rs/socket.rs

View check run for this annotation

Codecov / codecov/patch

wayland-backend/src/rs/socket.rs#L84

Added line #L84 was not covered by tests
})
.flatten();
for (fd, place) in received_fds.zip(fds.iter_mut()) {
fd_count += 1;
*place = fd;
*place = fd.into_raw_fd();
}
Ok((msg.bytes, fd_count))
}
Expand Down Expand Up @@ -141,7 +151,7 @@
let written = self.socket.send_msg(bytes, fds)?;
for &fd in fds {
// once the fds are sent, we can close them
let _ = ::nix::unistd::close(fd);
unsafe { rustix::io::close(fd) };
}
written
};
Expand Down Expand Up @@ -192,7 +202,7 @@
if !self.attempt_write_message(msg)? {
// If this fails again, this means the message is too big
// to be transmitted at all
return Err(::nix::errno::Errno::E2BIG.into());
return Err(rustix::io::Errno::TOOBIG.into());

Check warning on line 205 in wayland-backend/src/rs/socket.rs

View check run for this annotation

Codecov / codecov/patch

wayland-backend/src/rs/socket.rs#L205

Added line #L205 was not covered by tests
}
}
Ok(())
Expand All @@ -215,7 +225,7 @@
};
if in_bytes == 0 {
// the other end of the socket was closed
return Err(::nix::errno::Errno::EPIPE.into());
return Err(rustix::io::Errno::PIPE.into());
}
// advance the storage
self.in_data.advance(in_bytes / 4 + usize::from(in_bytes % 4 > 0));
Expand Down Expand Up @@ -342,14 +352,14 @@
use crate::protocol::{AllowNull, Argument, ArgumentType, Message};

use std::ffi::CString;
use std::os::unix::io::RawFd;
use std::os::unix::io::BorrowedFd;
use std::os::unix::prelude::IntoRawFd;

use smallvec::smallvec;

fn same_file(a: RawFd, b: RawFd) -> bool {
let stat1 = ::nix::sys::stat::fstat(a).unwrap();
let stat2 = ::nix::sys::stat::fstat(b).unwrap();
fn same_file(a: BorrowedFd, b: BorrowedFd) -> bool {
let stat1 = rustix::fs::fstat(a).unwrap();
let stat2 = rustix::fs::fstat(b).unwrap();
stat1.st_dev == stat2.st_dev && stat1.st_ino == stat2.st_ino
}

Expand All @@ -366,7 +376,9 @@
assert_eq!(msg1.args.len(), msg2.args.len());
for (arg1, arg2) in msg1.args.iter().zip(msg2.args.iter()) {
if let (Argument::Fd(fd1), Argument::Fd(fd2)) = (arg1, arg2) {
assert!(same_file(fd1.as_raw_fd(), fd2.as_raw_fd()));
let fd1 = unsafe { BorrowedFd::borrow_raw(fd1.as_raw_fd()) };
let fd2 = unsafe { BorrowedFd::borrow_raw(fd2.as_raw_fd()) };
assert!(same_file(fd1, fd2));
} else {
assert_eq!(arg1, arg2);
}
Expand Down
2 changes: 1 addition & 1 deletion wayland-backend/src/sys/client_impl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ impl ConnectionState {
#[inline]
fn store_and_return_error(&mut self, err: std::io::Error) -> WaylandError {
// check if it was actually a protocol error
let err = if err.raw_os_error() == Some(nix::errno::Errno::EPROTO as i32) {
let err = if err.raw_os_error() == Some(rustix::io::Errno::PROTO.raw_os_error()) {
let mut object_id = 0;
let mut interface = std::ptr::null();
let code = unsafe {
Expand Down
Loading
Loading