Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: apply max payload option when encoding (fix #113) #114

Merged
merged 6 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions engineioxide/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ mod body;
mod engine;
mod futures;
mod packet;
mod peekable;
mod transport;
87 changes: 86 additions & 1 deletion engineioxide/src/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,39 @@ impl Packet {
_ => panic!("Packet is not a binary"),
}
}

/// Get the max size the packet could have when serialized
///
/// If b64 is true, it returns the max size when serialized to base64
///
/// The base64 max size factor is `ceil(n / 3) * 4`
pub(crate) fn get_size_hint(&self, b64: bool) -> usize {
match self {
Packet::Open(_) => 151, // max possible size for the open packet serialized
Packet::Close => 1,
Packet::Ping => 1,
Packet::Pong => 1,
Packet::PingUpgrade => 6,
Packet::PongUpgrade => 6,
Packet::Message(msg) => 1 + msg.len(),
Packet::Upgrade => 1,
Packet::Noop => 1,
Packet::Binary(data) => {
if b64 {
1 + ((data.len() as f64) / 3.).ceil() as usize * 4
} else {
1 + data.len()
}
}
Packet::BinaryV3(data) => {
if b64 {
2 + ((data.len() as f64) / 3.).ceil() as usize * 4
} else {
1 + data.len()
}
}
}
}
}

/// Serialize a [Packet] to a [String] according to the Engine.IO protocol
Expand Down Expand Up @@ -179,7 +212,7 @@ mod tests {
use crate::config::EngineIoConfig;

use super::*;
use std::convert::TryInto;
use std::{convert::TryInto, time::Duration};

#[test]
fn test_open_packet() {
Expand Down Expand Up @@ -249,4 +282,56 @@ mod tests {
let packet: Packet = packet_str.try_into().unwrap();
assert_eq!(packet, Packet::BinaryV3(vec![1, 2, 3]));
}

#[test]
fn test_packet_get_size_hint() {
// Max serialized packet
let open = OpenPacket::new(
TransportType::Polling,
Sid::MAX,
&EngineIoConfig {
max_buffer_size: usize::MAX,
max_payload: u64::MAX,
ping_interval: Duration::MAX,
ping_timeout: Duration::MAX,
transports: TransportType::Polling as u8 | TransportType::Websocket as u8,
..Default::default()
},
);
let size = serde_json::to_string(&open).unwrap().len();
let packet = Packet::Open(open);
assert_eq!(packet.get_size_hint(false), size);

let packet = Packet::Close;
assert_eq!(packet.get_size_hint(false), 1);

let packet = Packet::Ping;
assert_eq!(packet.get_size_hint(false), 1);

let packet = Packet::Pong;
assert_eq!(packet.get_size_hint(false), 1);

let packet = Packet::PingUpgrade;
assert_eq!(packet.get_size_hint(false), 6);

let packet = Packet::PongUpgrade;
assert_eq!(packet.get_size_hint(false), 6);

let packet = Packet::Message("hello".to_string());
assert_eq!(packet.get_size_hint(false), 6);

let packet = Packet::Upgrade;
assert_eq!(packet.get_size_hint(false), 1);

let packet = Packet::Noop;
assert_eq!(packet.get_size_hint(false), 1);

let packet = Packet::Binary(vec![1, 2, 3]);
assert_eq!(packet.get_size_hint(false), 4);
assert_eq!(packet.get_size_hint(true), 5);

let packet = Packet::BinaryV3(vec![1, 2, 3]);
assert_eq!(packet.get_size_hint(false), 4);
assert_eq!(packet.get_size_hint(true), 6);
}
}
79 changes: 79 additions & 0 deletions engineioxide/src/peekable.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
use tokio::sync::mpsc::{error::TryRecvError, Receiver};

/// Peekable receiver for polling transport
/// It is a thin wrapper around a [`Receiver`](tokio::sync::mpsc::Receiver) that allows to peek the next packet without consuming it
///
/// Its main goal is to be able to peek the next packet without consuming it to calculate the
/// packet length when using polling transport to check if it fits according to the max_payload setting
#[derive(Debug)]
pub struct PeekableReceiver<T> {
rx: Receiver<T>,
next: Option<T>,
}
impl<T> PeekableReceiver<T> {
pub fn new(rx: Receiver<T>) -> Self {
Self { rx, next: None }
}
pub fn peek(&mut self) -> Option<&T> {
if self.next.is_none() {
self.next = self.rx.try_recv().ok();
}
self.next.as_ref()
}
pub async fn recv(&mut self) -> Option<T> {
if self.next.is_none() {
self.rx.recv().await
} else {
self.next.take()
}
}
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
if self.next.is_none() {
self.rx.try_recv()
} else {
Ok(self.next.take().unwrap())
}
}

pub fn close(&mut self) {
self.rx.close()
}
}

