Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(socketio/ns): store ns path and rooms as Cow<'static, str> #124

Merged
merged 23 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
49665c3
feat(socketio/ns): store ns path as `Cow<'static, str>`
Totodore Oct 21, 2023
14d00e7
Merge branch 'main' into feat-cow-namespace-and-rooms
Totodore Oct 21, 2023
20f89a5
fix(clippy): useless conversions and immediate dereference
Totodore Oct 28, 2023
6b4ed75
Merge remote-tracking branch 'origin/main' into feat-cow-namespace-an…
Totodore Oct 28, 2023
837e7b9
feat(socketio/packet): move event type from String to `Cow<'static, s…
Totodore Oct 28, 2023
8ae3e16
feat(socketio/packet): remove `ConnectErrorPacket`
Totodore Oct 28, 2023
c15bc61
feat(socketio/config): remove duplicated engine io path
Totodore Oct 28, 2023
c75e903
feat(socketio/adapter): move `Room` type from `String` to `Cow<'stati…
Totodore Oct 28, 2023
e2098d6
fix(socketio/packet): connect error layout
Totodore Oct 28, 2023
e494b93
fix(clippy): redundant closure
Totodore Oct 29, 2023
01890b9
fix(socketio/packet): connect error layout
Totodore Oct 29, 2023
e02a2a9
fix(socketio/packet): fix packet encode connect
Totodore Oct 29, 2023
4bd0b7b
fix(socketio/packet): remove useless connect error decode test
Totodore Oct 29, 2023
de543fa
Merge remote-tracking branch 'origin/main' into feat-cow-namespace-an…
Totodore Oct 29, 2023
2acd187
bench(socketio): fix bench to work with new packet impl
Totodore Oct 29, 2023
ceaaf67
feat(socketio/packet): preallocate a buffer when encoding a packet
Totodore Oct 29, 2023
e370fad
fix(socketio/packet): put a / in front of non /* nsps
Totodore Oct 29, 2023
d6f8f09
feat(socketio/packet): decode packets from bytes rather than char iter
Totodore Oct 29, 2023
21f0531
fix(clippy): `manual check for common ascii range`
Totodore Oct 29, 2023
576a1e0
test(socketio/packet): add a test for `get_size_hint`
Totodore Oct 29, 2023
c13fd50
fix(clippy): `single-character string constant used as pattern`
Totodore Oct 29, 2023
eb20f85
feat(socketio): switch from to manual impl for number insertion
Totodore Oct 29, 2023
53db020
Revert "feat(socketio): switch from to manual impl for number insert…
Totodore Oct 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 23 additions & 18 deletions socketioxide/src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
//! Other adapters can be made to share the state between multiple servers.

use std::{
borrow::Cow,
collections::{HashMap, HashSet},
convert::Infallible,
sync::{Arc, RwLock, Weak},
Expand All @@ -15,7 +16,6 @@ use futures::{
stream::{self, BoxStream},
StreamExt,
};
use itertools::Itertools;
use serde::de::DeserializeOwned;

use crate::{
Expand All @@ -28,7 +28,7 @@ use crate::{
};

/// A room identifier
pub type Room = String;
pub type Room = Cow<'static, str>;

/// Flags that can be used to modify the behavior of the broadcast methods.
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
Expand All @@ -47,18 +47,18 @@ pub struct BroadcastOptions {
/// The flags to apply to the broadcast.
pub flags: HashSet<BroadcastFlags>,
/// The rooms to broadcast to.
pub rooms: Vec<Room>,
pub rooms: HashSet<Room>,
/// The rooms to exclude from the broadcast.
pub except: Vec<Room>,
pub except: HashSet<Room>,
/// The socket id of the sender.
pub sid: Option<Sid>,
}
impl BroadcastOptions {
pub fn new(sid: Option<Sid>) -> Self {
Self {
flags: HashSet::new(),
rooms: Vec::new(),
except: Vec::new(),
rooms: HashSet::new(),
except: HashSet::new(),
sid,
}
}
Expand Down Expand Up @@ -94,7 +94,7 @@ pub trait Adapter: std::fmt::Debug + Send + Sync + 'static {
/// Broadcast the packet to the sockets that match the [`BroadcastOptions`] and return a stream of ack responses.
fn broadcast_with_ack<V: DeserializeOwned>(
&self,
packet: Packet,
packet: Packet<'static>,
opts: BroadcastOptions,
) -> Result<BoxStream<'static, Result<AckResponse<V>, AckError>>, BroadcastError>;

Expand Down Expand Up @@ -208,7 +208,7 @@ impl Adapter for LocalAdapter {

fn broadcast_with_ack<V: DeserializeOwned>(
&self,
packet: Packet,
packet: Packet<'static>,
opts: BroadcastOptions,
) -> Result<BoxStream<'static, Result<AckResponse<V>, AckError>>, BroadcastError> {
let duration = opts.flags.iter().find_map(|flag| match flag {
Expand Down Expand Up @@ -241,7 +241,7 @@ impl Adapter for LocalAdapter {
}

//TODO: make this operation O(1)
fn socket_rooms(&self, sid: Sid) -> Result<Vec<String>, Infallible> {
fn socket_rooms(&self, sid: Sid) -> Result<Vec<Cow<'static, str>>, Infallible> {
let rooms_map = self.rooms.read().unwrap();
Ok(rooms_map
.iter()
Expand Down Expand Up @@ -300,7 +300,6 @@ impl LocalAdapter {
.iter()
.filter_map(|room| rooms_map.get(room))
.flatten()
.unique()
.filter(|sid| {
!except.contains(*sid)
&& (!opts.flags.contains(&BroadcastFlags::Broadcast)
Expand All @@ -321,7 +320,7 @@ impl LocalAdapter {
}
}

fn get_except_sids(&self, except: &Vec<Room>) -> HashSet<Sid> {
fn get_except_sids(&self, except: &HashSet<Room>) -> HashSet<Sid> {
let mut except_sids = HashSet::new();
let rooms_map = self.rooms.read().unwrap();
for room in except {
Expand All @@ -337,6 +336,12 @@ impl LocalAdapter {
mod test {
use super::*;

macro_rules! hash_set {
{$($v: expr),* $(,)?} => {
std::collections::HashSet::from([$($v,)*])
};
}

#[tokio::test]
async fn test_server_count() {
let ns = Namespace::new_dummy([]);
Expand Down Expand Up @@ -412,7 +417,7 @@ mod test {
adapter.add_all(socket, ["room1"]).unwrap();

let mut opts = BroadcastOptions::new(Some(socket));
opts.rooms = vec!["room1".to_string()];
opts.rooms = hash_set!["room1".into()];
adapter.add_sockets(opts, "room2").unwrap();
let rooms_map = adapter.rooms.read().unwrap();

Expand All @@ -429,7 +434,7 @@ mod test {
adapter.add_all(socket, ["room1"]).unwrap();

let mut opts = BroadcastOptions::new(Some(socket));
opts.rooms = vec!["room1".to_string()];
opts.rooms = hash_set!["room1".into()];
adapter.add_sockets(opts, "room2").unwrap();

{
Expand All @@ -441,7 +446,7 @@ mod test {
}

let mut opts = BroadcastOptions::new(Some(socket));
opts.rooms = vec!["room1".to_string()];
opts.rooms = hash_set!["room1".into()];
adapter.del_sockets(opts, "room2").unwrap();

{
Expand Down Expand Up @@ -498,7 +503,7 @@ mod test {
.unwrap();

let mut opts = BroadcastOptions::new(Some(socket0));
opts.rooms = vec!["room5".to_string()];
opts.rooms = hash_set!["room5".into()];
match adapter.disconnect_socket(opts) {
// todo it returns Ok, in previous commits it also returns Ok
Err(BroadcastError::SendError(_)) | Ok(_) => {}
Expand Down Expand Up @@ -531,15 +536,15 @@ mod test {

// socket 2 is the sender
let mut opts = BroadcastOptions::new(Some(socket2));
opts.rooms = vec!["room1".to_string()];
opts.except = vec!["room2".to_string()];
opts.rooms = hash_set!["room1".into()];
opts.except = hash_set!["room2".into()];
let sockets = adapter.fetch_sockets(opts).unwrap();
assert_eq!(sockets.len(), 1);
assert_eq!(sockets[0].id, socket1);

let mut opts = BroadcastOptions::new(Some(socket2));
opts.flags.insert(BroadcastFlags::Broadcast);
opts.except = vec!["room2".to_string()];
opts.except = hash_set!["room2".into()];
let sockets = adapter.fetch_sockets(opts).unwrap();
assert_eq!(sockets.len(), 1);

Expand Down
13 changes: 7 additions & 6 deletions socketioxide/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, RwLock};

Expand All @@ -21,7 +22,7 @@ use crate::{ProtocolVersion, Socket};
#[derive(Debug)]
pub struct Client<A: Adapter> {
pub(crate) config: Arc<SocketIoConfig>,
ns: RwLock<HashMap<String, Arc<Namespace<A>>>>,
ns: RwLock<HashMap<Cow<'static, str>, Arc<Namespace<A>>>>,
}

impl<A: Adapter> Client<A> {
Expand Down Expand Up @@ -58,7 +59,7 @@ impl<A: Adapter> Client<A> {
fn sock_connect(
&self,
auth: Option<String>,
ns_path: String,
ns_path: &str,
esocket: &Arc<engineioxide::Socket<SocketData>>,
) -> Result<(), Error> {
#[cfg(feature = "tracing")]
Expand Down Expand Up @@ -93,7 +94,7 @@ impl<A: Adapter> Client<A> {
}

/// Cache-in the socket data until all the binary payloads are received
fn sock_recv_bin_packet(&self, socket: &EIoSocket<SocketData>, packet: Packet) {
fn sock_recv_bin_packet(&self, socket: &EIoSocket<SocketData>, packet: Packet<'static>) {
socket
.data
.partial_bin_packet
Expand Down Expand Up @@ -132,7 +133,7 @@ impl<A: Adapter> Client<A> {
}

/// Add a new namespace handler
pub fn add_ns<C, F, V>(&self, path: String, callback: C)
pub fn add_ns<C, F, V>(&self, path: Cow<'static, str>, callback: C)
where
C: Fn(Arc<Socket<A>>, V) -> F + Send + Sync + 'static,
F: Future<Output = ()> + Send + 'static,
Expand Down Expand Up @@ -171,7 +172,7 @@ impl<A: Adapter> Client<A> {
pub struct SocketData {
/// Partial binary packet that is being received
/// Stored here until all the binary payloads are received
pub partial_bin_packet: Mutex<Option<Packet>>,
pub partial_bin_packet: Mutex<Option<Packet<'static>>>,

/// Channel used to notify the socket that it has been connected to a namespace
#[cfg(feature = "v5")]
Expand Down Expand Up @@ -245,7 +246,7 @@ impl<A: Adapter> EngineIoHandler for Client<A> {

let res: Result<(), Error> = match packet.inner {
PacketData::Connect(auth) => self
.sock_connect(auth, packet.ns, &socket)
.sock_connect(auth, &packet.ns, &socket)
.map_err(Into::into),
PacketData::BinaryEvent(_, _, _) | PacketData::BinaryAck(_, _) => {
self.sock_recv_bin_packet(&socket, packet);
Expand Down
2 changes: 1 addition & 1 deletion socketioxide/src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ impl<A: Adapter> AckSender<A> {
/// Send the ack response to the client.
pub fn send(self, data: impl Serialize) -> Result<(), AckSenderError<A>> {
if let Some(ack_id) = self.ack_id {
let ns = self.socket.ns().clone();
let ns = self.socket.ns();
let data = match serde_json::to_value(&data) {
Err(err) => {
return Err(AckSenderError::SendError {
Expand Down
20 changes: 9 additions & 11 deletions socketioxide/src/io.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{sync::Arc, time::Duration};
use std::{borrow::Cow, sync::Arc, time::Duration};

use engineioxide::{
config::{EngineIoConfig, EngineIoConfigBuilder, TransportType},
Expand Down Expand Up @@ -55,16 +55,14 @@ impl Default for SocketIoConfig {
pub struct SocketIoBuilder {
config: SocketIoConfig,
engine_config_builder: EngineIoConfigBuilder,
req_path: String,
}

impl SocketIoBuilder {
/// Create a new [`SocketIoBuilder`] with default config
pub fn new() -> Self {
Self {
config: SocketIoConfig::default(),
engine_config_builder: EngineIoConfigBuilder::new(),
req_path: "/socket.io".to_string(),
engine_config_builder: EngineIoConfigBuilder::new().req_path("/socket.io".to_string()),
}
}

Expand All @@ -73,7 +71,7 @@ impl SocketIoBuilder {
/// Defaults to "/socket.io".
#[inline]
pub fn req_path(mut self, req_path: String) -> Self {
self.req_path = req_path;
self.engine_config_builder = self.engine_config_builder.req_path(req_path);
self
}

Expand Down Expand Up @@ -163,7 +161,7 @@ impl SocketIoBuilder {
///
/// The layer can be used as a tower layer
pub fn build_layer_with_adapter<A: Adapter>(mut self) -> (SocketIoLayer<A>, SocketIo<A>) {
self.config.engine_config = self.engine_config_builder.req_path(self.req_path).build();
self.config.engine_config = self.engine_config_builder.build();

let (layer, client) = SocketIoLayer::from_config(Arc::new(self.config));
(layer, SocketIo(client))
Expand All @@ -190,7 +188,7 @@ impl SocketIoBuilder {
pub fn build_svc_with_adapter<A: Adapter>(
mut self,
) -> (SocketIoService<A, NotFoundService>, SocketIo<A>) {
self.config.engine_config = self.engine_config_builder.req_path(self.req_path).build();
self.config.engine_config = self.engine_config_builder.build();

let (svc, client) =
SocketIoService::with_config_inner(NotFoundService, Arc::new(self.config));
Expand All @@ -215,7 +213,7 @@ impl SocketIoBuilder {
mut self,
svc: S,
) -> (SocketIoService<A, S>, SocketIo<A>) {
self.config.engine_config = self.engine_config_builder.req_path(self.req_path).build();
self.config.engine_config = self.engine_config_builder.build();

let (svc, client) = SocketIoService::with_config_inner(svc, Arc::new(self.config));
(svc, SocketIo(client))
Expand Down Expand Up @@ -320,7 +318,7 @@ impl<A: Adapter> SocketIo<A> {
///
/// ```
#[inline]
pub fn ns<C, F, V>(&self, path: impl Into<String>, callback: C)
pub fn ns<C, F, V>(&self, path: impl Into<Cow<'static, str>>, callback: C)
where
C: Fn(Arc<Socket<A>>, V) -> F + Send + Sync + 'static,
F: Future<Output = ()> + Send + 'static,
Expand Down Expand Up @@ -569,7 +567,7 @@ impl<A: Adapter> SocketIo<A> {
#[inline]
pub fn emit(
&self,
event: impl Into<String>,
event: impl Into<Cow<'static, str>>,
data: impl serde::Serialize,
) -> Result<(), serde_json::Error> {
self.get_default_op().emit(event, data)
Expand Down Expand Up @@ -608,7 +606,7 @@ impl<A: Adapter> SocketIo<A> {
#[inline]
pub fn emit_with_ack<V: DeserializeOwned + Send>(
&self,
event: impl Into<String>,
event: impl Into<Cow<'static, str>>,
data: impl serde::Serialize,
) -> Result<BoxStream<'static, Result<AckResponse<V>, AckError>>, BroadcastError> {
self.get_default_op().emit_with_ack(event, data)
Expand Down
11 changes: 6 additions & 5 deletions socketioxide/src/ns.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{
borrow::Cow,
collections::HashMap,
sync::{Arc, RwLock},
};
Expand All @@ -17,14 +18,14 @@ use futures::Future;
use serde::de::DeserializeOwned;

pub struct Namespace<A: Adapter> {
pub path: String,
pub path: Cow<'static, str>,
pub(crate) adapter: A,
handler: BoxedNamespaceHandler<A>,
sockets: RwLock<HashMap<Sid, Arc<Socket<A>>>>,
}

impl<A: Adapter> Namespace<A> {
pub fn new<C, F, V>(path: String, callback: C) -> Arc<Self>
pub fn new<C, F, V>(path: Cow<'static, str>, callback: C) -> Arc<Self>
where
C: Fn(Arc<Socket<A>>, V) -> F + Send + Sync + 'static,
F: Future<Output = ()> + Send + 'static,
Expand Down Expand Up @@ -52,7 +53,7 @@ impl<A: Adapter> Namespace<A> {
self.sockets.write().unwrap().insert(sid, socket.clone());

let protocol = esocket.protocol.into();
if let Err(_e) = socket.send(Packet::connect(self.path.clone(), socket.id, protocol)) {
if let Err(_e) = socket.send(Packet::connect(&self.path, socket.id, protocol)) {
#[cfg(feature = "tracing")]
tracing::debug!("error sending connect packet: {:?}, closing conn", _e);
esocket.close(engineioxide::DisconnectReason::PacketParsingError);
Expand All @@ -78,7 +79,7 @@ impl<A: Adapter> Namespace<A> {
pub fn recv(&self, sid: Sid, packet: PacketData) -> Result<(), Error> {
match packet {
PacketData::Connect(_) => unreachable!("connect packets should be handled before"),
PacketData::ConnectError(_) => Ok(()),
PacketData::ConnectError => Err(Error::InvalidPacketType),
packet => self.get_socket(sid)?.recv(packet),
}
}
Expand Down Expand Up @@ -113,7 +114,7 @@ impl<A: Adapter> Namespace<A> {
#[cfg(test)]
impl<A: Adapter> Namespace<A> {
pub fn new_dummy<const S: usize>(sockets: [Sid; S]) -> Arc<Self> {
let ns = Namespace::new("/".to_string(), |_, _: ()| async {});
let ns = Namespace::new(Cow::Borrowed("/"), |_, _: ()| async {});
for sid in sockets {
ns.sockets
.write()
Expand Down
Loading
Loading