Skip to content

Commit

Permalink
Add ENOBUFS handling for unsolicited messages
Browse files Browse the repository at this point in the history
This can happen when large burst of messages come all of a sudden, which
happen very easily when routing protocols are involved (e.g. BGP). The
current implementation incorrectly assumes that any failure to read from
the socket is akin to the socket closed. This is not the case.

This adds handling for this specific error, which translates to a
wrapper struct in the unsolicited messages stream: either a message, or
an overrun. This lets applications handle best for their usecase such
event: either resync because messages are lost, or do nothing if the
listening is informational only (e.g. logging).
  • Loading branch information
Tuetuopay committed Oct 25, 2022
1 parent 6c73fd1 commit 7ed06ec
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 37 deletions.
24 changes: 16 additions & 8 deletions examples/audit_netlink_events.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use std::process;

use netlink_proto::{
new_connection,
packet::NetlinkEvent,
sys::{protocols::NETLINK_AUDIT, SocketAddr},
};

Expand All @@ -45,10 +46,10 @@ async fn main() -> Result<(), String> {
// - `handle` is a `Handle` to the `Connection`. We use it to send netlink
// messages and receive responses to these messages.
//
// - `messages` is a channel receiver through which we receive messages that we
// - `events` is a channel receiver through which we receive messages that we
// have not sollicated, ie that are not response to a request we made. In this
// example, we'll receive the audit event through that channel.
let (conn, mut handle, mut messages) = new_connection(NETLINK_AUDIT)
let (conn, mut handle, mut events) = new_connection(NETLINK_AUDIT)
.map_err(|e| format!("Failed to create a new netlink connection: {}", e))?;

// Spawn the `Connection` so that it starts polling the netlink
Expand Down Expand Up @@ -85,13 +86,20 @@ async fn main() -> Result<(), String> {
}
});

