Skip to content

Commit

Permalink
Merge pull request hermit-os#1298 from cagatay-y/unexpose-transfertokens
Browse files Browse the repository at this point in the history
 virtq: don't expose TransferTokens to the drivers
  • Loading branch information
mkroening authored Jul 5, 2024
2 parents 97354d3 + 2f8dbfa commit 7da9130
Show file tree
Hide file tree
Showing 5 changed files with 687 additions and 903 deletions.
11 changes: 5 additions & 6 deletions src/drivers/fs/virtio_fs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ use crate::drivers::virtio::transport::mmio::{ComCfg, IsrStatus, NotifCfg};
use crate::drivers::virtio::transport::pci::{ComCfg, IsrStatus, NotifCfg};
use crate::drivers::virtio::virtqueue::error::VirtqError;
use crate::drivers::virtio::virtqueue::split::SplitVq;
use crate::drivers::virtio::virtqueue::{AsSliceU8, BufferType, Virtq, VqIndex, VqSize};
use crate::drivers::virtio::virtqueue::{
AsSliceU8, BufferToken, BufferType, Virtq, VqIndex, VqSize,
};
use crate::fs::fuse::{self, FuseInterface, Rsp, RspHeader};

/// A wrapper struct for the raw configuration structure.
Expand Down Expand Up @@ -176,11 +178,8 @@ impl FuseInterface for VirtioFsDriver {
]
};

