diff --git a/examples/tcp_client.rs b/examples/tcp_client.rs index 413f6b6..d887e57 100644 --- a/examples/tcp_client.rs +++ b/examples/tcp_client.rs @@ -6,9 +6,7 @@ use socket2::Type; fn main() -> io::Result<()> { let socket = socket2::Socket::new(socket2::Domain::IPV4, Type::STREAM, None)?; let poller = polling::Poller::new()?; - unsafe { - poller.add(&socket, Event::new(0, true, true))?; - } + poller.add(&socket, Event::new(0, true, true))?; let addr = net::SocketAddr::new(net::Ipv4Addr::LOCALHOST.into(), 8080); socket.set_nonblocking(true)?; let _ = socket.connect(&addr.into()); diff --git a/examples/two-listeners.rs b/examples/two-listeners.rs index bf54eee..510ad84 100644 --- a/examples/two-listeners.rs +++ b/examples/two-listeners.rs @@ -10,10 +10,8 @@ fn main() -> io::Result<()> { l2.set_nonblocking(true)?; let poller = Poller::new()?; - unsafe { - poller.add(&l1, Event::readable(1))?; - poller.add(&l2, Event::readable(2))?; - } + poller.add(&l1, Event::readable(1))?; + poller.add(&l2, Event::readable(2))?; println!("You can connect to the server using `nc`:"); println!(" $ nc 127.0.0.1 8001"); diff --git a/src/kqueue.rs b/src/kqueue.rs index 3e0b044..58ae6e2 100644 --- a/src/kqueue.rs +++ b/src/kqueue.rs @@ -84,11 +84,17 @@ impl Poller { /// # Safety /// /// The file descriptor must be valid and it must last until it is deleted. - pub unsafe fn add(&self, fd: RawFd, ev: Event, mode: PollMode) -> io::Result<()> { - self.add_source(SourceId::Fd(fd))?; + pub fn add(&self, fd: BorrowedFd<'_>, ev: Event, mode: PollMode) -> io::Result<()> { + let rawfd = fd.as_raw_fd(); + + // SAFETY: `rawfd` is valid as it is from `BorrowedFd`. And + // this block never closes / deletes `rawfd`. + unsafe { + self.add_source(SourceId::Fd(rawfd))?; + } // File descriptors don't need to be added explicitly, so just modify the interest. - self.modify(BorrowedFd::borrow_raw(fd), ev, mode) + self.modify(fd, ev, mode) } /// Modifies an existing file descriptor. diff --git a/src/lib.rs b/src/lib.rs index 1a2e449..d2ca3d7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -507,7 +507,7 @@ impl Poller { /// poller.delete(&source)?; /// # std::io::Result::Ok(()) /// ``` - pub unsafe fn add(&self, source: impl AsRawSource, interest: Event) -> io::Result<()> { + pub fn add(&self, source: impl AsSource, interest: Event) -> io::Result<()> { self.add_with_mode(source, interest, PollMode::Oneshot) } @@ -526,9 +526,9 @@ impl Poller { /// /// If the operating system does not support the specified mode, this function /// will return an error. - pub unsafe fn add_with_mode( + pub fn add_with_mode( &self, - source: impl AsRawSource, + source: impl AsSource, interest: Event, mode: PollMode, ) -> io::Result<()> { @@ -538,7 +538,7 @@ impl Poller { "the key is not allowed to be `usize::MAX`", )); } - self.poller.add(source.raw(), interest, mode) + self.poller.add(source.as_fd(), interest, mode) } /// Modifies the interest in a file descriptor or socket. diff --git a/tests/concurrent_modification.rs b/tests/concurrent_modification.rs index ab3e5fb..c4b1965 100644 --- a/tests/concurrent_modification.rs +++ b/tests/concurrent_modification.rs @@ -20,9 +20,7 @@ fn concurrent_add() -> io::Result<()> { }) .add(|| { thread::sleep(Duration::from_millis(100)); - unsafe { - poller.add(&reader, Event::readable(0))?; - } + poller.add(&reader, Event::readable(0))?; writer.write_all(&[1])?; Ok(()) }) @@ -46,9 +44,7 @@ fn concurrent_add() -> io::Result<()> { fn concurrent_modify() -> io::Result<()> { let (reader, mut writer) = tcp_pair()?; let poller = Poller::new()?; - unsafe { - poller.add(&reader, Event::none(0))?; - } + poller.add(&reader, Event::none(0))?; let mut events = Events::new(); @@ -84,9 +80,7 @@ fn concurrent_interruption() -> io::Result<()> { let (reader, _writer) = tcp_pair()?; let poller = Poller::new()?; - unsafe { - poller.add(&reader, Event::none(0))?; - } + poller.add(&reader, Event::none(0))?; let mut events = Events::new(); let events_borrow = &mut events; diff --git a/tests/io.rs b/tests/io.rs index dc42103..70445e9 100644 --- a/tests/io.rs +++ b/tests/io.rs @@ -8,9 +8,7 @@ use std::time::Duration; fn basic_io() { let poller = Poller::new().unwrap(); let (read, mut write) = tcp_pair().unwrap(); - unsafe { - poller.add(&read, Event::readable(1)).unwrap(); - } + poller.add(&read, Event::readable(1)).unwrap(); // Nothing should be available at first. let mut events = Events::new(); @@ -42,7 +40,7 @@ fn basic_io() { #[test] fn insert_twice() { #[cfg(unix)] - use std::os::unix::io::AsRawFd; + use std::os::unix::io::AsFd; #[cfg(windows)] use std::os::windows::io::AsRawSocket; @@ -50,18 +48,16 @@ fn insert_twice() { let read = Arc::new(read); let poller = Poller::new().unwrap(); - unsafe { - #[cfg(unix)] - let read = read.as_raw_fd(); - #[cfg(windows)] - let read = read.as_raw_socket(); + #[cfg(unix)] + let read = read.as_fd(); + #[cfg(windows)] + let read = read.as_raw_socket(); - poller.add(read, Event::readable(1)).unwrap(); - assert_eq!( - poller.add(read, Event::readable(1)).unwrap_err().kind(), - io::ErrorKind::AlreadyExists - ); - } + poller.add(read, Event::readable(1)).unwrap(); + assert_eq!( + poller.add(read, Event::readable(1)).unwrap_err().kind(), + io::ErrorKind::AlreadyExists + ); write.write_all(&[1]).unwrap(); let mut events = Events::new(); diff --git a/tests/many_connections.rs b/tests/many_connections.rs index 6a74c9e..3bc3f8c 100644 --- a/tests/many_connections.rs +++ b/tests/many_connections.rs @@ -22,9 +22,7 @@ fn many_connections() { let poller = polling::Poller::new().unwrap(); for (i, reader, _) in connections.iter() { - unsafe { - poller.add(reader, polling::Event::readable(*i)).unwrap(); - } + poller.add(reader, polling::Event::readable(*i)).unwrap(); } let mut events = Events::new(); diff --git a/tests/multiple_pollers.rs b/tests/multiple_pollers.rs index 18f0efd..1b84470 100644 --- a/tests/multiple_pollers.rs +++ b/tests/multiple_pollers.rs @@ -18,14 +18,12 @@ fn level_triggered() { // Register the source into both pollers. let (mut reader, mut writer) = tcp_pair().unwrap(); - unsafe { - poller1 - .add_with_mode(&reader, Event::readable(1), PollMode::Level) - .unwrap(); - poller2 - .add_with_mode(&reader, Event::readable(2), PollMode::Level) - .unwrap(); - } + poller1 + .add_with_mode(&reader, Event::readable(1), PollMode::Level) + .unwrap(); + poller2 + .add_with_mode(&reader, Event::readable(2), PollMode::Level) + .unwrap(); // Neither poller should have any events. assert_eq!( @@ -139,14 +137,12 @@ fn edge_triggered() { // Register the source into both pollers. let (mut reader, mut writer) = tcp_pair().unwrap(); - unsafe { - poller1 - .add_with_mode(&reader, Event::readable(1), PollMode::Edge) - .unwrap(); - poller2 - .add_with_mode(&reader, Event::readable(2), PollMode::Edge) - .unwrap(); - } + poller1 + .add_with_mode(&reader, Event::readable(1), PollMode::Edge) + .unwrap(); + poller2 + .add_with_mode(&reader, Event::readable(2), PollMode::Edge) + .unwrap(); // Neither poller should have any events. assert_eq!( @@ -256,14 +252,12 @@ fn oneshot_triggered() { // Register the source into both pollers. let (mut reader, mut writer) = tcp_pair().unwrap(); - unsafe { - poller1 - .add_with_mode(&reader, Event::readable(1), PollMode::Oneshot) - .unwrap(); - poller2 - .add_with_mode(&reader, Event::readable(2), PollMode::Oneshot) - .unwrap(); - } + poller1 + .add_with_mode(&reader, Event::readable(1), PollMode::Oneshot) + .unwrap(); + poller2 + .add_with_mode(&reader, Event::readable(2), PollMode::Oneshot) + .unwrap(); // Neither poller should have any events. assert_eq!( diff --git a/tests/other_modes.rs b/tests/other_modes.rs index 407e42b..bbe3d61 100644 --- a/tests/other_modes.rs +++ b/tests/other_modes.rs @@ -16,7 +16,8 @@ fn level_triggered() { // Create our poller and register our streams. let poller = Poller::new().unwrap(); - if unsafe { poller.add_with_mode(&reader, Event::readable(reader_token), PollMode::Level) } + if poller + .add_with_mode(&reader, Event::readable(reader_token), PollMode::Level) .is_err() { // Only panic if we're on a platform that should support level mode. @@ -104,7 +105,8 @@ fn edge_triggered() { // Create our poller and register our streams. let poller = Poller::new().unwrap(); - if unsafe { poller.add_with_mode(&reader, Event::readable(reader_token), PollMode::Edge) } + if poller + .add_with_mode(&reader, Event::readable(reader_token), PollMode::Edge) .is_err() { // Only panic if we're on a platform that should support level mode. @@ -194,14 +196,13 @@ fn edge_oneshot_triggered() { // Create our poller and register our streams. let poller = Poller::new().unwrap(); - if unsafe { - poller.add_with_mode( + if poller + .add_with_mode( &reader, Event::readable(reader_token), PollMode::EdgeOneshot, ) - } - .is_err() + .is_err() { // Only panic if we're on a platform that should support level mode. cfg_if::cfg_if! {