// Finally, start receiving event through the `messages` channel.
// Finally, start receiving event through the `events` channel.
println!("Starting to print audit events... press ^C to interrupt");
while let Some((message, _addr)) = messages.next().await {
if let NetlinkPayload::Error(err_message) = message.payload {
eprintln!("received an error message: {:?}", err_message);
} else {
println!("{:?}", message);
while let Some(event) = events.next().await {
match event {
NetlinkEvent::Message((message, _addr)) => {
if let NetlinkPayload::Error(err_message) = message.payload {
eprintln!("received an error message: {:?}", err_message);
} else {
println!("{:?}", message);
}
}
NetlinkEvent::Overrun => {
println!("Netlink socket overrun. Some messages were lost");
}
}
}

Expand Down
28 changes: 17 additions & 11 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use futures::{
};
use log::{error, warn};
use netlink_packet_core::{
NetlinkDeserializable, NetlinkMessage, NetlinkPayload, NetlinkSerializable,
NetlinkDeserializable, NetlinkEvent, NetlinkMessage, NetlinkPayload, NetlinkSerializable,
};

use crate::{
Expand Down Expand Up @@ -46,7 +46,7 @@ where

/// Channel used to transmit to the ConnectionHandle the unsolicited
/// messages received from the socket (multicast messages for instance).
unsolicited_messages_tx: Option<UnboundedSender<(NetlinkMessage<T>, SocketAddr)>>,
unsolicited_messages_tx: Option<UnboundedSender<NetlinkEvent<(NetlinkMessage<T>, SocketAddr)>>>,

socket_closed: bool,
}
Expand All @@ -59,7 +59,7 @@ where
{
pub(crate) fn new(
requests_rx: UnboundedReceiver<Request<T>>,
unsolicited_messages_tx: UnboundedSender<(NetlinkMessage<T>, SocketAddr)>,
unsolicited_messages_tx: UnboundedSender<NetlinkEvent<(NetlinkMessage<T>, SocketAddr)>>,
protocol: isize,
) -> io::Result<Self> {
let socket = S::new(protocol)?;
Expand Down Expand Up @@ -125,10 +125,14 @@ where
loop {
trace!("polling socket");
match socket.as_mut().poll_next(cx) {
Poll::Ready(Some((message, addr))) => {
Poll::Ready(Some(NetlinkEvent::Message((message, addr)))) => {
trace!("read datagram from socket");
self.protocol.handle_message(message, addr);
}
Poll::Ready(Some(NetlinkEvent::Overrun)) => {
warn!("netlink socket buffer full");
self.protocol.handle_buffer_full();
}
Poll::Ready(None) => {
warn!("netlink socket stream shut down");
self.socket_closed = true;
Expand Down Expand Up @@ -159,11 +163,13 @@ where

pub fn forward_unsolicited_messages(&mut self) {
if self.unsolicited_messages_tx.is_none() {
while let Some((message, source)) = self.protocol.incoming_requests.pop_front() {
warn!(
"ignoring unsolicited message {:?} from {:?}",
message, source
);
while let Some(event) = self.protocol.incoming_requests.pop_front() {
match event {
NetlinkEvent::Message((message, source)) => {
warn!("ignoring unsolicited message {message:?} from {source:?}")
}
NetlinkEvent::Overrun => warn!("ignoring unsolicited socket overrun"),
}
}
return;
}
Expand All @@ -177,11 +183,11 @@ where
..
} = self;

while let Some((message, source)) = protocol.incoming_requests.pop_front() {
while let Some(event) = protocol.incoming_requests.pop_front() {
if unsolicited_messages_tx
.as_mut()
.unwrap()
.unbounded_send((message, source))
.unbounded_send(event)
.is_err()
{
// The channel is unbounded so the only error that can
Expand Down
30 changes: 27 additions & 3 deletions src/framed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ use crate::{
codecs::NetlinkMessageCodec,
sys::{AsyncSocket, SocketAddr},
};
use netlink_packet_core::{NetlinkDeserializable, NetlinkMessage, NetlinkSerializable};
use netlink_packet_core::{
NetlinkDeserializable, NetlinkEvent, NetlinkMessage, NetlinkSerializable,
};

/// Buffer overrun condition
const ENOBUFS: i32 = 105;

pub struct NetlinkFramed<T, S, C> {
socket: S,
Expand All @@ -38,7 +43,7 @@ where
S: AsyncSocket,
C: NetlinkMessageCodec,
{
type Item = (NetlinkMessage<T>, SocketAddr);
type Item = NetlinkEvent<(NetlinkMessage<T>, SocketAddr)>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let Self {
Expand All @@ -50,7 +55,9 @@ where

loop {
match C::decode::<T>(reader) {
Ok(Some(item)) => return Poll::Ready(Some((item, *in_addr))),
Ok(Some(item)) => {
return Poll::Ready(Some(NetlinkEvent::Message((item, *in_addr))))
}
Ok(None) => {}
Err(e) => {
error!("unrecoverable error in decoder: {:?}", e);
Expand All @@ -63,6 +70,23 @@ where

*in_addr = match ready!(socket.poll_recv_from(cx, reader)) {
Ok(addr) => addr,
// When receiving messages in multicast mode (i.e. we subscribed to
// notifications), the kernel will not wait for us to read datagrams before
// sending more. The receive buffer has a finite size, so once it is full (no
// more message can fit in), new messages will be dropped and recv calls will
// return `ENOBUFS`.
// This needs to be handled for applications to resynchronize with the contents
// of the kernel if necessary.
// We don't need to do anything special:
// - contents of the reader is still valid because we won't have partial messages
// in there anyways (large enough buffer)
// - contents of the socket's internal buffer is still valid because the kernel
// won't put partial data in it
Err(e) if e.raw_os_error() == Some(ENOBUFS) => {
// ENOBUFS
warn!("netlink socket buffer full");
return Poll::Ready(Some(NetlinkEvent::Overrun));
}
Err(e) => {
error!("failed to read from netlink socket: {:?}", e);
return Poll::Ready(None);
Expand Down
43 changes: 31 additions & 12 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
//! use futures::stream::StreamExt;
//! use netlink_packet_audit::{
//! AuditMessage,
//! NetlinkEvent,
//! NetlinkMessage,
//! NetlinkPayload,
//! StatusMessage,
Expand Down Expand Up @@ -44,11 +45,11 @@
//! // - `handle` is a `Handle` to the `Connection`. We use it to send
//! // netlink messages and receive responses to these messages.
//! //
//! // - `messages` is a channel receiver through which we receive
//! // - `events` is a channel receiver through which we receive
//! // messages that we have not solicited, ie that are not
//! // response to a request we made. In this example, we'll receive
//! // the audit event through that channel.
//! let (conn, mut handle, mut messages) = new_connection(NETLINK_AUDIT)
//! let (conn, mut handle, mut events) = new_connection(NETLINK_AUDIT)
//! .map_err(|e| format!("Failed to create a new netlink connection: {}", e))?;
//!
//! // Spawn the `Connection` so that it starts polling the netlink
Expand Down Expand Up @@ -85,13 +86,23 @@
//! }
//! });
//!
//! // Finally, start receiving event through the `messages` channel.
//! // Finally, start receiving event through the `events` channel.
//! println!("Starting to print audit events... press ^C to interrupt");
//! while let Some((message, _addr)) = messages.next().await {
//! if let NetlinkPayload::Error(err_message) = message.payload {
//! eprintln!("received an error message: {:?}", err_message);
//! } else {
//! println!("{:?}", message);
//! while let Some(event) = events.next().await {
//! match event {
//! NetlinkEvent::Message((message, _addr)) => {
//! if let NetlinkPayload::Error(err_message) = message.payload {
//! eprintln!("received an error message: {:?}", err_message);
//! } else {
//! println!("{:?}", message);
//! }
//! }
//! // Netlink sockets have a finite receive buffer that can fill up if there are more
//! // messages sent by the kernel than we can read.
//! // In this case at least one message has been lost.
//! NetlinkEvent::Overrun => {
//! println!("Netlink socket overrun. Some messages were lost");
//! }
//! }
//! }
//!
Expand Down Expand Up @@ -229,7 +240,9 @@ pub fn new_connection<T>(
) -> io::Result<(
Connection<T>,
ConnectionHandle<T>,
UnboundedReceiver<(packet::NetlinkMessage<T>, sys::SocketAddr)>,
UnboundedReceiver<
packet::NetlinkEvent<(packet::NetlinkMessage<T>, sys::SocketAddr)>,
>,
)>
where
T: Debug + packet::NetlinkSerializable + packet::NetlinkDeserializable + Unpin,
Expand All @@ -245,7 +258,9 @@ pub fn new_connection_with_socket<T, S>(
) -> io::Result<(
Connection<T, S>,
ConnectionHandle<T>,
UnboundedReceiver<(packet::NetlinkMessage<T>, sys::SocketAddr)>,
UnboundedReceiver<
packet::NetlinkEvent<(packet::NetlinkMessage<T>, sys::SocketAddr)>,
>,
)>
where
T: Debug + packet::NetlinkSerializable + packet::NetlinkDeserializable + Unpin,
Expand All @@ -262,15 +277,19 @@ pub fn new_connection_with_codec<T, S, C>(
) -> io::Result<(
Connection<T, S, C>,
ConnectionHandle<T>,
UnboundedReceiver<(packet::NetlinkMessage<T>, sys::SocketAddr)>,
UnboundedReceiver<
packet::NetlinkEvent<(packet::NetlinkMessage<T>, sys::SocketAddr)>,
>,
)>
where
T: Debug + packet::NetlinkSerializable + packet::NetlinkDeserializable + Unpin,
S: sys::AsyncSocket,
C: NetlinkMessageCodec,
{
let (requests_tx, requests_rx) = unbounded::<Request<T>>();
let (messages_tx, messages_rx) = unbounded::<(packet::NetlinkMessage<T>, sys::SocketAddr)>();
let (messages_tx, messages_rx) = unbounded::<
packet::NetlinkEvent<(packet::NetlinkMessage<T>, sys::SocketAddr)>,
>();
Ok((
Connection::new(requests_rx, messages_tx, protocol)?,
ConnectionHandle::new(requests_tx),
Expand Down
13 changes: 10 additions & 3 deletions src/protocol/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ use std::{
};

use netlink_packet_core::{
constants::*, NetlinkDeserializable, NetlinkMessage, NetlinkPayload, NetlinkSerializable,
constants::*, NetlinkDeserializable, NetlinkEvent, NetlinkMessage,
NetlinkPayload, NetlinkSerializable,
};

use super::Request;
Expand Down Expand Up @@ -53,7 +54,8 @@ pub(crate) struct Protocol<T, M> {
pub incoming_responses: VecDeque<Response<T, M>>,

/// Requests from remote peers
pub incoming_requests: VecDeque<(NetlinkMessage<T>, SocketAddr)>,
pub incoming_requests:
VecDeque<NetlinkEvent<(NetlinkMessage<T>, SocketAddr)>>,

/// The messages to be sent out
pub outgoing_messages: VecDeque<(NetlinkMessage<T>, SocketAddr)>,
Expand All @@ -80,10 +82,15 @@ where
if let hash_map::Entry::Occupied(entry) = self.pending_requests.entry(request_id) {
Self::handle_response(&mut self.incoming_responses, entry, message);
} else {
self.incoming_requests.push_back((message, source));
self.incoming_requests
.push_back(NetlinkEvent::Message((message, source)));
}
}

pub fn handle_buffer_full(&mut self) {
self.incoming_requests.push_back(NetlinkEvent::Overrun);
}

fn handle_response(
incoming_responses: &mut VecDeque<Response<T, M>>,
entry: hash_map::OccupiedEntry<RequestId, PendingRequest<M>>,
Expand Down

0 comments on commit 7ed06ec

Please sign in to comment.