#[cfg(test)]
mod tests {
use tokio::sync::Mutex;

#[tokio::test]
async fn peek() {
use super::PeekableReceiver;
use crate::packet::Packet;
use tokio::sync::mpsc::channel;

let (tx, rx) = channel(1);
let rx = Mutex::new(PeekableReceiver::new(rx));
let mut rx = rx.lock().await;

assert!(rx.peek().is_none());

tx.send(Packet::Ping).await.unwrap();
assert_eq!(rx.peek(), Some(&Packet::Ping));
assert_eq!(rx.recv().await, Some(Packet::Ping));
assert!(rx.peek().is_none());

tx.send(Packet::Pong).await.unwrap();
assert_eq!(rx.peek(), Some(&Packet::Pong));
assert_eq!(rx.recv().await, Some(Packet::Pong));
assert!(rx.peek().is_none());

tx.send(Packet::Close).await.unwrap();
assert_eq!(rx.peek(), Some(&Packet::Close));
assert_eq!(rx.recv().await, Some(Packet::Close));
assert!(rx.peek().is_none());

tx.send(Packet::Close).await.unwrap();
assert_eq!(rx.peek(), Some(&Packet::Close));
assert_eq!(rx.recv().await, Some(Packet::Close));
assert!(rx.peek().is_none());
}
}
11 changes: 7 additions & 4 deletions engineioxide/src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ use tokio::{
use tokio_tungstenite::tungstenite;
use tracing::debug;

use crate::{config::EngineIoConfig, errors::Error, packet::Packet, service::ProtocolVersion};
use crate::{
config::EngineIoConfig, errors::Error, packet::Packet, peekable::PeekableReceiver,
service::ProtocolVersion,
};
use crate::{sid_generator::Sid, transport::TransportType};

/// Http Request data used to create a socket
Expand Down Expand Up @@ -118,7 +121,7 @@ where
/// * From the [encoder](crate::service::encoder) if the transport is polling
/// * From the fn [`on_ws_req_init`](crate::engine::EngineIo) if the transport is websocket
/// * Automatically via the [`close_session fn`](crate::engine::EngineIo::close_session) as a fallback. Because with polling transport, if the client is not currently polling then the encoder will never be able to close the channel
pub(crate) internal_rx: Mutex<Receiver<Packet>>,
pub(crate) internal_rx: Mutex<PeekableReceiver<Packet>>,

/// Channel to send [Packet] to the internal connection
internal_tx: mpsc::Sender<Packet>,
Expand Down Expand Up @@ -166,7 +169,7 @@ where
protocol,
transport: AtomicU8::new(transport as u8),

internal_rx: Mutex::new(internal_rx),
internal_rx: Mutex::new(PeekableReceiver::new(internal_rx)),
internal_tx,

heartbeat_rx: Mutex::new(heartbeat_rx),
Expand Down Expand Up @@ -409,7 +412,7 @@ where
protocol: ProtocolVersion::V4,
transport: AtomicU8::new(TransportType::Websocket as u8),

internal_rx: Mutex::new(internal_rx),
internal_rx: Mutex::new(PeekableReceiver::new(internal_rx)),
internal_tx,

heartbeat_rx: Mutex::new(heartbeat_rx),
Expand Down
14 changes: 9 additions & 5 deletions engineioxide/src/transport/polling/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use crate::{
packet::{OpenPacket, Packet},
service::ProtocolVersion,
sid_generator::Sid,
transport::polling::payload::Payload,
DisconnectReason, SocketReq,
};

Expand Down Expand Up @@ -49,7 +50,7 @@ where

engine.handler.on_connect(socket);

let packet: String = Packet::Open(packet).try_into()?;
let packet: String = Packet::Open(packet).try_into().unwrap();
let packet = {
#[cfg(feature = "v3")]
{
Expand Down Expand Up @@ -97,13 +98,16 @@ where

debug!("[sid={sid}] polling request");

let max_payload = engine.config.max_payload;

#[cfg(feature = "v3")]
let (payload, is_binary) = payload::encoder(rx, protocol, socket.supports_binary).await?;
let Payload { data, has_binary } =
payload::encoder(rx, protocol, socket.supports_binary, max_payload).await?;
#[cfg(not(feature = "v3"))]
let (payload, is_binary) = payload::encoder(rx, protocol).await?;
let Payload { data, has_binary } = payload::encoder(rx, protocol, max_payload).await?;

debug!("[sid={sid}] sending data: {:?}", payload);
Ok(http_response(StatusCode::OK, payload, is_binary)?)
debug!("[sid={sid}] sending data: {:?}", data);
Ok(http_response(StatusCode::OK, data, has_binary)?)
}

/// Handle http polling post request
Expand Down
Loading