let transfer_tkn = self.vqueues[1]
.clone()
.prep_transfer_from_raw(send, recv, BufferType::Direct)
.unwrap();
transfer_tkn.dispatch_blocking()?;
let buffer_tkn = BufferToken::from_existing(send, recv).unwrap();
self.vqueues[1].dispatch_blocking(buffer_tkn, BufferType::Direct)?;
Ok(unsafe {
Rsp {
headers: rsp_headers.assume_init(),
Expand Down
159 changes: 90 additions & 69 deletions src/drivers/net/virtio/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ cfg_if::cfg_if! {
}
}

use alloc::boxed::Box;
use alloc::rc::Rc;
use alloc::vec::Vec;
use core::cmp::Ordering;
Expand Down Expand Up @@ -61,8 +60,8 @@ impl CtrlQueue {

pub struct RxQueues {
vqs: Vec<Rc<dyn Virtq>>,
poll_sender: async_channel::Sender<Box<BufferToken>>,
poll_receiver: async_channel::Receiver<Box<BufferToken>>,
poll_sender: async_channel::Sender<BufferToken>,
poll_receiver: async_channel::Receiver<BufferToken>,
is_multi: bool,
}

Expand All @@ -77,12 +76,12 @@ impl RxQueues {
}
}

/// Takes care if handling packets correctly which need some processing after being received.
/// This currently include nothing. But in the future it might include among others::
/// Takes care of handling packets correctly which need some processing after being received.
/// This currently include nothing. But in the future it might include among others:
/// * Calculating missing checksums
/// * Merging receive buffers, by simply checking the poll_queue (if VIRTIO_NET_F_MRG_BUF)
fn post_processing(buffer_tkn: Box<BufferToken>) -> Result<Box<BufferToken>, VirtioNetError> {
Ok(buffer_tkn)
fn post_processing(_buffer_tkn: &mut BufferToken) -> Result<(), VirtioNetError> {
Ok(())
}

/// Adds a given queue to the underlying vector and populates the queue with RecvBuffers.
Expand All @@ -102,7 +101,7 @@ impl RxQueues {
//
let spec = BuffSpec::Single(Bytes::new(rx_size).unwrap());
for _ in 0..num_buff {
let buff_tkn = match vq.clone().prep_buffer(None, Some(spec.clone())) {
let buff_tkn = match BufferToken::new(None, Some(spec.clone())) {
Ok(tkn) => tkn,
Err(_vq_err) => {
error!("Setup of network queue failed, which should not happen!");
Expand All @@ -113,10 +112,12 @@ impl RxQueues {
// BufferTokens are directly provided to the queue
// TransferTokens are directly dispatched
// Transfers will be awaited at the queue
match buff_tkn
.provide(BufferType::Direct)
.dispatch_await(self.poll_sender.clone(), false)
{
match vq.dispatch_await(
buff_tkn,
self.poll_sender.clone(),
false,
BufferType::Direct,
) {
Ok(_) => (),
Err(_) => {
error!("Descriptor IDs were exhausted earlier than expected.");
Expand All @@ -133,7 +134,7 @@ impl RxQueues {
}
}

fn get_next(&mut self) -> Option<Box<BufferToken>> {
fn get_next(&mut self) -> Option<BufferToken> {
let transfer = self.poll_receiver.try_recv();

transfer
Expand Down Expand Up @@ -181,8 +182,8 @@ impl RxQueues {
/// to the respective queue structures.
pub struct TxQueues {
vqs: Vec<Rc<dyn Virtq>>,
poll_sender: async_channel::Sender<Box<BufferToken>>,
poll_receiver: async_channel::Receiver<Box<BufferToken>>,
poll_sender: async_channel::Sender<BufferToken>,
poll_receiver: async_channel::Receiver<BufferToken>,
ready_queue: Vec<BufferToken>,
/// Indicates, whether the Driver/Device are using multiple
/// queues for communication.
Expand Down Expand Up @@ -253,13 +254,11 @@ impl TxQueues {
let num_buff: u16 = vq.size().into();

for _ in 0..num_buff {
self.ready_queue.push(
vq.clone()
.prep_buffer(Some(spec.clone()), None)
.unwrap()
.write_seq(Some(&Hdr::default()), None::<&Hdr>)
.unwrap(),
)
let mut buffer_tkn = BufferToken::new(Some(spec.clone()), None).unwrap();
buffer_tkn
.write_seq(Some(&Hdr::default()), None::<&Hdr>)
.unwrap();
self.ready_queue.push(buffer_tkn)
}
} else {
// Virtio specification v1.1. - 5.1.6.2 point 5.
Expand All @@ -275,13 +274,11 @@ impl TxQueues {
let num_buff: u16 = vq.size().into();

for _ in 0..num_buff {
self.ready_queue.push(
vq.clone()
.prep_buffer(Some(spec.clone()), None)
.unwrap()
.write_seq(Some(&Hdr::default()), None::<&Hdr>)
.unwrap(),
)
let mut buffer_tkn = BufferToken::new(Some(spec.clone()), None).unwrap();
buffer_tkn
.write_seq(Some(&Hdr::default()), None::<&Hdr>)
.unwrap();
self.ready_queue.push(buffer_tkn)
}
}
} else {
Expand Down Expand Up @@ -317,24 +314,24 @@ impl TxQueues {
self.poll();
}

while let Ok(buffer_token) = self.poll_receiver.try_recv() {
let mut tkn = buffer_token.reset();
let (send_len, _) = tkn.len();
while let Ok(mut buffer_token) = self.poll_receiver.try_recv() {
buffer_token.reset();
let (send_len, _) = buffer_token.len();

match send_len.cmp(&len) {
Ordering::Less => {}
Ordering::Equal => return Some((tkn, 0)),
Ordering::Equal => return Some((buffer_token, 0)),
Ordering::Greater => {
tkn.restr_size(Some(len), None).unwrap();
return Some((tkn, 0));
buffer_token.restr_size(Some(len), None).unwrap();
return Some((buffer_token, 0));
}
}
}

// As usize is currently safe as the minimal usize is defined as 16bit in rust.
let spec = BuffSpec::Single(Bytes::new(len).unwrap());

match self.vqs[0].clone().prep_buffer(Some(spec), None) {
match BufferToken::new(Some(spec), None) {
Ok(tkn) => Some((tkn, 0)),
Err(_) => {
// Here it is possible if multiple queues are enabled to get another buffertoken from them!
Expand Down Expand Up @@ -457,9 +454,13 @@ impl NetworkDriver for VirtioNetDriver {
.into();
}

buff_tkn
.provide(BufferType::Direct)
.dispatch_await(self.send_vqs.poll_sender.clone(), false)
self.send_vqs.vqs[0]
.dispatch_await(
buff_tkn,
self.send_vqs.poll_sender.clone(),
false,
BufferType::Direct,
)
.unwrap();

result
Expand All @@ -470,16 +471,16 @@ impl NetworkDriver for VirtioNetDriver {

fn receive_packet(&mut self) -> Option<(RxToken, TxToken)> {
match self.recv_vqs.get_next() {
Some(transfer) => {
let transfer = match RxQueues::post_processing(transfer) {
Some(mut buffer_tkn) => {
match RxQueues::post_processing(&mut buffer_tkn) {
Ok(trf) => trf,
Err(vnet_err) => {
warn!("Post processing failed. Err: {:?}", vnet_err);
return None;
}
};

let (_, recv_data_opt) = transfer.as_slices().unwrap();
let (_, recv_data_opt) = buffer_tkn.as_slices().unwrap();
let mut recv_data = recv_data_opt.unwrap();

// If the given length isn't 1, we currently fail.
Expand All @@ -491,10 +492,15 @@ impl NetworkDriver for VirtioNetDriver {

// drop packets with invalid packet size
if packet.len() < HEADER_SIZE {
transfer
.reset()
.provide(BufferType::Direct)
.dispatch_await(self.recv_vqs.poll_sender.clone(), false)
buffer_tkn.reset();

self.recv_vqs.vqs[0]
.dispatch_await(
buffer_tkn,
self.recv_vqs.poll_sender.clone(),
false,
BufferType::Direct,
)
.unwrap();

return None;
Expand All @@ -509,45 +515,60 @@ impl NetworkDriver for VirtioNetDriver {
let num_buffers = header.num_buffers;

vec_data.extend_from_slice(&packet[mem::size_of::<Hdr>()..]);
transfer
.reset()
.provide(BufferType::Direct)
.dispatch_await(self.recv_vqs.poll_sender.clone(), false)
buffer_tkn.reset();
self.recv_vqs.vqs[0]
.dispatch_await(
buffer_tkn,
self.recv_vqs.poll_sender.clone(),
false,
BufferType::Direct,
)
.unwrap();

num_buffers
};

for _ in 1..num_buffers.to_ne() {
let transfer =
match RxQueues::post_processing(self.recv_vqs.get_next().unwrap()) {
Ok(trf) => trf,
Err(vnet_err) => {
warn!("Post processing failed. Err: {:?}", vnet_err);
return None;
}
};

let (_, recv_data_opt) = transfer.as_slices().unwrap();
let mut buffer_tkn = self.recv_vqs.get_next().unwrap();
match RxQueues::post_processing(&mut buffer_tkn) {
Ok(trf) => trf,
Err(vnet_err) => {
warn!("Post processing failed. Err: {:?}", vnet_err);
return None;
}
};

let (_, recv_data_opt) = buffer_tkn.as_slices().unwrap();
let mut recv_data = recv_data_opt.unwrap();
let packet = recv_data.pop().unwrap();
vec_data.extend_from_slice(packet);
transfer
.reset()
.provide(BufferType::Direct)
.dispatch_await(self.recv_vqs.poll_sender.clone(), false)
buffer_tkn.reset();

self.recv_vqs.vqs[0]
.dispatch_await(
buffer_tkn,
self.recv_vqs.poll_sender.clone(),
false,
BufferType::Direct,
)
.unwrap();
}

Some((RxToken::new(vec_data), TxToken::new()))
} else {
error!("Empty transfer, or with wrong buffer layout. Reusing and returning error to user-space network driver...");
transfer
.reset()
buffer_tkn.reset();
buffer_tkn
.write_seq(None::<&Hdr>, Some(&Hdr::default()))
.unwrap()
.provide(BufferType::Direct)
.dispatch_await(self.recv_vqs.poll_sender.clone(), false)
.unwrap();

self.recv_vqs.vqs[0]
.dispatch_await(
buffer_tkn,
self.recv_vqs.poll_sender.clone(),
false,
BufferType::Direct,
)
.unwrap();

None
Expand Down
Loading

0 comments on commit 7da9130

Please sign in to comment.