diff --git a/.github/workflows/adapter-ci/docker-compose.yml b/.github/workflows/adapter-ci/docker-compose.yml new file mode 100644 index 00000000..19d85d69 --- /dev/null +++ b/.github/workflows/adapter-ci/docker-compose.yml @@ -0,0 +1,112 @@ +name: adapter-setup +services: + valkey: + image: valkey/valkey + network_mode: host + healthcheck: + test: "valkey-cli ping" + interval: 2s + timeout: 5s + redis-node-0: + image: docker.io/bitnami/redis-cluster:7.0 + network_mode: host + healthcheck: + test: "redis-cli ping" + interval: 2s + timeout: 5s + environment: + - ALLOW_EMPTY_PASSWORD=yes + - REDIS_PORT_NUMBER=7000 + - REDIS_CLUSTER_ANNOUNCE_PORT=7000 + - REDIS_CLUSTER_ANNOUNCE_IP=127.0.0.1 # host ip address + - REDIS_CLUSTER_ANNOUNCE_BUS_PORT=17000 + - REDIS_CLUSTER_DYNAMIC_IPS=no + - REDIS_NODES=127.0.0.1:7000 127.0.0.1:7001 127.0.0.1:7002 127.0.0.1:7003 127.0.0.1:7004 127.0.0.1:7005 + + redis-node-1: + image: docker.io/bitnami/redis-cluster:7.0 + network_mode: host + healthcheck: + test: "redis-cli ping" + interval: 2s + timeout: 5s + environment: + - ALLOW_EMPTY_PASSWORD=yes + - REDIS_PORT_NUMBER=7001 + - REDIS_CLUSTER_ANNOUNCE_PORT=7001 + - REDIS_CLUSTER_ANNOUNCE_BUS_PORT=17001 + - REDIS_CLUSTER_ANNOUNCE_IP=127.0.0.1 + - REDIS_CLUSTER_DYNAMIC_IPS=no + - REDIS_NODES=127.0.0.1:7000 127.0.0.1:7001 127.0.0.1:7002 127.0.0.1:7003 127.0.0.1:7004 127.0.0.1:7005 + + redis-node-2: + image: docker.io/bitnami/redis-cluster:7.0 + network_mode: host + healthcheck: + test: "redis-cli ping" + interval: 2s + timeout: 5s + environment: + - ALLOW_EMPTY_PASSWORD=yes + - REDIS_PORT_NUMBER=7002 + - REDIS_CLUSTER_ANNOUNCE_PORT=7002 + - REDIS_CLUSTER_ANNOUNCE_BUS_PORT=17002 + - REDIS_CLUSTER_ANNOUNCE_IP=127.0.0.1 + - REDIS_CLUSTER_DYNAMIC_IPS=no + - REDIS_NODES=127.0.0.1:7000 127.0.0.1:7001 127.0.0.1:7002 127.0.0.1:7003 127.0.0.1:7004 127.0.0.1:7005 + + redis-node-3: + image: docker.io/bitnami/redis-cluster:7.0 + network_mode: host + healthcheck: + test: "redis-cli ping" + interval: 2s + timeout: 5s + environment: + - ALLOW_EMPTY_PASSWORD=yes + - REDIS_PORT_NUMBER=7003 + - REDIS_CLUSTER_ANNOUNCE_PORT=7003 + - REDIS_CLUSTER_ANNOUNCE_BUS_PORT=17003 + - REDIS_CLUSTER_ANNOUNCE_IP=127.0.0.1 + - REDIS_CLUSTER_DYNAMIC_IPS=no + - REDIS_NODES=127.0.0.1:7000 127.0.0.1:7001 127.0.0.1:7002 127.0.0.1:7003 127.0.0.1:7004 127.0.0.1:7005 + + redis-node-4: + image: docker.io/bitnami/redis-cluster:7.0 + network_mode: host + healthcheck: + test: "redis-cli ping" + interval: 2s + timeout: 5s + environment: + - ALLOW_EMPTY_PASSWORD=yes + - REDIS_PORT_NUMBER=7004 + - REDIS_CLUSTER_ANNOUNCE_PORT=7004 + - REDIS_CLUSTER_ANNOUNCE_BUS_PORT=17004 + - REDIS_CLUSTER_ANNOUNCE_IP=127.0.0.1 + - REDIS_CLUSTER_DYNAMIC_IPS=no + - REDIS_NODES=127.0.0.1:7000 127.0.0.1:7001 127.0.0.1:7002 127.0.0.1:7003 127.0.0.1:7004 127.0.0.1:7005 + + redis-node-5: + image: docker.io/bitnami/redis-cluster:7.0 + network_mode: host + healthcheck: + test: "redis-cli ping" + interval: 2s + timeout: 5s + depends_on: + - redis-node-0 + - redis-node-1 + - redis-node-2 + - redis-node-3 + - redis-node-4 + environment: + - ALLOW_EMPTY_PASSWORD=yes + - REDIS_CLUSTER_REPLICAS=1 + - REDIS_PORT_NUMBER=7005 + - REDIS_CLUSTER_ANNOUNCE_PORT=7005 + - REDIS_CLUSTER_ANNOUNCE_BUS_PORT=17005 + - REDIS_CLUSTER_ANNOUNCE_IP=127.0.0.1 + - REDIS_CLUSTER_DYNAMIC_IPS=no + - REDIS_NODES=127.0.0.1:7000 127.0.0.1:7001 127.0.0.1:7002 127.0.0.1:7003 127.0.0.1:7004 127.0.0.1:7005 + - REDIS_CLUSTER_CREATOR=yes diff --git a/.github/workflows/fuzzing.yml b/.github/workflows/fuzzing.yml index 4ccf4d21..d372fdc4 100644 --- a/.github/workflows/fuzzing.yml +++ b/.github/workflows/fuzzing.yml @@ -34,7 +34,7 @@ jobs: - run: cargo install cargo-fuzz - name: download corpus data run: | - cd crates/${{ matrix.crate }} + cd crates/${{ matrix.crate }}/fuzz wget https://github.com/Totodore/socketioxide-fuzzing-corpus/archive/refs/tags/v$TEST_DATA_VERSION.zip unzip v$TEST_DATA_VERSION.zip rm v$TEST_DATA_VERSION.zip @@ -69,5 +69,5 @@ jobs: - run: cargo install cargo-fuzz - name: cargo fuzz run decode_packet run: | - cd crates/${{ matrix.crate }} + cd crates/${{ matrix.crate }}/fuzz cargo fuzz run ${{ matrix.target }} -- -timeout=5 -max_len=2048 -runs=2000000 -only_ascii=1 diff --git a/.github/workflows/github-ci.yml b/.github/workflows/github-ci.yml index f5e52bf5..20286f87 100644 --- a/.github/workflows/github-ci.yml +++ b/.github/workflows/github-ci.yml @@ -5,8 +5,6 @@ on: branches: - main pull_request: - branches: - - main jobs: format: @@ -58,10 +56,10 @@ jobs: target/ key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}-nightly - name: Check unused dependencies on default features - run: cargo udeps --workspace + run: RUSTFLAGS="--cfg fuzzing" cargo udeps - name: Check unused dependencies on all features - run: cargo udeps --all-features --workspace + run: RUSTFLAGS="--cfg fuzzing" cargo udeps --all-features msrv: runs-on: ubuntu-latest @@ -85,7 +83,7 @@ jobs: components: rustfmt, clippy - name: check crates - run: cargo check -p socketioxide -p engineioxide --all-features + run: cargo check --all-features feature_set: runs-on: ubuntu-latest @@ -110,7 +108,7 @@ jobs: key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} - name: check --feature-powerset - run: cargo hack check --feature-powerset --no-dev-deps -p socketioxide -p engineioxide + run: cargo hack check --feature-powerset --no-dev-deps -p socketioxide -p engineioxide -p socketioxide-redis examples: runs-on: ubuntu-latest @@ -271,3 +269,57 @@ jobs: - name: Client output if: always() run: cat client.txt + adapter: + runs-on: ubuntu-latest + needs: [socket_io, engine_io] + strategy: + matrix: + socketio-version: [v4, v4-msgpack, v5, v5-msgpack] + adapter: [fred-e2e, redis-e2e, redis-cluster-e2e, fred-cluster-e2e] + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: stable + - uses: actions/cache@v4 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + target/ + key: ${{ runner.os }}-cargo-adapter + - uses: actions/setup-node@v4 + with: + node-version: 22 + - name: install adapter infra + uses: hoverkraft-tech/compose-action@v2.0.2 + with: + compose-file: ./.github/workflows/adapter-ci/docker-compose.yml + - run: cd e2e/adapter && npm install && npm install ts-node --location=global + - name: Install deps & run tests + run: | + PARSER=$(echo ${{ matrix.socketio-version }} | cut -d'-' -f2 -s) + VERSION=$(echo ${{ matrix.socketio-version }} | cut -d'-' -f1) + cargo build -p adapter-e2e --bin ${{ matrix.adapter }} --features $VERSION,$PARSER + cd e2e/adapter && CMD="cargo run -p adapter-e2e --bin ${{ matrix.adapter }} --features $VERSION,$PARSER" ts-node client.ts + - name: Server output + if: always() + run: cat e2e/adapter/*.log + all_passed: + runs-on: ubuntu-latest + needs: + [ + adapter, + feature_set, + format, + udeps, + msrv, + examples, + doctest, + rust-clippy-analyze, + ] + steps: + - name: All passed + run: echo "All tests passed" diff --git a/Cargo.toml b/Cargo.toml index 2b254972..ec022bc2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,6 +40,9 @@ hyper-util.version = "0.1" hyper = "1.5" pin-project-lite = "0.2" matchit = "0.8" +rmp-serde = "1.3" +rmp = "0.8" +rustversion = "1" # Dev deps tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/crates/engineioxide/Cargo.toml b/crates/engineioxide/Cargo.toml index 803c4d21..72bd089d 100644 --- a/crates/engineioxide/Cargo.toml +++ b/crates/engineioxide/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "engineioxide" description = "Engine IO server implementation in rust as a Tower Service." -version.workspace = true +version = "0.15.1" edition.workspace = true rust-version.workspace = true authors.workspace = true diff --git a/crates/engineioxide/Readme.md b/crates/engineioxide/Readme.md index 68ba4b33..6fef31a2 100644 --- a/crates/engineioxide/Readme.md +++ b/crates/engineioxide/Readme.md @@ -53,10 +53,10 @@ impl EngineIoHandler for MyHandler { let cnt = self.user_cnt.fetch_sub(1, Ordering::Relaxed) - 1; socket.emit(cnt.to_string()).ok(); } - fn on_message(&self, msg: Str, socket: Arc>) { + fn on_message(self: &Arc, msg: Str, socket: Arc>) { *socket.data.id.lock().unwrap() = msg.into(); // bind a provided user id to a socket } - fn on_binary(&self, data: Bytes, socket: Arc>) { } + fn on_binary(self: &Arc, data: Bytes, socket: Arc>) { } } // Create a new engineio layer diff --git a/crates/engineioxide/src/config.rs b/crates/engineioxide/src/config.rs index 103dca5c..8a9eaeea 100644 --- a/crates/engineioxide/src/config.rs +++ b/crates/engineioxide/src/config.rs @@ -15,8 +15,8 @@ //! type Data = (); //! fn on_connect(self: Arc, socket: Arc>) { } //! fn on_disconnect(&self, socket: Arc>, reason: DisconnectReason) { } -//! fn on_message(&self, msg: Str, socket: Arc>) { } -//! fn on_binary(&self, data: Bytes, socket: Arc>) { } +//! fn on_message(self: &Arc, msg: Str, socket: Arc>) { } +//! fn on_binary(self: &Arc, data: Bytes, socket: Arc>) { } //! } //! //! let config = EngineIoConfig::builder() @@ -150,12 +150,12 @@ impl EngineIoConfigBuilder { /// println!("socket disconnect {}", socket.id); /// } /// - /// fn on_message(&self, msg: Str, socket: Arc>) { + /// fn on_message(self: &Arc, msg: Str, socket: Arc>) { /// println!("Ping pong message {:?}", msg); /// socket.emit(msg).unwrap(); /// } /// - /// fn on_binary(&self, data: Bytes, socket: Arc>) { + /// fn on_binary(self: &Arc, data: Bytes, socket: Arc>) { /// println!("Ping pong binary message {:?}", data); /// socket.emit_binary(data).unwrap(); /// } diff --git a/crates/engineioxide/src/engine.rs b/crates/engineioxide/src/engine.rs index f1ba50e7..65eb8ed6 100644 --- a/crates/engineioxide/src/engine.rs +++ b/crates/engineioxide/src/engine.rs @@ -115,12 +115,12 @@ mod tests { println!("socket disconnect {} {:?}", socket.id, reason); } - fn on_message(&self, msg: Str, socket: Arc>) { + fn on_message(self: &Arc, msg: Str, socket: Arc>) { println!("Ping pong message {:?}", msg); socket.emit(msg).ok(); } - fn on_binary(&self, data: Bytes, socket: Arc>) { + fn on_binary(self: &Arc, data: Bytes, socket: Arc>) { println!("Ping pong binary message {:?}", data); socket.emit_binary(data).ok(); } diff --git a/crates/engineioxide/src/handler.rs b/crates/engineioxide/src/handler.rs index 9be6a9dc..f2f70f64 100644 --- a/crates/engineioxide/src/handler.rs +++ b/crates/engineioxide/src/handler.rs @@ -30,10 +30,10 @@ //! let cnt = self.user_cnt.fetch_sub(1, Ordering::Relaxed) - 1; //! socket.emit(cnt.to_string()).ok(); //! } -//! fn on_message(&self, msg: Str, socket: Arc>) { +//! fn on_message(self: &Arc, msg: Str, socket: Arc>) { //! *socket.data.id.lock().unwrap() = msg.into(); // bind a provided user id to a socket //! } -//! fn on_binary(&self, data: Bytes, socket: Arc>) { } +//! fn on_binary(self: &Arc, data: Bytes, socket: Arc>) { } //! } //! //! // Create an engine io service with the given handler @@ -60,8 +60,8 @@ pub trait EngineIoHandler: std::fmt::Debug + Send + Sync + 'static { fn on_disconnect(&self, socket: Arc>, reason: DisconnectReason); /// Called when a message is received from the client. - fn on_message(&self, msg: Str, socket: Arc>); + fn on_message(self: &Arc, msg: Str, socket: Arc>); /// Called when a binary message is received from the client. - fn on_binary(&self, data: Bytes, socket: Arc>); + fn on_binary(self: &Arc, data: Bytes, socket: Arc>); } diff --git a/crates/engineioxide/src/layer.rs b/crates/engineioxide/src/layer.rs index a3c62802..06657a55 100644 --- a/crates/engineioxide/src/layer.rs +++ b/crates/engineioxide/src/layer.rs @@ -15,8 +15,8 @@ //! type Data = (); //! fn on_connect(self: Arc, socket: Arc>) { } //! fn on_disconnect(&self, socket: Arc>, reason: DisconnectReason) { } -//! fn on_message(&self, msg: Str, socket: Arc>) { } -//! fn on_binary(&self, data: Bytes, socket: Arc>) { } +//! fn on_message(self: &Arc, msg: Str, socket: Arc>) { } +//! fn on_binary(self: &Arc, data: Bytes, socket: Arc>) { } //! } //! // Create a new engineio layer //! let layer = EngineIoLayer::new(Arc::new(MyHandler)); diff --git a/crates/engineioxide/src/service/mod.rs b/crates/engineioxide/src/service/mod.rs index f0379ccc..261fbb1b 100644 --- a/crates/engineioxide/src/service/mod.rs +++ b/crates/engineioxide/src/service/mod.rs @@ -17,8 +17,8 @@ //! type Data = (); //! fn on_connect(self: Arc, socket: Arc>) { } //! fn on_disconnect(&self, socket: Arc>, reason: DisconnectReason) { } -//! fn on_message(&self, msg: Str, socket: Arc>) { } -//! fn on_binary(&self, data: Bytes, socket: Arc>) { } +//! fn on_message(self: &Arc, msg: Str, socket: Arc>) { } +//! fn on_binary(self: &Arc, data: Bytes, socket: Arc>) { } //! } //! //! // Create a new engine.io service that will return a 404 not found response for other requests diff --git a/crates/engineioxide/src/sid.rs b/crates/engineioxide/src/sid.rs index c1bc938c..67a6ac15 100644 --- a/crates/engineioxide/src/sid.rs +++ b/crates/engineioxide/src/sid.rs @@ -16,13 +16,13 @@ pub struct Sid([u8; 16]); impl Sid { /// A zeroed session id pub const ZERO: Self = Self([0u8; 16]); - /// Generate a new random session id (base64 10 chars) + /// Generate a new random session id (base64 16 chars) pub fn new() -> Self { Self::default() } - /// Get the session id as a base64 10 chars string - pub fn as_str(&self) -> &str { + /// Get the session id as a base64 16 chars string + pub const fn as_str(&self) -> &str { // SAFETY: SID is always a base64 chars string unsafe { std::str::from_utf8_unchecked(&self.0) } } diff --git a/crates/engineioxide/src/socket.rs b/crates/engineioxide/src/socket.rs index 2c312069..8afc0b18 100644 --- a/crates/engineioxide/src/socket.rs +++ b/crates/engineioxide/src/socket.rs @@ -46,10 +46,10 @@ //! fn on_disconnect(&self, socket: Arc>, reason: DisconnectReason) { //! let cnt = self.user_cnt.fetch_sub(1, Ordering::Relaxed) - 1; //! } -//! fn on_message(&self, msg: Str, socket: Arc>) { +//! fn on_message(self: &Arc, msg: Str, socket: Arc>) { //! *socket.data.id.lock().unwrap() = msg.into(); // bind a provided user id to a socket //! } -//! fn on_binary(&self, data: Bytes, socket: Arc>) { } +//! fn on_binary(self: &Arc, data: Bytes, socket: Arc>) { } //! } //! //! let svc = EngineIoService::new(Arc::new(MyHandler::default())); diff --git a/crates/engineioxide/src/str.rs b/crates/engineioxide/src/str.rs index da7a57dc..a38858d3 100644 --- a/crates/engineioxide/src/str.rs +++ b/crates/engineioxide/src/str.rs @@ -100,6 +100,11 @@ impl From for String { unsafe { String::from_utf8_unchecked(vec) } } } +impl From for Vec { + fn from(value: Str) -> Self { + Vec::from(value.0) + } +} impl Serialize for Str { fn serialize(&self, serializer: S) -> Result where diff --git a/crates/engineioxide/tests/disconnect_reason.rs b/crates/engineioxide/tests/disconnect_reason.rs index 56725674..7339f12d 100644 --- a/crates/engineioxide/tests/disconnect_reason.rs +++ b/crates/engineioxide/tests/disconnect_reason.rs @@ -39,12 +39,12 @@ impl EngineIoHandler for MyHandler { self.disconnect_tx.try_send(reason).unwrap(); } - fn on_message(&self, msg: Str, socket: Arc>) { + fn on_message(self: &Arc, msg: Str, socket: Arc>) { println!("Ping pong message {:?}", msg); socket.emit(msg).ok(); } - fn on_binary(&self, data: Bytes, socket: Arc>) { + fn on_binary(self: &Arc, data: Bytes, socket: Arc>) { println!("Ping pong binary message {:?}", data); socket.emit_binary(data).ok(); } diff --git a/crates/parser-common/Cargo.toml b/crates/parser-common/Cargo.toml index 83f9dd72..25092249 100644 --- a/crates/parser-common/Cargo.toml +++ b/crates/parser-common/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "socketioxide-parser-common" description = "Common parser for the socketioxide protocol" -version.workspace = true +version = "0.15.1" edition.workspace = true rust-version.workspace = true authors.workspace = true @@ -23,9 +23,6 @@ socketioxide-core = { version = "0.15.0", path = "../socketioxide-core" } criterion.workspace = true rmpv = { version = "1.3", features = ["with-serde"] } -[features] -fuzzing = ["socketioxide-core/fuzzing"] - [[bench]] name = "packet_encode" path = "benches/packet_encode.rs" diff --git a/crates/parser-common/fuzz/Cargo.toml b/crates/parser-common/fuzz/Cargo.toml index 2918e328..9a47f1d7 100644 --- a/crates/parser-common/fuzz/Cargo.toml +++ b/crates/parser-common/fuzz/Cargo.toml @@ -9,8 +9,8 @@ cargo-fuzz = true [dependencies] libfuzzer-sys = "0.4" -socketioxide-parser-common = { path = "..", features = ["fuzzing"] } -socketioxide-core = { path = "../../socketioxide-core", features = ["fuzzing"] } +socketioxide-parser-common = { path = ".." } +socketioxide-core = { path = "../../socketioxide-core" } serde_json.workspace = true bytes.workspace = true diff --git a/crates/parser-common/fuzz/config.toml b/crates/parser-common/fuzz/config.toml new file mode 100644 index 00000000..8ffb3e10 --- /dev/null +++ b/crates/parser-common/fuzz/config.toml @@ -0,0 +1,2 @@ +[build] +rustflags = ["--cfg fuzzing"] diff --git a/crates/parser-common/src/de.rs b/crates/parser-common/src/de.rs index 40e3a8d9..91f5f2b5 100644 --- a/crates/parser-common/src/de.rs +++ b/crates/parser-common/src/de.rs @@ -7,9 +7,7 @@ use socketioxide_core::{ Str, Value, }; -pub fn deserialize_packet( - data: Str, -) -> Result<(Packet, Option), ParseError> { +pub fn deserialize_packet(data: Str) -> Result<(Packet, Option), ParseError> { if data.is_empty() { return Err(ParseError::InvalidPacketType); } diff --git a/crates/parser-common/src/lib.rs b/crates/parser-common/src/lib.rs index efcfa064..7fd281df 100644 --- a/crates/parser-common/src/lib.rs +++ b/crates/parser-common/src/lib.rs @@ -45,7 +45,7 @@ use bytes::Bytes; use serde::{Deserialize, Serialize}; use socketioxide_core::{ packet::{Packet, PacketData}, - parser::{Parse, ParseError, ParserState}, + parser::{Parse, ParseError, ParserError, ParserState}, Str, Value, }; @@ -59,17 +59,11 @@ mod value; pub struct CommonParser; impl Parse for CommonParser { - type EncodeError = serde_json::Error; - type DecodeError = serde_json::Error; fn encode(self, packet: Packet) -> Value { ser::serialize_packet(packet) } - fn decode_str( - self, - state: &ParserState, - value: Str, - ) -> Result> { + fn decode_str(self, state: &ParserState, value: Str) -> Result { let (packet, incoming_binary_cnt) = de::deserialize_packet(value)?; if packet.inner.is_binary() { let incoming_binary_cnt = incoming_binary_cnt.ok_or(ParseError::InvalidAttachments)?; @@ -87,11 +81,7 @@ impl Parse for CommonParser { } } - fn decode_bin( - self, - state: &ParserState, - data: Bytes, - ) -> Result> { + fn decode_bin(self, state: &ParserState, data: Bytes) -> Result { let packet = &mut *state.partial_bin_packet.lock().unwrap(); match packet { Some(Packet { @@ -117,8 +107,8 @@ impl Parse for CommonParser { self, data: &T, event: Option<&str>, - ) -> Result { - value::to_value(data, event) + ) -> Result { + value::to_value(data, event).map_err(ParserError::new) } #[inline] @@ -126,32 +116,32 @@ impl Parse for CommonParser { self, value: &'de mut Value, with_event: bool, - ) -> Result { - value::from_value(value, with_event) + ) -> Result { + value::from_value(value, with_event).map_err(ParserError::new) } fn decode_default<'de, T: Deserialize<'de>>( self, value: Option<&'de Value>, - ) -> Result { + ) -> Result { if let Some(value) = value { let data = value .as_str() .expect("CommonParser only supports string values"); - serde_json::from_str(data) + serde_json::from_str(data).map_err(ParserError::new) } else { - serde_json::from_str("{}") + serde_json::from_str("{}").map_err(ParserError::new) } } - fn encode_default(self, data: &T) -> Result { - let value = serde_json::to_string(data)?; + fn encode_default(self, data: &T) -> Result { + let value = serde_json::to_string(data).map_err(ParserError::new)?; Ok(Value::Str(Str::from(value), None)) } #[inline] - fn read_event(self, value: &Value) -> Result<&str, Self::DecodeError> { - value::read_event(value) + fn read_event(self, value: &Value) -> Result<&str, ParserError> { + value::read_event(value).map_err(ParserError::new) } } @@ -590,7 +580,7 @@ mod test { fn decode_default_none() { // Common parser should deserialize by default to an empty map to match the behavior of the // socket.io client when deserializing incoming connect message without an auth payload. - let data: serde_json::Result> = CommonParser.decode_default(None); + let data = CommonParser.decode_default::>(None); assert!(matches!(data, Ok(d) if d.is_empty())); } @@ -598,8 +588,8 @@ mod test { fn decode_default_some() { // Common parser should deserialize by default to an empty map to match the behavior of the // socket.io client when deserializing incoming connect message without an auth payload. - let data: serde_json::Result = - CommonParser.decode_default(Some(&Value::Str("\"test\"".into(), None))); + let data = + CommonParser.decode_default::(Some(&Value::Str("\"test\"".into(), None))); assert!(matches!(data, Ok(d) if d == "test")); } diff --git a/crates/parser-common/src/value/ser.rs b/crates/parser-common/src/value/ser.rs index f4ff3ae8..8a5e8675 100644 --- a/crates/parser-common/src/value/ser.rs +++ b/crates/parser-common/src/value/ser.rs @@ -318,7 +318,6 @@ impl<'a, S: ser::Serializer> serde::Serializer for Serializer<'a, S> { } fn serialize_bytes(self, v: &[u8]) -> Result { - use serde::ser::SerializeMap; let num = { // SAFETY: the binary_payloads are only accessed in the context of the current serialization // in a sequential manner. The only mutation place is here, hence it remains safe. diff --git a/crates/parser-msgpack/Cargo.toml b/crates/parser-msgpack/Cargo.toml index 5dfb5aaf..66f28a9b 100644 --- a/crates/parser-msgpack/Cargo.toml +++ b/crates/parser-msgpack/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "socketioxide-parser-msgpack" description = "Msgpack parser for the socketioxide protocol" -version.workspace = true +version = "0.15.1" edition.workspace = true rust-version.workspace = true authors.workspace = true @@ -15,19 +15,14 @@ readme.workspace = true [dependencies] bytes.workspace = true serde.workspace = true -rmp = "0.8" -rmp-serde = "1.3" +rmp-serde.workspace = true +rmp.workspace = true socketioxide-core = { version = "0.15.0", path = "../socketioxide-core" } [dev-dependencies] serde_json.workspace = true criterion.workspace = true - -[features] -fuzzing = ["socketioxide-core/fuzzing"] - - [[bench]] name = "packet_encode" path = "benches/packet_encode.rs" diff --git a/crates/parser-msgpack/fuzz/Cargo.toml b/crates/parser-msgpack/fuzz/Cargo.toml index a530fe24..425eca13 100644 --- a/crates/parser-msgpack/fuzz/Cargo.toml +++ b/crates/parser-msgpack/fuzz/Cargo.toml @@ -10,8 +10,8 @@ cargo-fuzz = true [dependencies] libfuzzer-sys = "0.4" rmpv = { version = "1.3.0", features = ["with-serde"] } -socketioxide-parser-msgpack = { path = "..", features = ["fuzzing"] } -socketioxide-core = { path = "../../socketioxide-core", features = ["fuzzing"] } +socketioxide-parser-msgpack = { path = ".." } +socketioxide-core = { path = "../../socketioxide-core" } bytes.workspace = true [[bin]] diff --git a/crates/parser-msgpack/fuzz/config.toml b/crates/parser-msgpack/fuzz/config.toml new file mode 100644 index 00000000..8ffb3e10 --- /dev/null +++ b/crates/parser-msgpack/fuzz/config.toml @@ -0,0 +1,2 @@ +[build] +rustflags = ["--cfg fuzzing"] diff --git a/crates/parser-msgpack/src/de.rs b/crates/parser-msgpack/src/de.rs index 2b05c616..4d550f54 100644 --- a/crates/parser-msgpack/src/de.rs +++ b/crates/parser-msgpack/src/de.rs @@ -9,11 +9,11 @@ use rmp::{ use rmp_serde::decode::Error as DecodeError; use socketioxide_core::{ packet::{Packet, PacketData}, - parser::ParseError, + parser::{ParseError, ParserError}, Str, Value, }; -pub fn deserialize_packet(buff: Bytes) -> Result> { +pub fn deserialize_packet(buff: Bytes) -> Result { let mut reader = Cursor::new(buff); let maplen = read_map_len(&mut reader).map_err(|e| { use DecodeError::*; @@ -22,17 +22,17 @@ pub fn deserialize_packet(buff: Bytes) -> Result ValueReadError::InvalidDataRead(e) => InvalidDataRead(e), ValueReadError::TypeMismatch(e) => TypeMismatch(e), }; - ParseError::ParserError(e) + ParseError::ParserError(ParserError::new(e)) })?; // Bound check to prevent DoS attacks. // other implementations might add some other keys that we don't support // Therefore, we limit the number of keys to 20 if maplen == 0 || maplen > 20 { - Err(DecodeError::Uncategorized(format!( + Err(ParserError::new(DecodeError::Uncategorized(format!( "packet length too big or empty: {}", maplen - )))?; + ))))?; } let mut index = 0xff; @@ -41,7 +41,8 @@ pub fn deserialize_packet(buff: Bytes) -> Result let mut id = None; for _ in 0..maplen { - parse_key_value(&mut reader, &mut index, &mut nsp, &mut data_pos, &mut id)?; + parse_key_value(&mut reader, &mut index, &mut nsp, &mut data_pos, &mut id) + .map_err(ParserError::new)?; } let buff = reader.into_inner(); let mut data = buff.slice(data_pos.clone()); @@ -64,7 +65,8 @@ pub fn deserialize_packet(buff: Bytes) -> Result struct ErrorMessage { message: String, } - let ErrorMessage { message } = rmp_serde::decode::from_slice(&buff[data_pos])?; + let ErrorMessage { message } = + rmp_serde::decode::from_slice(&buff[data_pos]).map_err(ParserError::new)?; PacketData::ConnectError(message) } 5 => PacketData::BinaryEvent(data, id), diff --git a/crates/parser-msgpack/src/lib.rs b/crates/parser-msgpack/src/lib.rs index 57f44836..6c38b9f7 100644 --- a/crates/parser-msgpack/src/lib.rs +++ b/crates/parser-msgpack/src/lib.rs @@ -45,13 +45,15 @@ //! will be directly converted to the following msgpack binary format: //! `84 A4 74 79 70 65 02 A3 6E 73 70 A1 2F A4 64 61 74 61 92 A5 65 76 65 6E 74 A3 66 6F 6F A2 69 64 01` +use std::str; + use bytes::Bytes; use de::deserialize_packet; use ser::serialize_packet; use serde::Deserialize; use socketioxide_core::{ packet::Packet, - parser::{Parse, ParseError, ParserState}, + parser::{Parse, ParseError, ParserError, ParserState}, Str, Value, }; @@ -64,27 +66,16 @@ mod value; pub struct MsgPackParser; impl Parse for MsgPackParser { - type EncodeError = rmp_serde::encode::Error; - type DecodeError = rmp_serde::decode::Error; - fn encode(self, packet: Packet) -> socketioxide_core::Value { let data = serialize_packet(packet); Value::Bytes(data.into()) } - fn decode_str( - self, - _: &ParserState, - _data: Str, - ) -> Result> { + fn decode_str(self, _: &ParserState, _data: Str) -> Result { Err(ParseError::UnexpectedStringPacket) } - fn decode_bin( - self, - _: &ParserState, - bin: Bytes, - ) -> Result> { + fn decode_bin(self, _: &ParserState, bin: Bytes) -> Result { deserialize_packet(bin) } @@ -92,39 +83,38 @@ impl Parse for MsgPackParser { self, data: &T, event: Option<&str>, - ) -> Result { - value::to_value(data, event) + ) -> Result { + value::to_value(data, event).map_err(ParserError::new) } fn decode_value<'de, T: Deserialize<'de>>( self, value: &'de mut Value, with_event: bool, - ) -> Result { - value::from_value(value, with_event) + ) -> Result { + value::from_value(value, with_event).map_err(ParserError::new) } fn decode_default<'de, T: Deserialize<'de>>( self, value: Option<&'de Value>, - ) -> Result { + ) -> Result { if let Some(value) = value { let value = value.as_bytes().expect("value should be bytes"); - rmp_serde::from_slice(value) + rmp_serde::from_slice(value).map_err(ParserError::new) } else { - rmp_serde::from_slice(&[0xc0]) // nil value + rmp_serde::from_slice(&[0xc0]).map_err(ParserError::new) // nil value } } - fn encode_default( - self, - data: &T, - ) -> Result { - rmp_serde::to_vec_named(data).map(|b| Value::Bytes(b.into())) + fn encode_default(self, data: &T) -> Result { + rmp_serde::to_vec_named(data) + .map(|b| Value::Bytes(b.into())) + .map_err(ParserError::new) } - fn read_event(self, value: &Value) -> Result<&str, Self::DecodeError> { - value::read_event(value) + fn read_event(self, value: &Value) -> Result<&str, ParserError> { + value::read_event(value).map_err(ParserError::new) } } diff --git a/crates/socketioxide-core/Cargo.toml b/crates/socketioxide-core/Cargo.toml index 55039468..6124ebeb 100644 --- a/crates/socketioxide-core/Cargo.toml +++ b/crates/socketioxide-core/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "socketioxide-core" -description = "Core of the socketioxide library. Contains basic types and interfaces for the socketioxide crate and the parser sub-crates" -version.workspace = true +description = "Core of the socketioxide library. Contains basic types and interfaces for the socketioxide crate and all other related sub-crates." +version = "0.15.1" edition.workspace = true rust-version.workspace = true authors.workspace = true @@ -14,10 +14,17 @@ readme.workspace = true [dependencies] bytes.workspace = true -engineioxide = { version = "0.15.0", path = "../engineioxide" } +engineioxide = { version = "0.15", path = "../engineioxide" } serde.workspace = true thiserror.workspace = true -arbitrary = { version = "1.3.2", features = ["derive"], optional = true } +futures-core.workspace = true +smallvec = { workspace = true, features = ["serde"] } -[features] -fuzzing = ["dep:arbitrary"] +[target."cfg(fuzzing)".dependencies] +arbitrary = { version = "1.3.2", features = ["derive"] } + +[dev-dependencies] +serde_json.workspace = true + +[lints.rust] +unexpected_cfgs = { level = "warn", check-cfg = ['cfg(fuzzing)'] } diff --git a/crates/socketioxide-core/src/adapter.rs b/crates/socketioxide-core/src/adapter.rs new file mode 100644 index 00000000..f7439387 --- /dev/null +++ b/crates/socketioxide-core/src/adapter.rs @@ -0,0 +1,1062 @@ +//! The adapter module contains the [`CoreAdapter`] trait and other related types. +//! +//! It is used to implement communication between socket.io servers to share messages and state. +//! +//! The [`CoreLocalAdapter`] provide a local implementation that will allow any implementors to apply local +//! operations (`broadcast_with_ack`, `broadcast`, `rooms`, etc...). +use std::{ + borrow::Cow, + collections::{hash_map, hash_set, HashMap, HashSet}, + error::Error as StdError, + future::{self, Future}, + hash::Hash, + slice, + sync::{Arc, RwLock}, + time::Duration, +}; + +use engineioxide::{sid::Sid, Str}; +use futures_core::{FusedStream, Stream}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use smallvec::SmallVec; + +use crate::{ + errors::{AdapterError, BroadcastError, SocketError}, + packet::Packet, + parser::Parse, + Uid, Value, +}; + +/// A room identifier +pub type Room = Cow<'static, str>; + +/// Flags that can be used to modify the behavior of the broadcast methods. +#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] +pub enum BroadcastFlags { + /// Broadcast only to the current server + Local = 0x01, + /// Broadcast to all clients except the sender + Broadcast = 0x02, +} + +/// Options that can be used to modify the behavior of the broadcast methods. +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)] +pub struct BroadcastOptions { + /// The flags to apply to the broadcast represented as a bitflag. + flags: u8, + /// The rooms to broadcast to. + pub rooms: SmallVec<[Room; 4]>, + /// The rooms to exclude from the broadcast. + pub except: SmallVec<[Room; 4]>, + /// The socket id of the sender. + pub sid: Option, + /// The target server id can be used to optimize the broadcast. + /// More specifically when we use broadcasting to apply a single action on a remote socket. + /// We now the server_id of the remote socket, so we can send the action directly to the server. + pub server_id: Option, +} +impl BroadcastOptions { + /// Add any flags to the options. + pub fn add_flag(&mut self, flag: BroadcastFlags) { + self.flags |= flag as u8; + } + /// Check if the options have a flag. + pub fn has_flag(&self, flag: BroadcastFlags) -> bool { + self.flags & flag as u8 == flag as u8 + } + + /// get the flags of the options. + pub fn flags(&self) -> u8 { + self.flags + } + + /// Set the socket id of the sender. + pub fn new(sid: Sid) -> Self { + Self { + sid: Some(sid), + ..Default::default() + } + } + /// Create a new broadcast options from a remote socket data. + pub fn new_remote(data: &RemoteSocketData) -> Self { + Self { + sid: Some(data.id), + server_id: Some(data.server_id), + ..Default::default() + } + } +} + +/// A trait for types that can be used as a room parameter. +/// +/// [`String`], [`Vec`], [`Vec<&str>`], [`&'static str`](str) and const arrays are implemented by default. +pub trait RoomParam: Send + 'static { + /// The type of the iterator returned by `into_room_iter`. + type IntoIter: Iterator; + + /// Convert `self` into an iterator of rooms. + fn into_room_iter(self) -> Self::IntoIter; +} + +impl RoomParam for Room { + type IntoIter = std::iter::Once; + #[inline(always)] + fn into_room_iter(self) -> Self::IntoIter { + std::iter::once(self) + } +} +impl RoomParam for String { + type IntoIter = std::iter::Once; + #[inline(always)] + fn into_room_iter(self) -> Self::IntoIter { + std::iter::once(Cow::Owned(self)) + } +} +impl RoomParam for Vec { + type IntoIter = std::iter::Map, fn(String) -> Room>; + #[inline(always)] + fn into_room_iter(self) -> Self::IntoIter { + self.into_iter().map(Cow::Owned) + } +} +impl RoomParam for Vec<&'static str> { + type IntoIter = std::iter::Map, fn(&'static str) -> Room>; + #[inline(always)] + fn into_room_iter(self) -> Self::IntoIter { + self.into_iter().map(Cow::Borrowed) + } +} + +impl RoomParam for Vec { + type IntoIter = std::vec::IntoIter; + #[inline(always)] + fn into_room_iter(self) -> Self::IntoIter { + self.into_iter() + } +} +impl RoomParam for &'static str { + type IntoIter = std::iter::Once; + #[inline(always)] + fn into_room_iter(self) -> Self::IntoIter { + std::iter::once(Cow::Borrowed(self)) + } +} +impl RoomParam for [&'static str; COUNT] { + type IntoIter = + std::iter::Map, fn(&'static str) -> Room>; + + #[inline(always)] + fn into_room_iter(self) -> Self::IntoIter { + self.into_iter().map(Cow::Borrowed) + } +} +impl RoomParam for &'static [&'static str] { + type IntoIter = + std::iter::Map, fn(&'static &'static str) -> Room>; + + #[inline(always)] + fn into_room_iter(self) -> Self::IntoIter { + self.iter().map(|i| Cow::Borrowed(*i)) + } +} +impl RoomParam for [String; COUNT] { + type IntoIter = std::iter::Map, fn(String) -> Room>; + #[inline(always)] + fn into_room_iter(self) -> Self::IntoIter { + self.into_iter().map(Cow::Owned) + } +} +impl RoomParam for Sid { + type IntoIter = std::iter::Once; + #[inline(always)] + fn into_room_iter(self) -> Self::IntoIter { + std::iter::once(Cow::Owned(self.to_string())) + } +} + +/// A item yield by the ack stream. +pub type AckStreamItem = (Sid, Result); +/// The [`SocketEmitter`] will be implemented by the socketioxide library. +/// It is simply used as an abstraction to allow the adapter to communicate +/// with the socket server without the need to depend on the socketioxide lib. +pub trait SocketEmitter: Send + Sync + 'static { + /// An error that can occur when sending data an acknowledgment. + type AckError: StdError + Send + Serialize + DeserializeOwned + 'static; + /// A stream that emits the acknowledgments of multiple sockets. + type AckStream: Stream> + FusedStream + Send + 'static; + + /// Get all the socket ids in the namespace. + fn get_all_sids(&self, filter: impl Fn(&Sid) -> bool) -> Vec; + /// Get the socket data that match the list of socket ids. + fn get_remote_sockets(&self, sids: BroadcastIter<'_>) -> Vec; + /// Send data to the list of socket ids. + fn send_many(&self, sids: BroadcastIter<'_>, data: Value) -> Result<(), Vec>; + /// Send data to the list of socket ids and get a stream of acks and the number of expected acks. + fn send_many_with_ack( + &self, + sids: BroadcastIter<'_>, + packet: Packet, + timeout: Option, + ) -> (Self::AckStream, u32); + /// Disconnect all the sockets in the list. + /// TODO: take a [`BroadcastIter`]. Currently it is impossible because it may create deadlocks + /// with Adapter::del_all call. + fn disconnect_many(&self, sids: Vec) -> Result<(), Vec>; + /// Get the path of the namespace. + fn path(&self) -> &Str; + /// Get the parser of the namespace. + fn parser(&self) -> impl Parse; + /// Get the unique server id. + fn server_id(&self) -> Uid; +} + +/// For static namespaces, the init response will be managed by the user. +/// However, for dynamic namespaces, the socket.io client will manage the response. +/// As it does not know the type of the response, the spawnable trait is used to spawn the response. +/// Without the client having to know the type of the response. +pub trait Spawnable { + /// Spawn the response. Implementors should spawn the future with `tokio::spawn` if it is an async function. + /// They should also print a `tracing::error` log in case of an error. + fn spawn(self); +} +impl Spawnable for () { + fn spawn(self) {} +} + +/// A trait to add a "defined" bound to adapter types. +/// This allow the socket io library to implement function given a *defined* adapter +/// and not a generic `A: Adapter`. +/// +/// This is useful to force the user to handle potential init response type [`CoreAdapter::InitRes`]. +pub trait DefinedAdapter {} + +/// An adapter is responsible for managing the state of the namespace. +/// This adapter can be implemented to share the state between multiple servers. +/// +/// A [`CoreLocalAdapter`] instance will be given when constructing this type, it will allow +/// you to manipulate local sockets (emitting, fetching data, broadcasting). +pub trait CoreAdapter: Sized + Send + Sync + 'static { + /// An error that can occur when using the adapter. + type Error: StdError + Into + Send + 'static; + /// A shared state between all the namespace [`CoreAdapter`]. + /// This can be used to share a connection for example. + type State: Send + Sync + 'static; + /// A stream that emits the acknowledgments of multiple sockets. + type AckStream: Stream> + FusedStream + Send + 'static; + /// A named result type for the initialization of the adapter. + type InitRes: Spawnable + Send; + + /// Creates a new adapter with the given state and local adapter. + /// + /// The state is used to share a common state between all your adapters. E.G. a connection to a remote system. + /// The local adapter is used to manipulate the local sockets. + fn new(state: &Self::State, local: CoreLocalAdapter) -> Self; + + /// Initializes the adapter. The on_success callback should be called when the adapter ready. + fn init(self: Arc, on_success: impl FnOnce() + Send + 'static) -> Self::InitRes; + + /// Closes the adapter. + fn close(&self) -> impl Future> + Send { + future::ready(Ok(())) + } + + /// Returns the number of servers. + fn server_count(&self) -> impl Future> + Send { + future::ready(Ok(1)) + } + + /// Broadcasts the packet to the sockets that match the [`BroadcastOptions`]. + fn broadcast( + &self, + packet: Packet, + opts: BroadcastOptions, + ) -> impl Future> + Send { + future::ready( + self.get_local() + .broadcast(packet, opts) + .map_err(BroadcastError::from), + ) + } + + /// Broadcasts the packet to the sockets that match the [`BroadcastOptions`] + /// and return a stream of ack responses. + /// + /// This method does not have default implementation because GAT cannot have default impls. + /// + fn broadcast_with_ack( + &self, + packet: Packet, + opts: BroadcastOptions, + timeout: Option, + ) -> impl Future> + Send; + + /// Adds the sockets that match the [`BroadcastOptions`] to the rooms. + fn add_sockets( + &self, + opts: BroadcastOptions, + rooms: impl RoomParam, + ) -> impl Future> + Send { + self.get_local().add_sockets(opts, rooms); + future::ready(Ok(())) + } + + /// Removes the sockets that match the [`BroadcastOptions`] from the rooms. + fn del_sockets( + &self, + opts: BroadcastOptions, + rooms: impl RoomParam, + ) -> impl Future> + Send { + self.get_local().del_sockets(opts, rooms); + future::ready(Ok(())) + } + + /// Disconnects the sockets that match the [`BroadcastOptions`]. + fn disconnect_socket( + &self, + opts: BroadcastOptions, + ) -> impl Future> + Send { + future::ready( + self.get_local() + .disconnect_socket(opts) + .map_err(BroadcastError::Socket), + ) + } + + /// Fetches rooms that match the [`BroadcastOptions`] + fn rooms( + &self, + opts: BroadcastOptions, + ) -> impl Future, Self::Error>> + Send { + future::ready(Ok(self.get_local().rooms(opts).into_iter().collect())) + } + + /// Fetches remote sockets that match the [`BroadcastOptions`]. + fn fetch_sockets( + &self, + opts: BroadcastOptions, + ) -> impl Future, Self::Error>> + Send { + future::ready(Ok(self.get_local().fetch_sockets(opts))) + } + + /// Returns the local adapter. Used to enable default behaviors. + fn get_local(&self) -> &CoreLocalAdapter; + + //TODO: implement + // fn server_side_emit(&self, packet: Packet, opts: BroadcastOptions) -> Result; + // fn persist_session(&self, sid: i64); + // fn restore_session(&self, sid: i64) -> Session; +} + +/// The default adapter. Store the state in memory. +pub struct CoreLocalAdapter { + rooms: RwLock>>, + sockets: RwLock>>, + emitter: E, +} + +impl CoreLocalAdapter { + /// Create a new local adapter with the given sockets interface. + pub fn new(emitter: E) -> Self { + Self { + rooms: RwLock::new(HashMap::new()), + sockets: RwLock::new(HashMap::new()), + emitter, + } + } + + /// Clears all the rooms and sockets. + pub fn close(&self) { + let mut rooms = self.rooms.write().unwrap(); + rooms.clear(); + rooms.shrink_to_fit(); + } + + /// Adds the socket to all the rooms. + pub fn add_all(&self, sid: Sid, rooms: impl RoomParam) { + let mut rooms_map = self.rooms.write().unwrap(); + let mut socket_map = self.sockets.write().unwrap(); + for room in rooms.into_room_iter() { + rooms_map.entry(room.clone()).or_default().insert(sid); + socket_map.entry(sid).or_default().insert(room); + } + } + + /// Removes the socket from the rooms. + pub fn del(&self, sid: Sid, rooms: impl RoomParam) { + let mut rooms_map = self.rooms.write().unwrap(); + let mut socket_map = self.sockets.write().unwrap(); + for room in rooms.into_room_iter() { + remove_and_clean_entry(rooms_map.entry(room.clone()), &sid, || { + socket_map.entry(sid).and_modify(|r| { + r.remove(&room); + }); + }); + } + } + + /// Removes the socket from all the rooms. + pub fn del_all(&self, sid: Sid) { + let mut rooms_map = self.rooms.write().unwrap(); + if let Some(rooms) = self.sockets.write().unwrap().remove(&sid) { + for room in rooms { + remove_and_clean_entry(rooms_map.entry(room.clone()), &sid, || ()); + } + } + } + + /// Broadcasts the packet to the sockets that match the [`BroadcastOptions`]. + pub fn broadcast( + &self, + packet: Packet, + opts: BroadcastOptions, + ) -> Result<(), Vec> { + let room_map = self.rooms.read().unwrap(); + let sids = self.apply_opts(&opts, &room_map); + + if sids.is_empty() { + return Ok(()); + } + + let data = self.emitter.parser().encode(packet); + self.emitter.send_many(sids, data) + } + + /// Broadcasts the packet to the sockets that match the [`BroadcastOptions`] and return a stream of ack responses. + /// Also returns the number of local expected aknowledgements to know when to stop waiting. + pub fn broadcast_with_ack( + &self, + packet: Packet, + opts: BroadcastOptions, + timeout: Option, + ) -> (E::AckStream, u32) { + let room_map = self.rooms.read().unwrap(); + let sids = self.apply_opts(&opts, &room_map); + // We cannot pre-serialize the packet because we need to change the ack id. + self.emitter.send_many_with_ack(sids, packet, timeout) + } + + /// Returns the sockets ids that match the [`BroadcastOptions`]. + pub fn sockets(&self, opts: BroadcastOptions) -> Vec { + self.apply_opts(&opts, &self.rooms.read().unwrap()) + .collect() + } + + /// Returns the sockets ids that match the [`BroadcastOptions`]. + pub fn fetch_sockets(&self, opts: BroadcastOptions) -> Vec { + let rooms = self.rooms.read().unwrap(); + let sids = self.apply_opts(&opts, &rooms); + self.emitter.get_remote_sockets(sids) + } + + /// Returns the rooms of the socket. + pub fn socket_rooms(&self, sid: Sid) -> HashSet { + self.sockets + .read() + .unwrap() + .get(&sid) + .cloned() + .unwrap_or_default() + } + + /// Adds the sockets that match the [`BroadcastOptions`] to the rooms. + pub fn add_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) { + let rooms: Vec = rooms.into_room_iter().collect(); + let mut room_map = self.rooms.write().unwrap(); + let mut socket_map = self.sockets.write().unwrap(); + // Here we have to collect sids, because we are going to modify the rooms map. + let sids = self.apply_opts(&opts, &room_map).collect::>(); + for sid in &sids { + let entry = socket_map.entry(*sid).or_default(); + for room in &rooms { + entry.insert(room.clone()); + } + } + for room in rooms { + let entry = room_map.entry(room).or_default(); + for sid in &sids { + entry.insert(*sid); + } + } + } + + /// Removes the sockets that match the [`BroadcastOptions`] from the rooms. + pub fn del_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) { + let rooms: Vec = rooms.into_room_iter().collect(); + let mut rooms_map = self.rooms.write().unwrap(); + let mut socket_map = self.sockets.write().unwrap(); + let sids = self.apply_opts(&opts, &rooms_map).collect::>(); + for room in rooms { + for sid in &sids { + remove_and_clean_entry(socket_map.entry(*sid), &room, || ()); + remove_and_clean_entry(rooms_map.entry(room.clone()), sid, || ()); + } + } + } + + /// Disconnects the sockets that match the [`BroadcastOptions`]. + pub fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), Vec> { + let sids = self + .apply_opts(&opts, &self.rooms.read().unwrap()) + .collect(); + self.emitter.disconnect_many(sids) + } + + /// Returns all the matching rooms + pub fn rooms(&self, opts: BroadcastOptions) -> HashSet { + let rooms = self.rooms.read().unwrap(); + let sockets = self.sockets.read().unwrap(); + let sids = self.apply_opts(&opts, &rooms); + sids.filter_map(|id| sockets.get(&id)) + .flatten() + .cloned() + .collect() + } + + /// Get the namespace path. + pub fn path(&self) -> &Str { + self.emitter.path() + } + + /// Get the parser of the namespace. + pub fn parser(&self) -> impl Parse + '_ { + self.emitter.parser() + } + /// Get the unique server identifier + pub fn server_id(&self) -> Uid { + self.emitter.server_id() + } +} + +/// The default broadcast iterator. +/// Extract, flatten and filter a list of sid from a room list +struct BroadcastRooms<'a> { + rooms: slice::Iter<'a, Room>, + rooms_map: &'a HashMap>, + except: HashSet, + flatten_iter: Option>, +} +impl<'a> BroadcastRooms<'a> { + fn new( + rooms: &'a [Room], + rooms_map: &'a HashMap>, + except: HashSet, + ) -> Self { + BroadcastRooms { + rooms: rooms.iter(), + rooms_map, + except, + flatten_iter: None, + } + } +} +impl Iterator for BroadcastRooms<'_> { + type Item = Sid; + fn next(&mut self) -> Option { + loop { + match self.flatten_iter.as_mut().and_then(Iterator::next) { + Some(sid) if !self.except.contains(sid) => return Some(*sid), + Some(_) => continue, + None => self.flatten_iter = None, + } + + let room = self.rooms.next()?; + self.flatten_iter = self.rooms_map.get(room).map(HashSet::iter); + } + } +} + +impl CoreLocalAdapter { + /// Applies the given `opts` and return the sockets that match. + fn apply_opts<'a>( + &self, + opts: &'a BroadcastOptions, + rooms: &'a HashMap>, + ) -> BroadcastIter<'a> { + let is_broadcast = opts.has_flag(BroadcastFlags::Broadcast); + + let mut except = get_except_sids(&opts.except, rooms); + // In case of broadcast flag + if the sender is set, + // we should not broadcast to it. + if is_broadcast && opts.sid.is_some() { + except.insert(opts.sid.unwrap()); + } + + if !opts.rooms.is_empty() { + let iter = BroadcastRooms::new(&opts.rooms, rooms, except); + InnerBroadcastIter::BroadcastRooms(iter).into() + } else if is_broadcast { + let sids = self.emitter.get_all_sids(|id| !except.contains(id)); + InnerBroadcastIter::GlobalBroadcast(sids.into_iter()).into() + } else if let Some(id) = opts.sid { + InnerBroadcastIter::Single(id).into() + } else { + InnerBroadcastIter::None.into() + } + } +} + +#[inline] +fn get_except_sids(except: &[Room], rooms: &HashMap>) -> HashSet { + let mut except_sids = HashSet::new(); + for room in except { + if let Some(sockets) = rooms.get(room) { + except_sids.extend(sockets); + } + } + except_sids +} + +/// Remove a field from a HashSet value and remove it if empty. +/// Call `cleanup` fn if the entry exists +#[inline] +fn remove_and_clean_entry( + entry: hash_map::Entry<'_, K, HashSet>, + el: &T, + cleanup: impl FnOnce(), +) { + //TODO: use hashmap raw entry when stabilized to avoid entry clone. + // https://github.com/rust-lang/rust/issues/56167 + match entry { + hash_map::Entry::Occupied(mut entry) => { + entry.get_mut().remove(el); + if entry.get().is_empty() { + entry.remove_entry(); + } + cleanup(); + } + hash_map::Entry::Vacant(_) => (), + } +} + +/// An iterator that yields the socket ids that match the broadcast options. +/// Used with the [`SocketEmitter`] interface. +pub struct BroadcastIter<'a> { + inner: InnerBroadcastIter<'a>, +} +enum InnerBroadcastIter<'a> { + BroadcastRooms(BroadcastRooms<'a>), + GlobalBroadcast( as IntoIterator>::IntoIter), + Single(Sid), + None, +} +impl BroadcastIter<'_> { + fn is_empty(&self) -> bool { + matches!(self.inner, InnerBroadcastIter::None) + } +} +impl<'a> From> for BroadcastIter<'a> { + fn from(inner: InnerBroadcastIter<'a>) -> Self { + BroadcastIter { inner } + } +} + +impl Iterator for BroadcastIter<'_> { + type Item = Sid; + + #[inline(always)] + fn next(&mut self) -> Option { + self.inner.next() + } +} +impl Iterator for InnerBroadcastIter<'_> { + type Item = Sid; + + fn next(&mut self) -> Option { + match self { + InnerBroadcastIter::BroadcastRooms(inner) => inner.next(), + InnerBroadcastIter::GlobalBroadcast(inner) => inner.next(), + InnerBroadcastIter::Single(sid) => { + let sid = *sid; + *self = InnerBroadcastIter::None; + Some(sid) + } + InnerBroadcastIter::None => None, + } + } +} + +/// Represent the data of a remote socket. +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Default, Clone)] +pub struct RemoteSocketData { + /// The id of the remote socket. + pub id: Sid, + /// The server id this socket is connected to. + pub server_id: Uid, + /// The namespace this socket is connected to. + pub ns: Str, +} + +#[cfg(test)] +mod test { + + use smallvec::smallvec; + use std::{ + array, + pin::Pin, + task::{Context, Poll}, + }; + + use super::*; + + struct StubSockets { + sockets: HashSet, + path: Str, + } + impl StubSockets { + fn new(sockets: &[Sid]) -> Self { + let sockets = HashSet::from_iter(sockets.iter().copied()); + Self { + sockets, + path: Str::from("/"), + } + } + } + + struct StubAckStream; + impl Stream for StubAckStream { + type Item = (Sid, Result); + fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(None) + } + } + impl FusedStream for StubAckStream { + fn is_terminated(&self) -> bool { + true + } + } + #[derive(Debug, Serialize, Deserialize)] + struct StubError; + impl std::fmt::Display for StubError { + fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Ok(()) + } + } + impl std::error::Error for StubError {} + + impl SocketEmitter for StubSockets { + type AckError = StubError; + type AckStream = StubAckStream; + fn get_all_sids(&self, filter: impl Fn(&Sid) -> bool) -> Vec { + self.sockets + .iter() + .copied() + .filter(|id| filter(id)) + .collect() + } + + fn get_remote_sockets(&self, sids: BroadcastIter<'_>) -> Vec { + sids.map(|id| RemoteSocketData { + id, + server_id: Uid::ZERO, + ns: self.path.clone(), + }) + .collect() + } + + fn send_many(&self, _: BroadcastIter<'_>, _: Value) -> Result<(), Vec> { + Ok(()) + } + + fn send_many_with_ack( + &self, + _: BroadcastIter<'_>, + _: Packet, + _: Option, + ) -> (Self::AckStream, u32) { + (StubAckStream, 0) + } + + fn disconnect_many(&self, _: Vec) -> Result<(), Vec> { + Ok(()) + } + + fn path(&self) -> &Str { + &self.path + } + fn parser(&self) -> impl Parse { + crate::parser::test::StubParser + } + fn server_id(&self) -> Uid { + Uid::ZERO + } + } + + fn create_adapter(sockets: [Sid; S]) -> CoreLocalAdapter { + CoreLocalAdapter::new(StubSockets::new(&sockets)) + } + + #[test] + fn add_all() { + let socket = Sid::new(); + let adapter = create_adapter([socket]); + adapter.add_all(socket, ["room1", "room2"]); + let rooms_map = adapter.rooms.read().unwrap(); + let socket_map = adapter.sockets.read().unwrap(); + assert_eq!(rooms_map.len(), 2); + assert_eq!(socket_map.len(), 1); + assert_eq!(rooms_map.get("room1").unwrap().len(), 1); + assert_eq!(rooms_map.get("room2").unwrap().len(), 1); + + let rooms = socket_map.get(&socket).unwrap(); + assert!(rooms.contains("room1")); + assert!(rooms.contains("room2")); + } + + #[test] + fn del() { + let socket = Sid::new(); + let adapter = create_adapter([socket]); + adapter.add_all(socket, ["room1", "room2"]); + { + let rooms_map = adapter.rooms.read().unwrap(); + assert_eq!(rooms_map.len(), 2); + assert_eq!(rooms_map.get("room1").unwrap().len(), 1); + assert_eq!(rooms_map.get("room2").unwrap().len(), 1); + let socket_map = adapter.sockets.read().unwrap(); + let rooms = socket_map.get(&socket).unwrap(); + assert!(rooms.contains("room1")); + assert!(rooms.contains("room2")); + } + adapter.del(socket, "room1"); + let rooms_map = adapter.rooms.read().unwrap(); + let socket_map = adapter.sockets.read().unwrap(); + assert_eq!(rooms_map.len(), 1); + assert!(rooms_map.get("room1").is_none()); + assert_eq!(rooms_map.get("room2").unwrap().len(), 1); + assert_eq!(socket_map.get(&socket).unwrap().len(), 1); + } + #[test] + fn del_all() { + let socket = Sid::new(); + let adapter = create_adapter([socket]); + adapter.add_all(socket, ["room1", "room2"]); + + { + let rooms_map = adapter.rooms.read().unwrap(); + assert_eq!(rooms_map.len(), 2); + assert_eq!(rooms_map.get("room1").unwrap().len(), 1); + assert_eq!(rooms_map.get("room2").unwrap().len(), 1); + + let socket_map = adapter.sockets.read().unwrap(); + let rooms = socket_map.get(&socket).unwrap(); + assert!(rooms.contains("room1")); + assert!(rooms.contains("room2")); + } + + adapter.del_all(socket); + + { + let rooms_map = adapter.rooms.read().unwrap(); + assert_eq!(rooms_map.len(), 0); + + let socket_map = adapter.sockets.read().unwrap(); + assert!(socket_map.get(&socket).is_none()); + } + } + + #[test] + fn socket_room() { + let sid1 = Sid::new(); + let sid2 = Sid::new(); + let sid3 = Sid::new(); + let adapter = create_adapter([sid1, sid2, sid3]); + adapter.add_all(sid1, ["room1", "room2"]); + adapter.add_all(sid2, ["room1"]); + adapter.add_all(sid3, ["room2"]); + assert!(adapter.socket_rooms(sid1).contains(&Cow::Borrowed("room1"))); + assert!(adapter.socket_rooms(sid1).contains(&Cow::Borrowed("room2"))); + assert_eq!( + adapter.socket_rooms(sid2).into_iter().collect::>(), + ["room1"] + ); + assert_eq!( + adapter.socket_rooms(sid3).into_iter().collect::>(), + ["room2"] + ); + } + + #[test] + fn add_socket() { + let socket = Sid::new(); + let adapter = create_adapter([socket]); + adapter.add_all(socket, ["room1"]); + + let mut opts = BroadcastOptions::new(socket); + opts.rooms = smallvec!["room1".into()]; + adapter.add_sockets(opts, "room2"); + let rooms_map = adapter.rooms.read().unwrap(); + + assert_eq!(rooms_map.len(), 2); + assert!(rooms_map.get("room1").unwrap().contains(&socket)); + assert!(rooms_map.get("room2").unwrap().contains(&socket)); + } + + #[test] + fn del_socket() { + let socket = Sid::new(); + let adapter = create_adapter([socket]); + adapter.add_all(socket, ["room1"]); + + let mut opts = BroadcastOptions::new(socket); + opts.rooms = smallvec!["room1".into()]; + adapter.add_sockets(opts, "room2"); + + { + let rooms_map = adapter.rooms.read().unwrap(); + + assert_eq!(rooms_map.len(), 2); + assert!(rooms_map.get("room1").unwrap().contains(&socket)); + assert!(rooms_map.get("room2").unwrap().contains(&socket)); + } + + let mut opts = BroadcastOptions::new(socket); + opts.rooms = smallvec!["room1".into()]; + adapter.del_sockets(opts, "room2"); + + { + let rooms_map = adapter.rooms.read().unwrap(); + + assert_eq!(rooms_map.len(), 1); + assert!(rooms_map.get("room1").unwrap().contains(&socket)); + assert!(rooms_map.get("room2").is_none()); + } + } + + #[test] + fn sockets() { + let socket0 = Sid::new(); + let socket1 = Sid::new(); + let socket2 = Sid::new(); + let adapter = create_adapter([socket0, socket1, socket2]); + adapter.add_all(socket0, ["room1", "room2"]); + adapter.add_all(socket1, ["room1", "room3"]); + adapter.add_all(socket2, ["room2", "room3"]); + + let mut opts = BroadcastOptions { + rooms: smallvec!["room1".into()], + ..Default::default() + }; + let sockets = adapter.sockets(opts.clone()); + assert_eq!(sockets.len(), 2); + assert!(sockets.contains(&socket0)); + assert!(sockets.contains(&socket1)); + + opts.rooms = smallvec!["room2".into()]; + let sockets = adapter.sockets(opts.clone()); + assert_eq!(sockets.len(), 2); + assert!(sockets.contains(&socket0)); + assert!(sockets.contains(&socket2)); + + opts.rooms = smallvec!["room3".into()]; + let sockets = adapter.sockets(opts.clone()); + assert_eq!(sockets.len(), 2); + assert!(sockets.contains(&socket1)); + assert!(sockets.contains(&socket2)); + } + + #[test] + fn disconnect_socket() { + let socket0 = Sid::new(); + let socket1 = Sid::new(); + let socket2 = Sid::new(); + let adapter = create_adapter([socket0, socket1, socket2]); + adapter.add_all(socket0, ["room1", "room2", "room4"]); + adapter.add_all(socket1, ["room1", "room3", "room5"]); + adapter.add_all(socket2, ["room2", "room3", "room6"]); + + let mut opts = BroadcastOptions::new(socket0); + opts.rooms = smallvec!["room5".into()]; + adapter.disconnect_socket(opts).unwrap(); + + let mut opts = BroadcastOptions::default(); + opts.rooms.push("room2".into()); + let sockets = adapter.sockets(opts.clone()); + assert_eq!(sockets.len(), 2); + assert!(sockets.contains(&socket2)); + assert!(sockets.contains(&socket0)); + } + #[test] + fn disconnect_empty_opts() { + let adapter = create_adapter([]); + let opts = BroadcastOptions::default(); + adapter.disconnect_socket(opts).unwrap(); + } + #[test] + fn rooms() { + let socket0 = Sid::new(); + let socket1 = Sid::new(); + let socket2 = Sid::new(); + let adapter = create_adapter([socket0, socket1, socket2]); + adapter.add_all(socket0, ["room1", "room2", "room4"]); + adapter.add_all(socket1, ["room1", "room3", "room5"]); + adapter.add_all(socket2, ["room2", "room3", "room6"]); + + let mut opts = BroadcastOptions::new(socket0); + opts.rooms = smallvec!["room5".into()]; + opts.add_flag(BroadcastFlags::Broadcast); + let rooms = adapter.rooms(opts); + assert_eq!(rooms.len(), 3); + assert!(rooms.contains(&Cow::Borrowed("room1"))); + assert!(rooms.contains(&Cow::Borrowed("room3"))); + assert!(rooms.contains(&Cow::Borrowed("room5"))); + + let mut opts = BroadcastOptions::default(); + opts.rooms.push("room2".into()); + let rooms = adapter.rooms(opts.clone()); + assert_eq!(rooms.len(), 5); + assert!(rooms.contains(&Cow::Borrowed("room1"))); + assert!(rooms.contains(&Cow::Borrowed("room2"))); + assert!(rooms.contains(&Cow::Borrowed("room3"))); + assert!(rooms.contains(&Cow::Borrowed("room4"))); + assert!(rooms.contains(&Cow::Borrowed("room6"))); + } + + #[test] + fn apply_opts() { + let mut sockets: [Sid; 3] = array::from_fn(|_| Sid::new()); + sockets.sort(); + let adapter = create_adapter(sockets); + + adapter.add_all(sockets[0], ["room1", "room2"]); + adapter.add_all(sockets[1], ["room1", "room3"]); + adapter.add_all(sockets[2], ["room1", "room2", "room3"]); + + // socket 2 is the sender + let mut opts = BroadcastOptions::new(sockets[2]); + opts.rooms = smallvec!["room1".into()]; + opts.except = smallvec!["room2".into()]; + let sids = adapter + .apply_opts(&opts, &adapter.rooms.read().unwrap()) + .collect::>(); + assert_eq!(sids, [sockets[1]]); + + let mut opts = BroadcastOptions::new(sockets[2]); + opts.add_flag(BroadcastFlags::Broadcast); + let mut sids = adapter + .apply_opts(&opts, &adapter.rooms.read().unwrap()) + .collect::>(); + sids.sort(); + assert_eq!(sids, [sockets[0], sockets[1]]); + + let mut opts = BroadcastOptions::new(sockets[2]); + opts.add_flag(BroadcastFlags::Broadcast); + opts.except = smallvec!["room2".into()]; + let sids = adapter + .apply_opts(&opts, &adapter.rooms.read().unwrap()) + .collect::>(); + assert_eq!(sids.len(), 1); + + let opts = BroadcastOptions::new(sockets[2]); + let sids = adapter + .apply_opts(&opts, &adapter.rooms.read().unwrap()) + .collect::>(); + assert_eq!(sids.len(), 1); + assert_eq!(sids[0], sockets[2]); + + let opts = BroadcastOptions::new(Sid::new()); + let sids = adapter + .apply_opts(&opts, &adapter.rooms.read().unwrap()) + .collect::>(); + assert_eq!(sids.len(), 1); + } +} diff --git a/crates/socketioxide-core/src/errors.rs b/crates/socketioxide-core/src/errors.rs new file mode 100644 index 00000000..5fe4d521 --- /dev/null +++ b/crates/socketioxide-core/src/errors.rs @@ -0,0 +1,62 @@ +//! All the errors that can be returned by the crate. Mostly when using the [adapter](crate::adapter) module. +use std::{convert::Infallible, fmt}; + +use serde::{Deserialize, Serialize}; + +use crate::parser::ParserError; + +/// Error type when using the underlying engine.io socket +#[derive(Debug, thiserror::Error, Serialize, Deserialize, Clone)] +pub enum SocketError { + /// The socket channel is full. + /// You might need to increase the channel size with the [`SocketIoBuilder::max_buffer_size`] method. + /// + /// [`SocketIoBuilder::max_buffer_size`]: https://docs.rs/socketioxide/latest/socketioxide/struct.SocketIoBuilder.html#method.max_buffer_size + #[error("internal channel full error")] + InternalChannelFull, + + /// The socket is already closed + #[error("socket closed")] + Closed, +} + +/// Error type for the [`CoreAdapter`](crate::adapter::CoreAdapter) trait. +#[derive(Debug, thiserror::Error)] +pub struct AdapterError(#[from] pub Box); +impl fmt::Display for AdapterError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.0, f) + } +} +impl From for AdapterError { + fn from(_: Infallible) -> Self { + panic!("Infallible should never be constructed, this is a bug") + } +} + +/// Error type for broadcast operations. +#[derive(thiserror::Error, Debug)] +pub enum BroadcastError { + // This type should never constructed with an empty vector! + /// An error occurred while sending packets. + #[error("Error sending data through the engine.io socket: {0:?}")] + Socket(Vec), + + /// An error occurred while serializing the packet. + #[error("Error serializing packet: {0:?}")] + Serialize(#[from] ParserError), + + /// An error occured while broadcasting to other nodes. + #[error("Adapter error: {0}")] + Adapter(#[from] AdapterError), +} + +impl From> for BroadcastError { + fn from(value: Vec) -> Self { + assert!( + !value.is_empty(), + "Cannot construct a BroadcastError from an empty vec of SocketError" + ); + Self::Socket(value) + } +} diff --git a/crates/socketioxide-core/src/lib.rs b/crates/socketioxide-core/src/lib.rs index 246d3e2f..4a96e822 100644 --- a/crates/socketioxide-core/src/lib.rs +++ b/crates/socketioxide-core/src/lib.rs @@ -32,12 +32,39 @@ //! This crate is the core of the socketioxide crate. //! It contains basic types and interfaces for the socketioxide crate and the parser sub-crates. +pub mod adapter; +pub mod errors; pub mod packet; pub mod parser; -use std::collections::VecDeque; +use std::{collections::VecDeque, ops::Deref}; +use bytes::Bytes; pub use engineioxide::{sid::Sid, Str}; +use serde::{Deserialize, Serialize}; + +/// Represents a unique identifier for a server. +#[derive(Clone, Serialize, Deserialize, Debug, Copy, PartialEq, Eq, Default)] +pub struct Uid(Sid); +impl Deref for Uid { + type Target = Sid; + fn deref(&self) -> &Self::Target { + &self.0 + } +} +impl std::fmt::Display for Uid { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} +impl Uid { + /// A zeroed server id. + pub const ZERO: Self = Self(Sid::ZERO); + /// Create a new unique identifier. + pub fn new() -> Self { + Self(Sid::new()) + } +} /// Represents a value that can be sent over the engine.io wire as an engine.io packet /// or the data that can be outputed by a binary parser (e.g. [`MsgPackParser`](../socketioxide_parser_msgpack/index.html)) @@ -53,7 +80,37 @@ pub enum Value { Bytes(bytes::Bytes), } -#[cfg(feature = "fuzzing")] +/// Custom implementation to serialize enum variant as u8. +impl Serialize for Value { + fn serialize(&self, serializer: S) -> Result { + let raw = match self { + Value::Str(data, bins) => (0u8, data.as_bytes(), bins), + Value::Bytes(data) => (1u8, data.as_ref(), &None), + }; + raw.serialize(serializer) + } +} +impl<'de> Deserialize<'de> for Value { + fn deserialize>(deserializer: D) -> Result { + let (idx, data, bins): (u8, Vec, Option>) = + Deserialize::deserialize(deserializer)?; + let res = match idx { + 0 => Value::Str( + Str::from(String::from_utf8(data).map_err(serde::de::Error::custom)?), + bins, + ), + 1 => Value::Bytes(Bytes::from(data)), + i => Err(serde::de::Error::custom(format!( + "invalid value type: {}", + i + )))?, + }; + Ok(res) + } +} + +#[cfg(fuzzing)] +#[doc(hidden)] impl arbitrary::Arbitrary<'_> for Value { fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result { let res = match u.arbitrary::()? { @@ -98,3 +155,46 @@ impl Value { self.len() == 0 } } + +#[cfg(test)] +mod tests { + use super::{Str, Value}; + use bytes::Bytes; + use std::collections::VecDeque; + + fn assert_serde_value(value: Value) { + let serialized = serde_json::to_string(&value).unwrap(); + let deserialized: Value = serde_json::from_str(&serialized).unwrap(); + + assert_eq!(value, deserialized); + } + + #[test] + fn value_serde_str_with_bins() { + let mut bins = VecDeque::new(); + bins.push_back(Bytes::from_static(&[1, 2, 3, 4])); + bins.push_back(Bytes::from_static(&[5, 6, 7, 8])); + + let value = Value::Str(Str::from("value".to_string()), Some(bins)); + assert_serde_value(value); + } + + #[test] + fn value_serde_bytes() { + let value = Value::Bytes(Bytes::from_static(&[1, 2, 3, 4])); + assert_serde_value(value); + } + + #[test] + fn value_serde_str_without_bins() { + let value = Value::Str(Str::from("value_no_bins".to_string()), None); + assert_serde_value(value); + } + + #[test] + fn value_serde_invalid_type() { + let invalid_data = "[2, [1,2,3,4], null]"; + let result: Result = serde_json::from_str(invalid_data); + assert!(result.is_err()); + } +} diff --git a/crates/socketioxide-core/src/packet.rs b/crates/socketioxide-core/src/packet.rs index bfb9159f..ba02eb14 100644 --- a/crates/socketioxide-core/src/packet.rs +++ b/crates/socketioxide-core/src/packet.rs @@ -139,9 +139,66 @@ pub struct ConnectPacket { pub sid: Sid, } +impl Serialize for Packet { + fn serialize(&self, serializer: S) -> Result { + #[derive(Serialize)] + struct RawPacket<'a> { + ns: &'a Str, + r#type: u8, + data: Option<&'a Value>, + ack: Option, + error: Option<&'a String>, + } + let (r#type, data, ack, error) = match &self.inner { + PacketData::Connect(v) => (0, v.as_ref(), None, None), + PacketData::Disconnect => (1, None, None, None), + PacketData::Event(v, ack) => (2, Some(v), *ack, None), + PacketData::EventAck(v, ack) => (3, Some(v), Some(*ack), None), + PacketData::ConnectError(e) => (4, None, None, Some(e)), + PacketData::BinaryEvent(v, ack) => (5, Some(v), *ack, None), + PacketData::BinaryAck(v, ack) => (6, Some(v), Some(*ack), None), + }; + let raw = RawPacket { + ns: &self.ns, + data, + ack, + error, + r#type, + }; + raw.serialize(serializer) + } +} +impl<'de> Deserialize<'de> for Packet { + fn deserialize>(deserializer: D) -> Result { + #[derive(Deserialize)] + struct RawPacket { + ns: Str, + r#type: u8, + data: Option, + ack: Option, + error: Option, + } + let raw = RawPacket::deserialize(deserializer)?; + let err = |field| serde::de::Error::custom(format!("missing field: {}", field)); + let inner = match raw.r#type { + 0 => PacketData::Connect(raw.data), + 1 => PacketData::Disconnect, + 2 => PacketData::Event(raw.data.ok_or(err("data"))?, raw.ack), + 3 => PacketData::EventAck(raw.data.ok_or(err("data"))?, raw.ack.ok_or(err("ack"))?), + 4 => PacketData::ConnectError(raw.error.ok_or(err("error"))?), + 5 => PacketData::BinaryEvent(raw.data.ok_or(err("data"))?, raw.ack), + 6 => PacketData::BinaryAck(raw.data.ok_or(err("data"))?, raw.ack.ok_or(err("ack"))?), + i => return Err(serde::de::Error::custom(format!("invalid packet type {i}"))), + }; + Ok(Self { inner, ns: raw.ns }) + } +} + #[cfg(test)] mod tests { + use std::collections::VecDeque; + use super::{Packet, PacketData, Value}; use bytes::Bytes; @@ -181,4 +238,99 @@ mod tests { Packet::ack("/", val1.clone(), 120).inner, PacketData::EventAck(v, 120) if v == val1)); } + + fn assert_serde_packet(packet: Packet) { + let serialized = serde_json::to_string(&packet).unwrap(); + let deserialized: Packet = serde_json::from_str(&serialized).unwrap(); + assert_eq!(packet, deserialized); + } + #[test] + fn packet_serde_connect() { + let packet = Packet { + ns: "/".into(), + inner: PacketData::Connect(Some(Value::Str("test_data".into(), None))), + }; + assert_serde_packet(packet); + } + + #[test] + fn packet_serde_disconnect() { + let packet = Packet { + ns: "/".into(), + inner: PacketData::Disconnect, + }; + assert_serde_packet(packet); + } + + #[test] + fn packet_serde_event() { + let packet = Packet { + ns: "/".into(), + inner: PacketData::Event(Value::Str("event_data".into(), None), None), + }; + assert_serde_packet(packet); + + let mut bins = VecDeque::new(); + bins.push_back(Bytes::from_static(&[1, 2, 3, 4])); + bins.push_back(Bytes::from_static(&[1, 2, 3, 4])); + let packet = Packet { + ns: "/".into(), + inner: PacketData::Event(Value::Str("event_data".into(), Some(bins)), Some(12)), + }; + assert_serde_packet(packet); + } + + #[test] + fn packet_serde_event_ack() { + let packet = Packet { + ns: "/".into(), + inner: PacketData::EventAck(Value::Str("event_ack_data".into(), None), 42), + }; + assert_serde_packet(packet); + } + + #[test] + fn packet_serde_connect_error() { + let packet = Packet { + ns: "/".into(), + inner: PacketData::ConnectError("connection_error".into()), + }; + assert_serde_packet(packet); + } + + #[test] + fn packet_serde_binary_event() { + let packet = Packet { + ns: "/".into(), + inner: PacketData::BinaryEvent(Value::Str("binary_event_data".into(), None), None), + }; + assert_serde_packet(packet); + + let mut bins = VecDeque::new(); + bins.push_back(Bytes::from_static(&[1, 2, 3, 4])); + bins.push_back(Bytes::from_static(&[1, 2, 3, 4])); + let packet = Packet { + ns: "/".into(), + inner: PacketData::BinaryEvent(Value::Str("event_data".into(), Some(bins)), Some(12)), + }; + assert_serde_packet(packet); + } + + #[test] + fn packet_serde_binary_ack() { + let packet = Packet { + ns: "/".into(), + inner: PacketData::BinaryAck(Value::Str("binary_ack_data".into(), None), 99), + }; + assert_serde_packet(packet); + + let mut bins = VecDeque::new(); + bins.push_back(Bytes::from_static(&[1, 2, 3, 4])); + bins.push_back(Bytes::from_static(&[1, 2, 3, 4])); + let packet = Packet { + ns: "/".into(), + inner: PacketData::BinaryAck(Value::Str("binary_ack_data".into(), Some(bins)), 99), + }; + assert_serde_packet(packet); + } } diff --git a/crates/socketioxide-core/src/parser.rs b/crates/socketioxide-core/src/parser.rs index a11e9842..40666c63 100644 --- a/crates/socketioxide-core/src/parser.rs +++ b/crates/socketioxide-core/src/parser.rs @@ -20,6 +20,7 @@ //! along with other metadata (namespace, ack ID, etc.), into a fully serialized packet ready to be sent. use std::{ + error::Error as StdError, fmt, sync::{atomic::AtomicUsize, Mutex}, }; @@ -46,29 +47,16 @@ pub struct ParserState { /// All socket.io parser should implement this trait. /// Parsers should be stateless. pub trait Parse: Default + Copy { - /// The error produced when encoding a packet - type EncodeError: std::error::Error; - /// The error produced when decoding a packet - type DecodeError: std::error::Error; - /// Convert a packet into multiple payloads to be sent. fn encode(self, packet: Packet) -> Value; /// Parse a given input string. If the payload needs more adjacent binary packet, /// the partial packet will be kept and a [`ParseError::NeedsMoreBinaryData`] will be returned. - fn decode_str( - self, - state: &ParserState, - data: Str, - ) -> Result>; + fn decode_str(self, state: &ParserState, data: Str) -> Result; /// Parse a given input binary. If the payload needs more adjacent binary packet, /// the partial packet is still kept and a [`ParseError::NeedsMoreBinaryData`] will be returned. - fn decode_bin( - self, - state: &ParserState, - bin: Bytes, - ) -> Result>; + fn decode_bin(self, state: &ParserState, bin: Bytes) -> Result; /// Convert any serializable data to a generic [`Value`] to be later included as a payload in the packet. /// @@ -80,7 +68,7 @@ pub trait Parse: Default + Copy { self, data: &T, event: Option<&str>, - ) -> Result; + ) -> Result; /// Convert any generic [`Value`] to a deserializable type. /// It should always be an array (according to the serde model). @@ -94,27 +82,63 @@ pub trait Parse: Default + Copy { self, value: &'de mut Value, with_event: bool, - ) -> Result; + ) -> Result; /// Convert any generic [`Value`] to a type with the default serde impl without binary + event tricks. /// This is mainly used for connect payloads. fn decode_default<'de, T: Deserialize<'de>>( self, value: Option<&'de Value>, - ) -> Result; + ) -> Result; /// Convert any type to a generic [`Value`] with the default serde impl without binary + event tricks. /// This is mainly used for connect payloads. - fn encode_default(self, data: &T) -> Result; + fn encode_default(self, data: &T) -> Result; /// Try to read the event name from the given payload data. /// The event name should be the first element of the provided array according to the serde model. - fn read_event(self, value: &Value) -> Result<&str, Self::DecodeError>; + fn read_event(self, value: &Value) -> Result<&str, ParserError>; +} + +/// A parser error that wraps any error that can occur during parsing. +/// +/// E.g. `serde_json::Error`, `rmp_serde::Error`... +#[derive(Debug)] +pub struct ParserError { + inner: Box, } +impl fmt::Display for ParserError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.inner.fmt(f) + } +} +impl std::error::Error for ParserError {} +impl Serialize for ParserError { + fn serialize(&self, serializer: S) -> Result { + self.inner.to_string().serialize(serializer) + } +} +impl<'de> Deserialize<'de> for ParserError { + fn deserialize>(deserializer: D) -> Result { + let s = String::deserialize(deserializer)?; + #[derive(Debug, thiserror::Error)] + #[error("remote err: {0:?}")] + struct RemoteErr(String); + Ok(Self::new(RemoteErr(s))) + } +} +impl ParserError { + /// Create a new parser error from any error that implements [`std::error::Error`] + pub fn new(inner: E) -> Self { + Self { + inner: Box::new(inner), + } + } +} /// Errors when parsing/serializing socket.io packets #[derive(thiserror::Error, Debug)] -pub enum ParseError { +pub enum ParseError { /// Invalid packet type #[error("invalid packet type")] InvalidPacketType, @@ -161,24 +185,7 @@ pub enum ParseError { /// The inner parser error #[error("parser error: {0:?}")] - ParserError(#[from] E), -} -impl ParseError { - /// Wrap the [`ParseError::ParserError`] variant with a new error type - pub fn wrap_err(self, f: impl FnOnce(E) -> E1) -> ParseError { - match self { - Self::ParserError(e) => ParseError::ParserError(f(e)), - ParseError::InvalidPacketType => ParseError::InvalidPacketType, - ParseError::InvalidAckId => ParseError::InvalidAckId, - ParseError::InvalidEventName => ParseError::InvalidEventName, - ParseError::InvalidData => ParseError::InvalidData, - ParseError::InvalidNamespace => ParseError::InvalidNamespace, - ParseError::InvalidAttachments => ParseError::InvalidAttachments, - ParseError::UnexpectedBinaryPacket => ParseError::UnexpectedBinaryPacket, - ParseError::UnexpectedStringPacket => ParseError::UnexpectedStringPacket, - ParseError::NeedsMoreBinaryData => ParseError::NeedsMoreBinaryData, - } - } + ParserError(#[from] ParserError), } /// A seed that can be used to deserialize only the 1st element of a sequence @@ -467,8 +474,10 @@ pub fn is_de_tuple<'de, T: serde::Deserialize<'de>>() -> bool { Err(IsTupleSerdeError(v)) => v, } } + +#[doc(hidden)] #[cfg(test)] -mod test { +pub mod test { use super::*; use serde::{Deserialize, Serialize}; @@ -493,4 +502,75 @@ mod test { assert!(!is_ser_tuple(&UnitStruct)); assert!(!is_de_tuple::()); } + + /// A stub parser that always returns an error. Only used for testing. + #[derive(Debug, Default, Clone, Copy)] + pub struct StubParser; + + /// A stub error that is used for testing. + #[derive(Serialize, Deserialize)] + pub struct StubError; + + impl std::fmt::Debug for StubError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("StubError") + } + } + impl std::fmt::Display for StubError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("StubError") + } + } + impl std::error::Error for StubError {} + + fn stub_err() -> ParserError { + ParserError { + inner: Box::new(StubError), + } + } + /// === impl StubParser === + impl Parse for StubParser { + fn encode(self, _: Packet) -> Value { + Value::Bytes(Bytes::new()) + } + + fn decode_str(self, _: &ParserState, _: Str) -> Result { + Err(ParseError::ParserError(stub_err())) + } + + fn decode_bin(self, _: &ParserState, _: bytes::Bytes) -> Result { + Err(ParseError::ParserError(stub_err())) + } + + fn encode_value( + self, + _: &T, + _: Option<&str>, + ) -> Result { + Err(stub_err()) + } + + fn decode_value<'de, T: serde::Deserialize<'de>>( + self, + _: &'de mut Value, + _: bool, + ) -> Result { + Err(stub_err()) + } + + fn decode_default<'de, T: serde::Deserialize<'de>>( + self, + _: Option<&'de Value>, + ) -> Result { + Err(stub_err()) + } + + fn encode_default(self, _: &T) -> Result { + Err(stub_err()) + } + + fn read_event(self, _: &Value) -> Result<&str, ParserError> { + Ok("") + } + } } diff --git a/crates/socketioxide-redis/Cargo.toml b/crates/socketioxide-redis/Cargo.toml new file mode 100644 index 00000000..c7fdd92a --- /dev/null +++ b/crates/socketioxide-redis/Cargo.toml @@ -0,0 +1,56 @@ +[package] +name = "socketioxide-redis" +description = "Redis adapter for the socket.io protocol" +version = "0.1.0" +edition.workspace = true +rust-version.workspace = true +authors.workspace = true +repository.workspace = true +homepage.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true +readme.workspace = true + +[features] +redis = ["dep:redis"] +redis-cluster = ["redis", "redis/cluster-async"] +fred = ["dep:fred"] +default = ["redis"] + +[dependencies] +socketioxide-core = { version = "0.15", path = "../socketioxide-core" } +futures-core.workspace = true +futures-util.workspace = true +pin-project-lite.workspace = true +serde.workspace = true +smallvec = { workspace = true, features = ["serde"] } +tokio = { workspace = true, features = ["macros", "time", "rt", "sync"] } +rmp-serde.workspace = true +rmp.workspace = true +bytes.workspace = true +tracing.workspace = true +thiserror.workspace = true + +# Redis implementation +fred = { version = "10", features = [ + "subscriber-client", + "i-pubsub", +], default-features = false, optional = true } +redis = { version = "0.28", features = [ + "aio", + "tokio-comp", + "streams", +], default-features = false, optional = true } + +[dev-dependencies] +tokio = { workspace = true, features = [ + "macros", + "parking_lot", + "rt-multi-thread", +] } +socketioxide = { path = "../socketioxide", features = [ + "tracing", + "__test_harness", +] } +tracing-subscriber.workspace = true diff --git a/crates/socketioxide-redis/src/drivers/fred.rs b/crates/socketioxide-redis/src/drivers/fred.rs new file mode 100644 index 00000000..1206d9ad --- /dev/null +++ b/crates/socketioxide-redis/src/drivers/fred.rs @@ -0,0 +1,170 @@ +use std::{ + collections::HashMap, + fmt, + sync::{Arc, RwLock}, +}; + +use tokio::sync::{broadcast, mpsc}; + +use super::{ChanItem, Driver, MessageStream}; + +use fred::{ + interfaces::PubsubInterface, + prelude::{ClientLike, EventInterface, FredResult}, + types::Message, +}; + +pub use fred as fred_client; + +/// An error type for the fred driver. +#[derive(Debug)] +pub struct FredError(fred::error::Error); + +impl From for FredError { + fn from(e: fred::error::Error) -> Self { + Self(e) + } +} +impl fmt::Display for FredError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} +impl std::error::Error for FredError {} + +type HandlerMap = HashMap>; + +/// Return the channel, data and an optional req_id from a message. +fn read_msg(msg: Message) -> Option { + let chan = msg.channel.to_string(); + let data = msg.value.into_owned_bytes()?; + Some((chan, data)) +} + +/// Pipe messages from the fred client to the handlers. +async fn msg_handler(mut rx: broadcast::Receiver, handlers: Arc>) { + loop { + match rx.recv().await { + Ok(msg) => { + if let Some((chan, data)) = read_msg(msg) { + if let Some(tx) = handlers.read().unwrap().get(&chan) { + tx.try_send((chan, data)).unwrap(); + } else { + tracing::warn!(chan, "no handler for channel"); + } + } + } + // From the fred docs, even if the connection closed, the receiver will not be closed. + // Therefore if it happens, we should just return. + Err(broadcast::error::RecvError::Closed) => return, + Err(broadcast::error::RecvError::Lagged(n)) => { + tracing::warn!("fred driver pubsub channel lagged by {}", n); + } + } + } +} + +/// A driver implementation for the [fred](docs.rs/fred) pub/sub backend. +#[derive(Clone)] +pub struct FredDriver { + handlers: Arc>, + conn: fred::clients::SubscriberClient, +} + +impl FredDriver { + /// Create a new redis driver from a redis client. + pub async fn new(client: fred::clients::SubscriberClient) -> FredResult { + let handlers = Arc::new(RwLock::new(HashMap::new())); + tokio::spawn(msg_handler(client.message_rx(), handlers.clone())); + client.init().await?; + + Ok(Self { + conn: client, + handlers, + }) + } +} + +impl Driver for FredDriver { + type Error = FredError; + + async fn publish(&self, chan: String, val: Vec) -> Result<(), Self::Error> { + // We could use the receiver count from here. This would avoid a call to `server_cnt`. + self.conn.spublish::(chan, val).await?; + Ok(()) + } + + async fn subscribe( + &self, + chan: String, + size: usize, + ) -> Result, Self::Error> { + self.conn.clone().ssubscribe(chan.as_str()).await?; + let (tx, rx) = mpsc::channel(size); + self.handlers.write().unwrap().insert(chan, tx); + Ok(MessageStream::new(rx)) + } + + async fn unsubscribe(&self, chan: String) -> Result<(), Self::Error> { + self.handlers.write().unwrap().remove(&chan); + self.conn.sunsubscribe(chan).await?; + Ok(()) + } + + async fn num_serv(&self, chan: &str) -> Result { + let (_, num): (String, u16) = self.conn.pubsub_shardnumsub(chan).await?; + Ok(num) + } +} + +#[cfg(test)] +mod tests { + + use fred::{ + prelude::Server, + types::{MessageKind, Value}, + }; + use std::time::Duration; + use tokio::time; + const TIMEOUT: Duration = Duration::from_millis(100); + + use super::*; + #[tokio::test] + async fn watch_handle_message() { + let mut handlers = HashMap::new(); + let (tx, mut rx) = mpsc::channel(1); + let (tx1, rx1) = broadcast::channel(1); + handlers.insert("test".to_string(), tx); + tokio::spawn(msg_handler(rx1, Arc::new(RwLock::new(handlers)))); + let msg = Message { + channel: "test".into(), + kind: MessageKind::Message, + value: "foo".into(), + server: Server::new("0.0.0.0", 0), + }; + tx1.send(msg).unwrap(); + let (chan, data) = time::timeout(TIMEOUT, rx.recv()).await.unwrap().unwrap(); + assert_eq!(chan, "test"); + assert_eq!(data, "foo".as_bytes()); + } + + #[tokio::test] + async fn watch_handler_pattern() { + let mut handlers = HashMap::new(); + + let (tx, mut rx) = mpsc::channel(1); + handlers.insert("test-response#namespace#uid#".to_string(), tx); + let (tx1, rx1) = broadcast::channel(1); + tokio::spawn(msg_handler(rx1, Arc::new(RwLock::new(handlers)))); + let msg = Message { + channel: "test-response#namespace#uid#".into(), + kind: MessageKind::Message, + value: Value::from_static(b"foo"), + server: Server::new("0.0.0.0", 0), + }; + tx1.send(msg).unwrap(); + let (chan, data) = time::timeout(TIMEOUT, rx.recv()).await.unwrap().unwrap(); + assert_eq!(chan, "test-response#namespace#uid#"); + assert_eq!(data, "foo".as_bytes()); + } +} diff --git a/crates/socketioxide-redis/src/drivers/mod.rs b/crates/socketioxide-redis/src/drivers/mod.rs new file mode 100644 index 00000000..efb4129f --- /dev/null +++ b/crates/socketioxide-redis/src/drivers/mod.rs @@ -0,0 +1,80 @@ +use std::{future::Future, pin::Pin, task}; + +use futures_core::Stream; +use pin_project_lite::pin_project; +use tokio::sync::mpsc; + +/// A driver implementation for the [redis](docs.rs/redis) pub/sub backend. +#[cfg(feature = "redis")] +#[cfg_attr(docsrs, doc(cfg(feature = "redis")))] +pub mod redis; + +/// A driver implementation for the [fred](docs.rs/fred) pub/sub backend. +#[cfg(feature = "fred")] +#[cfg_attr(docsrs, doc(cfg(feature = "fred")))] +pub mod fred; + +pin_project! { + /// A stream of raw messages received from a channel. + /// Messages are encoded with msgpack. + #[derive(Debug)] + pub struct MessageStream { + #[pin] + rx: mpsc::Receiver, + } +} + +impl MessageStream { + /// Create a new empty message stream. + pub fn new_empty() -> Self { + // mpsc bounded channel requires buffer > 0 + let (_, rx) = mpsc::channel(1); + Self { rx } + } + /// Create a new message stream from a receiver. + pub fn new(rx: mpsc::Receiver) -> Self { + Self { rx } + } +} + +impl Stream for MessageStream { + type Item = T; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> task::Poll> { + self.project().rx.poll_recv(cx) + } +} + +/// A message item that can be returned from a channel. +pub type ChanItem = (String, Vec); + +/// The driver trait can be used to support different pub/sub backends. +/// It must share handlers/connection between its clones. +pub trait Driver: Clone + Send + Sync + 'static { + /// The error type for the driver. + type Error: std::error::Error + Send + 'static; + + /// Publish a message to a channel. + fn publish( + &self, + chan: String, + val: Vec, + ) -> impl Future> + Send; + + /// Subscribe to a channel, it will return a stream of messages. + /// The size parameter is the buffer size of the channel. + fn subscribe( + &self, + chan: String, + size: usize, + ) -> impl Future, Self::Error>> + Send; + + /// Unsubscribe from a channel. + fn unsubscribe(&self, pat: String) -> impl Future> + Send; + + /// Returns the number of socket.io servers. + fn num_serv(&self, chan: &str) -> impl Future> + Send; +} diff --git a/crates/socketioxide-redis/src/drivers/redis.rs b/crates/socketioxide-redis/src/drivers/redis.rs new file mode 100644 index 00000000..031bfc86 --- /dev/null +++ b/crates/socketioxide-redis/src/drivers/redis.rs @@ -0,0 +1,241 @@ +use std::{ + collections::HashMap, + fmt, + sync::{Arc, RwLock}, +}; + +use redis::{aio::MultiplexedConnection, AsyncCommands, FromRedisValue, PushInfo, RedisResult}; +use tokio::sync::mpsc; + +use super::{ChanItem, Driver, MessageStream}; + +pub use redis as redis_client; + +/// An error type for the redis driver. +#[derive(Debug)] +pub struct RedisError(redis::RedisError); + +impl From for RedisError { + fn from(e: redis::RedisError) -> Self { + Self(e) + } +} +impl fmt::Display for RedisError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} +impl std::error::Error for RedisError {} + +type HandlerMap = HashMap>; +/// A driver implementation for the [redis](docs.rs/redis) pub/sub backend. +#[derive(Clone)] +pub struct RedisDriver { + handlers: Arc>, + conn: MultiplexedConnection, +} + +/// A driver implementation for the [redis](docs.rs/redis) pub/sub backend. +#[cfg(feature = "redis-cluster")] +#[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))] +#[derive(Clone)] +pub struct ClusterDriver { + handlers: Arc>, + conn: redis::cluster_async::ClusterConnection, +} + +fn read_msg(msg: redis::PushInfo) -> RedisResult)>> { + match msg.kind { + redis::PushKind::Message | redis::PushKind::SMessage => { + if msg.data.len() < 2 { + return Ok(None); + } + let mut iter = msg.data.into_iter(); + let channel: String = FromRedisValue::from_owned_redis_value(iter.next().unwrap())?; + let message = FromRedisValue::from_owned_redis_value(iter.next().unwrap())?; + Ok(Some((channel, message))) + } + _ => Ok(None), + } +} + +fn handle_msg(msg: PushInfo, handlers: Arc>) { + match read_msg(msg) { + Ok(Some((chan, msg))) => { + if let Some(tx) = handlers.read().unwrap().get(&chan) { + if let Err(e) = tx.try_send((chan, msg)) { + tracing::warn!("redis pubsub channel full {e}"); + } + } else { + tracing::warn!(chan, "no handler for channel"); + } + } + Ok(_) => {} + Err(e) => { + tracing::error!("error reading message from redis: {e}"); + } + } +} +impl RedisDriver { + /// Create a new redis driver from a redis client. + pub async fn new(client: &redis::Client) -> Result { + let handlers = Arc::new(RwLock::new(HashMap::new())); + let handlers_clone = handlers.clone(); + let config = redis::AsyncConnectionConfig::new().set_push_sender(move |msg| { + handle_msg(msg, handlers_clone.clone()); + Ok::<(), std::convert::Infallible>(()) + }); + + let conn = client + .get_multiplexed_async_connection_with_config(&config) + .await?; + + Ok(Self { conn, handlers }) + } +} + +#[cfg(feature = "redis-cluster")] +impl ClusterDriver { + /// Create a new redis driver from a redis cluster client. + #[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))] + pub async fn new( + client_builder: redis::cluster::ClusterClientBuilder, + ) -> Result { + let handlers = Arc::new(RwLock::new(HashMap::new())); + let handlers_clone = handlers.clone(); + let conn = client_builder + .push_sender(move |msg| { + handle_msg(msg, handlers_clone.clone()); + Ok::<(), std::convert::Infallible>(()) + }) + .build() + .unwrap() + .get_async_connection() + .await?; + + Ok(Self { conn, handlers }) + } +} + +impl Driver for RedisDriver { + type Error = RedisError; + + async fn publish(&self, chan: String, val: Vec) -> Result<(), Self::Error> { + self.conn + .clone() + .publish::<_, _, redis::Value>(chan, val) + .await?; + Ok(()) + } + + async fn subscribe( + &self, + chan: String, + size: usize, + ) -> Result, Self::Error> { + self.conn.clone().subscribe(chan.as_str()).await?; + let (tx, rx) = mpsc::channel(size); + self.handlers.write().unwrap().insert(chan, tx); + Ok(MessageStream::new(rx)) + } + + async fn unsubscribe(&self, chan: String) -> Result<(), Self::Error> { + self.handlers.write().unwrap().remove(&chan); + self.conn.clone().unsubscribe(chan).await?; + Ok(()) + } + + async fn num_serv(&self, chan: &str) -> Result { + let mut conn = self.conn.clone(); + let (_, count): (String, u16) = redis::cmd("PUBSUB") + .arg("NUMSUB") + .arg(chan) + .query_async(&mut conn) + .await?; + Ok(count) + } +} + +#[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))] +#[cfg(feature = "redis-cluster")] +impl Driver for ClusterDriver { + type Error = RedisError; + + async fn publish(&self, chan: String, val: Vec) -> Result<(), Self::Error> { + self.conn + .clone() + .spublish::<_, _, redis::Value>(chan, val) + .await?; + Ok(()) + } + + async fn subscribe( + &self, + chan: String, + size: usize, + ) -> Result, Self::Error> { + self.conn.clone().ssubscribe(chan.as_str()).await?; + let (tx, rx) = mpsc::channel(size); + self.handlers.write().unwrap().insert(chan, tx); + Ok(MessageStream::new(rx)) + } + + async fn unsubscribe(&self, chan: String) -> Result<(), Self::Error> { + self.handlers.write().unwrap().remove(&chan); + self.conn.clone().sunsubscribe(chan).await?; + Ok(()) + } + + async fn num_serv(&self, chan: &str) -> Result { + let mut conn = self.conn.clone(); + let (_, count): (String, u16) = redis::cmd("PUBSUB") + .arg("SHARDNUMSUB") + .arg(chan) + .query_async(&mut conn) + .await?; + Ok(count) + } +} + +#[cfg(test)] +mod tests { + + use super::*; + #[test] + fn watch_handle_message() { + let mut handlers = HashMap::new(); + + let (tx, mut rx) = mpsc::channel(1); + handlers.insert("test".to_string(), tx); + let msg = redis::PushInfo { + kind: redis::PushKind::Message, + data: vec![ + redis::Value::BulkString("test".into()), + redis::Value::BulkString("foo".into()), + ], + }; + super::handle_msg(msg, Arc::new(RwLock::new(handlers))); + let (chan, data) = rx.try_recv().unwrap(); + assert_eq!(chan, "test"); + assert_eq!(data, "foo".as_bytes()); + } + + #[test] + fn watch_handler_pattern() { + let mut handlers = HashMap::new(); + + let (tx1, mut rx1) = mpsc::channel(1); + handlers.insert("test-response#namespace#uid#".to_string(), tx1); + let msg = redis::PushInfo { + kind: redis::PushKind::Message, + data: vec![ + redis::Value::BulkString("test-response#namespace#uid#".into()), + redis::Value::BulkString("foo".into()), + ], + }; + super::handle_msg(msg, Arc::new(RwLock::new(handlers))); + let (chan, data) = rx1.try_recv().unwrap(); + assert_eq!(chan, "test-response#namespace#uid#"); + assert_eq!(data, "foo".as_bytes()); + } +} diff --git a/crates/socketioxide-redis/src/lib.rs b/crates/socketioxide-redis/src/lib.rs new file mode 100644 index 00000000..09adacf4 --- /dev/null +++ b/crates/socketioxide-redis/src/lib.rs @@ -0,0 +1,1109 @@ +#![cfg_attr(docsrs, feature(doc_cfg))] +#![warn( + clippy::all, + clippy::todo, + clippy::empty_enum, + clippy::mem_forget, + clippy::unused_self, + clippy::filter_map_next, + clippy::needless_continue, + clippy::needless_borrow, + clippy::match_wildcard_for_single_variants, + clippy::if_let_mutex, + clippy::await_holding_lock, + clippy::match_on_vec_items, + clippy::imprecise_flops, + clippy::suboptimal_flops, + clippy::lossy_float_literal, + clippy::rest_pat_in_fully_bound_structs, + clippy::fn_params_excessive_bools, + clippy::exit, + clippy::inefficient_to_string, + clippy::linkedlist, + clippy::macro_use_imports, + clippy::option_option, + clippy::verbose_file_reads, + clippy::unnested_or_patterns, + rust_2018_idioms, + future_incompatible, + nonstandard_style, + missing_docs +)] + +//! # A redis/valkey adapter implementation for the socketioxide crate. +//! The adapter is used to communicate with other nodes of the same application. +//! This allows to broadcast messages to sockets connected on other servers, +//! to get the list of rooms, to add or remove sockets from rooms, etc. +//! +//! To achieve this, the adapter uses a [pub/sub](https://redis.io/docs/latest/develop/interact/pubsub/) system +//! through Redis to communicate with other servers. +//! +//! The [`Driver`] abstraction allows the use of any pub/sub client. +//! Three implementations are provided: +//! * [`RedisDriver`](crate::drivers::redis::RedisDriver) for the [`redis`] crate with a standalone redis. +//! * [`ClusterDriver`](crate::drivers::redis::ClusterDriver) for the [`redis`] crate with a redis cluster. +//! * [`FredDriver`](crate::drivers::fred::FredDriver) for the [`fred`] crate with a standalone/cluster redis. +//! +//! When using redis clusters, the drivers employ [sharded pub/sub](https://redis.io/docs/latest/develop/interact/pubsub/#sharded-pubsub) +//! to distribute the load across Redis nodes. +//! +//! You can also implement your own driver by implementing the [`Driver`] trait. +//! +//!
+//! The provided driver implementations are using RESP3 for efficiency purposes. +//! Make sure your redis server supports it (redis v7 and above). +//! If not, you can implement your own driver using the RESP2 protocol. +//!
+//! +//! ## Example with the [`redis`] driver +//! ```rust +//! # use socketioxide::{SocketIo, extract::{SocketRef, Data}, adapter::Adapter}; +//! # use socketioxide_redis::{RedisAdapterCtr, RedisAdapter}; +//! # async fn doc_main() -> Result<(), Box> { +//! async fn on_connect(socket: SocketRef) { +//! socket.join("room1"); +//! socket.on("event", on_event); +//! let _ = socket.broadcast().emit("hello", "world").await.ok(); +//! } +//! async fn on_event(socket: SocketRef, Data(data): Data) {} +//! +//! let client = redis::Client::open("redis://127.0.0.1:6379?protocol=RESP3")?; +//! let adapter = RedisAdapterCtr::new_with_redis(&client).await?; +//! let (layer, io) = SocketIo::builder() +//! .with_adapter::>(adapter) +//! .build_layer(); +//! Ok(()) +//! # } +//! ``` +//! +//! +//! ## Example with the [`fred`] driver +//! ```rust +//! # use socketioxide::{SocketIo, extract::{SocketRef, Data}, adapter::Adapter}; +//! # use socketioxide_redis::{RedisAdapterCtr, FredAdapter}; +//! # use fred::types::RespVersion; +//! # async fn doc_main() -> Result<(), Box> { +//! async fn on_connect(socket: SocketRef) { +//! socket.join("room1"); +//! socket.on("event", on_event); +//! let _ = socket.broadcast().emit("hello", "world").await.ok(); +//! } +//! async fn on_event(socket: SocketRef, Data(data): Data) {} +//! +//! let mut config = fred::prelude::Config::from_url("redis://127.0.0.1:6379?protocol=resp3")?; +//! // We need to manually set the RESP3 version because +//! // the fred crate does not parse the protocol query parameter. +//! config.version = RespVersion::RESP3; +//! let client = fred::prelude::Builder::from_config(config).build_subscriber_client()?; +//! let adapter = RedisAdapterCtr::new_with_fred(client).await?; +//! let (layer, io) = SocketIo::builder() +//! .with_adapter::>(adapter) +//! .build_layer(); +//! Ok(()) +//! # } +//! ``` +//! +//! +//! ## Example with the [`redis`] cluster driver +//! ```rust +//! # use socketioxide::{SocketIo, extract::{SocketRef, Data}, adapter::Adapter}; +//! # use socketioxide_redis::{RedisAdapterCtr, ClusterAdapter}; +//! # async fn doc_main() -> Result<(), Box> { +//! async fn on_connect(socket: SocketRef) { +//! socket.join("room1"); +//! socket.on("event", on_event); +//! let _ = socket.broadcast().emit("hello", "world").await.ok(); +//! } +//! async fn on_event(socket: SocketRef, Data(data): Data) {} +//! +//! // single node cluster +//! let builder = redis::cluster::ClusterClient::builder(std::iter::once( +//! "redis://127.0.0.1:6379?protocol=resp3", +//! )); +//! let adapter = RedisAdapterCtr::new_with_cluster(builder).await?; +//! +//! let (layer, io) = SocketIo::builder() +//! .with_adapter::>(adapter) +//! .build_layer(); +//! Ok(()) +//! # } +//! ``` +//! +//! ## How does it work? +//! +//! An adapter is created for each created namespace and it takes a corresponding [`CoreLocalAdapter`]. +//! The [`CoreLocalAdapter`] allows to manage the local rooms and local sockets. The default `LocalAdapter` +//! is simply a wrapper around this [`CoreLocalAdapter`]. +//! +//! The adapter is then initialized with the [`RedisAdapter::init`] method. +//! This will subscribe to 3 channels: +//! * `"{prefix}-request#{namespace}#"`: A global channel to receive broadcasted requests. +//! * `"{prefix}-request#{namespace}#{uid}#"`: A specific channel to receive requests only for this server. +//! * `"{prefix}-response#{namespace}#{uid}#"`: A specific channel to receive responses only for this server. +//! Messages sent to this channel will be always in the form `[req_id, data]`. This will allow the adapter to extract the request id +//! and route the response to the approriate stream before deserializing the data. +//! +//! All messages are encoded with msgpack. +//! +//! There are 7 types of requests: +//! * Broadcast a packet to all the matching sockets. +//! * Broadcast a packet to all the matching sockets and wait for a stream of acks. +//! * Disconnect matching sockets. +//! * Get all the rooms. +//! * Add matching sockets to rooms. +//! * Remove matching sockets to rooms. +//! * Fetch all the remote sockets matching the options. +//! +//! For ack streams, the adapter will first send a `BroadcastAckCount` response to the server that sent the request, +//! and then send the acks as they are received (more details in [`RedisAdapter::broadcast_with_ack`] fn). +//! +//! On the other side, each time an action has to be performed on the local server, the adapter will +//! first broadcast a request to all the servers and then perform the action locally. + +use std::{ + borrow::Cow, + collections::HashMap, + fmt, + future::{self, Future}, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, + time::Duration, +}; + +use drivers::{ChanItem, Driver, MessageStream}; +use futures_core::Stream; +use futures_util::StreamExt; +use request::{ + read_req_id, RequestIn, RequestOut, RequestTypeIn, RequestTypeOut, Response, ResponseType, +}; +use serde::{de::DeserializeOwned, Serialize}; +use socketioxide_core::{ + adapter::{ + BroadcastFlags, BroadcastOptions, CoreAdapter, CoreLocalAdapter, DefinedAdapter, + RemoteSocketData, Room, RoomParam, SocketEmitter, Spawnable, + }, + errors::{AdapterError, BroadcastError}, + packet::Packet, + Sid, Uid, +}; +use stream::{AckStream, DropStream}; +use tokio::{sync::mpsc, time}; + +/// Drivers are an abstraction over the pub/sub backend used by the adapter. +/// You can use the provided implementation or implement your own. +pub mod drivers; + +mod request; +mod stream; + +/// Represent any error that might happen when using this adapter. +#[derive(thiserror::Error)] +pub enum Error { + /// Redis driver error + #[error("driver error: {0}")] + Driver(R::Error), + /// Packet encoding error + #[error("packet encoding error: {0}")] + Decode(#[from] rmp_serde::decode::Error), + /// Packet decoding error + #[error("packet decoding error: {0}")] + Encode(#[from] rmp_serde::encode::Error), +} + +impl Error { + fn from_driver(err: R::Error) -> Self { + Self::Driver(err) + } +} +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Driver(err) => write!(f, "Driver error: {:?}", err), + Self::Decode(err) => write!(f, "Decode error: {:?}", err), + Self::Encode(err) => write!(f, "Encode error: {:?}", err), + } + } +} + +impl From> for AdapterError { + fn from(err: Error) -> Self { + AdapterError::from(Box::new(err) as Box) + } +} + +/// The configuration of the [`RedisAdapter`]. +#[derive(Debug, Clone)] +pub struct RedisAdapterConfig { + /// The request timeout. It is mainly used when expecting response such as when using + /// `broadcast_with_ack` or `rooms`. Default is 5 seconds. + pub request_timeout: Duration, + + /// The prefix used for the channels. Default is "socket.io". + pub prefix: Cow<'static, str>, + + /// The channel size used to receive ack responses. Default is 255. + /// + /// If you have a lot of servers/sockets and that you may miss acknowledgement because they arrive faster + /// than you poll them with the returned stream, you might want to increase this value. + pub ack_response_buffer: usize, + + /// The channel size used to receive messages. Default is 1024. + /// + /// If your server is under heavy load, you might want to increase this value. + pub stream_buffer: usize, +} +impl RedisAdapterConfig { + /// Create a new config. + pub fn new() -> Self { + Self::default() + } + /// Set the request timeout. Default is 5 seconds. + pub fn with_request_timeout(mut self, timeout: Duration) -> Self { + self.request_timeout = timeout; + self + } + + /// Set the prefix used for the channels. Default is "socket.io". + pub fn with_prefix(mut self, prefix: impl Into>) -> Self { + self.prefix = prefix.into(); + self + } + + /// Set the channel size used to send ack responses. Default is 255. + /// + /// If you have a lot of servers/sockets and that you may miss acknowledgement because they arrive faster + /// than you poll them with the returned stream, you might want to increase this value. + pub fn with_ack_response_buffer(mut self, buffer: usize) -> Self { + assert!(buffer > 0, "buffer size must be greater than 0"); + self.ack_response_buffer = buffer; + self + } + + /// Set the channel size used to receive messages. Default is 1024. + /// + /// If your server is under heavy load, you might want to increase this value. + pub fn with_stream_buffer(mut self, buffer: usize) -> Self { + assert!(buffer > 0, "buffer size must be greater than 0"); + self.stream_buffer = buffer; + self + } +} + +impl Default for RedisAdapterConfig { + fn default() -> Self { + Self { + request_timeout: Duration::from_secs(5), + prefix: Cow::Borrowed("socket.io"), + ack_response_buffer: 255, + stream_buffer: 1024, + } + } +} + +/// The adapter constructor. For each namespace you define, a new adapter instance is created +/// from this constructor. +#[derive(Debug)] +pub struct RedisAdapterCtr { + driver: R, + config: RedisAdapterConfig, +} + +#[cfg(feature = "redis")] +impl RedisAdapterCtr { + /// Create a new adapter constructor with the [`redis`] driver and a default config. + #[cfg_attr(docsrs, doc(cfg(feature = "redis")))] + pub async fn new_with_redis(client: &redis::Client) -> redis::RedisResult { + Self::new_with_redis_config(client, RedisAdapterConfig::default()).await + } + /// Create a new adapter constructor with the [`redis`] driver and a custom config. + #[cfg_attr(docsrs, doc(cfg(feature = "redis")))] + pub async fn new_with_redis_config( + client: &redis::Client, + config: RedisAdapterConfig, + ) -> redis::RedisResult { + let driver = drivers::redis::RedisDriver::new(client).await?; + Ok(Self::new_with_driver(driver, config)) + } +} +#[cfg(feature = "redis-cluster")] +impl RedisAdapterCtr { + /// Create a new adapter constructor with the [`redis`] driver and a default config. + #[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))] + pub async fn new_with_cluster( + builder: redis::cluster::ClusterClientBuilder, + ) -> redis::RedisResult { + Self::new_with_cluster_config(builder, RedisAdapterConfig::default()).await + } + + /// Create a new adapter constructor with the [`redis`] driver and a default config. + #[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))] + pub async fn new_with_cluster_config( + builder: redis::cluster::ClusterClientBuilder, + config: RedisAdapterConfig, + ) -> redis::RedisResult { + let driver = drivers::redis::ClusterDriver::new(builder).await?; + Ok(Self::new_with_driver(driver, config)) + } +} +#[cfg(feature = "fred")] +impl RedisAdapterCtr { + /// Create a new adapter constructor with the default [`fred`] driver and a default config. + #[cfg_attr(docsrs, doc(cfg(feature = "fred")))] + pub async fn new_with_fred( + client: fred::clients::SubscriberClient, + ) -> fred::prelude::FredResult { + Self::new_with_fred_config(client, RedisAdapterConfig::default()).await + } + /// Create a new adapter constructor with the default [`fred`] driver and a custom config. + #[cfg_attr(docsrs, doc(cfg(feature = "fred")))] + pub async fn new_with_fred_config( + client: fred::clients::SubscriberClient, + config: RedisAdapterConfig, + ) -> fred::prelude::FredResult { + let driver = drivers::fred::FredDriver::new(client).await?; + Ok(Self::new_with_driver(driver, config)) + } +} +impl RedisAdapterCtr { + /// Create a new adapter constructor with a custom redis/valkey driver and a config. + /// + /// You can implement your own driver by implementing the [`Driver`] trait with any redis/valkey client. + /// Check the [`drivers`] module for more information. + pub fn new_with_driver(driver: R, config: RedisAdapterConfig) -> RedisAdapterCtr { + RedisAdapterCtr { driver, config } + } +} + +pub(crate) type ResponseHandlers = HashMap>>; + +/// The redis adapter with the fred driver. +#[cfg_attr(docsrs, doc(cfg(feature = "fred")))] +#[cfg(feature = "fred")] +pub type FredAdapter = CustomRedisAdapter; + +/// The redis adapter with the redis driver. +#[cfg_attr(docsrs, doc(cfg(feature = "redis")))] +#[cfg(feature = "redis")] +pub type RedisAdapter = CustomRedisAdapter; + +/// The redis adapter with the redis cluster driver. +#[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))] +#[cfg(feature = "redis-cluster")] +pub type ClusterAdapter = CustomRedisAdapter; + +/// The redis adapter implementation. +/// It is generic over the [`Driver`] used to communicate with the redis server. +/// And over the [`SocketEmitter`] used to communicate with the local server. This allows to +/// avoid cyclic dependencies between the adapter, `socketioxide-core` and `socketioxide` crates. +pub struct CustomRedisAdapter { + /// The driver used by the adapter. This is used to communicate with the redis server. + /// All the redis adapter instances share the same driver. + driver: R, + /// The configuration of the adapter. + config: RedisAdapterConfig, + /// A unique identifier for the adapter to identify itself in the redis server. + uid: Uid, + /// The local adapter, used to manage local rooms and socket stores. + local: CoreLocalAdapter, + /// The request channel used to broadcast requests to all the servers. + /// format: `{prefix}-request#{path}#`. + req_chan: String, + /// A map of response handlers used to await for responses from the remote servers. + responses: Arc>, +} + +impl DefinedAdapter for CustomRedisAdapter {} +impl CoreAdapter for CustomRedisAdapter { + type Error = Error; + type State = RedisAdapterCtr; + type AckStream = AckStream; + type InitRes = InitRes; + + fn new(state: &Self::State, local: CoreLocalAdapter) -> Self { + let req_chan = format!("{}-request#{}#", state.config.prefix, local.path()); + let uid = local.server_id(); + Self { + local, + req_chan, + uid, + driver: state.driver.clone(), + config: state.config.clone(), + responses: Arc::new(Mutex::new(HashMap::new())), + } + } + + fn init(self: Arc, on_success: impl FnOnce() + Send + 'static) -> Self::InitRes { + let fut = async move { + check_ns(self.local.path())?; + let global_stream = self.subscribe(self.req_chan.clone()).await?; + let specific_stream = self.subscribe(self.get_req_chan(Some(self.uid))).await?; + let response_chan = format!( + "{}-response#{}#{}#", + &self.config.prefix, + self.local.path(), + self.uid + ); + + let response_stream = self.subscribe(response_chan.clone()).await?; + let stream = futures_util::stream::select(global_stream, specific_stream); + let stream = futures_util::stream::select(stream, response_stream); + tokio::spawn(self.pipe_stream(stream, response_chan)); + on_success(); + Ok(()) + }; + InitRes(Box::pin(fut)) + } + + async fn close(&self) -> Result<(), Self::Error> { + let response_chan = format!( + "{}-response#{}#{}#", + &self.config.prefix, + self.local.path(), + self.uid + ); + tokio::try_join!( + self.driver.unsubscribe(self.req_chan.clone()), + self.driver.unsubscribe(self.get_req_chan(Some(self.uid))), + self.driver.unsubscribe(response_chan) + ) + .map_err(Error::from_driver)?; + + Ok(()) + } + + /// Get the number of servers by getting the number of subscribers to the request channel. + async fn server_count(&self) -> Result { + let count = self + .driver + .num_serv(&self.req_chan) + .await + .map_err(Error::from_driver)?; + + Ok(count) + } + + /// Broadcast a packet to all the servers to send them through their sockets. + async fn broadcast( + &self, + packet: Packet, + opts: BroadcastOptions, + ) -> Result<(), BroadcastError> { + if !is_local_op(self.uid, &opts) { + let req = RequestOut::new(self.uid, RequestTypeOut::Broadcast(&packet), &opts); + self.send_req(req, opts.server_id) + .await + .map_err(AdapterError::from)?; + } + + self.local.broadcast(packet, opts)?; + Ok(()) + } + + /// Broadcast a packet to all the servers to send them through their sockets. + /// + /// Returns a Stream that is a combination of the local ack stream and a remote [`MessageStream`]. + /// Here is a specific protocol in order to know how many message the server expect to close + /// the stream at the right time: + /// * Get the number `n` of remote servers. + /// * Send the broadcast request. + /// * Expect `n` `BroadcastAckCount` response in the stream to know the number `m` of expected ack responses. + /// * Expect `sum(m)` broadcast counts sent by the servers. + /// + /// Example with 3 remote servers (n = 3): + /// ```text + /// +---+ +---+ +---+ + /// | A | | B | | C | + /// +---+ +---+ +---+ + /// | | | + /// |---BroadcastWithAck--->| | + /// |---BroadcastWithAck--------------------------->| + /// | | | + /// |<-BroadcastAckCount(2)-| (n = 2; m = 2) | + /// |<-BroadcastAckCount(2)-------(n = 2; m = 4)----| + /// | | | + /// |<----------------Ack---------------------------| + /// |<----------------Ack---| | + /// | | | + /// |<----------------Ack---------------------------| + /// |<----------------Ack---| | + async fn broadcast_with_ack( + &self, + packet: Packet, + opts: BroadcastOptions, + timeout: Option, + ) -> Result { + if is_local_op(self.uid, &opts) { + let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout); + let stream = AckStream::new_local(local); + return Ok(stream); + } + let req = RequestOut::new(self.uid, RequestTypeOut::BroadcastWithAck(&packet), &opts); + let req_id = req.id; + + let remote_serv_cnt = self.server_count().await?.saturating_sub(1); + + let (tx, rx) = mpsc::channel(self.config.ack_response_buffer + remote_serv_cnt as usize); + self.responses.lock().unwrap().insert(req_id, tx); + let remote = MessageStream::new(rx); + + self.send_req(req, opts.server_id).await?; + let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout); + + Ok(AckStream::new( + local, + remote, + self.config.request_timeout, + remote_serv_cnt, + req_id, + self.responses.clone(), + )) + } + + async fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), BroadcastError> { + if !is_local_op(self.uid, &opts) { + let req = RequestOut::new(self.uid, RequestTypeOut::DisconnectSockets, &opts); + self.send_req(req, opts.server_id) + .await + .map_err(AdapterError::from)?; + } + self.local + .disconnect_socket(opts) + .map_err(BroadcastError::Socket)?; + + Ok(()) + } + + async fn rooms(&self, opts: BroadcastOptions) -> Result, Self::Error> { + const PACKET_IDX: u8 = 2; + + if is_local_op(self.uid, &opts) { + return Ok(self.local.rooms(opts).into_iter().collect()); + } + let req = RequestOut::new(self.uid, RequestTypeOut::AllRooms, &opts); + let req_id = req.id; + + // First get the remote stream because redis might send + // the responses before subscription is done. + let stream = self.get_res::<()>(req_id, PACKET_IDX).await?; + self.send_req(req, opts.server_id).await?; + let local = self.local.rooms(opts); + let rooms = stream + .filter_map(|item| future::ready(item.into_rooms())) + .fold(local, |mut acc, item| async move { + acc.extend(item); + acc + }) + .await; + Ok(Vec::from_iter(rooms)) + } + + async fn add_sockets( + &self, + opts: BroadcastOptions, + rooms: impl RoomParam, + ) -> Result<(), Self::Error> { + let rooms: Vec = rooms.into_room_iter().collect(); + if !is_local_op(self.uid, &opts) { + let req = RequestOut::new(self.uid, RequestTypeOut::AddSockets(&rooms), &opts); + self.send_req(req, opts.server_id).await?; + } + self.local.add_sockets(opts, rooms); + Ok(()) + } + + async fn del_sockets( + &self, + opts: BroadcastOptions, + rooms: impl RoomParam, + ) -> Result<(), Self::Error> { + let rooms: Vec = rooms.into_room_iter().collect(); + if !is_local_op(self.uid, &opts) { + let req = RequestOut::new(self.uid, RequestTypeOut::DelSockets(&rooms), &opts); + self.send_req(req, opts.server_id).await?; + } + self.local.del_sockets(opts, rooms); + Ok(()) + } + + async fn fetch_sockets( + &self, + opts: BroadcastOptions, + ) -> Result, Self::Error> { + if is_local_op(self.uid, &opts) { + return Ok(self.local.fetch_sockets(opts)); + } + const PACKET_IDX: u8 = 3; + let req = RequestOut::new(self.uid, RequestTypeOut::FetchSockets, &opts); + let req_id = req.id; + // First get the remote stream because redis might send + // the responses before subscription is done. + let remote = self.get_res::(req_id, PACKET_IDX).await?; + + self.send_req(req, opts.server_id).await?; + let local = self.local.fetch_sockets(opts); + let sockets = remote + .filter_map(|item| future::ready(item.into_fetch_sockets())) + .fold(local, |mut acc, item| async move { + acc.extend(item); + acc + }) + .await; + Ok(sockets) + } + + fn get_local(&self) -> &CoreLocalAdapter { + &self.local + } +} + +/// Error that can happen when initializing the adapter. +#[derive(thiserror::Error)] +pub enum InitError { + /// Driver error. + #[error("driver error: {0}")] + Driver(D::Error), + /// Malformed namespace path. + #[error("malformed namespace path, it must not contain '#'")] + MalformedNamespace, +} +impl fmt::Debug for InitError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Driver(err) => fmt::Debug::fmt(err, f), + Self::MalformedNamespace => write!(f, "Malformed namespace path"), + } + } +} +/// The result of the init future. +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct InitRes(futures_core::future::BoxFuture<'static, Result<(), InitError>>); + +impl Future for InitRes { + type Output = Result<(), InitError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.0.as_mut().poll(cx) + } +} +impl Spawnable for InitRes { + fn spawn(self) { + tokio::spawn(async move { + if let Err(e) = self.0.await { + tracing::error!("error initializing adapter: {e}"); + } + }); + } +} + +impl CustomRedisAdapter { + /// Build a response channel for a request. + /// + /// The uid is used to identify the server that sent the request. + /// The req_id is used to identify the request. + fn get_res_chan(&self, uid: Uid) -> String { + let path = self.local.path(); + let prefix = &self.config.prefix; + format!("{}-response#{}#{}#", prefix, path, uid) + } + /// Build a request channel for a request. + /// + /// If we know the target server id, we can build a channel specific to this server. + /// Otherwise, we use the default request channel that will broadcast the request to all the servers. + fn get_req_chan(&self, node_id: Option) -> String { + match node_id { + Some(uid) => format!("{}{}#", self.req_chan, uid), + None => self.req_chan.clone(), + } + } + + async fn pipe_stream( + self: Arc, + mut stream: impl Stream + Unpin, + response_chan: String, + ) { + while let Some((chan, item)) = stream.next().await { + if chan.starts_with(&self.req_chan) { + if let Err(e) = self.recv_req(item) { + let ns = self.local.path(); + let uid = self.uid; + tracing::warn!(?uid, ?ns, "request handler error: {e}"); + } + } else if chan == response_chan { + let req_id = read_req_id(&item); + tracing::trace!(?req_id, ?chan, ?response_chan, "extracted sid"); + let handlers = self.responses.lock().unwrap(); + if let Some(tx) = req_id.and_then(|id| handlers.get(&id)) { + if let Err(e) = tx.try_send(item) { + tracing::warn!("error sending response to handler: {e}"); + } + } else { + tracing::warn!(?req_id, "could not find req handler"); + } + } else { + tracing::warn!("unexpected message/channel: {chan}"); + } + } + } + + /// Handle a generic request received from the request channel. + fn recv_req(self: &Arc, item: Vec) -> Result<(), Error> { + let req: RequestIn = rmp_serde::from_slice(&item)?; + if req.node_id == self.uid { + return Ok(()); + } + + tracing::trace!(?req, "handling request"); + + match req.r#type { + RequestTypeIn::Broadcast(p) => self.recv_broadcast(req.opts, p), + RequestTypeIn::BroadcastWithAck(_) => self.clone().recv_broadcast_with_ack(req), + RequestTypeIn::DisconnectSockets => self.recv_disconnect_sockets(req), + RequestTypeIn::AllRooms => self.recv_rooms(req), + RequestTypeIn::AddSockets(rooms) => self.recv_add_sockets(req.opts, rooms), + RequestTypeIn::DelSockets(rooms) => self.recv_del_sockets(req.opts, rooms), + RequestTypeIn::FetchSockets => self.recv_fetch_sockets(req), + }; + Ok(()) + } + + fn recv_broadcast(&self, opts: BroadcastOptions, packet: Packet) { + if let Err(e) = self.local.broadcast(packet, opts) { + let ns = self.local.path(); + tracing::warn!(?self.uid, ?ns, "remote request broadcast handler: {:?}", e); + } + } + + fn recv_disconnect_sockets(&self, req: RequestIn) { + if let Err(e) = self.local.disconnect_socket(req.opts) { + let ns = self.local.path(); + tracing::warn!( + ?self.uid, + ?ns, + "remote request disconnect sockets handler: {:?}", + e + ); + } + } + + fn recv_broadcast_with_ack(self: Arc, req: RequestIn) { + let packet = match req.r#type { + RequestTypeIn::BroadcastWithAck(p) => p, + _ => unreachable!(), + }; + let (stream, count) = self.local.broadcast_with_ack(packet, req.opts, None); + tokio::spawn(async move { + let on_err = |err| { + let ns = self.local.path(); + tracing::warn!( + ?self.uid, + ?ns, + "remote request broadcast with ack handler errors: {:?}", + err + ); + }; + // First send the count of expected acks to the server that sent the request. + // This is used to keep track of the number of expected acks. + let res = Response { + r#type: ResponseType::<()>::BroadcastAckCount(count), + node_id: self.uid, + }; + if let Err(err) = self.send_res(req.node_id, req.id, res).await { + on_err(err); + return; + } + + // Then send the acks as they are received. + futures_util::pin_mut!(stream); + while let Some(ack) = stream.next().await { + let res = Response { + r#type: ResponseType::BroadcastAck(ack), + node_id: self.uid, + }; + if let Err(err) = self.send_res(req.node_id, req.id, res).await { + on_err(err); + return; + } + } + }); + } + + fn recv_rooms(&self, req: RequestIn) { + let rooms = self.local.rooms(req.opts); + let res = Response { + r#type: ResponseType::<()>::AllRooms(rooms), + node_id: self.uid, + }; + let fut = self.send_res(req.node_id, req.id, res); + let ns = self.local.path().clone(); + let uid = self.uid; + tokio::spawn(async move { + if let Err(err) = fut.await { + tracing::warn!(?uid, ?ns, "remote request rooms handler: {:?}", err); + } + }); + } + + fn recv_add_sockets(&self, opts: BroadcastOptions, rooms: Vec) { + self.local.add_sockets(opts, rooms); + } + + fn recv_del_sockets(&self, opts: BroadcastOptions, rooms: Vec) { + self.local.del_sockets(opts, rooms); + } + fn recv_fetch_sockets(&self, req: RequestIn) { + let sockets = self.local.fetch_sockets(req.opts); + let res = Response { + node_id: self.uid, + r#type: ResponseType::FetchSockets(sockets), + }; + let fut = self.send_res(req.node_id, req.id, res); + let ns = self.local.path().clone(); + let uid = self.uid; + tokio::spawn(async move { + if let Err(err) = fut.await { + tracing::warn!(?uid, ?ns, "remote request fetch sockets handler: {:?}", err); + } + }); + } + + async fn send_req(&self, req: RequestOut<'_>, target_uid: Option) -> Result<(), Error> { + tracing::trace!(?req, "sending request"); + let req = rmp_serde::to_vec(&req)?; + let chan = self.get_req_chan(target_uid); + self.driver + .publish(chan, req) + .await + .map_err(Error::from_driver)?; + + Ok(()) + } + + fn send_res( + &self, + req_node_id: Uid, + req_id: Sid, + res: Response, + ) -> impl Future>> + Send + 'static { + let chan = self.get_res_chan(req_node_id); + tracing::trace!(?res, "sending response to {}", &chan); + // We send the req_id separated from the response object. + // This allows to partially decode the response and route by the req_id + // before fully deserializing it. + let res = rmp_serde::to_vec(&(req_id, res)); + let driver = self.driver.clone(); + async move { + driver + .publish(chan, res?) + .await + .map_err(Error::from_driver)?; + Ok(()) + } + } + + /// Await for all the responses from the remote servers. + async fn get_res( + &self, + req_id: Sid, + response_idx: u8, + ) -> Result>, Error> { + let remote_serv_cnt = self.server_count().await?.saturating_sub(1) as usize; + let (tx, rx) = mpsc::channel(std::cmp::max(remote_serv_cnt, 1)); + self.responses.lock().unwrap().insert(req_id, tx); + let stream = MessageStream::new(rx) + .filter_map(|item| { + let data = match rmp_serde::from_slice::<(Sid, Response)>(&item) { + Ok((_, data)) => Some(data), + Err(e) => { + tracing::warn!("error decoding response: {e}"); + None + } + }; + future::ready(data) + }) + .filter(move |item| future::ready(item.r#type.to_u8() == response_idx)) + .take(remote_serv_cnt) + .take_until(time::sleep(self.config.request_timeout)); + let stream = DropStream::new(stream, self.responses.clone(), req_id); + Ok(stream) + } + + /// Little wrapper to map the error type. + #[inline] + async fn subscribe(&self, pat: String) -> Result, InitError> { + tracing::trace!(?pat, "subscribing to"); + self.driver + .subscribe(pat, self.config.stream_buffer) + .await + .map_err(InitError::Driver) + } +} + +/// A local operator is either something that is flagged as local or a request that should be specifically +/// sent to the current server. +#[inline] +fn is_local_op(uid: Uid, opts: &BroadcastOptions) -> bool { + opts.has_flag(BroadcastFlags::Local) + || (!opts.has_flag(BroadcastFlags::Broadcast) + && opts.server_id == Some(uid) + && opts.rooms.is_empty() + && opts.sid.is_some()) +} + +/// Checks if the namespace path is valid +/// Panics if the path is empty or contains a `#` +fn check_ns(path: &str) -> Result<(), InitError> { + if path.is_empty() || path.contains('#') { + Err(InitError::MalformedNamespace) + } else { + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use futures_util::stream::{self, FusedStream, StreamExt}; + use socketioxide_core::{adapter::AckStreamItem, Str, Value}; + use std::convert::Infallible; + + #[derive(Clone)] + struct StubDriver; + impl Driver for StubDriver { + type Error = Infallible; + + async fn publish(&self, _: String, _: Vec) -> Result<(), Self::Error> { + Ok(()) + } + + async fn subscribe( + &self, + _: String, + _: usize, + ) -> Result, Self::Error> { + Ok(MessageStream::new_empty()) + } + + async fn unsubscribe(&self, _: String) -> Result<(), Self::Error> { + Ok(()) + } + + async fn num_serv(&self, _: &str) -> Result { + Ok(0) + } + } + fn new_stub_ack_stream( + remote: MessageStream>, + timeout: Duration, + ) -> AckStream>> { + AckStream::new( + stream::empty::>(), + remote, + timeout, + 2, + Sid::new(), + Arc::new(Mutex::new(HashMap::new())), + ) + } + + //TODO: test weird behaviours, packets out of orders, etc + #[tokio::test] + async fn ack_stream() { + let (tx, rx) = tokio::sync::mpsc::channel(255); + let remote = MessageStream::new(rx); + let stream = new_stub_ack_stream(remote, Duration::from_secs(10)); + let node_id = Uid::new(); + let req_id = Sid::new(); + + // The two servers will send 2 acks each. + let ack_cnt_res = Response::<()> { + node_id, + r#type: ResponseType::BroadcastAckCount(2), + }; + tx.try_send(rmp_serde::to_vec(&(req_id, &ack_cnt_res)).unwrap()) + .unwrap(); + tx.try_send(rmp_serde::to_vec(&(req_id, &ack_cnt_res)).unwrap()) + .unwrap(); + + let ack_res = Response:: { + node_id, + r#type: ResponseType::BroadcastAck((Sid::new(), Ok(Value::Str(Str::from(""), None)))), + }; + for _ in 0..4 { + tx.try_send(rmp_serde::to_vec(&(req_id, &ack_res)).unwrap()) + .unwrap(); + } + futures_util::pin_mut!(stream); + for _ in 0..4 { + assert!(stream.next().await.is_some()); + } + assert!(stream.is_terminated()); + } + + #[tokio::test] + async fn ack_stream_timeout() { + let (tx, rx) = tokio::sync::mpsc::channel(255); + let remote = MessageStream::new(rx); + let stream = new_stub_ack_stream(remote, Duration::from_millis(50)); + let node_id = Uid::new(); + let req_id = Sid::new(); + // There will be only one ack count and then the stream will timeout. + let ack_cnt_res = Response::<()> { + node_id, + r#type: ResponseType::BroadcastAckCount(2), + }; + tx.try_send(rmp_serde::to_vec(&(req_id, ack_cnt_res)).unwrap()) + .unwrap(); + + futures_util::pin_mut!(stream); + tokio::time::sleep(Duration::from_millis(50)).await; + assert!(stream.next().await.is_none()); + assert!(stream.is_terminated()); + } + + #[tokio::test] + async fn ack_stream_drop() { + let (tx, rx) = tokio::sync::mpsc::channel(255); + let remote = MessageStream::new(rx); + let handlers = Arc::new(Mutex::new(HashMap::new())); + let id = Sid::new(); + handlers.lock().unwrap().insert(id, tx); + let stream = AckStream::new( + stream::empty::>(), + remote, + Duration::from_secs(10), + 2, + id, + handlers.clone(), + ); + drop(stream); + assert!(handlers.lock().unwrap().is_empty(),); + } + + #[test] + fn test_is_local_op() { + let server_id = Uid::new(); + let remote = RemoteSocketData { + id: Sid::new(), + server_id, + ns: "/".into(), + }; + let opts = BroadcastOptions::new_remote(&remote); + assert!(is_local_op(server_id, &opts)); + assert!(!is_local_op(Uid::new(), &opts)); + let opts = BroadcastOptions::new(Sid::new()); + assert!(!is_local_op(Uid::new(), &opts)); + } + + #[test] + fn check_ns_error() { + assert!(matches!( + check_ns::("#"), + Err(InitError::MalformedNamespace) + )); + assert!(matches!( + check_ns::(""), + Err(InitError::MalformedNamespace) + )); + } +} diff --git a/crates/socketioxide-redis/src/request.rs b/crates/socketioxide-redis/src/request.rs new file mode 100644 index 00000000..aea972fa --- /dev/null +++ b/crates/socketioxide-redis/src/request.rs @@ -0,0 +1,395 @@ +//! Custom request and response types for the Redis adapter. +//! Custom serialization/deserialization to reduce the size of the messages. +use std::{collections::HashSet, str::FromStr}; + +use serde::{de::SeqAccess, Deserialize, Serialize}; +use socketioxide_core::{ + adapter::{BroadcastOptions, Room}, + packet::Packet, + Sid, Uid, Value, +}; + +#[derive(Debug, PartialEq)] +pub enum RequestTypeOut<'a> { + /// Broadcast a packet to matching sockets. + Broadcast(&'a Packet), + /// Broadcast a packet to matching sockets and wait for acks. + BroadcastWithAck(&'a Packet), + /// Disconnect matching sockets. + DisconnectSockets, + /// Get all the rooms server. + AllRooms, + /// Add matching sockets to the rooms. + AddSockets(&'a Vec), + /// Remove matching sockets from the rooms. + DelSockets(&'a Vec), + /// Fetch socket data. + FetchSockets, +} +impl RequestTypeOut<'_> { + fn to_u8(&self) -> u8 { + match self { + Self::Broadcast(_) => 0, + Self::BroadcastWithAck(_) => 1, + Self::DisconnectSockets => 2, + Self::AllRooms => 3, + Self::AddSockets(_) => 4, + Self::DelSockets(_) => 5, + Self::FetchSockets => 6, + } + } +} + +#[derive(Debug)] +pub enum RequestTypeIn { + /// Broadcast a packet to matching sockets. + Broadcast(Packet), + /// Broadcast a packet to matching sockets and wait for acks. + BroadcastWithAck(Packet), + /// Disconnect matching sockets. + DisconnectSockets, + /// Get all the rooms server. + AllRooms, + /// Add matching sockets to the rooms. + AddSockets(Vec), + /// Remove matching sockets from the rooms. + DelSockets(Vec), + /// Fetch socket data. + FetchSockets, +} + +#[derive(Debug, PartialEq)] +pub struct RequestOut<'a> { + pub node_id: Uid, + pub id: Sid, + pub r#type: RequestTypeOut<'a>, + pub opts: &'a BroadcastOptions, +} +impl<'a> RequestOut<'a> { + pub fn new(node_id: Uid, r#type: RequestTypeOut<'a>, opts: &'a BroadcastOptions) -> Self { + Self { + node_id, + id: Sid::new(), + r#type, + opts, + } + } +} + +/// Custom implementation to serialize enum variant as u8. +impl<'a> Serialize for RequestOut<'a> { + fn serialize(&self, serializer: S) -> Result { + #[derive(Debug, Serialize)] + struct RawRequest<'a> { + node_id: Uid, + id: Sid, + r#type: u8, + packet: Option<&'a Packet>, + rooms: Option<&'a Vec>, + opts: &'a BroadcastOptions, + } + let raw = RawRequest::<'a> { + node_id: self.node_id, + id: self.id, + r#type: self.r#type.to_u8(), + packet: match &self.r#type { + RequestTypeOut::Broadcast(p) | RequestTypeOut::BroadcastWithAck(p) => Some(p), + _ => None, + }, + rooms: match &self.r#type { + RequestTypeOut::AddSockets(r) | RequestTypeOut::DelSockets(r) => Some(r), + _ => None, + }, + opts: self.opts, + }; + raw.serialize(serializer) + } +} + +#[derive(Debug)] +pub struct RequestIn { + pub node_id: Uid, + pub id: Sid, + pub r#type: RequestTypeIn, + pub opts: BroadcastOptions, +} +impl<'de> Deserialize<'de> for RequestIn { + fn deserialize>(deserializer: D) -> Result { + #[derive(Debug, Deserialize)] + struct RawRequest { + node_id: Uid, + id: Sid, + r#type: u8, + packet: Option, + rooms: Option>, + opts: BroadcastOptions, + } + let raw = RawRequest::deserialize(deserializer)?; + let err = |field| serde::de::Error::custom(format!("missing field: {}", field)); + let r#type = match raw.r#type { + 0 => RequestTypeIn::Broadcast(raw.packet.ok_or(err("packet"))?), + 1 => RequestTypeIn::BroadcastWithAck(raw.packet.ok_or(err("packet"))?), + 2 => RequestTypeIn::DisconnectSockets, + 3 => RequestTypeIn::AllRooms, + 4 => RequestTypeIn::AddSockets(raw.rooms.ok_or(err("room"))?), + 5 => RequestTypeIn::DelSockets(raw.rooms.ok_or(err("room"))?), + 6 => RequestTypeIn::FetchSockets, + _ => return Err(serde::de::Error::custom("invalid request type")), + }; + Ok(Self { + node_id: raw.node_id, + id: raw.id, + r#type, + opts: raw.opts, + }) + } +} + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +pub struct Response { + pub node_id: Uid, + pub r#type: ResponseType, +} + +#[derive(Debug, PartialEq)] +pub enum ResponseType { + BroadcastAck((Sid, Result)), + BroadcastAckCount(u32), + AllRooms(HashSet), + FetchSockets(Vec), +} +impl ResponseType { + pub fn to_u8(&self) -> u8 { + match self { + Self::BroadcastAck(_) => 0, + Self::BroadcastAckCount(_) => 1, + Self::AllRooms(_) => 2, + Self::FetchSockets(_) => 3, + } + } +} +impl Serialize for ResponseType { + fn serialize(&self, serializer: S) -> Result { + match self { + Self::BroadcastAck((sid, res)) => (0, (sid, res)).serialize(serializer), + Self::BroadcastAckCount(count) => (1, count).serialize(serializer), + Self::AllRooms(rooms) => (2, rooms).serialize(serializer), + Self::FetchSockets(sockets) => (3, sockets).serialize(serializer), + } + } +} +impl<'de, D: Deserialize<'de>> Deserialize<'de> for ResponseType { + fn deserialize>(deserializer: DE) -> Result { + struct TupleVisitor(std::marker::PhantomData); + impl<'de, D: Deserialize<'de>> serde::de::Visitor<'de> for TupleVisitor { + type Value = ResponseType; + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "a tuple of u8 and D") + } + fn visit_seq>(self, mut seq: A) -> Result { + fn deser<'de, T: Deserialize<'de>, A: SeqAccess<'de>>( + seq: &mut A, + ) -> Result { + seq.next_element()? + .ok_or_else(|| serde::de::Error::invalid_length(1, &"")) + } + + let el = match deser::(&mut seq)? { + 0 => ResponseType::BroadcastAck(deser(&mut seq)?), + 1 => ResponseType::BroadcastAckCount(deser(&mut seq)?), + 2 => ResponseType::AllRooms(deser(&mut seq)?), + 3 => ResponseType::FetchSockets(deser(&mut seq)?), + _ => return Err(serde::de::Error::custom("invalid response type")), + }; + Ok(el) + } + } + + deserializer.deserialize_tuple(2, TupleVisitor::(std::marker::PhantomData)) + } +} +impl Response { + pub fn into_rooms(self) -> Option> { + match self.r#type { + ResponseType::AllRooms(rooms) => Some(rooms), + _ => None, + } + } + pub fn into_fetch_sockets(self) -> Option> { + match self.r#type { + ResponseType::FetchSockets(sockets) => Some(sockets), + _ => None, + } + } +} + +/// Extract the request id from a data encoded as `[Sid, ...]` +pub fn read_req_id(data: &[u8]) -> Option { + let mut rd = data; + let len = rmp::decode::read_array_len(&mut rd).ok()?; + if len < 1 { + return None; + } + + let mut buff = [0u8; Sid::ZERO.as_str().len()]; + let str = rmp::decode::read_str(&mut rd, &mut buff).ok()?; + Sid::from_str(str).ok() +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use super::*; + + impl<'a> From<&'a RequestIn> for RequestOut<'a> { + fn from(req: &'a RequestIn) -> Self { + Self { + node_id: req.node_id, + id: req.id, + opts: &req.opts, + r#type: match &req.r#type { + RequestTypeIn::Broadcast(p) => RequestTypeOut::Broadcast(p), + RequestTypeIn::BroadcastWithAck(p) => RequestTypeOut::BroadcastWithAck(p), + RequestTypeIn::DisconnectSockets => RequestTypeOut::DisconnectSockets, + RequestTypeIn::AllRooms => RequestTypeOut::AllRooms, + RequestTypeIn::AddSockets(r) => RequestTypeOut::AddSockets(r), + RequestTypeIn::DelSockets(r) => RequestTypeOut::DelSockets(r), + RequestTypeIn::FetchSockets => RequestTypeOut::FetchSockets, + }, + } + } + } + + fn assert_request_serde(value: RequestOut<'_>) { + let serialized = rmp_serde::to_vec(&value).unwrap(); + let deserialized: RequestIn = rmp_serde::from_slice(&serialized).unwrap(); + assert_eq!(value, (&deserialized).into()) + } + + #[test] + fn request_broadcast_serde() { + let packet = Packet::event("foo", Value::Str("bar".into(), None)); + let opts = BroadcastOptions::new(Sid::new()); + let req = RequestOut::new(Uid::new(), RequestTypeOut::Broadcast(&packet), &opts); + assert_request_serde(req); + } + + #[test] + fn request_broadcast_with_ack_serde() { + let packet = Packet::event("foo", Value::Str("bar".into(), None)); + let opts = BroadcastOptions::new(Sid::new()); + let req = RequestOut::new(Uid::new(), RequestTypeOut::BroadcastWithAck(&packet), &opts); + assert_request_serde(req); + } + + #[test] + fn request_add_sockets_serde() { + let opts = BroadcastOptions::new(Sid::new()); + let rooms = vec!["foo".into(), "bar".into()]; + let req = RequestOut::new(Uid::new(), RequestTypeOut::AddSockets(&rooms), &opts); + assert_request_serde(req); + } + + #[test] + fn request_del_sockets_serde() { + let opts = BroadcastOptions::new(Sid::new()); + let rooms = vec!["foo".into(), "bar".into()]; + let req = RequestOut::new(Uid::new(), RequestTypeOut::DelSockets(&rooms), &opts); + assert_request_serde(req); + } + + #[test] + fn request_disconnect_sockets_serde() { + let opts = BroadcastOptions::new(Sid::new()); + let req = RequestOut::new(Uid::new(), RequestTypeOut::DisconnectSockets, &opts); + assert_request_serde(req); + } + + #[test] + fn request_fetch_sockets_serde() { + let opts = BroadcastOptions::new(Sid::new()); + let req = RequestOut::new(Uid::new(), RequestTypeOut::FetchSockets, &opts); + assert_request_serde(req); + } + + #[test] + fn response_serde_broadcast_ack() { + let res = Response { + node_id: Uid::new(), + r#type: ResponseType::BroadcastAck(( + Sid::new(), + Ok(Value::Bytes(Bytes::from_static(b"test"))), + )), + }; + let serialized = rmp_serde::to_vec(&res).unwrap(); + let deserialized: Response = rmp_serde::from_slice(&serialized).unwrap(); + assert_eq!(res, deserialized); + } + #[test] + fn response_serde_broadcast_ack_count() { + let res = Response { + node_id: Uid::new(), + r#type: ResponseType::BroadcastAckCount(42), + }; + let serialized = rmp_serde::to_vec(&res).unwrap(); + let deserialized: Response = rmp_serde::from_slice(&serialized).unwrap(); + assert_eq!(res, deserialized); + } + + #[test] + fn response_serde_all_rooms() { + let rooms = ["foo".into(), "bar".into()]; + let res = Response { + node_id: Uid::new(), + r#type: ResponseType::AllRooms(rooms.iter().cloned().collect()), + }; + let serialized = rmp_serde::to_vec(&res).unwrap(); + let deserialized: Response = rmp_serde::from_slice(&serialized).unwrap(); + assert_eq!(res, deserialized); + } + + #[test] + fn response_serde_fetch_sockets() { + let sockets = vec![Sid::new(), Sid::new()]; + let res = Response { + node_id: Uid::new(), + r#type: ResponseType::FetchSockets(sockets), + }; + let serialized = rmp_serde::to_vec(&res).unwrap(); + let deserialized: Response = rmp_serde::from_slice(&serialized).unwrap(); + assert_eq!(res, deserialized); + } + + #[test] + fn read_req_id() { + let sid = Sid::new(); + let buff: [u8; 4] = [0, 1, 3, 4]; + let data = rmp_serde::to_vec(&(sid, buff)).unwrap(); + let req_id = super::read_req_id(&data); + assert_eq!(req_id, Some(sid)); + } + + #[test] + fn read_req_bad_id() { + let buff: [u8; 4] = [0, 1, 3, 4]; + let data = rmp_serde::to_vec(&("test", buff)).unwrap(); + let req_id = super::read_req_id(&data); + assert_eq!(req_id, None); + } + + #[test] + fn read_req_id_not_array() { + let sid = Sid::new(); + let data = rmp_serde::to_vec(&sid).unwrap(); + let req_id = super::read_req_id(&data); + assert_eq!(req_id, None); + } + + #[test] + fn read_req_id_empty_array() { + let data = rmp_serde::to_vec::<[u8; 0]>(&[]).unwrap(); + let req_id = super::read_req_id(&data); + assert_eq!(req_id, None); + } +} diff --git a/crates/socketioxide-redis/src/stream.rs b/crates/socketioxide-redis/src/stream.rs new file mode 100644 index 00000000..b0580b9a --- /dev/null +++ b/crates/socketioxide-redis/src/stream.rs @@ -0,0 +1,217 @@ +use std::{ + fmt, + pin::Pin, + sync::{Arc, Mutex}, + task::{self, Poll}, + time::Duration, +}; + +use futures_core::{FusedStream, Stream}; +use futures_util::{stream::TakeUntil, StreamExt}; +use pin_project_lite::pin_project; +use serde::de::DeserializeOwned; +use socketioxide_core::{adapter::AckStreamItem, Sid}; +use tokio::time; + +use crate::{ + drivers::MessageStream, + request::{Response, ResponseType}, + ResponseHandlers, +}; + +pin_project! { + /// A stream of acknowledgement messages received from the local and remote servers. + /// It merges the local ack stream with the remote ack stream from all the servers. + // The server_cnt is the number of servers that are expected to send a AckCount message. + // It is decremented each time a AckCount message is received. + // + // The ack_cnt is the number of acks that are expected to be received. It is the sum of all the the ack counts. + // And it is decremented each time an ack is received. + // + // Therefore an exhausted stream correspond to `ack_cnt == 0` and `server_cnt == 0`. + pub struct AckStream { + #[pin] + local: S, + #[pin] + remote: DropStream>, time::Sleep>>, + ack_cnt: u32, + total_ack_cnt: usize, + serv_cnt: u16, + } +} + +impl AckStream { + pub fn new( + local: S, + remote: MessageStream>, + timeout: Duration, + serv_cnt: u16, + req_id: Sid, + handlers: Arc>, + ) -> Self { + let remote = remote.take_until(time::sleep(timeout)); + let remote = DropStream::new(remote, handlers, req_id); + Self { + local, + remote, + ack_cnt: 0, + total_ack_cnt: 0, + serv_cnt, + } + } + pub fn new_local(local: S) -> Self { + let handlers = Arc::new(Mutex::new(ResponseHandlers::new())); + let remote = MessageStream::new_empty().take_until(time::sleep(Duration::ZERO)); + let remote = DropStream::new(remote, handlers, Sid::ZERO); + Self { + local, + remote, + ack_cnt: 0, + total_ack_cnt: 0, + serv_cnt: 0, + } + } +} +impl AckStream +where + Err: DeserializeOwned + fmt::Debug, + S: Stream> + FusedStream, +{ + /// Poll the remote stream. First the count of acks is received, then the acks are received. + /// We expect `serv_cnt` of `BroadcastAckCount` messages to be received, then we expect + /// `ack_cnt` of `BroadcastAck` messages. + fn poll_remote( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + ) -> Poll>> { + // remote stream is not fused, so we need to check if it is terminated + if FusedStream::is_terminated(&self) { + return Poll::Ready(None); + } + + let projection = self.as_mut().project(); + match projection.remote.poll_next(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(item)) => { + let res = rmp_serde::from_slice::<(Sid, Response)>(&item); + match res { + Ok(( + req_id, + Response { + node_id: uid, + r#type: ResponseType::BroadcastAckCount(count), + }, + )) if *projection.serv_cnt > 0 => { + tracing::trace!(?uid, ?req_id, "receiving broadcast ack count {count}"); + *projection.ack_cnt += count; + *projection.total_ack_cnt += count as usize; + *projection.serv_cnt -= 1; + self.poll_remote(cx) + } + Ok(( + req_id, + Response { + node_id: uid, + r#type: ResponseType::BroadcastAck((sid, res)), + }, + )) if *projection.ack_cnt > 0 => { + tracing::trace!(?uid, ?req_id, "receiving broadcast ack {sid} {:?}", res); + *projection.ack_cnt -= 1; + Poll::Ready(Some((sid, res))) + } + Ok((req_id, Response { node_id: uid, .. })) => { + tracing::warn!(?uid, ?req_id, ?self, "unexpected response type"); + self.poll_remote(cx) + } + Err(e) => { + tracing::warn!("error decoding ack response: {e}"); + self.poll_remote(cx) + } + } + } + } + } +} +impl Stream for AckStream +where + E: DeserializeOwned + fmt::Debug, + S: Stream> + FusedStream, +{ + type Item = AckStreamItem; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + match self.as_mut().project().local.poll_next(cx) { + Poll::Pending | Poll::Ready(None) => self.poll_remote(cx), + Poll::Ready(Some(item)) => Poll::Ready(Some(item)), + } + } + fn size_hint(&self) -> (usize, Option) { + let (lower, upper) = self.local.size_hint(); + (lower, upper.map(|upper| upper + self.total_ack_cnt)) + } +} + +impl FusedStream for AckStream +where + Err: DeserializeOwned + fmt::Debug, + S: Stream> + FusedStream, +{ + /// The stream is terminated if: + /// * The local stream is terminated. + /// * All the servers have sent the expected ack count. + /// * We have received all the expected acks. + fn is_terminated(&self) -> bool { + // remote stream is terminated if the timeout is reached + let remote_term = (self.ack_cnt == 0 && self.serv_cnt == 0) || self.remote.is_terminated(); + self.local.is_terminated() && remote_term + } +} +impl fmt::Debug for AckStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AckStream") + .field("ack_cnt", &self.ack_cnt) + .field("total_ack_cnt", &self.total_ack_cnt) + .field("serv_cnt", &self.serv_cnt) + .finish() + } +} + +pin_project! { + /// A stream that unsubscribes from its source channel when dropped. + pub struct DropStream { + #[pin] + stream: S, + req_id: Sid, + handlers: Arc> + } + impl PinnedDrop for DropStream { + fn drop(this: Pin<&mut Self>) { + let stream = this.project(); + let chan = stream.req_id; + tracing::debug!(?chan, "dropping stream"); + stream.handlers.lock().unwrap().remove(chan); + } + } +} +impl DropStream { + pub fn new(stream: S, handlers: Arc>, req_id: Sid) -> Self { + Self { + stream, + handlers, + req_id, + } + } +} +impl Stream for DropStream { + type Item = S::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + self.project().stream.poll_next(cx) + } +} +impl FusedStream for DropStream { + fn is_terminated(&self) -> bool { + self.stream.is_terminated() + } +} diff --git a/crates/socketioxide-redis/tests/broadcast.rs b/crates/socketioxide-redis/tests/broadcast.rs new file mode 100644 index 00000000..41d8e6e3 --- /dev/null +++ b/crates/socketioxide-redis/tests/broadcast.rs @@ -0,0 +1,150 @@ +use std::time::Duration; + +use socketioxide::{adapter::Adapter, extract::SocketRef}; +mod fixture; + +#[tokio::test] +pub async fn broadcast() { + async fn handler(socket: SocketRef) { + // delay to ensure all socket/servers are connected + tokio::time::sleep(tokio::time::Duration::from_millis(1)).await; + socket.broadcast().emit("test", &2).await.unwrap(); + } + + let [io1, io2] = fixture::spawn_servers(); + + io1.ns("/", handler).await.unwrap(); + io2.ns("/", handler).await.unwrap(); + + let ((_tx1, mut rx1), (_tx2, mut rx2)) = + tokio::join!(io1.new_dummy_sock("/", ()), io2.new_dummy_sock("/", ())); + + timeout_rcv!(&mut rx1); // Connect "/" packet + timeout_rcv!(&mut rx2); // Connect "/" packet + + assert_eq!(timeout_rcv!(&mut rx1), r#"42["test",2]"#); + assert_eq!(timeout_rcv!(&mut rx2), r#"42["test",2]"#); + + timeout_rcv_err!(&mut rx1); + timeout_rcv_err!(&mut rx2); +} + +#[tokio::test] +pub async fn broadcast_rooms() { + let [io1, io2, io3] = fixture::spawn_servers(); + let handler = |room: &'static str, to: &'static str| { + move |socket: SocketRef<_>| async move { + // delay to ensure all socket/servers are connected + socket.join(room); + tokio::time::sleep(tokio::time::Duration::from_millis(5)).await; + socket.to(to).emit("test", room).await.unwrap(); + } + }; + + io1.ns("/", handler("room1", "room2")).await.unwrap(); + io2.ns("/", handler("room2", "room3")).await.unwrap(); + io3.ns("/", handler("room3", "room1")).await.unwrap(); + + let ((_tx1, mut rx1), (_tx2, mut rx2), (_tx3, mut rx3)) = tokio::join!( + io1.new_dummy_sock("/", ()), + io2.new_dummy_sock("/", ()), + io3.new_dummy_sock("/", ()) + ); + + timeout_rcv!(&mut rx1); // Connect "/" packet + timeout_rcv!(&mut rx2); // Connect "/" packet + timeout_rcv!(&mut rx3); // Connect "/" packet + + // socket 1 is receiving a packet from io3 + assert_eq!(timeout_rcv!(&mut rx1), r#"42["test","room3"]"#); + // socket 2 is receiving a packet from io2 + assert_eq!(timeout_rcv!(&mut rx2), r#"42["test","room1"]"#); + // socket 3 is receiving a packet from io1 + assert_eq!(timeout_rcv!(&mut rx3), r#"42["test","room2"]"#); + + timeout_rcv_err!(&mut rx1); + timeout_rcv_err!(&mut rx2); + timeout_rcv_err!(&mut rx3); +} + +#[tokio::test] +pub async fn broadcast_with_ack() { + use futures_util::stream::StreamExt; + + async fn handler(socket: SocketRef) { + // delay to ensure all socket/servers are connected + tokio::time::sleep(tokio::time::Duration::from_millis(1)).await; + socket + .broadcast() + .emit_with_ack::<_, String>("test", "bar") + .await + .unwrap() + .for_each(|(_, res)| { + socket.emit("ack_res", &res).unwrap(); + async move {} + }) + .await; + } + + let [io1, io2] = fixture::spawn_servers(); + + io1.ns("/", handler).await.unwrap(); + io2.ns("/", || ()).await.unwrap(); + + let ((_tx1, mut rx1), (tx2, mut rx2)) = + tokio::join!(io1.new_dummy_sock("/", ()), io2.new_dummy_sock("/", ())); + + timeout_rcv!(&mut rx1); // Connect "/" packet + timeout_rcv!(&mut rx2); // Connect "/" packet + + assert_eq!(timeout_rcv!(&mut rx2), r#"421["test","bar"]"#); + let packet_res = r#"431["foo"]"#.to_string().try_into().unwrap(); + tx2.try_send(packet_res).unwrap(); + assert_eq!(timeout_rcv!(&mut rx1), r#"42["ack_res",{"Ok":"foo"}]"#); + + timeout_rcv_err!(&mut rx1); + timeout_rcv_err!(&mut rx2); +} + +#[tokio::test] +pub async fn broadcast_with_ack_timeout() { + use futures_util::StreamExt; + const TIMEOUT: Duration = Duration::from_millis(50); + + async fn handler(socket: SocketRef) { + socket + .broadcast() + .emit_with_ack::<_, String>("test", "bar") + .await + .unwrap() + .for_each(|(_, res)| { + socket.emit("ack_res", &res).unwrap(); + async move {} + }) + .await; + socket.emit("ack_res", "timeout").unwrap(); + } + + let [io1, io2] = fixture::spawn_buggy_servers(TIMEOUT); + + io1.ns("/", handler).await.unwrap(); + io2.ns("/", || ()).await.unwrap(); + + let now = std::time::Instant::now(); + let ((_tx1, mut rx1), (_tx2, mut rx2)) = + tokio::join!(io1.new_dummy_sock("/", ()), io2.new_dummy_sock("/", ())); + + timeout_rcv!(&mut rx1); // Connect "/" packet + timeout_rcv!(&mut rx2); // Connect "/" packet + + assert_eq!(timeout_rcv!(&mut rx2), r#"421["test","bar"]"#); // emit with ack message + // We do not answer + assert_eq!( + timeout_rcv!(&mut rx1, TIMEOUT.as_millis() as u64 + 100), + r#"42["ack_res","timeout"]"# + ); + assert!(now.elapsed() >= TIMEOUT); + + timeout_rcv_err!(&mut rx1); + timeout_rcv_err!(&mut rx2); +} diff --git a/crates/socketioxide-redis/tests/fixture.rs b/crates/socketioxide-redis/tests/fixture.rs new file mode 100644 index 00000000..af646377 --- /dev/null +++ b/crates/socketioxide-redis/tests/fixture.rs @@ -0,0 +1,177 @@ +#![allow(dead_code)] + +use std::{ + collections::HashMap, + future::Future, + sync::{Arc, RwLock}, + time::Duration, +}; +use tokio::sync::mpsc; + +use socketioxide::{adapter::Emitter, SocketIo}; +use socketioxide_redis::{ + drivers::{Driver, MessageStream}, + CustomRedisAdapter, RedisAdapterConfig, RedisAdapterCtr, +}; + +/// Spawns a number of servers with a stub driver for testing. +/// Every server will be connected to every other server. +pub fn spawn_servers() -> [SocketIo>; N] { + let sync_buff = Arc::new(RwLock::new(Vec::with_capacity(N))); + + [0; N].map(|_| { + let (driver, mut rx, tx) = StubDriver::new(N as u16); + + // pipe messages to all other servers + sync_buff.write().unwrap().push(tx); + let sync_buff = sync_buff.clone(); + tokio::spawn(async move { + while let Some((chan, data)) = rx.recv().await { + tracing::debug!("received data to broadcast {}", chan); + for tx in sync_buff.read().unwrap().iter() { + tracing::debug!("sending data for {}", chan); + tx.try_send((chan.clone(), data.clone())).unwrap(); + } + } + }); + + let adapter = RedisAdapterCtr::new_with_driver(driver, RedisAdapterConfig::default()); + let (_svc, io) = SocketIo::builder() + .with_adapter::>(adapter) + .build_svc(); + io + }) +} + +/// Spawns a number of servers with a stub driver for testing. +/// The internal server count is set to N + 2 to trigger a timeout when expecting N responses. +pub fn spawn_buggy_servers( + timeout: Duration, +) -> [SocketIo>; N] { + let sync_buff = Arc::new(RwLock::new(Vec::with_capacity(N))); + + [0; N].map(|_| { + let (driver, mut rx, tx) = StubDriver::new(N as u16 + 2); // Fake server count to trigger timeout + + // pipe messages to all other servers + sync_buff.write().unwrap().push(tx); + let sync_buff = sync_buff.clone(); + tokio::spawn(async move { + while let Some((chan, data)) = rx.recv().await { + tracing::debug!("received data to broadcast {}", chan); + for tx in sync_buff.read().unwrap().iter() { + tracing::debug!("sending data for {}", chan); + tx.try_send((chan.clone(), data.clone())).unwrap(); + } + } + }); + + let config = RedisAdapterConfig::new().with_request_timeout(timeout); + let adapter = RedisAdapterCtr::new_with_driver(driver, config); + let (_svc, io) = SocketIo::builder() + .with_adapter::>(adapter) + .build_svc(); + io + }) +} + +type ChanItem = (String, Vec); +type ResponseHandlers = HashMap>; +#[derive(Debug, Clone)] +pub struct StubDriver { + tx: mpsc::Sender, + handlers: Arc>, + num_serv: u16, +} + +async fn pipe_handlers(mut rx: mpsc::Receiver, handlers: Arc>) { + while let Some((chan, data)) = rx.recv().await { + let handlers = handlers.read().unwrap(); + tracing::debug!(?handlers, "received data to broadcast {}", chan); + if let Some(tx) = handlers.get(&chan) { + tx.try_send((chan, data)).unwrap(); + } + } +} +impl StubDriver { + pub fn new(num_serv: u16) -> (Self, mpsc::Receiver, mpsc::Sender) { + let (tx, rx) = mpsc::channel(255); // driver emitter + let (tx1, rx1) = mpsc::channel(255); // driver receiver + let handlers = Arc::new(RwLock::new(HashMap::new())); + + tokio::spawn(pipe_handlers(rx1, handlers.clone())); + + let driver = Self { + tx, + num_serv, + handlers, + }; + (driver, rx, tx1) + } + pub fn handler_cnt(&self) -> usize { + self.handlers.read().unwrap().len() + } +} + +impl Driver for StubDriver { + type Error = std::convert::Infallible; + + fn publish( + &self, + chan: String, + val: Vec, + ) -> impl Future> + Send { + self.tx.try_send((chan, val)).unwrap(); + async move { Ok(()) } + } + + async fn subscribe( + &self, + pat: String, + _: usize, + ) -> Result, Self::Error> { + let (tx, rx) = mpsc::channel(255); + self.handlers.write().unwrap().insert(pat, tx); + Ok(MessageStream::new(rx)) + } + + async fn unsubscribe(&self, pat: String) -> Result<(), Self::Error> { + self.handlers.write().unwrap().remove(&pat); + Ok(()) + } + + async fn num_serv(&self, _chan: &str) -> Result { + Ok(self.num_serv) + } +} + +#[macro_export] +macro_rules! timeout_rcv_err { + ($srx:expr) => { + tokio::time::timeout(std::time::Duration::from_millis(10), $srx.recv()) + .await + .unwrap_err(); + }; +} + +#[macro_export] +macro_rules! timeout_rcv { + ($srx:expr) => { + TryInto::::try_into( + tokio::time::timeout(std::time::Duration::from_millis(10), $srx.recv()) + .await + .unwrap() + .unwrap(), + ) + .unwrap() + }; + ($srx:expr, $t:expr) => { + TryInto::::try_into( + tokio::time::timeout(std::time::Duration::from_millis($t), $srx.recv()) + .await + .unwrap() + .unwrap(), + ) + .unwrap() + }; +} diff --git a/crates/socketioxide-redis/tests/local.rs b/crates/socketioxide-redis/tests/local.rs new file mode 100644 index 00000000..88b7b00e --- /dev/null +++ b/crates/socketioxide-redis/tests/local.rs @@ -0,0 +1,32 @@ +//! Check that each adapter function with a broadcast options that is [`Local`] returns an immediate future +mod fixture; + +macro_rules! assert_now { + ($fut:expr) => { + #[allow(unused_must_use)] + futures_util::FutureExt::now_or_never($fut) + .expect("Returned future should be sync") + .unwrap() + }; +} + +#[tokio::test] +async fn test_local_fns() { + let [io1, io2] = fixture::spawn_servers(); + + io1.ns("/", || ()).await.unwrap(); + io2.ns("/", || ()).await.unwrap(); + + let (_, mut rx1) = io1.new_dummy_sock("/", ()).await; + let (_, mut rx2) = io2.new_dummy_sock("/", ()).await; + + timeout_rcv!(&mut rx1); // connect packet + timeout_rcv!(&mut rx2); // connect packet + + assert_now!(io1.local().emit("test", "test")); + assert_now!(io1.local().emit_with_ack::<_, ()>("test", "test")); + assert_now!(io1.local().join("test")); + assert_now!(io1.local().leave("test")); + assert_now!(io1.local().disconnect()); + assert_now!(io1.local().fetch_sockets()); +} diff --git a/crates/socketioxide-redis/tests/rooms.rs b/crates/socketioxide-redis/tests/rooms.rs new file mode 100644 index 00000000..ad82f1a6 --- /dev/null +++ b/crates/socketioxide-redis/tests/rooms.rs @@ -0,0 +1,115 @@ +use std::time::Duration; + +use socketioxide::extract::SocketRef; + +mod fixture; + +#[tokio::test] +pub async fn all_rooms() { + let [io1, io2, io3] = fixture::spawn_servers(); + let handler = |rooms: &'static [&'static str]| move |socket: SocketRef<_>| socket.join(rooms); + + io1.ns("/", handler(&["room1", "room2"])).await.unwrap(); + io2.ns("/", handler(&["room2", "room3"])).await.unwrap(); + io3.ns("/", handler(&["room3", "room1"])).await.unwrap(); + + let ((_tx1, mut rx1), (_tx2, mut rx2), (_tx3, mut rx3)) = tokio::join!( + io1.new_dummy_sock("/", ()), + io2.new_dummy_sock("/", ()), + io3.new_dummy_sock("/", ()) + ); + + timeout_rcv!(&mut rx1); // Connect "/" packet + timeout_rcv!(&mut rx2); // Connect "/" packet + timeout_rcv!(&mut rx3); // Connect "/" packet + + const ROOMS: [&str; 3] = ["room1", "room2", "room3"]; + for io in [io1, io2, io3] { + let mut rooms = io.rooms().await.unwrap(); + rooms.sort(); + assert_eq!(rooms, ROOMS); + } + + timeout_rcv_err!(&mut rx1); + timeout_rcv_err!(&mut rx2); + timeout_rcv_err!(&mut rx3); +} + +#[tokio::test] +pub async fn all_rooms_timeout() { + const TIMEOUT: Duration = Duration::from_millis(50); + let [io1, io2, io3] = fixture::spawn_buggy_servers(TIMEOUT); + let handler = |rooms: &'static [&'static str]| move |socket: SocketRef<_>| socket.join(rooms); + + io1.ns("/", handler(&["room1", "room2"])).await.unwrap(); + io2.ns("/", handler(&["room2", "room3"])).await.unwrap(); + io3.ns("/", handler(&["room3", "room1"])).await.unwrap(); + + let ((_tx1, mut rx1), (_tx2, mut rx2), (_tx3, mut rx3)) = tokio::join!( + io1.new_dummy_sock("/", ()), + io2.new_dummy_sock("/", ()), + io3.new_dummy_sock("/", ()) + ); + + timeout_rcv!(&mut rx1); // Connect "/" packet + timeout_rcv!(&mut rx2); // Connect "/" packet + timeout_rcv!(&mut rx3); // Connect "/" packet + + const ROOMS: [&str; 3] = ["room1", "room2", "room3"]; + for io in [io1, io2, io3] { + let now = std::time::Instant::now(); + let mut rooms = io.rooms().await.unwrap(); + assert!(now.elapsed() >= TIMEOUT); // timeout time + rooms.sort(); + assert_eq!(rooms, ROOMS); + } + + timeout_rcv_err!(&mut rx1); + timeout_rcv_err!(&mut rx2); + timeout_rcv_err!(&mut rx3); +} +#[tokio::test] +pub async fn add_sockets() { + let handler = |room: &'static str| move |socket: SocketRef<_>| socket.join(room); + let [io1, io2] = fixture::spawn_servers(); + + io1.ns("/", handler("room1")).await.unwrap(); + io2.ns("/", handler("room3")).await.unwrap(); + + let ((_tx1, mut rx1), (_tx2, mut rx2)) = + tokio::join!(io1.new_dummy_sock("/", ()), io2.new_dummy_sock("/", ())); + + timeout_rcv!(&mut rx1); // Connect "/" packet + timeout_rcv!(&mut rx2); // Connect "/" packet + io1.broadcast().join("room2").await.unwrap(); + let mut rooms = io1.rooms().await.unwrap(); + rooms.sort(); + assert_eq!(rooms, ["room1", "room2", "room3"]); + + timeout_rcv_err!(&mut rx1); + timeout_rcv_err!(&mut rx2); +} + +#[tokio::test] +pub async fn del_sockets() { + let handler = |rooms: &'static [&'static str]| move |socket: SocketRef<_>| socket.join(rooms); + let [io1, io2] = fixture::spawn_servers(); + + io1.ns("/", handler(&["room1", "room2"])).await.unwrap(); + io2.ns("/", handler(&["room3", "room2"])).await.unwrap(); + + let ((_tx1, mut rx1), (_tx2, mut rx2)) = + tokio::join!(io1.new_dummy_sock("/", ()), io2.new_dummy_sock("/", ())); + + timeout_rcv!(&mut rx1); // Connect "/" packet + timeout_rcv!(&mut rx2); // Connect "/" packet + + io1.broadcast().leave("room2").await.unwrap(); + + let mut rooms = io1.rooms().await.unwrap(); + rooms.sort(); + assert_eq!(rooms, ["room1", "room3"]); + + timeout_rcv_err!(&mut rx1); + timeout_rcv_err!(&mut rx2); +} diff --git a/crates/socketioxide-redis/tests/sockets.rs b/crates/socketioxide-redis/tests/sockets.rs new file mode 100644 index 00000000..1ff24b0e --- /dev/null +++ b/crates/socketioxide-redis/tests/sockets.rs @@ -0,0 +1,169 @@ +use std::{str::FromStr, time::Duration}; + +use socketioxide::{ + adapter::Adapter, extract::SocketRef, operators::BroadcastOperators, socket::RemoteSocket, + SocketIo, +}; +use socketioxide_core::{adapter::RemoteSocketData, Sid, Str}; +use tokio::time::Instant; + +mod fixture; +fn extract_sid(data: &str) -> Sid { + let data = data + .split("\"sid\":\"") + .nth(1) + .and_then(|s| s.split('"').next()) + .unwrap(); + Sid::from_str(data).unwrap() +} +async fn fetch_sockets_data(op: BroadcastOperators) -> Vec { + let mut sockets = op + .fetch_sockets() + .await + .unwrap() + .into_iter() + .map(RemoteSocket::into_data) + .collect::>(); + sockets.sort_by(|a, b| a.id.cmp(&b.id)); + sockets +} +fn create_expected_sockets( + ids: [Sid; N], + ios: [&SocketIo; N], +) -> [RemoteSocketData; N] { + let mut i = 0; + let mut sockets = ios.map(|io| { + let id = ids[i]; + i += 1; + RemoteSocketData { + id, + server_id: io.config().server_id, + ns: Str::from("/"), + } + }); + sockets.sort_by(|a, b| a.id.cmp(&b.id)); + sockets +} + +#[tokio::test] +pub async fn fetch_sockets() { + let [io1, io2, io3] = fixture::spawn_servers::<3>(); + + io1.ns("/", || ()).await.unwrap(); + io2.ns("/", || ()).await.unwrap(); + io3.ns("/", || ()).await.unwrap(); + + let (_, mut rx1) = io1.new_dummy_sock("/", ()).await; + let (_, mut rx2) = io2.new_dummy_sock("/", ()).await; + let (_, mut rx3) = io3.new_dummy_sock("/", ()).await; + + let id1 = extract_sid(&timeout_rcv!(&mut rx1)); + let id2 = extract_sid(&timeout_rcv!(&mut rx2)); + let id3 = extract_sid(&timeout_rcv!(&mut rx3)); + + let mut expected_sockets = create_expected_sockets([id1, id2, id3], [&io1, &io2, &io3]); + expected_sockets.sort_by(|a, b| a.id.cmp(&b.id)); + + let sockets = fetch_sockets_data(io1.broadcast()).await; + assert_eq!(sockets, expected_sockets); + + let sockets = fetch_sockets_data(io2.broadcast()).await; + assert_eq!(sockets, expected_sockets); + + let sockets = fetch_sockets_data(io3.broadcast()).await; + assert_eq!(sockets, expected_sockets); +} + +#[tokio::test] +pub async fn fetch_sockets_with_rooms() { + let [io1, io2, io3] = fixture::spawn_servers::<3>(); + let handler = |rooms: &'static [&'static str]| move |socket: SocketRef<_>| socket.join(rooms); + + io1.ns("/", handler(&["room1", "room2"])).await.unwrap(); + io2.ns("/", handler(&["room2", "room3"])).await.unwrap(); + io3.ns("/", handler(&["room3", "room1"])).await.unwrap(); + + let (_, mut rx1) = io1.new_dummy_sock("/", ()).await; + let (_, mut rx2) = io2.new_dummy_sock("/", ()).await; + let (_, mut rx3) = io3.new_dummy_sock("/", ()).await; + + let id1 = extract_sid(&timeout_rcv!(&mut rx1)); + let id2 = extract_sid(&timeout_rcv!(&mut rx2)); + let id3 = extract_sid(&timeout_rcv!(&mut rx3)); + + let sockets = fetch_sockets_data(io1.to("room1")).await; + assert_eq!(sockets, create_expected_sockets([id1, id3], [&io1, &io3])); + + let sockets = fetch_sockets_data(io1.to("room2")).await; + assert_eq!(sockets, create_expected_sockets([id1, id2], [&io1, &io2])); + + let sockets = fetch_sockets_data(io1.to("room3")).await; + assert_eq!(sockets, create_expected_sockets([id2, id3], [&io2, &io3])); +} + +#[tokio::test] +pub async fn fetch_sockets_timeout() { + const TIMEOUT: Duration = Duration::from_millis(50); + let [io1, io2] = fixture::spawn_buggy_servers(TIMEOUT); + + io1.ns("/", || ()).await.unwrap(); + io2.ns("/", || ()).await.unwrap(); + + let (_, mut rx1) = io1.new_dummy_sock("/", ()).await; + let (_, mut rx2) = io2.new_dummy_sock("/", ()).await; + + timeout_rcv!(&mut rx1); // connect packet + timeout_rcv!(&mut rx2); // connect packet + + let now = Instant::now(); + io1.fetch_sockets().await.unwrap(); + assert!(now.elapsed() >= TIMEOUT); +} + +#[tokio::test] +pub async fn remote_socket_emit() { + let [io1, io2] = fixture::spawn_servers(); + + io1.ns("/", || ()).await.unwrap(); + io2.ns("/", || ()).await.unwrap(); + + let (_, mut rx1) = io1.new_dummy_sock("/", ()).await; + let (_, mut rx2) = io2.new_dummy_sock("/", ()).await; + + timeout_rcv!(&mut rx1); // connect packet + timeout_rcv!(&mut rx2); // connect packet + + let sockets = io1.fetch_sockets().await.unwrap(); + for socket in sockets { + socket.emit("test", "hello").await.unwrap(); + } + + assert_eq!(timeout_rcv!(&mut rx1), r#"42["test","hello"]"#); + assert_eq!(timeout_rcv!(&mut rx2), r#"42["test","hello"]"#); +} + +#[tokio::test] +pub async fn remote_socket_emit_with_ack() { + let [io1, io2] = fixture::spawn_servers(); + + io1.ns("/", || ()).await.unwrap(); + io2.ns("/", || ()).await.unwrap(); + + let (_, mut rx1) = io1.new_dummy_sock("/", ()).await; + let (_, mut rx2) = io2.new_dummy_sock("/", ()).await; + + timeout_rcv!(&mut rx1); // connect packet + timeout_rcv!(&mut rx2); // connect packet + + let sockets = io1.fetch_sockets().await.unwrap(); + for socket in sockets { + #[allow(unused_must_use)] + socket + .emit_with_ack::<_, ()>("test", "hello") + .await + .unwrap(); + } + + assert_eq!(timeout_rcv!(&mut rx1), r#"421["test","hello"]"#); + assert_eq!(timeout_rcv!(&mut rx2), r#"421["test","hello"]"#); +} diff --git a/crates/socketioxide/Cargo.toml b/crates/socketioxide/Cargo.toml index a7535392..0bdc49b7 100644 --- a/crates/socketioxide/Cargo.toml +++ b/crates/socketioxide/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "socketioxide" description = "Socket IO server implementation in rust as a Tower Service." -version.workspace = true +version = "0.15.1" edition.workspace = true rust-version.workspace = true authors.workspace = true @@ -13,8 +13,8 @@ license.workspace = true readme.workspace = true [dependencies] -engineioxide = { path = "../engineioxide", version = "0.15.1" } -socketioxide-core = { path = "../socketioxide-core", version = "0.15.1" } +engineioxide = { path = "../engineioxide", version = "0.15" } +socketioxide-core = { path = "../socketioxide-core", version = "0.15" } bytes.workspace = true futures-core.workspace = true @@ -29,8 +29,7 @@ thiserror.workspace = true hyper.workspace = true matchit.workspace = true pin-project-lite.workspace = true - -rustversion = "1.0.18" +rustversion.workspace = true # Parsers socketioxide-parser-common = { path = "../parser-common", version = "0.15.1" } diff --git a/crates/socketioxide/docs/operators/broadcast.md b/crates/socketioxide/docs/operators/broadcast.md index 01121aef..b8dad6d6 100644 --- a/crates/socketioxide/docs/operators/broadcast.md +++ b/crates/socketioxide/docs/operators/broadcast.md @@ -1,5 +1,5 @@ # Broadcast to all sockets without any filtering (except the current socket). -If you want to include the current socket use emit operators from the [`io`] global context. +If you want to include the current socket use the broadcast operators from the [`io`] global context. [`io`]: crate::SocketIo @@ -7,11 +7,11 @@ If you want to include the current socket use emit operators from the [`io`] glo ```rust # use socketioxide::{SocketIo, extract::*}; # use serde_json::Value; -fn handler(io: SocketIo, socket: SocketRef, Data(data): Data::) { +async fn handler(io: SocketIo, socket: SocketRef, Data(data): Data::) { // This message will be broadcast to all sockets in this namespace except this one. - socket.broadcast().emit("test", &data); + socket.broadcast().emit("test", &data).await; // This message will be broadcast to all sockets in this namespace, including this one. - io.emit("test", &data); + io.emit("test", &data).await; } let (_, io) = SocketIo::new_svc(); diff --git a/crates/socketioxide/docs/operators/disconnect.md b/crates/socketioxide/docs/operators/disconnect.md index e7af7cac..e6e23b1a 100644 --- a/crates/socketioxide/docs/operators/disconnect.md +++ b/crates/socketioxide/docs/operators/disconnect.md @@ -1,11 +1,25 @@ # Disconnect all sockets selected with the previous operators. -# Example +This will return a `Future` that must be awaited because socket.io may communicate with remote instances +if you use horizontal scaling through remote adapters. + +# Example from a socket +```rust +# use socketioxide::{SocketIo, extract::*}; +async fn handler(socket: SocketRef) { + // Disconnect all sockets in the room1 and room3 rooms, except for room2. + socket.within("room1").within("room3").except("room2").disconnect().await.unwrap(); +} +let (_, io) = SocketIo::new_svc(); +io.ns("/", |s: SocketRef| s.on("test", handler)); +``` + +# Example from the io struct ```rust # use socketioxide::{SocketIo, extract::*}; -fn handler(socket: SocketRef) { +async fn handler(socket: SocketRef, io: SocketIo) { // Disconnect all sockets in the room1 and room3 rooms, except for room2. - socket.within("room1").within("room3").except("room2").disconnect().unwrap(); + io.within("room1").within("room3").except("room2").disconnect().await.unwrap(); } let (_, io) = SocketIo::new_svc(); io.ns("/", |s: SocketRef| s.on("test", handler)); diff --git a/crates/socketioxide/docs/operators/emit.md b/crates/socketioxide/docs/operators/emit.md index d9423cfa..9e81681c 100644 --- a/crates/socketioxide/docs/operators/emit.md +++ b/crates/socketioxide/docs/operators/emit.md @@ -44,9 +44,7 @@ fn handler(socket: SocketRef, Data(data): Data::) { } let (_, io) = SocketIo::new_svc(); -io.ns("/", |socket: SocketRef| { - socket.on("test", handler); -}); +io.ns("/", |socket: SocketRef| socket.on("test", handler)); ``` # Single-socket binary example with the `bytes` crate @@ -66,27 +64,29 @@ fn handler(socket: SocketRef, Data(data): Data::<(String, Bytes, Bytes)>) { } let (_, io) = SocketIo::new_svc(); -io.ns("/", |socket: SocketRef| { - socket.on("test", handler); -}); +io.ns("/", |socket: SocketRef| socket.on("test", handler)); ``` # Broadcast example + +Here the emit method will return a `Future` that must be awaited because socket.io may communicate +with remote instances if you use horizontal scaling through remote adapters. + ```rust # use socketioxide::{SocketIo, extract::*}; # use serde_json::Value; # use std::sync::Arc; # use bytes::Bytes; -fn handler(socket: SocketRef, Data(data): Data::<(String, Bytes, Bytes)>) { +async fn handler(socket: SocketRef, Data(data): Data::<(String, Bytes, Bytes)>) { // Emit a test message in the room1 and room3 rooms, except for room2, with the received binary payload - socket.to("room1").to("room3").except("room2").emit("test", &data); + socket.to("room1").to("room3").except("room2").emit("test", &data).await; // Emit a test message with multiple arguments to the client - socket.to("room1").emit("test", &("world", "hello", 1)).ok(); + socket.to("room1").emit("test", &("world", "hello", 1)).await; // Emit a test message with an array as the first argument let arr = [1, 2, 3, 4]; - socket.to("room2").emit("test", &[arr]).ok(); + socket.to("room2").emit("test", &[arr]).await; } let (_, io) = SocketIo::new_svc(); diff --git a/crates/socketioxide/docs/operators/emit_with_ack.md b/crates/socketioxide/docs/operators/emit_with_ack.md index 1f6e9e67..8338b4f7 100644 --- a/crates/socketioxide/docs/operators/emit_with_ack.md +++ b/crates/socketioxide/docs/operators/emit_with_ack.md @@ -1,4 +1,4 @@ -# Emit a message to the client and wait for one or more acknowledgments. +# Emit a message to one or many clients and wait for one or more acknowledgments. See [`emit()`](#method.emit) for more details on emitting messages. @@ -10,7 +10,7 @@ To receive acknowledgments, an [`AckStream`] is returned. It can be used in two * As a [`Future`]: This will yield the first acknowledgment response received from the client, useful when expecting only one acknowledgment. # Errors -If packet encoding fails, an [`EncodeError`] is **immediately** returned. +If packet encoding fails, an [`ParserError`] is **immediately** returned. If the socket is full or if it is closed before receiving the acknowledgment, a [`SendError::Socket`] will be **immediately** returned, and the value to send will be given back. @@ -28,7 +28,7 @@ an [`AckError::Decode`] will be yielded. [`AckError::Socket`]: crate::AckError::Socket [`AckError::Socket(SocketError::Closed)`]: crate::SocketError::Closed [`SendError::Socket`]: crate::SendError::Socket -[`EncodeError`]: crate::EncodeError +[`ParserError`]: crate::ParserError [`io::get_socket()`]: crate::SocketIo#method.get_socket # Single-socket example @@ -45,9 +45,7 @@ async fn handler(socket: SocketRef, Data(data): Data::) { } let (_, io) = SocketIo::new_svc(); -io.ns("/", |socket: SocketRef| { - socket.on("test", handler); -}); +io.ns("/", |socket: SocketRef| socket.on("test", handler)); ``` # Single-socket example with custom acknowledgment timeout @@ -65,12 +63,14 @@ async fn handler(socket: SocketRef, Data(data): Data::) { } let (_, io) = SocketIo::new_svc(); -io.ns("/", |socket: SocketRef| { - socket.on("test", handler); -}); +io.ns("/", |socket: SocketRef| socket.on("test", handler)); ``` # Broadcast example + +Here the emit method will return a `Future` that must be awaited because socket.io may communicate +with remote instances if you use horizontal scaling through remote adapters. + ```rust # use socketioxide::{SocketIo, extract::*}; # use serde_json::Value; @@ -82,6 +82,7 @@ async fn handler(socket: SocketRef, Data(data): Data::) { .to("room3") .except("room2") .emit_with_ack::<_, String>("message-back", &data) + .await .unwrap(); ack_stream.for_each(|(id, ack)| async move { match ack { diff --git a/crates/socketioxide/docs/operators/except.md b/crates/socketioxide/docs/operators/except.md index 6b2e20df..c7692dea 100644 --- a/crates/socketioxide/docs/operators/except.md +++ b/crates/socketioxide/docs/operators/except.md @@ -4,16 +4,16 @@ ```rust # use socketioxide::{SocketIo, extract::*}; # use serde_json::Value; -fn handler(socket: SocketRef, Data(data): Data::) { +async fn handler(socket: SocketRef, Data(data): Data::) { // This message will be broadcast to all sockets in the namespace, // except for those in room1 and the current socket - socket.broadcast().except("room1").emit("test", &data); + socket.broadcast().except("room1").emit("test", &data).await; } let (_, io) = SocketIo::new_svc(); io.ns("/", |socket: SocketRef| { - socket.on("register1", |s: SocketRef| s.join("room1").unwrap()); - socket.on("register2", |s: SocketRef| s.join("room2").unwrap()); + socket.on("register1", |s: SocketRef| s.join("room1")); + socket.on("register2", |s: SocketRef| s.join("room2")); socket.on("test", handler); }); ``` diff --git a/crates/socketioxide/docs/operators/fetch_sockets.md b/crates/socketioxide/docs/operators/fetch_sockets.md new file mode 100644 index 00000000..9eaed38f --- /dev/null +++ b/crates/socketioxide/docs/operators/fetch_sockets.md @@ -0,0 +1,43 @@ +# Get all the local and remote sockets selected with the previous operators. + +
+ Use sockets() if you only have a single node. +
+ +Avoid using this method if you want to immediately perform actions on the sockets. +Instead, directly apply the actions using operators: + +## Correct Approach +```rust +# use socketioxide::{SocketIo, extract::*}; +# async fn doc_main() { +# let (_, io) = SocketIo::new_svc(); +io.within("room1").emit("foo", "bar").await.unwrap(); +io.within("room1").disconnect().await.unwrap(); +# } +``` + +## Incorrect Approach +```rust +# use socketioxide::{SocketIo, extract::*}; +# async fn doc_main() { +# let (_, io) = SocketIo::new_svc(); +let sockets = io.within("room1").fetch_sockets().await.unwrap(); +for socket in sockets { + socket.emit("test", &"Hello").await.unwrap(); + socket.leave("room1").await.unwrap(); +} +# } +``` + +# Example +```rust +# use socketioxide::{SocketIo, extract::*}; +# async fn doc_main() { +let (_, io) = SocketIo::new_svc(); +let sockets = io.within("room1").fetch_sockets().await.unwrap(); +for socket in sockets { + println!("Socket ID: {:?}", socket.data().id); +} +# } +``` diff --git a/crates/socketioxide/docs/operators/get_socket.md b/crates/socketioxide/docs/operators/get_socket.md index 7c3887a6..715f57e8 100644 --- a/crates/socketioxide/docs/operators/get_socket.md +++ b/crates/socketioxide/docs/operators/get_socket.md @@ -1 +1,5 @@ -# Get a [`SocketRef`] by the specified [`Sid`]. +# Get a local [`SocketRef`] by the specified [`Sid`]. + +
+ This will only work for local sockets. Use fetch_socket to get remote sockets. +
diff --git a/crates/socketioxide/docs/operators/join.md b/crates/socketioxide/docs/operators/join.md index e6f860cd..8249f14d 100644 --- a/crates/socketioxide/docs/operators/join.md +++ b/crates/socketioxide/docs/operators/join.md @@ -1,12 +1,16 @@ # Add all sockets selected with the previous operators to the specified room(s). +This will return a `Future` that must be awaited because socket.io may communicate with remote instances +if you use horizontal scaling through remote adapters. + # Example ```rust # use socketioxide::{SocketIo, extract::*}; -fn handler(socket: SocketRef) { +async fn handler(socket: SocketRef) { // Add all sockets that are in room1 and room3 to room4 and room5 - socket.within("room1").within("room3").join(["room4", "room5"]).unwrap(); - let sockets = socket.within("room4").within("room5").sockets().unwrap(); + socket.within("room1").within("room3").join(["room4", "room5"]).await.unwrap(); + // We should retrieve all the local sockets that are in room3 and room5 + let sockets = socket.within("room4").within("room5").sockets(); } let (_, io) = SocketIo::new_svc(); diff --git a/crates/socketioxide/docs/operators/leave.md b/crates/socketioxide/docs/operators/leave.md index 6a8db35c..53251cf1 100644 --- a/crates/socketioxide/docs/operators/leave.md +++ b/crates/socketioxide/docs/operators/leave.md @@ -1,12 +1,14 @@ # Remove all sockets selected with the previous operators from the specified room(s). +This will return a `Future` that must be awaited because socket.io may communicate with remote instances +if you use horizontal scaling through remote adapters. + # Example ```rust # use socketioxide::{SocketIo, extract::*}; -fn handler(socket: SocketRef) { +async fn handler(socket: SocketRef) { // Remove all sockets that are in room1 and room3 from room4 and room5 - socket.within("room1").within("room3").leave(["room4", "room5"]).unwrap(); - let sockets = socket.within("room4").within("room5").sockets().unwrap(); + socket.within("room1").within("room3").leave(["room4", "room5"]).await.unwrap(); } let (_, io) = SocketIo::new_svc(); diff --git a/crates/socketioxide/docs/operators/local.md b/crates/socketioxide/docs/operators/local.md index 2c3cff4c..cfec3cd7 100644 --- a/crates/socketioxide/docs/operators/local.md +++ b/crates/socketioxide/docs/operators/local.md @@ -5,10 +5,10 @@ When using the default in-memory adapter, this operator is a no-op. ```rust # use socketioxide::{SocketIo, extract::*}; # use serde_json::Value; -fn handler(socket: SocketRef, Data(data): Data::) { +async fn handler(socket: SocketRef, Data(data): Data::) { // This message will be broadcast to all sockets in this // namespace that are connected to this node - socket.local().emit("test", &data); + socket.local().emit("test", &data).await; } let (_, io) = SocketIo::new_svc(); diff --git a/crates/socketioxide/docs/operators/rooms.md b/crates/socketioxide/docs/operators/rooms.md index 461f5b9e..84cec496 100644 --- a/crates/socketioxide/docs/operators/rooms.md +++ b/crates/socketioxide/docs/operators/rooms.md @@ -1,11 +1,14 @@ -# Get all room names in the current namespace. +# Get all the rooms selected with the previous operators. + +This will return a `Future` that must be awaited because socket.io may communicate with remote instances +if you use horizontal scaling through remote adapters. # Example ```rust # use socketioxide::{SocketIo, extract::SocketRef}; -fn handler(socket: SocketRef, io: SocketIo) { +async fn handler(socket: SocketRef, io: SocketIo) { println!("Socket connected to the / namespace with id: {}", socket.id); - let rooms = io.rooms().unwrap(); + let rooms = io.rooms().await.unwrap(); println!("All rooms in the / namespace: {:?}", rooms); } diff --git a/crates/socketioxide/docs/operators/sockets.md b/crates/socketioxide/docs/operators/sockets.md index 6e3721b0..54554d34 100644 --- a/crates/socketioxide/docs/operators/sockets.md +++ b/crates/socketioxide/docs/operators/sockets.md @@ -1,15 +1,19 @@ -# Get all sockets selected with the previous operators. +# Get all the *local* sockets selected with the previous operators. This can be used to retrieve any extension data (with the `extensions` feature enabled) from the sockets or to make certain sockets join other rooms. +
+ This will only work for local sockets. Use fetch_sockets to get remote sockets. +
+ # Example ```rust # use socketioxide::{SocketIo, extract::*}; -fn handler(socket: SocketRef) { +async fn handler(socket: SocketRef) { // Find extension data in each socket in the room1 and room3 rooms, except for room2 - let sockets = socket.within("room1").within("room3").except("room2").sockets().unwrap(); + let sockets = socket.within("room1").within("room3").except("room2").sockets(); for socket in sockets { - println!("Socket custom string: {:?}", socket.extensions.get::()); + println!("Socket extension: {:?}", socket.extensions.get::()); } } diff --git a/crates/socketioxide/docs/operators/timeout.md b/crates/socketioxide/docs/operators/timeout.md index 114f796a..f3cf443b 100644 --- a/crates/socketioxide/docs/operators/timeout.md +++ b/crates/socketioxide/docs/operators/timeout.md @@ -17,6 +17,7 @@ async fn handler(socket: SocketRef, Data(data): Data::) { .except("room2") .timeout(Duration::from_secs(5)) .emit_with_ack::<_, Value>("message-back", &data) + .await .unwrap() .for_each(|(id, ack)| async move { match ack { diff --git a/crates/socketioxide/docs/operators/to.md b/crates/socketioxide/docs/operators/to.md index 48eabeaf..0af12fa0 100644 --- a/crates/socketioxide/docs/operators/to.md +++ b/crates/socketioxide/docs/operators/to.md @@ -15,13 +15,15 @@ async fn handler(socket: SocketRef, io: SocketIo, Data(data): Data::) { socket .to("room1") .to(["room2", "room3"]) - .emit("test", &data); + .emit("test", &data) + .await; // Emit a message to all sockets in room1, room2, room3, and room4, including the current socket io .to("room1") .to(["room2", "room3"]) - .emit("test", &data); + .emit("test", &data) + .await; } let (_, io) = SocketIo::new_svc(); diff --git a/crates/socketioxide/docs/operators/within.md b/crates/socketioxide/docs/operators/within.md index dba3e8d3..d457e1b2 100644 --- a/crates/socketioxide/docs/operators/within.md +++ b/crates/socketioxide/docs/operators/within.md @@ -18,7 +18,8 @@ async fn handler(socket: SocketRef, Data(data): Data::) { .within("room1") .within(["room2", "room3"]) .within(vec![other_rooms]) - .emit("test", &data); + .emit("test", &data) + .await; } let (_, io) = SocketIo::new_svc(); diff --git a/crates/socketioxide/src/ack.rs b/crates/socketioxide/src/ack.rs index d1e8aff3..de281634 100644 --- a/crates/socketioxide/src/ack.rs +++ b/crates/socketioxide/src/ack.rs @@ -1,10 +1,7 @@ //! Acknowledgement related types and functions. -//! -//! Here is the main type: -//! -//! - [`AckStream`]: A [`Stream`]/[`Future`] of data received from the client. use std::{ pin::Pin, + sync::Arc, task::{Context, Poll}, time::Duration, }; @@ -16,19 +13,19 @@ use serde::de::DeserializeOwned; use tokio::{sync::oneshot::Receiver, time::Timeout}; use crate::{ - adapter::Adapter, - errors::{AckError, SocketError}, - extract::SocketRef, - packet::Packet, + adapter::{Adapter, LocalAdapter}, + errors::AckError, parser::Parser, + socket::Socket, + SocketError, }; -use socketioxide_core::{parser::Parse, Value}; +use socketioxide_core::{packet::Packet, parser::Parse, Value}; pub(crate) type AckResult = Result; pin_project_lite::pin_project! { /// A [`Future`] of [`AckResponse`] received from the client with its corresponding [`Sid`]. /// It is used internally by [`AckStream`] and **should not** be used directly. - pub struct AckResultWithId { + struct AckResultWithId { id: Sid, #[pin] result: Timeout>>, @@ -94,24 +91,24 @@ pin_project_lite::pin_project! { /// /// // We apply the `for_each` StreamExt fn to the AckStream /// socket.broadcast().emit_with_ack::<_, String>("test", "test") + /// .await /// .unwrap() /// .for_each(|(id, ack)| async move { println!("Ack: {} {:?}", id, ack); }).await; /// }); /// ``` #[must_use = "futures and streams do nothing unless you `.await` or poll them"] - pub struct AckStream { + pub struct AckStream { #[pin] - inner: AckInnerStream, + inner: A::AckStream, parser: Parser, _marker: std::marker::PhantomData, } } pin_project_lite::pin_project! { - #[allow(missing_docs)] + #[doc(hidden)] #[project = InnerProj] - /// An internal stream used by [`AckStream`]. It should not be used directly except when implementing the - /// [`Adapter`](crate::adapter::Adapter) trait. + /// An internal stream used by [`AckStream`]. pub enum AckInnerStream { Stream { #[pin] @@ -129,33 +126,37 @@ pin_project_lite::pin_project! { // ==== impl AckInnerStream ==== impl AckInnerStream { + /// Creates a new empty [`AckInnerStream`] that will yield no value. + pub fn empty() -> Self { + AckInnerStream::Stream { + rxs: FuturesUnordered::new(), + } + } + /// Creates a new [`AckInnerStream`] from a [`Packet`] and a list of sockets. /// The [`Packet`] is sent to all the sockets and the [`AckInnerStream`] will wait /// for an acknowledgement from each socket. /// /// The [`AckInnerStream`] will wait for the default timeout specified in the config /// (5s by default) if no custom timeout is specified. - pub fn broadcast( + pub fn broadcast<'a, A: Adapter>( packet: Packet, - sockets: Vec>, - duration: Option, - ) -> Self { + sockets: impl Iterator>>, + duration: Duration, + ) -> (Self, u32) { let rxs = FuturesUnordered::new(); - - if sockets.is_empty() { - return AckInnerStream::Stream { rxs }; - } - - let duration = - duration.unwrap_or_else(|| sockets.first().unwrap().get_io().config().ack_timeout); + let mut count = 0; for socket in sockets { let rx = socket.send_with_ack(packet.clone()); rxs.push(AckResultWithId { result: tokio::time::timeout(duration, rx), id: socket.id, }); + count += 1; } - AckInnerStream::Stream { rxs } + #[cfg(feature = "tracing")] + tracing::debug!("broadcast with ack to {count} sockets"); + (AckInnerStream::Stream { rxs }, count) } /// Creates a new [`AckInnerStream`] from a [`oneshot::Receiver`](tokio) corresponding to the acknowledgement @@ -209,33 +210,9 @@ impl FusedStream for AckInnerStream { } } -impl Future for AckInnerStream { - type Output = AckResult; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.as_mut().poll_next(cx) { - Poll::Ready(Some(v)) => Poll::Ready(v.1), - Poll::Pending => Poll::Pending, - Poll::Ready(None) => { - unreachable!("stream should at least yield 1 value") - } - } - } -} - -impl FusedFuture for AckInnerStream { - fn is_terminated(&self) -> bool { - use AckInnerStream::*; - match self { - Stream { rxs, .. } => rxs.is_terminated(), - Fut { polled, .. } => *polled, - } - } -} - // ==== impl AckStream ==== -impl AckStream { - pub(crate) fn new(inner: AckInnerStream, parser: Parser) -> Self { +impl AckStream { + pub(crate) fn new(inner: A::AckStream, parser: Parser) -> Self { AckStream { inner, parser, @@ -244,7 +221,7 @@ impl AckStream { } } -impl Stream for AckStream { +impl Stream for AckStream { type Item = (Sid, AckResult); #[inline] @@ -262,30 +239,33 @@ impl Stream for AckStream { } } -impl FusedStream for AckStream { +impl FusedStream for AckStream { #[inline(always)] fn is_terminated(&self) -> bool { FusedStream::is_terminated(&self.inner) } } -impl Future for AckStream { +impl Future for AckStream { type Output = AckResult; #[inline] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let parser = self.parser; - self.project() - .inner - .poll(cx) - .map(|v| map_ack_response(v, parser)) + match self.project().inner.poll_next(cx) { + Poll::Ready(Some(v)) => Poll::Ready(map_ack_response(v.1, parser)), + Poll::Pending => Poll::Pending, + Poll::Ready(None) => { + unreachable!("stream should at least yield 1 value") + } + } } } -impl FusedFuture for AckStream { +impl FusedFuture for AckStream { #[inline(always)] fn is_terminated(&self) -> bool { - FusedFuture::is_terminated(&self.inner) + FusedStream::is_terminated(&self.inner) } } @@ -319,11 +299,12 @@ mod test { fn value(data: impl serde::Serialize) -> Value { CommonParser.encode_value(&data, None).unwrap() } - impl From for AckStream { + impl From for AckStream { fn from(val: AckInnerStream) -> Self { Self::new(val, Parser::default()) } } + const TIMEOUT: Duration = Duration::from_secs(5); #[tokio::test] async fn broadcast_ack() { @@ -331,8 +312,11 @@ mod test { let socket2 = create_socket(); let mut packet = get_packet(); packet.inner.set_ack_id(1); - let socks = vec![socket.clone().into(), socket2.clone().into()]; - let stream: AckStream = AckInnerStream::broadcast(packet, socks, None).into(); + let socks = vec![&socket, &socket2]; + let stream: AckStream = + AckInnerStream::broadcast(packet, socks.into_iter(), TIMEOUT) + .0 + .into(); let res_packet = Packet::ack("test", value("test"), 1); socket.recv(res_packet.inner.clone()).unwrap(); @@ -349,7 +333,7 @@ mod test { async fn ack_stream() { let (tx, rx) = tokio::sync::oneshot::channel(); let sid = Sid::new(); - let stream: AckStream = + let stream: AckStream = AckInnerStream::send(rx, Duration::from_secs(1), sid).into(); tx.send(Ok(value("test"))).unwrap(); @@ -363,7 +347,7 @@ mod test { async fn ack_fut() { let (tx, rx) = tokio::sync::oneshot::channel(); let sid = Sid::new(); - let stream: AckStream = + let stream: AckStream = AckInnerStream::send(rx, Duration::from_secs(1), sid).into(); tx.send(Ok(value("test"))).unwrap(); @@ -376,8 +360,11 @@ mod test { let socket2 = create_socket(); let mut packet = get_packet(); packet.inner.set_ack_id(1); - let socks = vec![socket.clone().into(), socket2.clone().into()]; - let stream: AckStream = AckInnerStream::broadcast(packet, socks, None).into(); + let socks = vec![&socket, &socket2]; + let stream: AckStream = + AckInnerStream::broadcast(packet, socks.into_iter(), TIMEOUT) + .0 + .into(); let res_packet = Packet::ack("test", value(132), 1); socket.recv(res_packet.inner.clone()).unwrap(); @@ -400,7 +387,7 @@ mod test { async fn ack_stream_with_deserialize_error() { let (tx, rx) = tokio::sync::oneshot::channel(); let sid = Sid::new(); - let stream: AckStream = + let stream: AckStream = AckInnerStream::send(rx, Duration::from_secs(1), sid).into(); tx.send(Ok(value(true))).unwrap(); assert_eq!(stream.size_hint().0, 1); @@ -419,7 +406,7 @@ mod test { async fn ack_fut_with_deserialize_error() { let (tx, rx) = tokio::sync::oneshot::channel(); let sid = Sid::new(); - let stream: AckStream = + let stream: AckStream = AckInnerStream::send(rx, Duration::from_secs(1), sid).into(); tx.send(Ok(value(true))).unwrap(); @@ -432,8 +419,11 @@ mod test { let socket2 = create_socket(); let mut packet = get_packet(); packet.inner.set_ack_id(1); - let socks = vec![socket.clone().into(), socket2.clone().into()]; - let stream: AckStream = AckInnerStream::broadcast(packet, socks, None).into(); + let socks = vec![&socket, &socket2]; + let stream: AckStream = + AckInnerStream::broadcast(packet, socks.into_iter(), TIMEOUT) + .0 + .into(); let res_packet = Packet::ack("test", value("test"), 1); socket.clone().recv(res_packet.inner.clone()).unwrap(); @@ -455,7 +445,7 @@ mod test { async fn ack_stream_with_closed_socket() { let (tx, rx) = tokio::sync::oneshot::channel(); let sid = Sid::new(); - let stream: AckStream = + let stream: AckStream = AckInnerStream::send(rx, Duration::from_secs(1), sid).into(); drop(tx); @@ -471,7 +461,7 @@ mod test { async fn ack_fut_with_closed_socket() { let (tx, rx) = tokio::sync::oneshot::channel(); let sid = Sid::new(); - let stream: AckStream = + let stream: AckStream = AckInnerStream::send(rx, Duration::from_secs(1), sid).into(); drop(tx); @@ -487,9 +477,11 @@ mod test { let socket2 = create_socket(); let mut packet = get_packet(); packet.inner.set_ack_id(1); - let socks = vec![socket.clone().into(), socket2.clone().into()]; - let stream: AckStream = - AckInnerStream::broadcast(packet, socks, Some(Duration::from_millis(10))).into(); + let socks = vec![&socket, &socket2]; + let stream: AckStream = + AckInnerStream::broadcast(packet, socks.into_iter(), Duration::from_millis(10)) + .0 + .into(); socket .recv(Packet::ack("test", value("test"), 1).inner) @@ -509,7 +501,7 @@ mod test { async fn ack_stream_with_timeout() { let (_tx, rx) = tokio::sync::oneshot::channel(); let sid = Sid::new(); - let stream: AckStream = + let stream: AckStream = AckInnerStream::send(rx, Duration::from_millis(10), sid).into(); futures_util::pin_mut!(stream); @@ -524,7 +516,7 @@ mod test { async fn ack_fut_with_timeout() { let (_tx, rx) = tokio::sync::oneshot::channel(); let sid = Sid::new(); - let stream: AckStream = + let stream: AckStream = AckInnerStream::send(rx, Duration::from_millis(10), sid).into(); assert!(matches!(stream.await.unwrap_err(), AckError::Timeout)); diff --git a/crates/socketioxide/src/adapter.rs b/crates/socketioxide/src/adapter.rs index 67fee0bc..ef923d3f 100644 --- a/crates/socketioxide/src/adapter.rs +++ b/crates/socketioxide/src/adapter.rs @@ -3,585 +3,50 @@ //! The default adapter is the [`LocalAdapter`], which stores the state in memory. //! Other adapters can be made to share the state between multiple servers. -use std::{ - borrow::Cow, - collections::{HashMap, HashSet}, - convert::Infallible, - sync::{RwLock, Weak}, - time::Duration, -}; - -use engineioxide::sid::Sid; - -use crate::{ - ack::AckInnerStream, - errors::{AdapterError, BroadcastError}, - extract::SocketRef, - ns::Namespace, - operators::RoomParam, +use socketioxide_core::{ + adapter::{BroadcastOptions, CoreAdapter, CoreLocalAdapter, DefinedAdapter, SocketEmitter}, packet::Packet, - DisconnectError, }; +use std::{convert::Infallible, sync::Arc, time::Duration}; -/// A room identifier -pub type Room = Cow<'static, str>; - -/// Flags that can be used to modify the behavior of the broadcast methods. -#[derive(Clone, Debug, Hash, PartialEq, Eq)] -pub enum BroadcastFlags { - /// Broadcast only to the current server - Local, - /// Broadcast to all clients except the sender - Broadcast, -} - -/// Options that can be used to modify the behavior of the broadcast methods. -#[derive(Clone, Debug, Default)] -pub struct BroadcastOptions { - /// The flags to apply to the broadcast. - pub flags: HashSet, - /// The rooms to broadcast to. - pub rooms: HashSet, - /// The rooms to exclude from the broadcast. - pub except: HashSet, - /// The socket id of the sender. - pub sid: Option, -} -//TODO: Make an AsyncAdapter trait -/// An adapter is responsible for managing the state of the server. +pub use crate::ns::Emitter; +pub use socketioxide_core::errors::AdapterError; +/// An adapter is responsible for managing the state of the namespace. /// This adapter can be implemented to share the state between multiple servers. /// The default adapter is the [`LocalAdapter`], which stores the state in memory. -pub trait Adapter: std::fmt::Debug + Send + Sync + 'static { - /// An error that can occur when using the adapter. The default [`LocalAdapter`] has an [`Infallible`] error. - type Error: std::error::Error + Into + Send + Sync + 'static; - - /// Create a new adapter and give the namespace ref to retrieve sockets. - fn new(ns: Weak>) -> Self - where - Self: Sized; - - /// Initializes the adapter. - fn init(&self) -> Result<(), Self::Error>; - /// Closes the adapter. - fn close(&self) -> Result<(), Self::Error>; - - /// Returns the number of servers. - fn server_count(&self) -> Result; - - /// Adds the socket to all the rooms. - fn add_all(&self, sid: Sid, rooms: impl RoomParam) -> Result<(), Self::Error>; - /// Removes the socket from the rooms. - fn del(&self, sid: Sid, rooms: impl RoomParam) -> Result<(), Self::Error>; - /// Removes the socket from all the rooms. - fn del_all(&self, sid: Sid) -> Result<(), Self::Error>; - - /// Broadcasts the packet to the sockets that match the [`BroadcastOptions`]. - fn broadcast(&self, packet: Packet, opts: BroadcastOptions) -> Result<(), BroadcastError>; - - /// Broadcasts the packet to the sockets that match the [`BroadcastOptions`] and return a stream of ack responses. - fn broadcast_with_ack( - &self, - packet: Packet, - opts: BroadcastOptions, - timeout: Option, - ) -> AckInnerStream; - - /// Returns the sockets ids that match the [`BroadcastOptions`]. - fn sockets(&self, rooms: impl RoomParam) -> Result, Self::Error>; - - /// Returns the rooms of the socket. - fn socket_rooms(&self, sid: Sid) -> Result, Self::Error>; - - /// Returns the sockets that match the [`BroadcastOptions`]. - fn fetch_sockets(&self, opts: BroadcastOptions) -> Result>, Self::Error> - where - Self: Sized; +pub trait Adapter: CoreAdapter + Sized {} +impl> Adapter for T {} - /// Adds the sockets that match the [`BroadcastOptions`] to the rooms. - fn add_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) - -> Result<(), Self::Error>; - /// Removes the sockets that match the [`BroadcastOptions`] from the rooms. - fn del_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) - -> Result<(), Self::Error>; - - /// Disconnects the sockets that match the [`BroadcastOptions`]. - fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), Vec>; - - /// Returns all the rooms for this adapter. - fn rooms(&self) -> Result, Self::Error>; - - //TODO: implement - // fn server_side_emit(&self, packet: Packet, opts: BroadcastOptions) -> Result; - // fn persist_session(&self, sid: i64); - // fn restore_session(&self, sid: i64) -> Session; -} +// === LocalAdapter impls === /// The default adapter. Store the state in memory. -#[derive(Debug)] -pub struct LocalAdapter { - rooms: RwLock>>, - ns: Weak>, -} +pub struct LocalAdapter(CoreLocalAdapter); -impl From for AdapterError { - fn from(_: Infallible) -> AdapterError { - unreachable!() - } -} - -impl Adapter for LocalAdapter { +impl CoreAdapter for LocalAdapter { type Error = Infallible; + type State = (); + type AckStream = ::AckStream; + type InitRes = (); - fn new(ns: Weak>) -> Self { - Self { - rooms: HashMap::new().into(), - ns, - } - } - - fn init(&self) -> Result<(), Infallible> { - Ok(()) + fn new(_state: &Self::State, local: CoreLocalAdapter) -> Self { + Self(local) } - fn close(&self) -> Result<(), Infallible> { - #[cfg(feature = "tracing")] - tracing::debug!("closing local adapter: {}", self.ns.upgrade().unwrap().path); - let mut rooms = self.rooms.write().unwrap(); - rooms.clear(); - rooms.shrink_to_fit(); - Ok(()) + fn init(self: Arc, on_success: impl FnOnce() + Send + 'static) -> Self::InitRes { + on_success(); } - fn server_count(&self) -> Result { - Ok(1) - } - - fn add_all(&self, sid: Sid, rooms: impl RoomParam) -> Result<(), Infallible> { - let mut rooms_map = self.rooms.write().unwrap(); - for room in rooms.into_room_iter() { - rooms_map.entry(room).or_default().insert(sid); - } - Ok(()) - } - - fn del(&self, sid: Sid, rooms: impl RoomParam) -> Result<(), Infallible> { - let mut rooms_map = self.rooms.write().unwrap(); - for room in rooms.into_room_iter() { - if let Some(room) = rooms_map.get_mut(&room) { - room.remove(&sid); - } - } - Ok(()) - } - - fn del_all(&self, sid: Sid) -> Result<(), Infallible> { - let mut rooms_map = self.rooms.write().unwrap(); - for room in rooms_map.values_mut() { - room.remove(&sid); - } - Ok(()) - } - - fn broadcast(&self, packet: Packet, opts: BroadcastOptions) -> Result<(), BroadcastError> { - use socketioxide_core::parser::Parse; - let sockets = self.apply_opts(opts); - - #[cfg(feature = "tracing")] - tracing::debug!("broadcasting packet to {} sockets", sockets.len()); - let parser = match sockets.first() { - Some(socket) => socket.parser(), - None => return Ok(()), - }; - let data = parser.encode(packet); - let errors: Vec<_> = sockets - .into_iter() - .filter_map(|socket| socket.send_raw(data.clone()).err()) - .collect(); - if errors.is_empty() { - Ok(()) - } else { - Err(errors.into()) - } - } - - fn broadcast_with_ack( + async fn broadcast_with_ack( &self, packet: Packet, opts: BroadcastOptions, timeout: Option, - ) -> AckInnerStream { - let sockets = self.apply_opts(opts); - #[cfg(feature = "tracing")] - tracing::debug!( - "broadcasting packet to {} sockets: {:?}", - sockets.len(), - sockets.iter().map(|s| s.id).collect::>() - ); - AckInnerStream::broadcast(packet, sockets, timeout) - } - - fn sockets(&self, rooms: impl RoomParam) -> Result, Infallible> { - let mut opts = BroadcastOptions::default(); - opts.rooms.extend(rooms.into_room_iter()); - Ok(self - .apply_opts(opts) - .into_iter() - .map(|socket| socket.id) - .collect()) - } - - //TODO: make this operation O(1) - fn socket_rooms(&self, sid: Sid) -> Result>, Infallible> { - let rooms_map = self.rooms.read().unwrap(); - Ok(rooms_map - .iter() - .filter(|(_, sockets)| sockets.contains(&sid)) - .map(|(room, _)| room.clone()) - .collect()) - } - - fn fetch_sockets(&self, opts: BroadcastOptions) -> Result>, Infallible> { - Ok(self.apply_opts(opts)) - } - - fn add_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) -> Result<(), Infallible> { - let rooms: Vec = rooms.into_room_iter().collect(); - for socket in self.apply_opts(opts) { - self.add_all(socket.id, rooms.clone()).unwrap(); - } - Ok(()) - } - - fn del_sockets(&self, opts: BroadcastOptions, rooms: impl RoomParam) -> Result<(), Infallible> { - let rooms: Vec = rooms.into_room_iter().collect(); - for socket in self.apply_opts(opts) { - self.del(socket.id, rooms.clone()).unwrap(); - } - Ok(()) + ) -> Result { + Ok(self.get_local().broadcast_with_ack(packet, opts, timeout).0) } - fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), Vec> { - let mut errors: Vec<_> = Vec::new(); - - for sock in self.apply_opts(opts) { - if let Err(e) = sock.disconnect() { - errors.push(e); - } - } - - if errors.is_empty() { - Ok(()) - } else { - Err(errors) - } - } - - fn rooms(&self) -> Result, Self::Error> { - Ok(self.rooms.read().unwrap().keys().cloned().collect()) - } -} - -impl LocalAdapter { - /// Applies the given `opts` and return the sockets that match. - fn apply_opts(&self, opts: BroadcastOptions) -> Vec> { - let rooms = opts.rooms; - - let except = self.get_except_sids(&opts.except); - let ns = self.ns.upgrade().unwrap(); - if !rooms.is_empty() { - let rooms_map = self.rooms.read().unwrap(); - rooms - .iter() - .filter_map(|room| rooms_map.get(room)) - .flatten() - .filter(|sid| { - !except.contains(*sid) - && (!opts.flags.contains(&BroadcastFlags::Broadcast) - || opts.sid.map(|s| s != **sid).unwrap_or(true)) - }) - .filter_map(|sid| ns.get_socket(*sid).ok()) - .map(SocketRef::from) - .collect() - } else if opts.flags.contains(&BroadcastFlags::Broadcast) { - let sockets = ns.get_sockets(); - sockets - .into_iter() - .filter(|socket| { - !except.contains(&socket.id) && opts.sid.map(|s| s != socket.id).unwrap_or(true) - }) - .map(SocketRef::from) - .collect() - } else if let Some(sock) = opts.sid.and_then(|sid| ns.get_socket(sid).ok()) { - vec![sock.into()] - } else { - vec![] - } - } - - fn get_except_sids(&self, except: &HashSet) -> HashSet { - let mut except_sids = HashSet::new(); - let rooms_map = self.rooms.read().unwrap(); - for room in except { - if let Some(sockets) = rooms_map.get(room) { - except_sids.extend(sockets); - } - } - except_sids - } -} - -#[cfg(test)] -mod test { - use super::*; - use std::sync::Arc; - - macro_rules! hash_set { - {$($v: expr),* $(,)?} => { - std::collections::HashSet::from([$($v,)*]) - }; - } - - #[tokio::test] - async fn test_server_count() { - let ns = Namespace::new_dummy([]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - assert_eq!(adapter.server_count().unwrap(), 1); - } - - #[tokio::test] - async fn test_add_all() { - let socket = Sid::new(); - let ns = Namespace::new_dummy([socket]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket, ["room1", "room2"]).unwrap(); - let rooms_map = adapter.rooms.read().unwrap(); - assert_eq!(rooms_map.len(), 2); - assert_eq!(rooms_map.get("room1").unwrap().len(), 1); - assert_eq!(rooms_map.get("room2").unwrap().len(), 1); - } - - #[tokio::test] - async fn test_del() { - let socket = Sid::new(); - let ns = Namespace::new_dummy([socket]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket, ["room1", "room2"]).unwrap(); - adapter.del(socket, "room1").unwrap(); - let rooms_map = adapter.rooms.read().unwrap(); - assert_eq!(rooms_map.len(), 2); - assert_eq!(rooms_map.get("room1").unwrap().len(), 0); - assert_eq!(rooms_map.get("room2").unwrap().len(), 1); - } - - #[tokio::test] - async fn test_del_all() { - let socket = Sid::new(); - let ns = Namespace::new_dummy([socket]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket, ["room1", "room2"]).unwrap(); - adapter.del_all(socket).unwrap(); - let rooms_map = adapter.rooms.read().unwrap(); - assert_eq!(rooms_map.len(), 2); - assert_eq!(rooms_map.get("room1").unwrap().len(), 0); - assert_eq!(rooms_map.get("room2").unwrap().len(), 0); - } - - #[tokio::test] - async fn test_socket_room() { - let sid1 = Sid::new(); - let sid2 = Sid::new(); - let sid3 = Sid::new(); - let ns = Namespace::new_dummy([sid1, sid2, sid3]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(sid1, ["room1", "room2"]).unwrap(); - adapter.add_all(sid2, ["room1"]).unwrap(); - adapter.add_all(sid3, ["room2"]).unwrap(); - assert!(adapter - .socket_rooms(sid1) - .unwrap() - .contains(&"room1".into())); - assert!(adapter - .socket_rooms(sid1) - .unwrap() - .contains(&"room2".into())); - assert_eq!(adapter.socket_rooms(sid2).unwrap(), ["room1"]); - assert_eq!(adapter.socket_rooms(sid3).unwrap(), ["room2"]); - } - - #[tokio::test] - async fn test_add_socket() { - let socket = Sid::new(); - let ns = Namespace::new_dummy([socket]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket, ["room1"]).unwrap(); - - let mut opts = BroadcastOptions { - sid: Some(socket), - ..Default::default() - }; - opts.rooms = hash_set!["room1".into()]; - adapter.add_sockets(opts, "room2").unwrap(); - let rooms_map = adapter.rooms.read().unwrap(); - - assert_eq!(rooms_map.len(), 2); - assert!(rooms_map.get("room1").unwrap().contains(&socket)); - assert!(rooms_map.get("room2").unwrap().contains(&socket)); - } - - #[tokio::test] - async fn test_del_socket() { - let socket = Sid::new(); - let ns = Namespace::new_dummy([socket]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket, ["room1"]).unwrap(); - - let mut opts = BroadcastOptions { - sid: Some(socket), - ..Default::default() - }; - opts.rooms = hash_set!["room1".into()]; - adapter.add_sockets(opts, "room2").unwrap(); - - { - let rooms_map = adapter.rooms.read().unwrap(); - - assert_eq!(rooms_map.len(), 2); - assert!(rooms_map.get("room1").unwrap().contains(&socket)); - assert!(rooms_map.get("room2").unwrap().contains(&socket)); - } - - let mut opts = BroadcastOptions { - sid: Some(socket), - ..Default::default() - }; - opts.rooms = hash_set!["room1".into()]; - adapter.del_sockets(opts, "room2").unwrap(); - - { - let rooms_map = adapter.rooms.read().unwrap(); - - assert_eq!(rooms_map.len(), 2); - assert!(rooms_map.get("room1").unwrap().contains(&socket)); - assert!(rooms_map.get("room2").unwrap().is_empty()); - } - } - - #[tokio::test] - async fn test_sockets() { - let socket0 = Sid::new(); - let socket1 = Sid::new(); - let socket2 = Sid::new(); - let ns = Namespace::new_dummy([socket0, socket1, socket2]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter.add_all(socket0, ["room1", "room2"]).unwrap(); - adapter.add_all(socket1, ["room1", "room3"]).unwrap(); - adapter.add_all(socket2, ["room2", "room3"]).unwrap(); - - let sockets = adapter.sockets("room1").unwrap(); - assert_eq!(sockets.len(), 2); - assert!(sockets.contains(&socket0)); - assert!(sockets.contains(&socket1)); - - let sockets = adapter.sockets("room2").unwrap(); - assert_eq!(sockets.len(), 2); - assert!(sockets.contains(&socket0)); - assert!(sockets.contains(&socket2)); - - let sockets = adapter.sockets("room3").unwrap(); - assert_eq!(sockets.len(), 2); - assert!(sockets.contains(&socket1)); - assert!(sockets.contains(&socket2)); - } - - #[tokio::test] - async fn test_disconnect_socket() { - let socket0 = Sid::new(); - let socket1 = Sid::new(); - let socket2 = Sid::new(); - let ns = Namespace::new_dummy([socket0, socket1, socket2]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - adapter - .add_all(socket0, ["room1", "room2", "room4"]) - .unwrap(); - adapter - .add_all(socket1, ["room1", "room3", "room5"]) - .unwrap(); - adapter - .add_all(socket2, ["room2", "room3", "room6"]) - .unwrap(); - - let mut opts = BroadcastOptions { - sid: Some(socket0), - ..Default::default() - }; - opts.rooms = hash_set!["room5".into()]; - adapter.disconnect_socket(opts).unwrap(); - - let sockets = adapter.sockets("room2").unwrap(); - assert_eq!(sockets.len(), 2); - assert!(sockets.contains(&socket2)); - assert!(sockets.contains(&socket0)); - } - #[tokio::test] - async fn test_apply_opts() { - let socket0 = Sid::new(); - let socket1 = Sid::new(); - let socket2 = Sid::new(); - let ns = Namespace::new_dummy([socket0, socket1, socket2]); - let adapter = LocalAdapter::new(Arc::downgrade(&ns)); - // Add socket 0 to room1 and room2 - adapter.add_all(socket0, ["room1", "room2"]).unwrap(); - // Add socket 1 to room1 and room3 - adapter.add_all(socket1, ["room1", "room3"]).unwrap(); - // Add socket 2 to room2 and room3 - adapter - .add_all(socket2, ["room1", "room2", "room3"]) - .unwrap(); - - // socket 2 is the sender - let mut opts = BroadcastOptions { - sid: Some(socket2), - ..Default::default() - }; - opts.rooms = hash_set!["room1".into()]; - opts.except = hash_set!["room2".into()]; - let sockets = adapter.fetch_sockets(opts).unwrap(); - assert_eq!(sockets.len(), 1); - assert_eq!(sockets[0].id, socket1); - - let mut opts = BroadcastOptions { - sid: Some(socket2), - ..Default::default() - }; - opts.flags.insert(BroadcastFlags::Broadcast); - let sockets = adapter.fetch_sockets(opts).unwrap(); - assert_eq!(sockets.len(), 2); - sockets.iter().for_each(|s| { - assert!(s.id == socket0 || s.id == socket1); - }); - - let mut opts = BroadcastOptions { - sid: Some(socket2), - ..Default::default() - }; - opts.flags.insert(BroadcastFlags::Broadcast); - opts.except = hash_set!["room2".into()]; - let sockets = adapter.fetch_sockets(opts).unwrap(); - assert_eq!(sockets.len(), 1); - - let opts = BroadcastOptions { - sid: Some(socket2), - ..Default::default() - }; - let sockets = adapter.fetch_sockets(opts).unwrap(); - assert_eq!(sockets.len(), 1); - assert_eq!(sockets[0].id, socket2); - - let opts = BroadcastOptions { - sid: Some(Sid::new()), - ..Default::default() - }; - let sockets = adapter.fetch_sockets(opts).unwrap(); - assert_eq!(sockets.len(), 0); + fn get_local(&self) -> &CoreLocalAdapter { + &self.0 } } +impl DefinedAdapter for LocalAdapter {} diff --git a/crates/socketioxide/src/client.rs b/crates/socketioxide/src/client.rs index 460d5a93..c1285e04 100644 --- a/crates/socketioxide/src/client.rs +++ b/crates/socketioxide/src/client.rs @@ -10,37 +10,37 @@ use futures_util::{FutureExt, TryFutureExt}; use engineioxide::sid::Sid; use matchit::{Match, Router}; +use socketioxide_core::packet::{Packet, PacketData}; use socketioxide_core::parser::{Parse, ParserState}; use socketioxide_core::Value; use tokio::sync::oneshot; -use crate::adapter::Adapter; -use crate::handler::ConnectHandler; -use crate::ns::NamespaceCtr; -use crate::parser::{ParseError, Parser}; -use crate::socket::DisconnectReason; use crate::{ + adapter::Adapter, errors::Error, - ns::Namespace, - packet::{Packet, PacketData}, - SocketIoConfig, + handler::ConnectHandler, + ns::{Namespace, NamespaceCtr}, + parser::{ParseError, Parser}, + socket::DisconnectReason, + ProtocolVersion, SocketIo, SocketIoConfig, }; -use crate::{ProtocolVersion, SocketIo}; pub struct Client { pub(crate) config: SocketIoConfig, nsps: RwLock>>>, router: RwLock>>, + adapter_state: A::State, #[cfg(feature = "state")] pub(crate) state: state::TypeMap![Send + Sync], } -/// ==== impl Client ==== +// ==== impl Client ==== impl Client
{ pub fn new( config: SocketIoConfig, + adapter_state: A::State, #[cfg(feature = "state")] mut state: state::TypeMap![Send + Sync], ) -> Self { #[cfg(feature = "state")] @@ -50,6 +50,7 @@ impl Client { config, nsps: RwLock::new(HashMap::new()), router: RwLock::new(Router::new()), + adapter_state, #[cfg(feature = "state")] state, } @@ -57,7 +58,7 @@ impl Client { /// Called when a socket connects to a new namespace fn sock_connect( - &self, + self: &Arc, auth: Option, ns_path: Str, esocket: &Arc>>, @@ -81,9 +82,16 @@ impl Client { // We have to create a new `Str` otherwise, we would keep a ref to the original connect packet // for the entire lifetime of the Namespace. let path = Str::copy_from_slice(&ns_path); - let ns = ns_ctr.get_new_ns(path.clone()); - self.nsps.write().unwrap().insert(path, ns.clone()); - tokio::spawn(connect(ns, esocket.clone())); + let ns = ns_ctr.get_new_ns(path.clone(), &self.adapter_state, &self.config); + let this = self.clone(); + let esocket = esocket.clone(); + let adapter = ns.adapter.clone(); + let on_success = move || { + this.nsps.write().unwrap().insert(path, ns.clone()); + tokio::spawn(connect(ns, esocket)); + }; + // We "ask" the adapter implementation to manage the init response itself + socketioxide_core::adapter::Spawnable::spawn(adapter.init(on_success)); } else if protocol == ProtocolVersion::V4 && ns_path == "/" { #[cfg(feature = "tracing")] tracing::error!( @@ -136,16 +144,21 @@ impl Client { } /// Adds a new namespace handler - pub fn add_ns(&self, path: Cow<'static, str>, callback: C) + pub fn add_ns(self: Arc, path: Cow<'static, str>, callback: C) -> A::InitRes where C: ConnectHandler, T: Send + Sync + 'static, { #[cfg(feature = "tracing")] tracing::debug!("adding namespace {}", path); - let path = Str::from(path); - let ns = Namespace::new(path.clone(), callback); - self.nsps.write().unwrap().insert(path, ns); + + let ns_path = Str::from(&path); + let ns = Namespace::new(ns_path.clone(), callback, &self.adapter_state, &self.config); + let adapter = ns.adapter.clone(); + let on_success = move || { + self.nsps.write().unwrap().insert(ns_path, ns); + }; + adapter.init(on_success) } pub fn add_dyn_ns(&self, path: String, callback: C) -> Result<(), matchit::InsertError> @@ -254,23 +267,16 @@ impl EngineIoHandler for Client { .filter_map(|ns| ns.get_socket(socket.id).ok()) .collect(); - let _res: Result, _> = socks + let _cnt = socks .into_iter() .map(|s| s.close(reason.clone().into())) - .collect(); + .count(); #[cfg(feature = "tracing")] - match _res { - Ok(vec) => { - tracing::debug!("disconnect handle spawned for {} namespaces", vec.len()) - } - Err(_e) => { - tracing::debug!("error while disconnecting socket: {}", _e) - } - } + tracing::debug!("disconnect handle spawned for {_cnt} namespaces"); } - fn on_message(&self, msg: Str, socket: Arc>>) { + fn on_message(self: &Arc, msg: Str, socket: Arc>>) { #[cfg(feature = "tracing")] tracing::debug!("received message: {:?}", msg); let packet = match self.parser().decode_str(&socket.data.parser_state, msg) { @@ -309,7 +315,7 @@ impl EngineIoHandler for Client { /// When a binary payload is received from a socket, it is applied to the partial binary packet /// /// If the packet is complete, it is propagated to the namespace - fn on_binary(&self, data: Bytes, socket: Arc>>) { + fn on_binary(self: &Arc, data: Bytes, socket: Arc>>) { #[cfg(feature = "tracing")] tracing::debug!("received binary: {:?}", &data); let packet = match self.parser().decode_bin(&socket.data.parser_state, data) { @@ -420,19 +426,21 @@ mod test { connect_timeout: CONNECT_TIMEOUT, ..Default::default() }; - let client = Client::::new( + let client = Client::new( config, + (), #[cfg(feature = "state")] Default::default(), ); - client.add_ns("/".into(), || {}); - Arc::new(client) + let client = Arc::new(client); + client.clone().add_ns("/".into(), || {}); + client } #[tokio::test] async fn get_ns() { let client = create_client(); - let ns = Namespace::new(Str::from("/"), || {}); + let ns = Namespace::new(Str::from("/"), || {}, &client.adapter_state, &client.config); client.nsps.write().unwrap().insert(Str::from("/"), ns); assert!(client.get_ns("/").is_some()); } diff --git a/crates/socketioxide/src/errors.rs b/crates/socketioxide/src/errors.rs index fc91f428..1506d1b5 100644 --- a/crates/socketioxide/src/errors.rs +++ b/crates/socketioxide/src/errors.rs @@ -1,10 +1,12 @@ use engineioxide::{sid::Sid, socket::DisconnectReason as EIoDisconnectReason}; -use std::fmt::{Debug, Display}; -use tokio::{sync::mpsc::error::TrySendError, time::error::Elapsed}; +use serde::{Deserialize, Serialize}; +use std::fmt::Debug; +use tokio::time::error::Elapsed; pub use matchit::InsertError as NsInsertError; -pub use crate::parser::{DecodeError, EncodeError}; +pub use crate::parser::ParserError; +pub use socketioxide_core::errors::{AdapterError, BroadcastError, SocketError}; /// Error type for socketio #[derive(thiserror::Error, Debug)] @@ -28,113 +30,42 @@ pub enum Error { pub(crate) struct ConnectFail; /// Error type for ack operations. -#[derive(thiserror::Error, Debug)] +#[derive(thiserror::Error, Debug, Serialize, Deserialize)] pub enum AckError { /// The ack response cannot be parsed #[error("cannot deserialize packet from ack response: {0:?}")] - Decode(#[from] DecodeError), + Decode(#[from] ParserError), /// The ack response timed out #[error("ack timeout error")] Timeout, - /// An error happened while broadcasting to other socket.io nodes - #[error("adapter error: {0}")] - Adapter(#[from] AdapterError), - /// Error sending/receiving data through the engine.io socket #[error("Error sending data through the engine.io socket: {0:?}")] Socket(#[from] SocketError), } -/// Error type for broadcast operations. -#[derive(thiserror::Error, Debug)] -pub enum BroadcastError { - /// An error occurred while sending packets. - #[error("Error sending data through the engine.io socket: {0:?}")] - Socket(Vec), - - /// An error occurred while serializing the packet. - #[error("Error serializing packet: {0:?}")] - Serialize(#[from] EncodeError), - - /// An error occured while broadcasting to other nodes. - #[error("Adapter error: {0}")] - Adapter(#[from] AdapterError), -} /// Error type for sending operations. #[derive(thiserror::Error, Debug)] pub enum SendError { /// An error occurred while serializing the packet. #[error("Error serializing packet: {0:?}")] - Serialize(#[from] EncodeError), + Serialize(#[from] ParserError), /// Error sending/receiving data through the engine.io socket #[error("Error sending data through the engine.io socket: {0:?}")] Socket(#[from] SocketError), } -/// Error type when using the underlying engine.io socket -#[derive(Debug, thiserror::Error)] -pub enum SocketError { - /// The socket channel is full. - /// You might need to increase the channel size with the [`SocketIoBuilder::max_buffer_size`] method. - /// - /// [`SocketIoBuilder::max_buffer_size`]: crate::SocketIoBuilder#method.max_buffer_size - #[error("internal channel full error")] - InternalChannelFull, - - /// The socket is already closed - #[error("socket closed")] - Closed, -} - -/// Error type for sending operations. +/// Error type for the [`emit_with_ack`](crate::operators::BroadcastOperators::emit_with_ack) method. #[derive(thiserror::Error, Debug)] -pub enum DisconnectError { - /// The socket channel is full. - /// You might need to increase the channel size with the [`SocketIoBuilder::max_buffer_size`] method. - /// - /// [`SocketIoBuilder::max_buffer_size`]: crate::SocketIoBuilder#method.max_buffer_size - #[error("internal channel full error")] - InternalChannelFull, - - /// An error occured while broadcasting to other nodes. - #[error("adapter error: {0:?}")] - Adapter(#[from] AdapterError), -} - -/// Error type for the [`Adapter`](crate::adapter::Adapter) trait. -#[derive(Debug, thiserror::Error)] -pub struct AdapterError(#[from] pub Box); -impl Display for AdapterError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - std::fmt::Display::fmt(&self.0, f) - } -} - -impl From> for SocketError { - fn from(value: TrySendError) -> Self { - match value { - TrySendError::Full(_) => Self::InternalChannelFull, - TrySendError::Closed(_) => Self::Closed, - } - } -} - -impl From> for BroadcastError { - /// Converts a vector of `SendError` into a `BroadcastError`. - /// - /// # Arguments - /// - /// * `value` - A vector of `SendError` representing the sending errors. - /// - /// # Returns - /// - /// A `BroadcastError` containing the sending errors. - fn from(value: Vec) -> Self { - Self::Socket(value) - } +pub enum EmitWithAckError { + /// An error occurred while encoding the data. + #[error("Error encoding data: {0:?}")] + Encode(#[from] ParserError), + /// An error occurred while broadcasting to other nodes. + #[error("Adapter error: {0:?}")] + Adapter(#[from] Box), } impl From for AckError { diff --git a/crates/socketioxide/src/extract/data.rs b/crates/socketioxide/src/extract/data.rs index 2251a22f..8e907530 100644 --- a/crates/socketioxide/src/extract/data.rs +++ b/crates/socketioxide/src/extract/data.rs @@ -2,10 +2,9 @@ use std::convert::Infallible; use std::sync::Arc; use crate::handler::{FromConnectParts, FromMessageParts}; -use crate::parser::DecodeError; use crate::{adapter::Adapter, socket::Socket}; use serde::de::DeserializeOwned; -use socketioxide_core::parser::Parse; +use socketioxide_core::parser::{Parse, ParserError}; use socketioxide_core::Value; /// An Extractor that returns the deserialized data without checking errors. @@ -37,10 +36,9 @@ where T: DeserializeOwned, A: Adapter, { - type Error = DecodeError; + type Error = ParserError; fn from_connect_parts(s: &Arc>, auth: &Option) -> Result { - let parser = s.parser(); - parser.decode_default(auth.as_ref()).map(Data) + s.parser.decode_default(auth.as_ref()).map(Data) } } @@ -49,14 +47,13 @@ where T: DeserializeOwned, A: Adapter, { - type Error = DecodeError; + type Error = ParserError; fn from_message_parts( s: &Arc>, v: &mut Value, _: &Option, ) -> Result { - let parser = s.parser(); - parser.decode_value(v, true).map(Data) + s.parser.decode_value(v, true).map(Data) } } @@ -81,7 +78,7 @@ where /// /// [`rmpv::Value`]: https://docs.rs/rmpv /// [`serde_json::Value`]: https://docs.rs/serde_json/latest/serde_json/value/ -pub struct TryData(pub Result); +pub struct TryData(pub Result); impl FromConnectParts for TryData where @@ -90,8 +87,7 @@ where { type Error = Infallible; fn from_connect_parts(s: &Arc>, auth: &Option) -> Result { - let parser = s.parser(); - Ok(TryData(parser.decode_default(auth.as_ref()))) + Ok(TryData(s.parser.decode_default(auth.as_ref()))) } } impl FromMessageParts for TryData @@ -105,10 +101,9 @@ where v: &mut Value, _: &Option, ) -> Result { - let parser = s.parser(); - Ok(TryData(parser.decode_value(v, true))) + Ok(TryData(s.parser.decode_value(v, true))) } } -super::__impl_deref!(TryData: Result); +super::__impl_deref!(TryData: Result); super::__impl_deref!(Data); diff --git a/crates/socketioxide/src/extract/mod.rs b/crates/socketioxide/src/extract/mod.rs index 6313b8ad..297a57b1 100644 --- a/crates/socketioxide/src/extract/mod.rs +++ b/crates/socketioxide/src/extract/mod.rs @@ -1,5 +1,4 @@ -//! ### Extractors for [`ConnectHandler`], [`ConnectMiddleware`], -//! [`MessageHandler`] and [`DisconnectHandler`](crate::handler::DisconnectHandler). +//! ### Extractors for [`ConnectHandler`], [`ConnectMiddleware`], [`MessageHandler`] and [`DisconnectHandler`](crate::handler::DisconnectHandler). //! //! They can be used to extract data from the context of the handler and get specific params. Here are some examples of extractors: //! * [`Data`]: extracts and deserialize from any receieved data, if a deserialization error occurs the handler won't be called: @@ -10,29 +9,34 @@ //! * [`TryData`]: extracts and deserialize from the any received data but with a `Result` type in case of error: //! - for [`ConnectHandler`] and [`ConnectMiddleware`]: extracts and deserialize from the incoming auth data //! - for [`MessageHandler`]: extracts and deserialize from the incoming message data -//! * [`SocketRef`]: extracts a reference to the [`Socket`](crate::socket::Socket) +//! * [`SocketRef`]: extracts a reference to the [`Socket`](crate::socket::Socket). //! * [`SocketIo`](crate::SocketIo): extracts a reference to the whole socket.io server context. -//! * [`AckSender`]: Can be used to send an ack response to the current message event -//! * [`ProtocolVersion`](crate::ProtocolVersion): extracts the protocol version -//! * [`TransportType`](crate::TransportType): extracts the transport type -//! * [`DisconnectReason`](crate::socket::DisconnectReason): extracts the reason of the disconnection +//! * [`AckSender`]: Can be used to send an ack response to the current message event. +//! * [`ProtocolVersion`](crate::ProtocolVersion): extracts the protocol version. +//! * [`TransportType`](crate::TransportType): extracts the transport type. +//! * [`DisconnectReason`](crate::socket::DisconnectReason): extracts the reason of the disconnection. //! * [`State`]: extracts a [`Clone`] of a state previously set with [`SocketIoBuilder::with_state`](crate::io::SocketIoBuilder). //! * [`Extension`]: extracts an extension of the given type stored on the called socket by cloning it. -//! * [`MaybeExtension`]: extracts an extension of the given type if it exists or [`None`] otherwise -//! * [`HttpExtension`]: extracts an http extension of the given type coming from the request. -//! (Similar to axum's [`extract::Extension`](https://docs.rs/axum/latest/axum/struct.Extension.html) +//! * [`MaybeExtension`]: extracts an extension of the given type if it exists or [`None`] otherwise. +//! * [`HttpExtension`]: extracts an http extension of the given type coming from the request +//! (Similar to axum's [`extract::Extension`](https://docs.rs/axum/latest/axum/struct.Extension.html). //! * [`MaybeHttpExtension`]: extracts an http extension of the given type if it exists or [`None`] otherwise. //! -//! ### You can also implement your own Extractor with the [`FromConnectParts`], [`FromMessageParts`] and -//! [`FromDisconnectParts`] traits +//! ### You can also implement your own Extractor! +//! Implement the [`FromConnectParts`], [`FromMessageParts`], [`FromMessage`] and [`FromDisconnectParts`] traits +//! on any type to extract data from the context of the handler. +//! //! When implementing these traits, if you clone the [`Arc`](crate::socket::Socket) make sure //! that it is dropped at least when the socket is disconnected. //! Otherwise it will create a memory leak. It is why the [`SocketRef`] extractor is used instead of cloning //! the socket for common usage. -//! If you want to deserialize the [`Value`](socketioxide_core::Value) data you must manually call the `Data` extractor to deserialize it. +//! +//! If you want to deserialize the [`Value`](socketioxide_core::Value) data you must manually call +//! the `Data` extractor to deserialize it. //! //! [`FromConnectParts`]: crate::handler::FromConnectParts //! [`FromMessageParts`]: crate::handler::FromMessageParts +//! [`FromMessage`]: crate::handler::FromMessage //! [`FromDisconnectParts`]: crate::handler::FromDisconnectParts //! [`ConnectHandler`]: crate::handler::ConnectHandler //! [`ConnectMiddleware`]: crate::handler::ConnectMiddleware @@ -42,14 +46,12 @@ //! #### Example that extracts a user id from the query params //! ```rust //! # use bytes::Bytes; -//! # use socketioxide::handler::{FromConnectParts, FromMessageParts}; +//! # use socketioxide::handler::{FromConnectParts, FromMessageParts, Value}; //! # use socketioxide::adapter::Adapter; //! # use socketioxide::socket::Socket; //! # use std::sync::Arc; //! # use std::convert::Infallible; //! # use socketioxide::SocketIo; -//! # use socketioxide_core::Value; -//! //! struct UserId(String); //! //! #[derive(Debug)] diff --git a/crates/socketioxide/src/extract/socket.rs b/crates/socketioxide/src/extract/socket.rs index 9c6807ef..aeb05b24 100644 --- a/crates/socketioxide/src/extract/socket.rs +++ b/crates/socketioxide/src/extract/socket.rs @@ -1,19 +1,19 @@ use std::convert::Infallible; use std::sync::Arc; -use crate::handler::{FromConnectParts, FromDisconnectParts, FromMessageParts}; use crate::{ adapter::{Adapter, LocalAdapter}, - errors::{DisconnectError, SendError}, - packet::Packet, + handler::{FromConnectParts, FromDisconnectParts, FromMessageParts}, socket::{DisconnectReason, Socket}, - SocketIo, + SendError, SocketIo, }; use serde::Serialize; -use socketioxide_core::parser::Parse; -use socketioxide_core::Value; +use socketioxide_core::{errors::SocketError, packet::Packet, parser::Parse, Value}; /// An Extractor that returns a reference to a [`Socket`]. +/// +/// It is generic over the [`Adapter`] type. If you plan to use it with another adapter than the default, +/// make sure to have a handler that is [generic over the adapter type](crate#adapters). #[derive(Debug)] pub struct SocketRef(Arc>); @@ -71,13 +71,16 @@ impl SocketRef { /// /// It will also call the disconnect handler if it is set. #[inline(always)] - pub fn disconnect(self) -> Result<(), DisconnectError> { + pub fn disconnect(self) -> Result<(), SocketError> { self.0.disconnect() } } /// An Extractor to send an ack response corresponding to the current event. /// If the client sent a normal message without expecting an ack, the ack callback will do nothing. +/// +/// It is generic over the [`Adapter`] type. If you plan to use it with another adapter than the default, +/// make sure to have a handler that is [generic over the adapter type](crate#adapters). #[derive(Debug)] pub struct AckSender { socket: Arc>, @@ -111,9 +114,9 @@ impl AckSender { } }; let ns = self.socket.ns.path.clone(); - let data = self.socket.parser().encode_value(data, None)?; + let data = self.socket.parser.encode_value(data, None)?; let packet = Packet::ack(ns, data, ack_id); - permit.send(packet, self.socket.parser()); + permit.send(packet, self.socket.parser); Ok(()) } else { Ok(()) diff --git a/crates/socketioxide/src/handler/connect.rs b/crates/socketioxide/src/handler/connect.rs index efa31345..19ca42e5 100644 --- a/crates/socketioxide/src/handler/connect.rs +++ b/crates/socketioxide/src/handler/connect.rs @@ -171,6 +171,7 @@ pub trait FromConnectParts: Sized { note = "This function is not a ConnectMiddleware. Check that: * It is a clonable sync or async `FnOnce` that returns `Result<(), E> where E: Display`. * All its arguments are valid connect extractors. +* If you use a custom adapter, it must be generic over the adapter type. See `https://docs.rs/socketioxide/latest/socketioxide/extract/index.html` for details.\n", label = "Invalid ConnectMiddleware" ) @@ -200,6 +201,7 @@ pub trait ConnectMiddleware: Sized + Clone + Send + Sync + 'stati note = "This function is not a ConnectHandler. Check that: * It is a clonable sync or async `FnOnce` that returns nothing. * All its arguments are valid connect extractors. +* If you use a custom adapter, it must be generic over the adapter type. See `https://docs.rs/socketioxide/latest/socketioxide/extract/index.html` for details.\n", label = "Invalid ConnectHandler" ) diff --git a/crates/socketioxide/src/handler/disconnect.rs b/crates/socketioxide/src/handler/disconnect.rs index ed94182d..2ccb2841 100644 --- a/crates/socketioxide/src/handler/disconnect.rs +++ b/crates/socketioxide/src/handler/disconnect.rs @@ -127,6 +127,7 @@ pub trait FromDisconnectParts: Sized { note = "This function is not a DisconnectHandler. Check that: * It is a clonable sync or async `FnOnce` that returns nothing. * All its arguments are valid disconnect extractors. +* If you use a custom adapter, it must be generic over the adapter type. See `https://docs.rs/socketioxide/latest/socketioxide/extract/index.html` for details.\n", label = "Invalid DisconnectHandler" ) diff --git a/crates/socketioxide/src/handler/message.rs b/crates/socketioxide/src/handler/message.rs index 846af566..d6af1fe5 100644 --- a/crates/socketioxide/src/handler/message.rs +++ b/crates/socketioxide/src/handler/message.rs @@ -96,6 +96,7 @@ pub(crate) trait ErasedMessageHandler: Send + Sync + 'static { note = "This function is not a MessageHandler. Check that: * It is a clonable sync or async `FnOnce` that returns nothing. * All its arguments are valid message extractors. +* If you use a custom adapter, it must be generic over the adapter type. See `https://docs.rs/socketioxide/latest/socketioxide/extract/index.html` for details.\n", label = "Invalid MessageHandler" ) diff --git a/crates/socketioxide/src/io.rs b/crates/socketioxide/src/io.rs index 48f13562..ae243072 100644 --- a/crates/socketioxide/src/io.rs +++ b/crates/socketioxide/src/io.rs @@ -1,4 +1,4 @@ -use std::{borrow::Cow, sync::Arc, time::Duration}; +use std::{borrow::Cow, fmt, sync::Arc, time::Duration}; use engineioxide::{ config::{EngineIoConfig, EngineIoConfigBuilder}, @@ -7,21 +7,26 @@ use engineioxide::{ TransportType, }; use serde::Serialize; +use socketioxide_core::{ + adapter::{DefinedAdapter, Room, RoomParam}, + Uid, +}; use socketioxide_parser_common::CommonParser; #[cfg(feature = "msgpack")] use socketioxide_parser_msgpack::MsgPackParser; use crate::{ ack::AckStream, - adapter::{Adapter, LocalAdapter, Room}, + adapter::{Adapter, LocalAdapter}, client::Client, extract::SocketRef, handler::ConnectHandler, layer::SocketIoLayer, - operators::{BroadcastOperators, RoomParam}, - parser::{self, Parser}, + operators::BroadcastOperators, + parser::Parser, service::SocketIoService, - BroadcastError, DisconnectError, + socket::RemoteSocket, + BroadcastError, EmitWithAckError, }; /// The parser to use to encode and decode socket.io packets @@ -62,6 +67,9 @@ pub struct SocketIoConfig { /// The parser to use to encode and decode socket.io packets pub(crate) parser: Parser, + + /// A global server identifier + pub server_id: Uid, } impl Default for SocketIoConfig { @@ -74,6 +82,7 @@ impl Default for SocketIoConfig { ack_timeout: Duration::from_secs(5), connect_timeout: Duration::from_secs(45), parser: Parser::default(), + server_id: Uid::new(), } } } @@ -84,23 +93,24 @@ impl Default for SocketIoConfig { pub struct SocketIoBuilder { config: SocketIoConfig, engine_config_builder: EngineIoConfigBuilder, - adapter: std::marker::PhantomData, + adapter_state: A::State, #[cfg(feature = "state")] state: state::TypeMap![Send + Sync], } -impl SocketIoBuilder { +impl SocketIoBuilder { /// Creates a new [`SocketIoBuilder`] with default config pub fn new() -> Self { Self { engine_config_builder: EngineIoConfigBuilder::new().req_path("/socket.io".to_string()), config: SocketIoConfig::default(), - adapter: std::marker::PhantomData, + adapter_state: (), #[cfg(feature = "state")] state: std::default::Default::default(), } } - +} +impl SocketIoBuilder { /// The path to listen for socket.io requests on. /// /// Defaults to "/socket.io". @@ -199,11 +209,11 @@ impl SocketIoBuilder { } /// Set a custom [`Adapter`] for this [`SocketIoBuilder`] - pub fn with_adapter(self) -> SocketIoBuilder { + pub fn with_adapter(self, adapter_state: B::State) -> SocketIoBuilder { SocketIoBuilder { config: self.config, engine_config_builder: self.engine_config_builder, - adapter: std::marker::PhantomData, + adapter_state, #[cfg(feature = "state")] state: self.state, } @@ -220,39 +230,40 @@ impl SocketIoBuilder { self.state.set(state); self } +} - /// Build a [`SocketIoLayer`] and a [`SocketIo`] instance - /// - /// The layer can be used as a tower layer +impl SocketIoBuilder { + /// Build a [`SocketIoLayer`] and a [`SocketIo`] instance that can be used as a [`tower_layer::Layer`]. pub fn build_layer(mut self) -> (SocketIoLayer, SocketIo) { self.config.engine_config = self.engine_config_builder.build(); let (layer, client) = SocketIoLayer::from_config( self.config, + self.adapter_state, #[cfg(feature = "state")] self.state, ); (layer, SocketIo(client)) } - /// Build a [`SocketIoService`] and a [`SocketIo`] instance + /// Build a [`SocketIoService`] and a [`SocketIo`] instance that + /// can be used as a [`hyper::service::Service`] or a [`tower_service::Service`]. /// /// This service will be a _standalone_ service that return a 404 error for every non-socket.io request - /// It can be used as a hyper service pub fn build_svc(mut self) -> (SocketIoService, SocketIo) { self.config.engine_config = self.engine_config_builder.build(); let (svc, client) = SocketIoService::with_config_inner( NotFoundService, self.config, + self.adapter_state, #[cfg(feature = "state")] self.state, ); (svc, SocketIo(client)) } - /// Build a [`SocketIoService`] and a [`SocketIo`] instance with an inner service - /// - /// It can be used as a hyper service + /// Build a [`SocketIoService`] and a [`SocketIo`] instance with an inner service that + /// can be used as a [`hyper::service::Service`] or a [`tower_service::Service`]. pub fn build_with_inner_svc( mut self, svc: S, @@ -262,6 +273,7 @@ impl SocketIoBuilder { let (svc, client) = SocketIoService::with_config_inner( svc, self.config, + self.adapter_state, #[cfg(feature = "state")] self.state, ); @@ -279,7 +291,9 @@ impl Default for SocketIoBuilder { /// It can be used as the main handle to access the whole socket.io context. /// /// You can also use it as an extractor for all your [`handlers`](crate::handler). -#[derive(Debug)] +/// +/// It is generic over the [`Adapter`] type. If you plan to use it with another adapter than the default, +/// make sure to have a handler that is [generic over the adapter type](crate#adapters). pub struct SocketIo(Arc>); impl SocketIo { @@ -319,95 +333,6 @@ impl SocketIo { &self.0.config } - /// # Register a [`ConnectHandler`] for the given namespace - /// - /// * See the [`connect`](crate::handler::connect) module doc for more details on connect handler. - /// * See the [`extract`](crate::extract) module doc for more details on available extractors. - /// - /// # Simple example with a sync closure: - /// ``` - /// # use socketioxide::{SocketIo, extract::*}; - /// # use serde::{Serialize, Deserialize}; - /// #[derive(Debug, Serialize, Deserialize)] - /// struct MyData { - /// name: String, - /// age: u8, - /// } - /// - /// let (_, io) = SocketIo::new_svc(); - /// io.ns("/", |socket: SocketRef| { - /// // Register a handler for the "test" event and extract the data as a `MyData` struct - /// // With the Data extractor, the handler is called only if the data can be deserialized as a `MyData` struct - /// // If you want to manage errors yourself you can use the TryData extractor - /// socket.on("test", |socket: SocketRef, Data::(data)| { - /// println!("Received a test message {:?}", data); - /// socket.emit("test-test", &MyData { name: "Test".to_string(), age: 8 }).ok(); // Emit a message to the client - /// }); - /// }); - /// - /// ``` - /// - /// # Example with a closure and an acknowledgement + binary data: - /// ``` - /// # use socketioxide::{SocketIo, extract::*}; - /// # use serde_json::Value; - /// # use serde::{Serialize, Deserialize}; - /// #[derive(Debug, Serialize, Deserialize)] - /// struct MyData { - /// name: String, - /// age: u8, - /// } - /// - /// let (_, io) = SocketIo::new_svc(); - /// io.ns("/", |socket: SocketRef| { - /// // Register an async handler for the "test" event and extract the data as a `MyData` struct - /// // Extract the binary payload as a `Vec` with the Bin extractor. - /// // It should be the last extractor because it consumes the request - /// socket.on("test", |socket: SocketRef, Data::(data), ack: AckSender| async move { - /// println!("Received a test message {:?}", data); - /// tokio::time::sleep(std::time::Duration::from_secs(1)).await; - /// ack.send(&data).ok(); // The data received is sent back to the client through the ack - /// socket.emit("test-test", &MyData { name: "Test".to_string(), age: 8 }).ok(); // Emit a message to the client - /// }); - /// }); - /// ``` - /// # Example with a closure and an authentication process: - /// ``` - /// # use socketioxide::{SocketIo, extract::{SocketRef, Data}}; - /// # use serde::{Serialize, Deserialize}; - /// #[derive(Debug, Deserialize)] - /// struct MyAuthData { - /// token: String, - /// } - /// #[derive(Debug, Serialize, Deserialize)] - /// struct MyData { - /// name: String, - /// age: u8, - /// } - /// - /// let (_, io) = SocketIo::new_svc(); - /// io.ns("/", |socket: SocketRef, Data(auth): Data| { - /// if auth.token.is_empty() { - /// println!("Invalid token, disconnecting"); - /// socket.disconnect().ok(); - /// return; - /// } - /// socket.on("test", |socket: SocketRef, Data::(data)| async move { - /// println!("Received a test message {:?}", data); - /// socket.emit("test-test", &MyData { name: "Test".to_string(), age: 8 }).ok(); // Emit a message to the client - /// }); - /// }); - /// - /// ``` - #[inline] - pub fn ns(&self, path: impl Into>, callback: C) - where - C: ConnectHandler, - T: Send + Sync + 'static, - { - self.0.add_ns(path.into(), callback); - } - /// # Register a [`ConnectHandler`] for the given dynamic namespace. /// /// You can specify dynamic parts in the path by using the `{name}` syntax. @@ -415,7 +340,8 @@ impl SocketIo { /// /// For more info about namespace routing, see the [matchit] router documentation. /// - /// The dynamic namespace will create a child namespace for any path that matches the given pattern with the given handler. + /// The dynamic namespace will create a child namespace for any path that matches the given pattern + /// with the given handler. /// /// * See the [`connect`](crate::handler::connect) module doc for more details on connect handler. /// * See the [`extract`](crate::extract) module doc for more details on available extractors. @@ -460,7 +386,8 @@ impl SocketIo { /// /// # Panics /// If the v4 protocol (legacy) is enabled and the namespace to delete is the default namespace "/". - /// For v4, the default namespace cannot be deleted. See [official doc](https://socket.io/docs/v3/namespaces/#main-namespace) for more informations. + /// For v4, the default namespace cannot be deleted. + /// See [official doc](https://socket.io/docs/v3/namespaces/#main-namespace) for more informations. #[inline] pub fn delete_ns(&self, path: impl AsRef) { self.0.delete_ns(path.as_ref()); @@ -468,7 +395,8 @@ impl SocketIo { /// # Gracefully close all the connections and drop every sockets /// - /// Any `on_disconnect` handler will called with [`DisconnectReason::ClosingServer`](crate::socket::DisconnectReason::ClosingServer) + /// Any `on_disconnect` handler will called with + /// [`DisconnectReason::ClosingServer`](crate::socket::DisconnectReason::ClosingServer) #[inline] pub async fn close(&self) { self.0.close().await; @@ -490,9 +418,11 @@ impl SocketIo { /// /// // Later in your code you can select the custom_ns namespace /// // and show all sockets connected to it - /// let sockets = io.of("custom_ns").unwrap().sockets().unwrap(); - /// for socket in sockets { - /// println!("found socket on /custom_ns namespace with id: {}", socket.id); + /// async fn test(io: SocketIo) { + /// let sockets = io.of("custom_ns").unwrap().sockets(); + /// for socket in sockets { + /// println!("found socket on /custom_ns namespace with id: {}", socket.id); + /// } /// } /// ``` #[inline] @@ -538,57 +468,64 @@ impl SocketIo { /// _Alias for `io.of("/").unwrap().emit()`_. If the **default namespace "/" is not found** this fn will panic! #[doc = include_str!("../docs/operators/emit.md")] #[inline] - pub fn emit( + pub async fn emit( &self, event: impl AsRef, data: &T, ) -> Result<(), BroadcastError> { - self.get_default_op().emit(event, data) + self.get_default_op().emit(event, data).await } /// _Alias for `io.of("/").unwrap().emit_with_ack()`_. If the **default namespace "/" is not found** this fn will panic! #[doc = include_str!("../docs/operators/emit_with_ack.md")] #[inline] - pub fn emit_with_ack( + pub async fn emit_with_ack( &self, event: impl AsRef, data: &T, - ) -> Result, parser::EncodeError> { - self.get_default_op().emit_with_ack(event, data) + ) -> Result, EmitWithAckError> { + self.get_default_op().emit_with_ack(event, data).await } /// _Alias for `io.of("/").unwrap().sockets()`_. If the **default namespace "/" is not found** this fn will panic! #[doc = include_str!("../docs/operators/sockets.md")] #[inline] - pub fn sockets(&self) -> Result>, A::Error> { + pub fn sockets(&self) -> Vec> { self.get_default_op().sockets() } + /// _Alias for `io.of("/").unwrap().fetch_sockets()`_. If the **default namespace "/" is not found** this fn will panic! + #[doc = include_str!("../docs/operators/fetch_sockets.md")] + #[inline] + pub async fn fetch_sockets(&self) -> Result>, A::Error> { + self.get_default_op().fetch_sockets().await + } + /// _Alias for `io.of("/").unwrap().disconnect()`_. If the **default namespace "/" is not found** this fn will panic! #[doc = include_str!("../docs/operators/disconnect.md")] #[inline] - pub fn disconnect(&self) -> Result<(), Vec> { - self.get_default_op().disconnect() + pub async fn disconnect(&self) -> Result<(), BroadcastError> { + self.get_default_op().disconnect().await } /// _Alias for `io.of("/").unwrap().join()`_. If the **default namespace "/" is not found** this fn will panic! #[doc = include_str!("../docs/operators/join.md")] #[inline] - pub fn join(self, rooms: impl RoomParam) -> Result<(), A::Error> { - self.get_default_op().join(rooms) + pub async fn join(self, rooms: impl RoomParam) -> Result<(), A::Error> { + self.get_default_op().join(rooms).await } /// _Alias for `io.of("/").unwrap().rooms()`_. If the **default namespace "/" is not found** this fn will panic! #[doc = include_str!("../docs/operators/rooms.md")] - pub fn rooms(&self) -> Result, A::Error> { - self.get_default_op().rooms() + pub async fn rooms(&self) -> Result, A::Error> { + self.get_default_op().rooms().await } /// _Alias for `io.of("/").unwrap().rooms()`_. If the **default namespace "/" is not found** this fn will panic! #[doc = include_str!("../docs/operators/leave.md")] #[inline] - pub fn leave(self, rooms: impl RoomParam) -> Result<(), A::Error> { - self.get_default_op().leave(rooms) + pub async fn leave(self, rooms: impl RoomParam) -> Result<(), A::Error> { + self.get_default_op().leave(rooms).await } /// _Alias for `io.of("/").unwrap().get_socket()`_. If the **default namespace "/" is not found** this fn will panic! @@ -602,7 +539,7 @@ impl SocketIo { #[doc = include_str!("../docs/operators/broadcast.md")] #[inline] pub fn broadcast(&self) -> BroadcastOperators { - self.get_default_op().broadcast() + self.get_default_op() } #[cfg(feature = "state")] @@ -630,6 +567,123 @@ impl SocketIo { } } +// This private impl is used to ensure that the following methods +// are only available on a *defined* adapter. +#[allow(private_bounds)] +impl SocketIo { + /// # Register a [`ConnectHandler`] for the given namespace + /// + /// * See the [`connect`](crate::handler::connect) module doc for more details on connect handler. + /// * See the [`extract`](crate::extract) module doc for more details on available extractors. + /// + /// # Simple example with a sync closure: + /// ``` + /// # use socketioxide::{SocketIo, extract::*}; + /// # use serde::{Serialize, Deserialize}; + /// #[derive(Debug, Serialize, Deserialize)] + /// struct MyData { + /// name: String, + /// age: u8, + /// } + /// + /// let (_, io) = SocketIo::new_svc(); + /// io.ns("/", |socket: SocketRef| { + /// // Register a handler for the "test" event and extract the data as a `MyData` struct + /// // With the Data extractor, the handler is called only if the data can be deserialized as a `MyData` struct + /// // If you want to manage errors yourself you can use the TryData extractor + /// socket.on("test", |socket: SocketRef, Data::(data)| { + /// println!("Received a test message {:?}", data); + /// socket.emit("test-test", &MyData { name: "Test".to_string(), age: 8 }).ok(); // Emit a message to the client + /// }); + /// }); + /// + /// ``` + /// + /// # Example with a closure and an acknowledgement + binary data: + /// ``` + /// # use socketioxide::{SocketIo, extract::*}; + /// # use serde_json::Value; + /// # use serde::{Serialize, Deserialize}; + /// #[derive(Debug, Serialize, Deserialize)] + /// struct MyData { + /// name: String, + /// age: u8, + /// } + /// + /// let (_, io) = SocketIo::new_svc(); + /// io.ns("/", |socket: SocketRef| { + /// // Register an async handler for the "test" event and extract the data as a `MyData` struct + /// // Extract the binary payload as a `Vec` with the Bin extractor. + /// // It should be the last extractor because it consumes the request + /// socket.on("test", |socket: SocketRef, Data::(data), ack: AckSender| async move { + /// println!("Received a test message {:?}", data); + /// tokio::time::sleep(std::time::Duration::from_secs(1)).await; + /// ack.send(&data).ok(); // The data received is sent back to the client through the ack + /// socket.emit("test-test", &MyData { name: "Test".to_string(), age: 8 }).ok(); // Emit a message to the client + /// }); + /// }); + /// ``` + /// # Example with a closure and an authentication process: + /// ``` + /// # use socketioxide::{SocketIo, extract::{SocketRef, Data}}; + /// # use serde::{Serialize, Deserialize}; + /// #[derive(Debug, Deserialize)] + /// struct MyAuthData { + /// token: String, + /// } + /// #[derive(Debug, Serialize, Deserialize)] + /// struct MyData { + /// name: String, + /// age: u8, + /// } + /// + /// let (_, io) = SocketIo::new_svc(); + /// io.ns("/", |socket: SocketRef, Data(auth): Data| { + /// if auth.token.is_empty() { + /// println!("Invalid token, disconnecting"); + /// socket.disconnect().ok(); + /// return; + /// } + /// socket.on("test", |socket: SocketRef, Data::(data)| async move { + /// println!("Received a test message {:?}", data); + /// socket.emit("test-test", &MyData { name: "Test".to_string(), age: 8 }).ok(); // Emit a message to the client + /// }); + /// }); + /// + /// ``` + /// + /// # With remote adapters, this method is only available on a defined adapter: + /// ```compile_fail + /// # use socketioxide::{SocketIo}; + /// // The SocketIo instance is generic over the adapter type. + /// fn test(io: SocketIo) { + /// io.ns("/", || ()); + /// } + /// ``` + /// ``` + /// # use socketioxide::{SocketIo, adapter::LocalAdapter}; + /// // The SocketIo instance is not generic over the adapter type. + /// fn test(io: SocketIo) { + /// io.ns("/", || ()); + /// } + /// fn test_default_adapter(io: SocketIo) { + /// io.ns("/", || ()); + /// } + /// ``` + pub fn ns(&self, path: impl Into>, callback: C) -> A::InitRes + where + C: ConnectHandler, + T: Send + Sync + 'static, + { + self.0.clone().add_ns(path.into(), callback) + } +} + +impl fmt::Debug for SocketIo { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SocketIo").field("client", &self.0).finish() + } +} impl Clone for SocketIo { fn clone(&self) -> Self { Self(self.0.clone()) @@ -667,7 +721,7 @@ mod tests { #[test] fn get_default_op() { - let (_, io) = SocketIo::builder().build_svc(); + let (_, io) = SocketIo::new_svc(); io.ns("/", || {}); let _ = io.get_default_op(); } @@ -675,13 +729,13 @@ mod tests { #[test] #[should_panic(expected = "default namespace not found")] fn get_default_op_panic() { - let (_, io) = SocketIo::builder().build_svc(); + let (_, io) = SocketIo::new_svc(); let _ = io.get_default_op(); } #[test] fn get_op() { - let (_, io) = SocketIo::builder().build_svc(); + let (_, io) = SocketIo::new_svc(); io.ns("test", || {}); assert!(io.get_op("test").is_some()); assert!(io.get_op("test2").is_none()); @@ -691,7 +745,7 @@ mod tests { async fn get_socket_by_sid() { use engineioxide::Socket; let sid = Sid::new(); - let (_, io) = SocketIo::builder().build_svc(); + let (_, io) = SocketIo::new_svc(); io.ns("/", || {}); let socket = Socket::>::new_dummy(sid, Box::new(|_, _| {})); socket.data.io.set(io.clone()).unwrap(); diff --git a/crates/socketioxide/src/layer.rs b/crates/socketioxide/src/layer.rs index 0266f101..eca8b3c6 100644 --- a/crates/socketioxide/src/layer.rs +++ b/crates/socketioxide/src/layer.rs @@ -43,10 +43,12 @@ impl Clone for SocketIoLayer { impl SocketIoLayer { pub(crate) fn from_config( config: SocketIoConfig, + adapter_state: A::State, #[cfg(feature = "state")] state: state::TypeMap![Send + Sync], ) -> (Self, Arc>) { let client = Arc::new(Client::new( config, + adapter_state, #[cfg(feature = "state")] state, )); diff --git a/crates/socketioxide/src/lib.rs b/crates/socketioxide/src/lib.rs index 8810c911..fb129d13 100644 --- a/crates/socketioxide/src/lib.rs +++ b/crates/socketioxide/src/lib.rs @@ -105,7 +105,8 @@ //! } //! ``` //! ## Initialisation -//! The [`SocketIo`] struct is the main entry point of the library. It is used to create a [`Layer`](tower_layer::Layer) or a [`Service`](tower_service::Service). +//! The [`SocketIo`] struct is the main entry point of the library. It is used to create +//! a [`Layer`](tower_layer::Layer) or a [`Service`](tower_service::Service). //! Later it can be used as the equivalent of the `io` object in the JS API. //! //! When creating your [`SocketIo`] instance, you can use the builder pattern to configure it with the [`SocketIoBuilder`] struct. @@ -170,11 +171,15 @@ //! * [`SocketIo`]: extracts a reference to the [`SocketIo`] handle //! //! ### Extractor order -//! Extractors are run in the order of their declaration in the handler signature. If an extractor returns an error, the handler won't be called and a `tracing::error!` call will be emitted if the `tracing` feature is enabled. +//! Extractors are run in the order of their declaration in the handler signature. +//! If an extractor returns an error, the handler won't be called and a `tracing::error!` call +//! will be emitted if the `tracing` feature is enabled. //! -//! For the [`MessageHandler`], some extractors require to _consume_ the event and therefore only implement the [`FromMessage`](handler::FromMessage) trait. +//! For the [`MessageHandler`], some extractors require to _consume_ the event and therefore +//! only implement the [`FromMessage`](handler::FromMessage) trait. //! -//! Note that any extractors that implement the [`FromMessageParts`] also implement by default the [`FromMessage`](handler::FromMessage) trait. +//! Note that any extractors that implement the [`FromMessageParts`] also implement by default +//! the [`FromMessage`](handler::FromMessage) trait. //! //! ## Events //! There are three types of events: @@ -233,7 +238,7 @@ //! [`serde_json::Value`]: https://docs.rs/serde_json/latest/serde_json/value //! //! #### Emit errors -//! If the data can't be serialized, an [`EncodeError`] will be returned. +//! If the data can't be serialized, a [`ParserError`] will be returned. //! //! If the socket is disconnected or the internal channel is full, a [`SendError`] will be returned. //! Moreover, a tracing log will be emitted if the `tracing` feature is enabled. @@ -266,7 +271,7 @@ //! [`BroadcastOperators::emit_with_ack`]: crate::operators::BroadcastOperators#method.emit_with_ack //! [`SocketIo::emit_with_ack`]: SocketIo#method.emit_with_ack //! [`AckStream`]: crate::ack::AckStream -//! [`EncodeError`]: crate::EncodeError +//! [`ParserError`]: crate::parser::ParserError //! //! ## [State management](#state-management) //! There are two ways to manage the state of the server: @@ -283,15 +288,43 @@ //! You can enable the `state` feature and use [`SocketIoBuilder::with_state`](SocketIoBuilder) method to set //! multiple global states for the server. You can then access them from any handler with the [`State`] extractor. //! -//! The state is stored in the [`SocketIo`] handle and is shared between all the sockets. The only limitation is that all the provided state types must be clonable. +//! The state is stored in the [`SocketIo`] handle and is shared between all the sockets. The only limitation is that all the +//! provided state types must be clonable. //! Therefore it is recommended to use the [`Arc`](std::sync::Arc) type to share the state between the handlers. //! //! You can then use the [`State`] extractor to access the state in the handlers. //! //! ## Adapters -//! This library is designed to work with clustering. It uses the [`Adapter`] trait to abstract the underlying storage. -//! By default it uses the [`LocalAdapter`] which is a simple in-memory adapter. -//! Currently there is no other adapters available but more will be added in the future. +//! This library is designed to support clustering through the use of adapters. +//! Adapters enable broadcasting messages and managing socket room memberships across nodes +//! without requiring changes to your code. The [`Adapter`] trait abstracts the underlying system, +//! making it easy to integrate with different implementations. +//! +//! Adapters typically interact with third-party systems like Redis, Postgres, Kafka, etc., +//! to facilitate message exchange between nodes. +//! +//! The default adapter is the [`LocalAdapter`], a simple in-memory implementation. If you intend +//! to use a different adapter, ensure that extractors are either generic over the adapter type +//! or explicitly specify the adapter type for each extractor that requires it. +//! +//! #### Write this: +//! ``` +//! # use socketioxide::{SocketIo, adapter::Adapter, extract::SocketRef}; +//! fn my_handler(s: SocketRef, io: SocketIo) { } +//! let (layer, io) = SocketIo::new_layer(); +//! io.ns("/", my_handler); +//! ``` +//! #### Instead of that: +//! ``` +//! # use socketioxide::{SocketIo, adapter::Adapter, extract::SocketRef}; +//! fn my_handler(s: SocketRef, io: SocketIo) { } +//! let (layer, io) = SocketIo::new_layer(); +//! io.ns("/", my_handler); +//! ``` +//! +//! Refer to the [README](https://github.com/totodore/socketioxide) for a list of available adapters and +//! the [examples](https://github.com/totodore/socketioxide/tree/main/examples) for detailed usage guidance. +//! You can also consult specific adapter crate documentation for more information. //! //! ## Parsers //! This library uses the socket.io common parser which is the default for all the socket.io implementations. @@ -324,13 +357,13 @@ //! [`AckSender`]: extract::AckSender //! [`AckSender::send`]: extract::AckSender#method.send //! [`io`]: SocketIo -pub mod adapter; #[cfg_attr(docsrs, doc(cfg(feature = "extensions")))] #[cfg(feature = "extensions")] pub mod extensions; pub mod ack; +pub mod adapter; pub mod extract; pub mod handler; pub mod layer; @@ -340,11 +373,10 @@ pub mod socket; pub use engineioxide::TransportType; pub use errors::{ - AckError, AdapterError, BroadcastError, DecodeError, DisconnectError, EncodeError, - NsInsertError, SendError, SocketError, + AckError, AdapterError, BroadcastError, EmitWithAckError, NsInsertError, ParserError, + SendError, SocketError, }; pub use io::{ParserConfig, SocketIo, SocketIoBuilder, SocketIoConfig}; -pub use socketioxide_core::packet; mod client; mod errors; diff --git a/crates/socketioxide/src/ns.rs b/crates/socketioxide/src/ns.rs index f82c428c..59048ba8 100644 --- a/crates/socketioxide/src/ns.rs +++ b/crates/socketioxide/src/ns.rs @@ -1,19 +1,27 @@ use std::{ collections::HashMap, - sync::{Arc, RwLock}, + sync::{Arc, RwLock, Weak}, + time::Duration, }; use crate::{ + ack::AckInnerStream, adapter::Adapter, + client::SocketData, errors::{ConnectFail, Error}, handler::{BoxedConnectHandler, ConnectHandler, MakeErasedHandler}, - packet::{ConnectPacket, Packet, PacketData}, + parser::Parser, socket::{DisconnectReason, Socket}, - ProtocolVersion, + ProtocolVersion, SocketIoConfig, }; -use crate::{client::SocketData, errors::AdapterError}; use engineioxide::{sid::Sid, Str}; -use socketioxide_core::{parser::Parse, Value}; +use socketioxide_core::{ + adapter::{BroadcastIter, CoreLocalAdapter, RemoteSocketData, SocketEmitter}, + errors::SocketError, + packet::{ConnectPacket, Packet, PacketData}, + parser::Parse, + Uid, Value, +}; /// A [`Namespace`] constructor used for dynamic namespaces /// A namespace constructor only hold a common handler that will be cloned @@ -23,7 +31,8 @@ pub struct NamespaceCtr { } pub struct Namespace { pub path: Str, - pub(crate) adapter: A, + pub(crate) adapter: Arc, + parser: Parser, handler: BoxedConnectHandler, sockets: RwLock>>>, } @@ -39,27 +48,50 @@ impl NamespaceCtr { handler: MakeErasedHandler::new_ns_boxed(handler), } } - pub fn get_new_ns(&self, path: Str) -> Arc> { - Arc::new_cyclic(|ns| Namespace { - path, - handler: self.handler.boxed_clone(), - sockets: HashMap::new().into(), - adapter: A::new(ns.clone()), - }) + pub fn get_new_ns( + &self, + path: Str, + adapter_state: &A::State, + config: &SocketIoConfig, + ) -> Arc> { + let handler = self.handler.boxed_clone(); + Namespace::new_boxed(path, handler, adapter_state, config) } } impl Namespace { - pub fn new(path: Str, handler: C) -> Arc + pub(crate) fn new( + path: Str, + handler: C, + adapter_state: &A::State, + config: &SocketIoConfig, + ) -> Arc where C: ConnectHandler + Send + Sync + 'static, T: Send + Sync + 'static, { + let handler = MakeErasedHandler::new_ns_boxed(handler); + Self::new_boxed(path, handler, adapter_state, config) + } + + fn new_boxed( + path: Str, + handler: BoxedConnectHandler, + adapter_state: &A::State, + config: &SocketIoConfig, + ) -> Arc { + let parser = config.parser; + let ack_timeout = config.ack_timeout; + let uid = config.server_id; Arc::new_cyclic(|ns| Self { - path, - handler: MakeErasedHandler::new_ns_boxed(handler), + path: path.clone(), + handler, + parser, sockets: HashMap::new().into(), - adapter: A::new(ns.clone()), + adapter: Arc::new(A::new( + adapter_state, + CoreLocalAdapter::new(Emitter::new(ns.clone(), parser, path, ack_timeout, uid)), + )), }) } @@ -74,7 +106,8 @@ impl Namespace { esocket: Arc>>, auth: Option, ) -> Result<(), ConnectFail> { - let socket: Arc> = Socket::new(sid, self.clone(), esocket.clone()).into(); + let socket: Arc> = + Socket::new(sid, self.clone(), esocket.clone(), self.parser).into(); if let Err(e) = self.handler.call_middleware(socket.clone(), &auth).await { #[cfg(feature = "tracing")] @@ -96,7 +129,7 @@ impl Namespace { let protocol = esocket.protocol.into(); let payload = ConnectPacket { sid: socket.id }; let payload = match protocol { - ProtocolVersion::V5 => Some(socket.parser().encode_default(&payload).unwrap()), + ProtocolVersion::V5 => Some(self.parser.encode_default(&payload).unwrap()), ProtocolVersion::V4 => None, }; if let Err(_e) = socket.send(Packet::connect(self.path.clone(), payload)) { @@ -112,15 +145,13 @@ impl Namespace { Ok(()) } - /// Removes a socket from a namespace and propagate the event to the adapter - pub fn remove_socket(&self, sid: Sid) -> Result<(), AdapterError> { + /// Removes a socket from a namespace + pub fn remove_socket(&self, sid: Sid) { #[cfg(feature = "tracing")] tracing::trace!(?sid, ?self.path, "removing socket from namespace"); self.sockets.write().unwrap().remove(&sid); - self.adapter - .del_all(sid) - .map_err(|err| AdapterError(Box::new(err))) + self.adapter.get_local().del_all(sid); } pub fn has(&self, sid: Sid) -> bool { @@ -169,17 +200,13 @@ impl Namespace { } else { for s in sockets.into_values() { let _sid = s.id; - let _err = s.close(reason); - #[cfg(feature = "tracing")] - if let Err(err) = _err { - tracing::debug!(?_sid, ?err, "error closing socket"); - } + s.close(reason); } } #[cfg(feature = "tracing")] tracing::debug!(?self.path, "all sockets in namespace closed"); - let _err = self.adapter.close(); + let _err = self.adapter.close().await; #[cfg(feature = "tracing")] if let Err(err) = _err { tracing::debug!(?err, ?self.path, "could not close adapter"); @@ -187,11 +214,184 @@ impl Namespace { } } +/// A type erased emitter to discard the adapter type parameter `A`. +/// Otherwise it creates a cyclic dependency between the namespace, the emitter and the adapter. +trait InnerEmitter: Send + Sync + 'static { + /// Get the remote socket data from the socket ids. + fn get_remote_sockets(&self, sids: BroadcastIter<'_>, uid: Uid) -> Vec; + /// Get all the socket ids in the namespace. + fn get_all_sids(&self, filter: &dyn Fn(&Sid) -> bool) -> Vec; + /// Send data to the list of socket ids. + fn send_many(&self, sids: BroadcastIter<'_>, data: Value) -> Result<(), Vec>; + /// Send data to the list of socket ids and get a stream of acks. + fn send_many_with_ack( + &self, + sids: BroadcastIter<'_>, + packet: Packet, + timeout: Duration, + ) -> (AckInnerStream, u32); + /// Disconnect all the sockets in the list. + fn disconnect_many(&self, sids: Vec) -> Result<(), Vec>; +} + +impl InnerEmitter for Namespace { + fn get_remote_sockets(&self, sids: BroadcastIter<'_>, uid: Uid) -> Vec { + let sockets = self.sockets.read().unwrap(); + sids.filter_map(|sid| sockets.get(&sid)) + .map(|socket| RemoteSocketData { + id: socket.id, + ns: self.path.clone(), + server_id: uid, + }) + .collect() + } + fn get_all_sids(&self, filter: &dyn Fn(&Sid) -> bool) -> Vec { + self.sockets + .read() + .unwrap() + .keys() + .filter(|id| filter(id)) + .copied() + .collect() + } + + fn send_many(&self, sids: BroadcastIter<'_>, data: Value) -> Result<(), Vec> { + let sockets = self.sockets.read().unwrap(); + let errs: Vec = sids + .filter_map(|sid| sockets.get(&sid)) + .filter_map(|socket| socket.send_raw(data.clone()).err()) + .collect(); + if errs.is_empty() { + Ok(()) + } else { + Err(errs) + } + } + + fn send_many_with_ack( + &self, + sids: BroadcastIter<'_>, + packet: Packet, + timeout: Duration, + ) -> (AckInnerStream, u32) { + let sockets_map = self.sockets.read().unwrap(); + let sockets = sids.filter_map(|sid| sockets_map.get(&sid)); + AckInnerStream::broadcast(packet, sockets, timeout) + } + + fn disconnect_many(&self, sids: Vec) -> Result<(), Vec> { + if sids.is_empty() { + return Ok(()); + } + // Here we can't take a ref because this would cause a deadlock. + // Ideally the disconnect / closing process should be refactored to avoid this. + let sockets = { + let sock_map = self.sockets.read().unwrap(); + sids.into_iter() + .filter_map(|sid| sock_map.get(&sid)) + .cloned() + .collect::>() + }; + + let errs = sockets + .into_iter() + .filter_map(|socket| socket.disconnect().err()) + .collect::>(); + if errs.is_empty() { + Ok(()) + } else { + Err(errs) + } + } +} + +/// Internal interface implementor to apply global operations on a namespace. +#[doc(hidden)] +pub struct Emitter { + /// This `Weak` allows to break the cyclic dependency between the namespace and the emitter. + ns: Weak, + parser: Parser, + path: Str, + ack_timeout: Duration, + uid: Uid, +} + +impl Emitter { + fn new( + ns: Weak>, + parser: Parser, + path: Str, + ack_timeout: Duration, + uid: Uid, + ) -> Self { + Self { + ns, + parser, + path, + ack_timeout, + uid, + } + } +} + +impl SocketEmitter for Emitter { + type AckError = crate::AckError; + type AckStream = AckInnerStream; + + fn get_all_sids(&self, filter: impl Fn(&Sid) -> bool) -> Vec { + self.ns + .upgrade() + .map(|ns| ns.get_all_sids(&filter)) + .unwrap_or_default() + } + fn get_remote_sockets(&self, sids: BroadcastIter<'_>) -> Vec { + self.ns + .upgrade() + .map(|ns| ns.get_remote_sockets(sids, self.uid)) + .unwrap_or_default() + } + + fn send_many(&self, sids: BroadcastIter<'_>, data: Value) -> Result<(), Vec> { + match self.ns.upgrade() { + Some(ns) => ns.send_many(sids, data), + None => Ok(()), + } + } + + fn send_many_with_ack( + &self, + sids: BroadcastIter<'_>, + packet: Packet, + timeout: Option, + ) -> (Self::AckStream, u32) { + self.ns + .upgrade() + .map(|ns| ns.send_many_with_ack(sids, packet, timeout.unwrap_or(self.ack_timeout))) + .unwrap_or((AckInnerStream::empty(), 0)) + } + + fn disconnect_many(&self, sids: Vec) -> Result<(), Vec> { + match self.ns.upgrade() { + Some(ns) => ns.disconnect_many(sids), + None => Ok(()), + } + } + fn parser(&self) -> impl Parse { + self.parser + } + fn server_id(&self) -> Uid { + self.uid + } + fn path(&self) -> &Str { + &self.path + } +} + #[doc(hidden)] #[cfg(feature = "__test_harness")] -impl Namespace { +impl Namespace { pub fn new_dummy(sockets: [Sid; S]) -> Arc { - let ns = Namespace::new("/".into(), || {}); + let ns = Namespace::new("/".into(), || {}, &(), &SocketIoConfig::default()); for sid in sockets { ns.sockets .write() @@ -206,11 +406,10 @@ impl Namespace { } } -impl std::fmt::Debug for Namespace { +impl std::fmt::Debug for Namespace { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Namespace") .field("path", &self.path) - .field("adapter", &self.adapter) .field("sockets", &self.sockets) .finish() } diff --git a/crates/socketioxide/src/operators.rs b/crates/socketioxide/src/operators.rs index 53eedc1b..cb041671 100644 --- a/crates/socketioxide/src/operators.rs +++ b/crates/socketioxide/src/operators.rs @@ -6,103 +6,26 @@ //! There are two types of operators: //! * [`ConfOperators`]: Chainable operators to configure the message to be sent. //! * [`BroadcastOperators`]: Chainable operators to select sockets to send a message to and to configure the message to be sent. -use std::borrow::Cow; -use std::{sync::Arc, time::Duration}; +use std::{future::Future, sync::Arc, time::Duration}; use engineioxide::sid::Sid; use serde::Serialize; -use socketioxide_core::parser::Parse; - -use crate::ack::{AckInnerStream, AckStream}; -use crate::adapter::LocalAdapter; -use crate::errors::{BroadcastError, DisconnectError}; -use crate::extract::SocketRef; -use crate::parser::{self, Parser}; -use crate::socket::Socket; -use crate::SendError; + use crate::{ - adapter::{Adapter, BroadcastFlags, BroadcastOptions, Room}, + ack::{AckInnerStream, AckStream}, + adapter::{Adapter, LocalAdapter}, + extract::SocketRef, ns::Namespace, - packet::Packet, + parser::Parser, + socket::{RemoteSocket, Socket}, + BroadcastError, EmitWithAckError, SendError, }; -/// A trait for types that can be used as a room parameter. -/// -/// [`String`], [`Vec`], [`Vec<&str>`], [`&'static str`](str) and const arrays are implemented by default. -pub trait RoomParam: 'static { - /// The type of the iterator returned by `into_room_iter`. - type IntoIter: Iterator; - - /// Convert `self` into an iterator of rooms. - fn into_room_iter(self) -> Self::IntoIter; -} - -impl RoomParam for Room { - type IntoIter = std::iter::Once; - #[inline(always)] - fn into_room_iter(self) -> Self::IntoIter { - std::iter::once(self) - } -} -impl RoomParam for String { - type IntoIter = std::iter::Once; - #[inline(always)] - fn into_room_iter(self) -> Self::IntoIter { - std::iter::once(Cow::Owned(self)) - } -} -impl RoomParam for Vec { - type IntoIter = std::iter::Map, fn(String) -> Room>; - #[inline(always)] - fn into_room_iter(self) -> Self::IntoIter { - self.into_iter().map(Cow::Owned) - } -} -impl RoomParam for Vec<&'static str> { - type IntoIter = std::iter::Map, fn(&'static str) -> Room>; - #[inline(always)] - fn into_room_iter(self) -> Self::IntoIter { - self.into_iter().map(Cow::Borrowed) - } -} - -impl RoomParam for Vec { - type IntoIter = std::vec::IntoIter; - #[inline(always)] - fn into_room_iter(self) -> Self::IntoIter { - self.into_iter() - } -} -impl RoomParam for &'static str { - type IntoIter = std::iter::Once; - #[inline(always)] - fn into_room_iter(self) -> Self::IntoIter { - std::iter::once(Cow::Borrowed(self)) - } -} -impl RoomParam for [&'static str; COUNT] { - type IntoIter = - std::iter::Map, fn(&'static str) -> Room>; - - #[inline(always)] - fn into_room_iter(self) -> Self::IntoIter { - self.into_iter().map(Cow::Borrowed) - } -} -impl RoomParam for [String; COUNT] { - type IntoIter = std::iter::Map, fn(String) -> Room>; - #[inline(always)] - fn into_room_iter(self) -> Self::IntoIter { - self.into_iter().map(Cow::Owned) - } -} -impl RoomParam for Sid { - type IntoIter = std::iter::Once; - #[inline(always)] - fn into_room_iter(self) -> Self::IntoIter { - std::iter::once(Cow::Owned(self.to_string())) - } -} +use socketioxide_core::{ + adapter::{BroadcastFlags, BroadcastOptions, Room, RoomParam}, + packet::Packet, + parser::{Parse, ParserError}, +}; /// Chainable operators to configure the message to be sent. pub struct ConfOperators<'a, A: Adapter = LocalAdapter> { @@ -119,14 +42,11 @@ pub struct BroadcastOperators { impl From> for BroadcastOperators { fn from(conf: ConfOperators<'_, A>) -> Self { - let opts = BroadcastOptions { - sid: Some(conf.socket.id), - ..Default::default() - }; + let opts = BroadcastOptions::new(conf.socket.id); Self { timeout: conf.timeout, ns: conf.socket.ns.clone(), - parser: conf.socket.parser(), + parser: conf.socket.parser, opts, } } @@ -181,8 +101,8 @@ impl ConfOperators<'_, A> { event: impl AsRef, data: &T, ) -> Result<(), SendError> { - use crate::errors::SocketError; use crate::socket::PermitExt; + use crate::SocketError; if !self.socket.connected() { return Err(SendError::Socket(SocketError::Closed)); } @@ -195,7 +115,7 @@ impl ConfOperators<'_, A> { } }; let packet = self.get_packet(event, data)?; - permit.send(packet, self.socket.parser()); + permit.send(packet, self.socket.parser); Ok(()) } @@ -206,7 +126,7 @@ impl ConfOperators<'_, A> { event: impl AsRef, data: &T, ) -> Result, SendError> { - use crate::errors::SocketError; + use crate::SocketError; if !self.socket.connected() { return Err(SendError::Socket(SocketError::Closed)); } @@ -224,33 +144,28 @@ impl ConfOperators<'_, A> { let packet = self.get_packet(event, data)?; let rx = self.socket.send_with_ack_permit(packet, permit); let stream = AckInnerStream::send(rx, timeout, self.socket.id); - Ok(AckStream::::new(stream, self.socket.parser())) + Ok(AckStream::::new(stream, self.socket.parser)) } #[doc = include_str!("../docs/operators/join.md")] - pub fn join(self, rooms: impl RoomParam) -> Result<(), A::Error> { + pub fn join(self, rooms: impl RoomParam) { self.socket.join(rooms) } #[doc = include_str!("../docs/operators/leave.md")] - pub fn leave(self, rooms: impl RoomParam) -> Result<(), A::Error> { + pub async fn leave(self, rooms: impl RoomParam) { self.socket.leave(rooms) } - /// Gets all room names for a given namespace - pub fn rooms(self) -> Result, A::Error> { - self.socket.rooms() - } - /// Creates a packet with the given event and data. fn get_packet( &mut self, event: impl AsRef, data: &T, - ) -> Result { + ) -> Result { let ns = self.socket.ns.path.clone(); let event = event.as_ref(); - let data = self.socket.parser().encode_value(&data, Some(event))?; + let data = self.socket.parser.encode_value(&data, Some(event))?; Ok(Packet::event(ns, data)) } } @@ -269,10 +184,7 @@ impl BroadcastOperators { timeout: None, ns, parser, - opts: BroadcastOptions { - sid: Some(sid), - ..Default::default() - }, + opts: BroadcastOptions::new(sid), } } @@ -296,13 +208,13 @@ impl BroadcastOperators { #[doc = include_str!("../docs/operators/local.md")] pub fn local(mut self) -> Self { - self.opts.flags.insert(BroadcastFlags::Local); + self.opts.add_flag(BroadcastFlags::Local); self } #[doc = include_str!("../docs/operators/broadcast.md")] pub fn broadcast(mut self) -> Self { - self.opts.flags.insert(BroadcastFlags::Broadcast); + self.opts.add_flag(BroadcastFlags::Broadcast); self } @@ -320,14 +232,20 @@ impl BroadcastOperators { mut self, event: impl AsRef, data: &T, - ) -> Result<(), BroadcastError> { - let packet = self.get_packet(event, data)?; - if let Err(e) = self.ns.adapter.broadcast(packet, self.opts) { - #[cfg(feature = "tracing")] - tracing::debug!("broadcast error: {e:?}"); - return Err(e); + ) -> impl Future> + Send { + let packet = self.get_packet(event, data); + async move { + self.ns + .adapter + .broadcast(packet?, self.opts) + .await + .map_err(|e| { + #[cfg(feature = "tracing")] + tracing::debug!("broadcast error: {e}"); + e + })?; + Ok(()) } - Ok(()) } #[doc = include_str!("../docs/operators/emit_with_ack.md")] @@ -335,38 +253,62 @@ impl BroadcastOperators { mut self, event: impl AsRef, data: &T, - ) -> Result, parser::EncodeError> { - let packet = self.get_packet(event, data)?; - let stream = self - .ns - .adapter - .broadcast_with_ack(packet, self.opts, self.timeout); - Ok(AckStream::new(stream, self.parser)) + ) -> impl Future, EmitWithAckError>> + Send { + let packet = self.get_packet(event, data); + async move { + let stream = self + .ns + .adapter + .broadcast_with_ack(packet?, self.opts, self.timeout) + .await + .map_err(|e| EmitWithAckError::Adapter(Box::new(e)))?; + Ok(AckStream::new(stream, self.parser)) + } } #[doc = include_str!("../docs/operators/sockets.md")] - pub fn sockets(self) -> Result>, A::Error> { - self.ns.adapter.fetch_sockets(self.opts) + pub fn sockets(self) -> Vec> { + let ids = self.ns.adapter.get_local().sockets(self.opts); + + ids.into_iter() + .filter_map(|id| self.ns.get_socket(id).ok()) + .map(SocketRef::from) + .collect() + } + + #[doc = include_str!("../docs/operators/fetch_sockets.md")] + pub async fn fetch_sockets(self) -> Result>, A::Error> { + let sockets = self + .ns + .adapter + .fetch_sockets(self.opts) + .await? + .into_iter() + .map(|data| RemoteSocket::new(data, &self.ns.adapter, self.parser)) + .collect(); + Ok(sockets) } #[doc = include_str!("../docs/operators/disconnect.md")] - pub fn disconnect(self) -> Result<(), Vec> { - self.ns.adapter.disconnect_socket(self.opts) + pub async fn disconnect(self) -> Result<(), BroadcastError> { + self.ns.adapter.disconnect_socket(self.opts).await } #[doc = include_str!("../docs/operators/join.md")] - pub fn join(self, rooms: impl RoomParam) -> Result<(), A::Error> { - self.ns.adapter.add_sockets(self.opts, rooms) + #[allow(clippy::manual_async_fn)] // related to issue: https://github.com/rust-lang/rust-clippy/issues/12664 + pub fn join(self, rooms: impl RoomParam) -> impl Future> + Send { + async move { self.ns.adapter.add_sockets(self.opts, rooms).await } } #[doc = include_str!("../docs/operators/leave.md")] - pub fn leave(self, rooms: impl RoomParam) -> Result<(), A::Error> { - self.ns.adapter.del_sockets(self.opts, rooms) + #[allow(clippy::manual_async_fn)] // related to issue: https://github.com/rust-lang/rust-clippy/issues/12664 + pub fn leave(self, rooms: impl RoomParam) -> impl Future> + Send { + async move { self.ns.adapter.del_sockets(self.opts, rooms).await } } #[doc = include_str!("../docs/operators/rooms.md")] - pub fn rooms(self) -> Result, A::Error> { - self.ns.adapter.rooms() + pub async fn rooms(self) -> Result, A::Error> { + self.ns.adapter.rooms(self.opts).await } #[doc = include_str!("../docs/operators/get_socket.md")] @@ -379,7 +321,7 @@ impl BroadcastOperators { &mut self, event: impl AsRef, data: &T, - ) -> Result { + ) -> Result { let ns = self.ns.path.clone(); let data = self.parser.encode_value(data, Some(event.as_ref()))?; Ok(Packet::event(ns, data)) diff --git a/crates/socketioxide/src/parser.rs b/crates/socketioxide/src/parser.rs index 713f8157..d6c036c0 100644 --- a/crates/socketioxide/src/parser.rs +++ b/crates/socketioxide/src/parser.rs @@ -16,6 +16,9 @@ use socketioxide_parser_common::CommonParser; #[cfg(feature = "msgpack")] use socketioxide_parser_msgpack::MsgPackParser; +pub(crate) use socketioxide_core::parser::ParseError; +pub use socketioxide_core::parser::ParserError; + /// All the parser available. /// It also implements the [`Parse`] trait and therefore the /// parser implementation is done over enum delegation. @@ -35,41 +38,7 @@ impl Default for Parser { } } -/// Errors that can occur during value encoding -#[derive(Debug, thiserror::Error)] -#[non_exhaustive] -pub enum EncodeError { - /// Common parser error - #[error("common parser: {0}")] - Common(::EncodeError), - /// MsgPack parser error - #[cfg_attr(docsrs, doc(cfg(feature = "msgpack")))] - #[cfg(feature = "msgpack")] - #[error("msgpack parser: {0}")] - MsgPack(::EncodeError), -} - -/// Errors that can occur during packet decoding or value decoding -#[derive(Debug, thiserror::Error)] -#[non_exhaustive] -pub enum DecodeError { - /// Common parser error - #[error("common parser: {0}")] - Common(::DecodeError), - /// MsgPack parser error - #[cfg_attr(docsrs, doc(cfg(feature = "msgpack")))] - #[cfg(feature = "msgpack")] - #[error("msgpack parser: {0}")] - MsgPack(::DecodeError), -} - -/// Parse errors occurring during packet parsing -pub(crate) type ParseError = socketioxide_core::parser::ParseError; - impl Parse for Parser { - type EncodeError = EncodeError; - type DecodeError = DecodeError; - fn encode(self, packet: Packet) -> Value { let value = match self { Parser::Common(p) => p.encode(packet), @@ -90,13 +59,9 @@ impl Parse for Parser { tracing::trace!(?state, "decoding bin payload: {:X}", bin); let packet = match self { - Parser::Common(p) => p - .decode_bin(state, bin) - .map_err(|e| e.wrap_err(DecodeError::Common)), + Parser::Common(p) => p.decode_bin(state, bin), #[cfg(feature = "msgpack")] - Parser::MsgPack(p) => p - .decode_bin(state, bin) - .map_err(|e| e.wrap_err(DecodeError::MsgPack)), + Parser::MsgPack(p) => p.decode_bin(state, bin), }?; #[cfg(feature = "tracing")] @@ -108,13 +73,9 @@ impl Parse for Parser { tracing::trace!(?data, ?state, "decoding str payload:"); let packet = match self { - Parser::Common(p) => p - .decode_str(state, data) - .map_err(|e| e.wrap_err(DecodeError::Common)), + Parser::Common(p) => p.decode_str(state, data), #[cfg(feature = "msgpack")] - Parser::MsgPack(p) => p - .decode_str(state, data) - .map_err(|e| e.wrap_err(DecodeError::MsgPack)), + Parser::MsgPack(p) => p.decode_str(state, data), }?; #[cfg(feature = "tracing")] @@ -126,11 +87,11 @@ impl Parse for Parser { self, data: &T, event: Option<&str>, - ) -> Result { + ) -> Result { let value = match self { - Parser::Common(p) => p.encode_value(data, event).map_err(EncodeError::Common), + Parser::Common(p) => p.encode_value(data, event), #[cfg(feature = "msgpack")] - Parser::MsgPack(p) => p.encode_value(data, event).map_err(EncodeError::MsgPack), + Parser::MsgPack(p) => p.encode_value(data, event), }; #[cfg(feature = "tracing")] tracing::trace!(?value, "value encoded:"); @@ -141,44 +102,40 @@ impl Parse for Parser { self, value: &'de mut Value, with_event: bool, - ) -> Result { + ) -> Result { #[cfg(feature = "tracing")] tracing::trace!(?value, "decoding value:"); match self { - Parser::Common(p) => p - .decode_value(value, with_event) - .map_err(DecodeError::Common), + Parser::Common(p) => p.decode_value(value, with_event), #[cfg(feature = "msgpack")] - Parser::MsgPack(p) => p - .decode_value(value, with_event) - .map_err(DecodeError::MsgPack), + Parser::MsgPack(p) => p.decode_value(value, with_event), } } fn decode_default<'de, T: Deserialize<'de>>( self, value: Option<&'de Value>, - ) -> Result { + ) -> Result { match self { - Parser::Common(p) => p.decode_default(value).map_err(DecodeError::Common), + Parser::Common(p) => p.decode_default(value), #[cfg(feature = "msgpack")] - Parser::MsgPack(p) => p.decode_default(value).map_err(DecodeError::MsgPack), + Parser::MsgPack(p) => p.decode_default(value), } } - fn encode_default(self, data: &T) -> Result { + fn encode_default(self, data: &T) -> Result { match self { - Parser::Common(p) => p.encode_default(data).map_err(EncodeError::Common), + Parser::Common(p) => p.encode_default(data), #[cfg(feature = "msgpack")] - Parser::MsgPack(p) => p.encode_default(data).map_err(EncodeError::MsgPack), + Parser::MsgPack(p) => p.encode_default(data), } } - fn read_event(self, value: &Value) -> Result<&str, Self::DecodeError> { + fn read_event(self, value: &Value) -> Result<&str, ParserError> { match self { - Parser::Common(p) => p.read_event(value).map_err(DecodeError::Common), + Parser::Common(p) => p.read_event(value), #[cfg(feature = "msgpack")] - Parser::MsgPack(p) => p.read_event(value).map_err(DecodeError::MsgPack), + Parser::MsgPack(p) => p.read_event(value), } } } diff --git a/crates/socketioxide/src/service.rs b/crates/socketioxide/src/service.rs index abd30a73..b9a03577 100644 --- a/crates/socketioxide/src/service.rs +++ b/crates/socketioxide/src/service.rs @@ -120,11 +120,13 @@ impl SocketIoService { pub(crate) fn with_config_inner( inner: S, config: SocketIoConfig, + adapter_state: A::State, #[cfg(feature = "state")] state: state::TypeMap![Send + Sync], ) -> (Self, Arc>) { let engine_config = config.engine_config.clone(); let client = Arc::new(Client::new( config, + adapter_state, #[cfg(feature = "state")] state, )); diff --git a/crates/socketioxide/src/socket.rs b/crates/socketioxide/src/socket.rs index 9a74e6ba..8f317070 100644 --- a/crates/socketioxide/src/socket.rs +++ b/crates/socketioxide/src/socket.rs @@ -3,7 +3,7 @@ use std::{ borrow::Cow, collections::HashMap, - fmt::Debug, + fmt::{self, Debug}, sync::{ atomic::{AtomicBool, AtomicI64, Ordering}, Arc, Mutex, RwLock, @@ -13,29 +13,31 @@ use std::{ use engineioxide::socket::{DisconnectReason as EIoDisconnectReason, Permit}; use serde::Serialize; -use tokio::sync::oneshot::{self, Receiver}; +use tokio::sync::{ + mpsc::error::TrySendError, + oneshot::{self, Receiver}, +}; #[cfg(feature = "extensions")] use crate::extensions::Extensions; use crate::{ ack::{AckInnerStream, AckResult, AckStream}, - adapter::{Adapter, LocalAdapter, Room}, - errors::{DisconnectError, Error, SendError}, + adapter::{Adapter, LocalAdapter}, + client::SocketData, + errors::Error, handler::{ BoxedDisconnectHandler, BoxedMessageHandler, DisconnectHandler, MakeErasedHandler, MessageHandler, }, ns::Namespace, - operators::{BroadcastOperators, ConfOperators, RoomParam}, + operators::{BroadcastOperators, ConfOperators}, parser::Parser, - AckError, SocketIo, -}; -use crate::{ - client::SocketData, - errors::{AdapterError, SocketError}, + AckError, SendError, SocketError, SocketIo, }; use socketioxide_core::{ + adapter::{BroadcastOptions, RemoteSocketData, Room, RoomParam}, + errors::{AdapterError, BroadcastError}, packet::{Packet, PacketData}, parser::Parse, Value, @@ -126,6 +128,144 @@ impl<'a> PermitExt<'a> for Permit<'a> { } } +/// A RemoteSocket is a [`Socket`] that is remotely connected on another server. +/// It implements a subset of the [`Socket`] API. +#[derive(Clone)] +pub struct RemoteSocket { + adapter: Arc, + parser: Parser, + data: RemoteSocketData, +} + +impl RemoteSocket { + pub(crate) fn new(data: RemoteSocketData, adapter: &Arc, parser: Parser) -> Self { + Self { + data, + adapter: adapter.clone(), + parser, + } + } + /// Consume the [`RemoteSocket`] and return its underlying data + #[inline] + pub fn into_data(self) -> RemoteSocketData { + self.data + } + /// Get a ref to the underlying data of the socket + #[inline] + pub fn data(&self) -> &RemoteSocketData { + &self.data + } +} +impl RemoteSocket { + /// # Emit a message to a client that is remotely connected on another server. + /// + /// See [`Socket::emit`] for more info. + pub async fn emit( + &self, + event: impl AsRef, + data: &T, + ) -> Result<(), RemoteActionError> { + let opts = self.get_opts(); + let data = self.parser.encode_value(data, Some(event.as_ref()))?; + let packet = Packet::event(self.data.ns.clone(), data); + self.adapter.broadcast(packet, opts).await?; + Ok(()) + } + + /// # Emit a message to a client that is remotely connected on another server and wait for an acknowledgement. + /// + /// See [`Socket::emit_with_ack`] for more info. + pub async fn emit_with_ack( + &self, + event: impl AsRef, + data: &T, + ) -> Result, RemoteActionError> { + let opts = self.get_opts(); + let data = self.parser.encode_value(data, Some(event.as_ref()))?; + let packet = Packet::event(self.data.ns.clone(), data); + let stream = self + .adapter + .broadcast_with_ack(packet, opts, None) + .await + .map_err(Into::::into)?; + Ok(AckStream::new(stream, self.parser)) + } + + /// # Get all room names this remote socket is connected to. + /// + /// See [`Socket::rooms`] for more info. + #[inline] + pub async fn rooms(&self) -> Result, A::Error> { + self.adapter.rooms(self.get_opts()).await + } + + /// # Add the remote socket to the specified room(s). + /// + /// See [`Socket::join`] for more info. + #[inline] + pub async fn join(&self, rooms: impl RoomParam) -> Result<(), A::Error> { + self.adapter.add_sockets(self.get_opts(), rooms).await + } + + /// # Remove the remote socket from the specified room(s). + /// + /// See [`Socket::leave`] for more info. + #[inline] + pub async fn leave(&self, rooms: impl RoomParam) -> Result<(), A::Error> { + self.adapter.del_sockets(self.get_opts(), rooms).await + } + + /// # Disconnect the remote socket from the current namespace, + /// + /// See [`Socket::disconnect`] for more info. + #[inline] + pub async fn disconnect(self) -> Result<(), RemoteActionError> { + self.adapter.disconnect_socket(self.get_opts()).await?; + Ok(()) + } + + #[inline(always)] + fn get_opts(&self) -> BroadcastOptions { + BroadcastOptions::new_remote(&self.data) + } +} +impl fmt::Debug for RemoteSocket { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RemoteSocket") + .field("id", &self.data.id) + .field("server_id", &self.data.server_id) + .field("ns", &self.data.ns) + .finish() + } +} + +/// A error that can occur when emitting a message to a remote socket. +#[derive(Debug, thiserror::Error)] +pub enum RemoteActionError { + /// The message data could not be encoded. + #[error("cannot encode data: {0}")] + Serialize(#[from] crate::parser::ParserError), + /// The remote socket is, in fact, a local socket and we should not emit to it. + #[error("cannot send the message to the local socket: {0}")] + Socket(crate::SocketError), + /// The message could not be sent to the remote server. + #[error("cannot propagate the request to the server: {0}")] + Adapter(#[from] AdapterError), +} +impl From for RemoteActionError { + fn from(value: BroadcastError) -> Self { + // This conversion assumes that we broadcast to a single (remote or not) socket. + match value { + BroadcastError::Socket(s) if !s.is_empty() => RemoteActionError::Socket(s[0].clone()), + BroadcastError::Socket(_) => { + panic!("BroadcastError with an empty socket vec is not permitted") + } + BroadcastError::Adapter(e) => e.into(), + BroadcastError::Serialize(e) => e.into(), + } + } +} + /// A Socket represents a client connected to a namespace. /// It is used to send and receive messages from the client, join and leave rooms, etc. /// The socket struct itself should not be used directly, but through a [`SocketRef`](crate::extract::SocketRef). @@ -136,6 +276,7 @@ pub struct Socket { ack_message: Mutex>>>, ack_counter: AtomicI64, connected: AtomicBool, + pub(crate) parser: Parser, /// The socket id pub id: Sid, @@ -155,6 +296,7 @@ impl Socket { sid: Sid, ns: Arc>, esocket: Arc>>, + parser: Parser, ) -> Self { Self { ns, @@ -163,6 +305,7 @@ impl Socket { ack_message: Mutex::new(HashMap::new()), ack_counter: AtomicI64::new(0), connected: AtomicBool::new(false), + parser, id: sid, #[cfg(feature = "extensions")] extensions: Extensions::new(), @@ -304,9 +447,9 @@ impl Socket { }; let ns = self.ns.path.clone(); - let data = self.parser().encode_value(data, Some(event.as_ref()))?; + let data = self.parser.encode_value(data, Some(event.as_ref()))?; - permit.send(Packet::event(ns, data), self.parser()); + permit.send(Packet::event(ns, data), self.parser); Ok(()) } @@ -328,37 +471,78 @@ impl Socket { } }; let ns = self.ns.path.clone(); - let data = self.parser().encode_value(data, Some(event.as_ref()))?; + let data = self.parser.encode_value(data, Some(event.as_ref()))?; let packet = Packet::event(ns, data); let rx = self.send_with_ack_permit(packet, permit); let stream = AckInnerStream::send(rx, self.get_io().config().ack_timeout, self.id); - Ok(AckStream::::new(stream, self.parser())) + Ok(AckStream::::new(stream, self.parser)) } // Room actions - #[doc = include_str!("../docs/operators/join.md")] - pub fn join(&self, rooms: impl RoomParam) -> Result<(), A::Error> { - self.ns.adapter.add_all(self.id, rooms) + /// # Add the current socket to the specified room(s). + /// + /// # Example + /// ```rust + /// # use socketioxide::{SocketIo, extract::*}; + /// async fn handler(socket: SocketRef) { + /// // Add all sockets that are in room1 and room3 to room4 and room5 + /// socket.join(["room4", "room5"]); + /// // We should retrieve all the local sockets that are in room3 and room5 + /// let sockets = socket.within("room4").within("room5").sockets(); + /// } + /// + /// let (_, io) = SocketIo::new_svc(); + /// io.ns("/", |s: SocketRef| s.on("test", handler)); + /// ``` + pub fn join(&self, rooms: impl RoomParam) { + self.ns.adapter.get_local().add_all(self.id, rooms) } - #[doc = include_str!("../docs/operators/leave.md")] - pub fn leave(&self, rooms: impl RoomParam) -> Result<(), A::Error> { - self.ns.adapter.del(self.id, rooms) + /// # Remove the current socket from the specified room(s). + /// + /// # Example + /// ```rust + /// # use socketioxide::{SocketIo, extract::*}; + /// async fn handler(socket: SocketRef) { + /// // Remove all sockets that are in room1 and room3 from room4 and room5 + /// socket.within("room1").within("room3").leave(["room4", "room5"]); + /// } + /// + /// let (_, io) = SocketIo::new_svc(); + /// io.ns("/", |s: SocketRef| s.on("test", handler)); + /// ``` + pub fn leave(&self, rooms: impl RoomParam) { + self.ns.adapter.get_local().del(self.id, rooms) } - /// # Leave all rooms where the socket is connected. - /// - /// ## Errors - /// When using a distributed adapter, it can return an [`Adapter::Error`] which is mostly related to network errors. - /// For the default [`LocalAdapter`] it is always an [`Infallible`](std::convert::Infallible) error - pub fn leave_all(&self) -> Result<(), A::Error> { - self.ns.adapter.del_all(self.id) + /// # Remove the current socket from all its rooms. + pub fn leave_all(&self) { + self.ns.adapter.get_local().del_all(self.id); } - #[doc = include_str!("../docs/operators/rooms.md")] - pub fn rooms(&self) -> Result, A::Error> { - self.ns.adapter.socket_rooms(self.id) + /// # Get all room names this socket is connected to. + /// + /// # Example + /// ```rust + /// # use socketioxide::{SocketIo, extract::SocketRef}; + /// async fn handler(socket: SocketRef) { + /// println!("Socket connected to the / namespace with id: {}", socket.id); + /// socket.join(["room1", "room2"]); + /// let rooms = socket.rooms(); + /// println!("All rooms in the / namespace: {:?}", rooms); + /// } + /// + /// let (_, io) = SocketIo::new_svc(); + /// io.ns("/", handler); + /// ``` + pub fn rooms(&self) -> Vec { + self.ns + .adapter + .get_local() + .socket_rooms(self.id) + .into_iter() + .collect() } /// # Return true if the socket is connected to the namespace. @@ -373,22 +557,22 @@ impl Socket { #[doc = include_str!("../docs/operators/to.md")] pub fn to(&self, rooms: impl RoomParam) -> BroadcastOperators { - BroadcastOperators::from_sock(self.ns.clone(), self.id, self.parser()).to(rooms) + BroadcastOperators::from_sock(self.ns.clone(), self.id, self.parser).to(rooms) } #[doc = include_str!("../docs/operators/within.md")] pub fn within(&self, rooms: impl RoomParam) -> BroadcastOperators { - BroadcastOperators::from_sock(self.ns.clone(), self.id, self.parser()).within(rooms) + BroadcastOperators::from_sock(self.ns.clone(), self.id, self.parser).within(rooms) } #[doc = include_str!("../docs/operators/except.md")] pub fn except(&self, rooms: impl RoomParam) -> BroadcastOperators { - BroadcastOperators::from_sock(self.ns.clone(), self.id, self.parser()).except(rooms) + BroadcastOperators::from_sock(self.ns.clone(), self.id, self.parser).except(rooms) } #[doc = include_str!("../docs/operators/local.md")] pub fn local(&self) -> BroadcastOperators { - BroadcastOperators::from_sock(self.ns.clone(), self.id, self.parser()).local() + BroadcastOperators::from_sock(self.ns.clone(), self.id, self.parser).local() } #[doc = include_str!("../docs/operators/timeout.md")] @@ -398,7 +582,7 @@ impl Socket { #[doc = include_str!("../docs/operators/broadcast.md")] pub fn broadcast(&self) -> BroadcastOperators { - BroadcastOperators::from_sock(self.ns.clone(), self.id, self.parser()).broadcast() + BroadcastOperators::from_sock(self.ns.clone(), self.id, self.parser).broadcast() } /// # Get the [`SocketIo`] context related to this socket @@ -413,13 +597,12 @@ impl Socket { /// # Disconnect the socket from the current namespace, /// /// It will also call the disconnect handler if it is set with a [`DisconnectReason::ServerNSDisconnect`]. - pub fn disconnect(self: Arc) -> Result<(), DisconnectError> { + pub fn disconnect(self: Arc) -> Result<(), SocketError> { let res = self.send(Packet::disconnect(self.ns.path.clone())); if let Err(SocketError::InternalChannelFull) = res { - return Err(DisconnectError::InternalChannelFull); + return Err(SocketError::InternalChannelFull); } - - self.close(DisconnectReason::ServerNSDisconnect)?; + self.close(DisconnectReason::ServerNSDisconnect); Ok(()) } @@ -482,18 +665,17 @@ impl Socket { self.connected.store(connected, Ordering::SeqCst); } - #[inline] - pub(crate) fn parser(&self) -> Parser { - self.get_io().config().parser - } - pub(crate) fn reserve(&self) -> Result, SocketError> { - Ok(self.esocket.reserve()?) + match self.esocket.reserve() { + Ok(permit) => Ok(permit), + Err(TrySendError::Full(_)) => Err(SocketError::InternalChannelFull), + Err(TrySendError::Closed(_)) => Err(SocketError::Closed), + } } pub(crate) fn send(&self, packet: Packet) -> Result<(), SocketError> { let permit = self.reserve()?; - permit.send(packet, self.parser()); + permit.send(packet, self.parser); Ok(()) } pub(crate) fn send_raw(&self, value: Value) -> Result<(), SocketError> { @@ -511,7 +693,7 @@ impl Socket { let ack = self.ack_counter.fetch_add(1, Ordering::SeqCst) + 1; packet.inner.set_ack_id(ack); - permit.send(packet, self.parser()); + permit.send(packet, self.parser); self.ack_message.lock().unwrap().insert(ack, tx); rx } @@ -535,7 +717,7 @@ impl Socket { /// Called when the socket is gracefully disconnected from the server or the client /// /// It maybe also close when the underlying transport is closed or failed. - pub(crate) fn close(self: Arc, reason: DisconnectReason) -> Result<(), AdapterError> { + pub(crate) fn close(self: Arc, reason: DisconnectReason) { self.set_connected(false); let handler = { self.disconnect_handler.lock().unwrap().take() }; @@ -546,8 +728,7 @@ impl Socket { handler.call(self.clone(), reason); } - self.ns.remove_socket(self.id)?; - Ok(()) + self.ns.remove_socket(self.id); } /// Receive data from client @@ -555,15 +736,16 @@ impl Socket { match packet { PacketData::Event(d, ack) | PacketData::BinaryEvent(d, ack) => self.recv_event(d, ack), PacketData::EventAck(d, ack) | PacketData::BinaryAck(d, ack) => self.recv_ack(d, ack), - PacketData::Disconnect => self - .close(DisconnectReason::ClientNSDisconnect) - .map_err(Error::from), + PacketData::Disconnect => { + self.close(DisconnectReason::ClientNSDisconnect); + Ok(()) + } _ => unreachable!(), } } fn recv_event(self: Arc, data: Value, ack: Option) -> Result<(), Error> { - let event = self.parser().read_event(&data).map_err(|_e| { + let event = self.parser.read_event(&data).map_err(|_e| { #[cfg(feature = "tracing")] tracing::debug!(?_e, "failed to read event"); Error::InvalidEventName @@ -602,20 +784,26 @@ impl PartialEq for Socket { #[doc(hidden)] #[cfg(feature = "__test_harness")] -impl Socket { +impl Socket { /// Creates a dummy socket for testing purposes - pub fn new_dummy(sid: Sid, ns: Arc>) -> Socket { + pub fn new_dummy(sid: Sid, ns: Arc>) -> Socket { use crate::client::Client; use crate::io::SocketIoConfig; let close_fn = Box::new(move |_, _| ()); let config = SocketIoConfig::default(); - let io = SocketIo::from(Arc::new(Client::::new( + let io = SocketIo::from(Arc::new(Client::new( config, + (), #[cfg(feature = "state")] std::default::Default::default(), ))); - let s = Socket::new(sid, ns, engineioxide::Socket::new_dummy(sid, close_fn)); + let s = Socket::new( + sid, + ns, + engineioxide::Socket::new_dummy(sid, close_fn), + Parser::default(), + ); s.esocket.data.io.set(io).unwrap(); s.set_connected(true); s diff --git a/crates/socketioxide/tests/acknowledgement.rs b/crates/socketioxide/tests/acknowledgement.rs index 928b32e6..1707239c 100644 --- a/crates/socketioxide/tests/acknowledgement.rs +++ b/crates/socketioxide/tests/acknowledgement.rs @@ -4,8 +4,8 @@ mod utils; use engineioxide::Packet::*; use futures_util::StreamExt; use socketioxide::extract::SocketRef; -use socketioxide::packet::PacketData; use socketioxide::SocketIo; +use socketioxide_core::packet::PacketData; use socketioxide_core::parser::Parse; use socketioxide_parser_common::CommonParser; use tokio::sync::mpsc; @@ -53,25 +53,25 @@ pub async fn broadcast_with_ack() { let (tx, mut rx) = mpsc::channel::<[String; 1]>(100); io.ns("/", move |socket: SocketRef, io: SocketIo| async move { - let res = io.emit_with_ack::<_, [String; 1]>("test", "foo"); - let sockets = io.sockets().unwrap(); + let res = io.emit_with_ack::<_, [String; 1]>("test", "foo").await; let res = assert_ok!(res); res.for_each(|(id, res)| { let ack = assert_ok!(res); assert_ok!(tx.try_send(ack)); - assert_some!(sockets.iter().find(|s| s.id == id)); + assert_some!(io.sockets().iter().find(|s| s.id == id)); async move {} }) .await; let res = io .timeout(Duration::from_millis(500)) - .emit_with_ack::<_, [String; 1]>("test", "foo"); + .emit_with_ack::<_, [String; 1]>("test", "foo") + .await; let res = assert_ok!(res); res.for_each(|(id, res)| { let ack = assert_ok!(res); assert_ok!(tx.try_send(ack)); - assert_some!(sockets.iter().find(|s| s.id == id)); + assert_some!(io.sockets().iter().find(|s| s.id == id)); async move {} }) .await; @@ -79,12 +79,13 @@ pub async fn broadcast_with_ack() { let res = socket .broadcast() .timeout(Duration::from_millis(500)) - .emit_with_ack::<_, [String; 1]>("test", "foo"); + .emit_with_ack::<_, [String; 1]>("test", "foo") + .await; let res = assert_ok!(res); res.for_each(|(id, res)| { let ack = assert_ok!(res); assert_ok!(tx.try_send(ack)); - assert_some!(sockets.iter().find(|s| s.id == id)); + assert_some!(io.sockets().iter().find(|s| s.id == id)); async move {} }) .await; diff --git a/crates/socketioxide/tests/connect.rs b/crates/socketioxide/tests/connect.rs index a023c32e..da11c419 100644 --- a/crates/socketioxide/tests/connect.rs +++ b/crates/socketioxide/tests/connect.rs @@ -3,10 +3,8 @@ mod utils; use bytes::Bytes; use engineioxide::Packet::*; use serde::Serialize; -use socketioxide::{ - extract::SocketRef, handler::ConnectHandler, packet::Packet, SendError, SocketError, SocketIo, -}; -use socketioxide_core::{parser::Parse, Value}; +use socketioxide::{extract::SocketRef, handler::ConnectHandler, SendError, SocketError, SocketIo}; +use socketioxide_core::{packet::Packet, parser::Parse, Value}; use socketioxide_parser_common::CommonParser; use tokio::sync::mpsc; diff --git a/crates/socketioxide/tests/disconnect_reason.rs b/crates/socketioxide/tests/disconnect_reason.rs index 646098be..25f81eb2 100644 --- a/crates/socketioxide/tests/disconnect_reason.rs +++ b/crates/socketioxide/tests/disconnect_reason.rs @@ -94,7 +94,7 @@ pub async fn ws_transport_closed() { stream.send(Message::Text("1".into())).await.unwrap(); - let data = tokio::time::timeout(Duration::from_millis(1), rx.recv()) + let data = tokio::time::timeout(Duration::from_millis(10), rx.recv()) .await .expect("timeout waiting for DisconnectReason::TransportClose") .unwrap(); @@ -189,7 +189,7 @@ pub async fn client_ns_disconnect() { stream.send(Message::Text("41".into())).await.unwrap(); - let data = tokio::time::timeout(Duration::from_millis(1), rx.recv()) + let data = tokio::time::timeout(Duration::from_millis(10), rx.recv()) .await .expect("timeout waiting for DisconnectReason::ClientNSDisconnect") .unwrap(); @@ -208,7 +208,7 @@ pub async fn server_ns_disconnect() { tokio::spawn(async move { tokio::time::sleep(Duration::from_millis(10)).await; - let s = io.sockets().unwrap().into_iter().next().unwrap(); + let s = io.sockets().into_iter().next().unwrap(); s.disconnect().unwrap(); }); diff --git a/crates/socketioxide/tests/extractors.rs b/crates/socketioxide/tests/extractors.rs index 1fbb3ec5..265ee2a6 100644 --- a/crates/socketioxide/tests/extractors.rs +++ b/crates/socketioxide/tests/extractors.rs @@ -5,15 +5,15 @@ use std::time::Duration; use serde_json::json; use socketioxide::extract::{Data, Extension, MaybeExtension, SocketRef, State, TryData}; use socketioxide::handler::ConnectHandler; -use socketioxide::DecodeError; +use socketioxide::ParserError; use socketioxide_core::parser::Parse; use socketioxide_core::Value; use socketioxide_parser_common::CommonParser; use tokio::sync::mpsc; use engineioxide::Packet as EioPacket; -use socketioxide::packet::Packet; use socketioxide::SocketIo; +use socketioxide_core::packet::Packet; mod fixture; mod utils; @@ -104,7 +104,7 @@ pub async fn data_extractor() { #[tokio::test] pub async fn try_data_extractor() { let (_, io) = SocketIo::new_svc(); - let (tx, mut rx) = mpsc::channel::>(4); + let (tx, mut rx) = mpsc::channel::>(4); io.ns("/", move |s: SocketRef, TryData(data): TryData| { assert_ok!(tx.try_send(data)); s.on("test", move |TryData(data): TryData| { diff --git a/e2e/adapter/.gitignore b/e2e/adapter/.gitignore new file mode 100644 index 00000000..eb03e3e1 --- /dev/null +++ b/e2e/adapter/.gitignore @@ -0,0 +1,2 @@ +node_modules +*.log diff --git a/e2e/adapter/Cargo.toml b/e2e/adapter/Cargo.toml new file mode 100644 index 00000000..065c8fb6 --- /dev/null +++ b/e2e/adapter/Cargo.toml @@ -0,0 +1,51 @@ +[package] +name = "adapter-e2e" +version = "0.1.0" +edition.workspace = true +rust-version.workspace = true +authors.workspace = true +repository.workspace = true +homepage.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true +publish = false + +[features] +v4 = ["socketioxide/v4"] +v5 = [] +msgpack = ["socketioxide/msgpack"] +default = ["v5"] + +[dependencies] +socketioxide = { path = "../../crates/socketioxide", default-features = false, features = [ + "tracing", +] } +socketioxide-redis = { path = "../../crates/socketioxide-redis", features = [ + "redis", + "redis-cluster", + "fred", +] } +hyper-util = { workspace = true, features = ["tokio"] } +hyper = { workspace = true, features = ["server", "http1"] } +tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } +futures-util.workspace = true +tracing-subscriber.workspace = true +tracing.workspace = true + + +[[bin]] +name = "redis-e2e" +path = "src/bins/redis.rs" + +[[bin]] +name = "redis-cluster-e2e" +path = "src/bins/redis_cluster.rs" + +[[bin]] +name = "fred-e2e" +path = "src/bins/fred.rs" + +[[bin]] +name = "fred-cluster-e2e" +path = "src/bins/fred_cluster.rs" diff --git a/e2e/adapter/client.ts b/e2e/adapter/client.ts new file mode 100644 index 00000000..76675468 --- /dev/null +++ b/e2e/adapter/client.ts @@ -0,0 +1,117 @@ +import { spawn_servers, spawn_sockets, TEST, timeout } from "./fixture"; +import assert from "assert"; + +assert(!!process.env.CMD, "CMD env var must be set"); + +// * Spawn 10 sockets on 3 servers +// * Call a `broadcast` event on each socket +// * Expect the socket to broadcast a message to all other sockets +async function broadcast() { + const sockets = await spawn_sockets([3000, 3001, 3002], 10); + for (const socket of sockets) { + let msgs: string[] = []; + const prom = new Promise((resolve) => { + for (const socket of sockets) { + socket.once("broadcast", (data: string) => { + msgs.push(data); + if (msgs.length === sockets.length) resolve(null); + }); + } + }); + socket.emit("broadcast"); + await timeout(prom); + assert.equal(Object.values(msgs).length, sockets.length); + for (const msg of msgs) { + assert.deepStrictEqual(msg, `hello from ${socket.id}`); + } + } + return sockets; +} + +async function broadcastWithAck() { + const sockets = await spawn_sockets([3000, 3001, 3002], 10); + const expected = sockets.map((s) => `ack from ${s.id}`).sort(); + for (const socket of sockets) { + socket.on("broadcast_with_ack", (_data, ack) => { + ack(`ack from ${socket.id}`); + }); + } + for (const socket of sockets) { + const res: string[] = await timeout( + socket.emitWithAck("broadcast_with_ack"), + ); + assert.deepStrictEqual(res.sort(), expected); + } + return sockets; +} + +async function disconnectSocket() { + const sockets = await spawn_sockets([3000, 3001, 3002], 10); + let cnt = 0; + const prom = new Promise((resolve) => { + for (const socket of sockets) { + socket.on("disconnect", () => { + cnt++; + if (cnt === sockets.length) resolve(null); + }); + } + }); + sockets[0].emit("disconnect_socket"); + await timeout(prom); + for (const socket of sockets) { + assert(!socket.connected); + } +} + +async function rooms() { + const sockets = await spawn_sockets([3000, 3001, 3002], 10); + const expected = [ + "room1", + "room2", + "room4", + "room5", + ...sockets.map((s) => s.id), + ].sort(); + for (const socket of sockets) { + const rooms: string[] = await timeout(socket.emitWithAck("rooms")); + assert.deepStrictEqual(rooms.sort(), expected); + } + return sockets; +} + +// * Spawn 10 sockets on 3 servers +// * Call a `fetch_sockets` event on each socket +// * Get the list of sockets and compare it to the expected list +async function fetchSockets() { + type SocketData = { id: string; ns: string }; + const sockets = await spawn_sockets([3000, 3001, 3002], 10); + const expected = sockets + .map((socket) => ({ + id: socket.id, + ns: "/", + })) + .sort((a, b) => a.id.localeCompare(b.id)); + + for (const socket of sockets) { + const data: SocketData[] = await timeout( + socket.emitWithAck("fetch_sockets"), + ); + const sorted = data + ?.map((data) => ({ id: data.id, ns: data.ns })) + ?.sort((a, b) => a.id.localeCompare(b.id)); + assert.deepStrictEqual(sorted, expected); + } + + return sockets; +} + +async function main() { + await spawn_servers([3000, 3001, 3002]); + await TEST(broadcast); + await TEST(broadcastWithAck); + await TEST(fetchSockets); + await TEST(disconnectSocket); + await TEST(rooms); + process.exit(); +} +main(); diff --git a/e2e/adapter/fixture.ts b/e2e/adapter/fixture.ts new file mode 100644 index 00000000..3c055f2e --- /dev/null +++ b/e2e/adapter/fixture.ts @@ -0,0 +1,88 @@ +import io, { Socket } from "socket.io-client"; +import msgpackParser from "socket.io-msgpack-parser"; +import { ChildProcess, exec, spawn } from "child_process"; +import assert from "assert"; +import { open } from "fs/promises"; + +export async function timeout_recv( + fn: (resolve: (value: T) => void) => any, + duration = 500, +) { + return new Promise((resolve, reject) => { + fn(resolve); + setTimeout(() => reject("timeout"), duration); + }); +} +export async function timeout( + promise: Promise, + duration = 500, +): Promise { + return new Promise((resolve, reject) => { + setTimeout(() => reject("timeout"), duration); + promise.then(resolve, reject); + }); +} + +export async function spawn_servers(ports: number[]) { + const args = process.env.CMD!.split(" "); + const bin = args.shift(); + const servers: [ChildProcess, number][] = []; + const logs: Record = {}; + for (const port of ports) { + exec(`kill $(lsof -t -i:${port})`); + console.log("spawning server on port", port); + const file = (await open(`${port}.log`, "w")).createWriteStream(); + console.log(`EXEC PORT=${port} ${bin} ${args.join(" ")}`); + const server = spawn(bin, args, { + shell: true, + env: { + ...process.env, + PORT: port.toString(), + }, + }); + logs[server.pid] = ""; + server.stdout.pipe(file); + server.stderr.pipe(file); + } + process.on("exit", () => { + for (const [server, port] of servers) { + console.log("killing", server.pid); + server.kill(); + exec(`kill $(lsof -t -i:${port})`); + } + }); + process.on("SIGINT", () => process.exit()); // catch ctrl-c + process.on("SIGTERM", () => process.exit()); // catch kill + await new Promise((resolve) => setTimeout(resolve, 1000)); + return servers; +} + +// Spawn a number of distributed sockets on a list of ports +export async function spawn_sockets(ports: number[], len: number) { + let sockets: Socket[] = []; + const parser = process.env.CMD.includes("msgpack") ? msgpackParser : null; + for (let i = 0; i < len; i++) { + const socket = io(`http://localhost:${ports[i % ports.length]}`, { + parser, + }); + assert( + await timeout_recv((resolve) => + socket.on("connect", () => resolve(true)), + ), + ); + sockets.push(socket); + } + return sockets; +} +export async function TEST( + fn: () => Promise, +): Promise { + console.log(`RUN ${fn.name}`); + const sockets = await fn(); + if (sockets) { + for (const socket of sockets) { + socket.disconnect(); + } + } + console.log(`OK ${fn.name}`); +} diff --git a/e2e/adapter/package-lock.json b/e2e/adapter/package-lock.json new file mode 100644 index 00000000..40f140a9 --- /dev/null +++ b/e2e/adapter/package-lock.json @@ -0,0 +1,348 @@ +{ + "name": "adapter-e2e", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "adapter-e2e", + "dependencies": { + "@types/node": "^22", + "socket.io-client": "^4.8.1", + "socket.io-msgpack-parser": "^3.0.2", + "ts-node": "^10.9.2" + } + }, + "node_modules/@cspotcode/source-map-support": { + "version": "0.8.1", + "resolved": "https://registry.npmjs.org/@cspotcode/source-map-support/-/source-map-support-0.8.1.tgz", + "integrity": "sha512-IchNf6dN4tHoMFIn/7OE8LWZ19Y6q/67Bmf6vnGREv8RSbBVb9LPJxEcnwrcwX6ixSvaiGoomAUvu4YSxXrVgw==", + "license": "MIT", + "dependencies": { + "@jridgewell/trace-mapping": "0.3.9" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/@jridgewell/resolve-uri": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.2.tgz", + "integrity": "sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==", + "license": "MIT", + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@jridgewell/sourcemap-codec": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.0.tgz", + "integrity": "sha512-gv3ZRaISU3fjPAgNsriBRqGWQL6quFx04YMPW/zD8XMLsU32mhCCbfbO6KZFLjvYpCZ8zyDEgqsgf+PwPaM7GQ==", + "license": "MIT" + }, + "node_modules/@jridgewell/trace-mapping": { + "version": "0.3.9", + "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.9.tgz", + "integrity": "sha512-3Belt6tdc8bPgAtbcmdtNJlirVoTmEb5e2gC94PnkwEW9jI6CAHUeoG85tjWP5WquqfavoMtMwiG4P926ZKKuQ==", + "license": "MIT", + "dependencies": { + "@jridgewell/resolve-uri": "^3.0.3", + "@jridgewell/sourcemap-codec": "^1.4.10" + } + }, + "node_modules/@socket.io/component-emitter": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.2.tgz", + "integrity": "sha512-9BCxFwvbGg/RsZK9tjXd8s4UcwR0MWeFQ1XEKIQVVvAGJyINdrqKMcTRyLoK8Rse1GjzLV9cwjWV1olXRWEXVA==", + "license": "MIT" + }, + "node_modules/@tsconfig/node10": { + "version": "1.0.11", + "resolved": "https://registry.npmjs.org/@tsconfig/node10/-/node10-1.0.11.tgz", + "integrity": "sha512-DcRjDCujK/kCk/cUe8Xz8ZSpm8mS3mNNpta+jGCA6USEDfktlNvm1+IuZ9eTcDbNk41BHwpHHeW+N1lKCz4zOw==", + "license": "MIT" + }, + "node_modules/@tsconfig/node12": { + "version": "1.0.11", + "resolved": "https://registry.npmjs.org/@tsconfig/node12/-/node12-1.0.11.tgz", + "integrity": "sha512-cqefuRsh12pWyGsIoBKJA9luFu3mRxCA+ORZvA4ktLSzIuCUtWVxGIuXigEwO5/ywWFMZ2QEGKWvkZG1zDMTag==", + "license": "MIT" + }, + "node_modules/@tsconfig/node14": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/@tsconfig/node14/-/node14-1.0.3.tgz", + "integrity": "sha512-ysT8mhdixWK6Hw3i1V2AeRqZ5WfXg1G43mqoYlM2nc6388Fq5jcXyr5mRsqViLx/GJYdoL0bfXD8nmF+Zn/Iow==", + "license": "MIT" + }, + "node_modules/@tsconfig/node16": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/@tsconfig/node16/-/node16-1.0.4.tgz", + "integrity": "sha512-vxhUy4J8lyeyinH7Azl1pdd43GJhZH/tP2weN8TntQblOY+A0XbT8DJk1/oCPuOOyg/Ja757rG0CgHcWC8OfMA==", + "license": "MIT" + }, + "node_modules/@types/node": { + "version": "22.10.5", + "resolved": "https://registry.npmjs.org/@types/node/-/node-22.10.5.tgz", + "integrity": "sha512-F8Q+SeGimwOo86fiovQh8qiXfFEh2/ocYv7tU5pJ3EXMSSxk1Joj5wefpFK2fHTf/N6HKGSxIDBT9f3gCxXPkQ==", + "license": "MIT", + "dependencies": { + "undici-types": "~6.20.0" + } + }, + "node_modules/acorn": { + "version": "8.14.0", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.14.0.tgz", + "integrity": "sha512-cl669nCJTZBsL97OF4kUQm5g5hC2uihk0NxY3WENAC0TYdILVkAyHymAntgxGkl7K+t0cXIrH5siy5S4XkFycA==", + "license": "MIT", + "bin": { + "acorn": "bin/acorn" + }, + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/acorn-walk": { + "version": "8.3.4", + "resolved": "https://registry.npmjs.org/acorn-walk/-/acorn-walk-8.3.4.tgz", + "integrity": "sha512-ueEepnujpqee2o5aIYnvHU6C0A42MNdsIDeqy5BydrkuC5R1ZuUFnm27EeFJGoEHJQgn3uleRvmTXaJgfXbt4g==", + "license": "MIT", + "dependencies": { + "acorn": "^8.11.0" + }, + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/arg": { + "version": "4.1.3", + "resolved": "https://registry.npmjs.org/arg/-/arg-4.1.3.tgz", + "integrity": "sha512-58S9QDqG0Xx27YwPSt9fJxivjYl432YCwfDMfZ+71RAqUrZef7LrKQZ3LHLOwCS4FLNBplP533Zx895SeOCHvA==", + "license": "MIT" + }, + "node_modules/component-emitter": { + "version": "1.3.1", + "resolved": "https://registry.npmjs.org/component-emitter/-/component-emitter-1.3.1.tgz", + "integrity": "sha512-T0+barUSQRTUQASh8bx02dl+DhF54GtIDY13Y3m9oWTklKbb3Wv974meRpeZ3lp1JpLVECWWNHC4vaG2XHXouQ==", + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/create-require": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/create-require/-/create-require-1.1.1.tgz", + "integrity": "sha512-dcKFX3jn0MpIaXjisoRvexIJVEKzaq7z2rZKxf+MSr9TkdmHmsU4m2lcLojrj/FHl8mk5VxMmYA+ftRkP/3oKQ==", + "license": "MIT" + }, + "node_modules/debug": { + "version": "4.3.7", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.7.tgz", + "integrity": "sha512-Er2nc/H7RrMXZBFCEim6TCmMk02Z8vLC2Rbi1KEBggpo0fS6l0S1nnapwmIi3yW/+GOJap1Krg4w0Hg80oCqgQ==", + "license": "MIT", + "dependencies": { + "ms": "^2.1.3" + }, + "engines": { + "node": ">=6.0" + }, + "peerDependenciesMeta": { + "supports-color": { + "optional": true + } + } + }, + "node_modules/diff": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/diff/-/diff-4.0.2.tgz", + "integrity": "sha512-58lmxKSA4BNyLz+HHMUzlOEpg09FV+ev6ZMe3vJihgdxzgcwZ8VoEEPmALCZG9LmqfVoNMMKpttIYTVG6uDY7A==", + "license": "BSD-3-Clause", + "engines": { + "node": ">=0.3.1" + } + }, + "node_modules/engine.io-client": { + "version": "6.6.2", + "resolved": "https://registry.npmjs.org/engine.io-client/-/engine.io-client-6.6.2.tgz", + "integrity": "sha512-TAr+NKeoVTjEVW8P3iHguO1LO6RlUz9O5Y8o7EY0fU+gY1NYqas7NN3slpFtbXEsLMHk0h90fJMfKjRkQ0qUIw==", + "license": "MIT", + "dependencies": { + "@socket.io/component-emitter": "~3.1.0", + "debug": "~4.3.1", + "engine.io-parser": "~5.2.1", + "ws": "~8.17.1", + "xmlhttprequest-ssl": "~2.1.1" + } + }, + "node_modules/engine.io-parser": { + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.2.3.tgz", + "integrity": "sha512-HqD3yTBfnBxIrbnM1DoD6Pcq8NECnh8d4As1Qgh0z5Gg3jRRIqijury0CL3ghu/edArpUYiYqQiDUQBIs4np3Q==", + "license": "MIT", + "engines": { + "node": ">=10.0.0" + } + }, + "node_modules/make-error": { + "version": "1.3.6", + "resolved": "https://registry.npmjs.org/make-error/-/make-error-1.3.6.tgz", + "integrity": "sha512-s8UhlNe7vPKomQhC1qFelMokr/Sc3AgNbso3n74mVPA5LTZwkB9NlXf4XPamLxJE8h0gh73rM94xvwRT2CVInw==", + "license": "ISC" + }, + "node_modules/ms": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", + "license": "MIT" + }, + "node_modules/notepack.io": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/notepack.io/-/notepack.io-2.2.0.tgz", + "integrity": "sha512-9b5w3t5VSH6ZPosoYnyDONnUTF8o0UkBw7JLA6eBlYJWyGT1Q3vQa8Hmuj1/X6RYvHjjygBDgw6fJhe0JEojfw==", + "license": "MIT" + }, + "node_modules/socket.io-client": { + "version": "4.8.1", + "resolved": "https://registry.npmjs.org/socket.io-client/-/socket.io-client-4.8.1.tgz", + "integrity": "sha512-hJVXfu3E28NmzGk8o1sHhN3om52tRvwYeidbj7xKy2eIIse5IoKX3USlS6Tqt3BHAtflLIkCQBkzVrEEfWUyYQ==", + "license": "MIT", + "dependencies": { + "@socket.io/component-emitter": "~3.1.0", + "debug": "~4.3.2", + "engine.io-client": "~6.6.1", + "socket.io-parser": "~4.2.4" + }, + "engines": { + "node": ">=10.0.0" + } + }, + "node_modules/socket.io-msgpack-parser": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/socket.io-msgpack-parser/-/socket.io-msgpack-parser-3.0.2.tgz", + "integrity": "sha512-1e76bJ1PCKi9H+JiYk+S29PBJvknHjQWM7Mtj0hjF2KxDA6b6rQxv3rTsnwBoz/haZOhlCDIMQvPATbqYeuMxg==", + "license": "MIT", + "dependencies": { + "component-emitter": "~1.3.0", + "notepack.io": "~2.2.0" + } + }, + "node_modules/socket.io-parser": { + "version": "4.2.4", + "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.4.tgz", + "integrity": "sha512-/GbIKmo8ioc+NIWIhwdecY0ge+qVBSMdgxGygevmdHj24bsfgtCmcUUcQ5ZzcylGFHsN3k4HB4Cgkl96KVnuew==", + "license": "MIT", + "dependencies": { + "@socket.io/component-emitter": "~3.1.0", + "debug": "~4.3.1" + }, + "engines": { + "node": ">=10.0.0" + } + }, + "node_modules/ts-node": { + "version": "10.9.2", + "resolved": "https://registry.npmjs.org/ts-node/-/ts-node-10.9.2.tgz", + "integrity": "sha512-f0FFpIdcHgn8zcPSbf1dRevwt047YMnaiJM3u2w2RewrB+fob/zePZcrOyQoLMMO7aBIddLcQIEK5dYjkLnGrQ==", + "license": "MIT", + "dependencies": { + "@cspotcode/source-map-support": "^0.8.0", + "@tsconfig/node10": "^1.0.7", + "@tsconfig/node12": "^1.0.7", + "@tsconfig/node14": "^1.0.0", + "@tsconfig/node16": "^1.0.2", + "acorn": "^8.4.1", + "acorn-walk": "^8.1.1", + "arg": "^4.1.0", + "create-require": "^1.1.0", + "diff": "^4.0.1", + "make-error": "^1.1.1", + "v8-compile-cache-lib": "^3.0.1", + "yn": "3.1.1" + }, + "bin": { + "ts-node": "dist/bin.js", + "ts-node-cwd": "dist/bin-cwd.js", + "ts-node-esm": "dist/bin-esm.js", + "ts-node-script": "dist/bin-script.js", + "ts-node-transpile-only": "dist/bin-transpile.js", + "ts-script": "dist/bin-script-deprecated.js" + }, + "peerDependencies": { + "@swc/core": ">=1.2.50", + "@swc/wasm": ">=1.2.50", + "@types/node": "*", + "typescript": ">=2.7" + }, + "peerDependenciesMeta": { + "@swc/core": { + "optional": true + }, + "@swc/wasm": { + "optional": true + } + } + }, + "node_modules/typescript": { + "version": "5.7.2", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.7.2.tgz", + "integrity": "sha512-i5t66RHxDvVN40HfDd1PsEThGNnlMCMT3jMUuoh9/0TaqWevNontacunWyN02LA9/fIbEWlcHZcgTKb9QoaLfg==", + "license": "Apache-2.0", + "peer": true, + "bin": { + "tsc": "bin/tsc", + "tsserver": "bin/tsserver" + }, + "engines": { + "node": ">=14.17" + } + }, + "node_modules/undici-types": { + "version": "6.20.0", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.20.0.tgz", + "integrity": "sha512-Ny6QZ2Nju20vw1SRHe3d9jVu6gJ+4e3+MMpqu7pqE5HT6WsTSlce++GQmK5UXS8mzV8DSYHrQH+Xrf2jVcuKNg==", + "license": "MIT" + }, + "node_modules/v8-compile-cache-lib": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz", + "integrity": "sha512-wa7YjyUGfNZngI/vtK0UHAN+lgDCxBPCylVXGp0zu59Fz5aiGtNXaq3DhIov063MorB+VfufLh3JlF2KdTK3xg==", + "license": "MIT" + }, + "node_modules/ws": { + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.17.1.tgz", + "integrity": "sha512-6XQFvXTkbfUOZOKKILFG1PDK2NDQs4azKQl26T0YS5CxqWLgXajbPZ+h4gZekJyRqFU8pvnbAbbs/3TgRPy+GQ==", + "license": "MIT", + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "bufferutil": "^4.0.1", + "utf-8-validate": ">=5.0.2" + }, + "peerDependenciesMeta": { + "bufferutil": { + "optional": true + }, + "utf-8-validate": { + "optional": true + } + } + }, + "node_modules/xmlhttprequest-ssl": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/xmlhttprequest-ssl/-/xmlhttprequest-ssl-2.1.2.tgz", + "integrity": "sha512-TEU+nJVUUnA4CYJFLvK5X9AOeH4KvDvhIfm0vV1GaQRtchnG0hgK5p8hw/xjv8cunWYCsiPCSDzObPyhEwq3KQ==", + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/yn": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/yn/-/yn-3.1.1.tgz", + "integrity": "sha512-Ux4ygGWsu2c7isFWe8Yu1YluJmqVhxqK2cLXNQA5AcC3QfbGNpM7fu0Y8b/z16pXLnFxZYvWhd3fhBY9DLmC6Q==", + "license": "MIT", + "engines": { + "node": ">=6" + } + } + } +} diff --git a/e2e/adapter/package.json b/e2e/adapter/package.json new file mode 100644 index 00000000..3a17f42f --- /dev/null +++ b/e2e/adapter/package.json @@ -0,0 +1,9 @@ +{ + "name": "adapter-e2e", + "dependencies": { + "@types/node": "^22", + "socket.io-client": "^4.8.1", + "socket.io-msgpack-parser": "^3.0.2", + "ts-node": "^10.9.2" + } +} diff --git a/e2e/adapter/src/bins/fred.rs b/e2e/adapter/src/bins/fred.rs new file mode 100644 index 00000000..5307b1ad --- /dev/null +++ b/e2e/adapter/src/bins/fred.rs @@ -0,0 +1,67 @@ +use fred::types::RespVersion; +use hyper::server::conn::http1; +use hyper_util::rt::TokioIo; +use socketioxide::SocketIo; +use socketioxide_redis::drivers::fred::fred_client as fred; +use socketioxide_redis::RedisAdapterCtr; +use tokio::net::TcpListener; +use tracing::{info, Level}; +use tracing_subscriber::FmtSubscriber; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let subscriber = FmtSubscriber::builder() + .with_line_number(true) + .with_max_level(Level::TRACE) + .finish(); + tracing::subscriber::set_global_default(subscriber)?; + let mut config = fred::prelude::Config::from_url("redis://127.0.0.1:6379?protocol=resp3")?; + config.version = RespVersion::RESP3; + let client = fred::prelude::Builder::from_config(config).build_subscriber_client()?; + let adapter = RedisAdapterCtr::new_with_fred(client).await?; + #[allow(unused_mut)] + let mut builder = + SocketIo::builder().with_adapter::>(adapter); + + #[cfg(feature = "msgpack")] + { + builder = builder.with_parser(socketioxide::ParserConfig::msgpack()); + }; + + let (svc, io) = builder.build_svc(); + + io.ns("/", adapter_e2e::handler).await.unwrap(); + + #[cfg(feature = "v5")] + info!("Starting server with v5 protocol"); + #[cfg(feature = "v4")] + info!("Starting server with v4 protocol"); + let port: u16 = std::env::var("PORT") + .expect("a PORT env var should be set") + .parse() + .unwrap(); + + let listener = TcpListener::bind(("127.0.0.1", port)).await?; + + // We start a loop to continuously accept incoming connections + loop { + let (stream, _) = listener.accept().await?; + + // Use an adapter to access something implementing `tokio::io` traits as if they implement + // `hyper::rt` IO traits. + let io = TokioIo::new(stream); + let svc = svc.clone(); + + // Spawn a tokio task to serve multiple connections concurrently + tokio::task::spawn(async move { + // Finally, we bind the incoming connection to our `hello` service + if let Err(err) = http1::Builder::new() + .serve_connection(io, svc) + .with_upgrades() + .await + { + println!("Error serving connection: {:?}", err); + } + }); + } +} diff --git a/e2e/adapter/src/bins/fred_cluster.rs b/e2e/adapter/src/bins/fred_cluster.rs new file mode 100644 index 00000000..1d9056da --- /dev/null +++ b/e2e/adapter/src/bins/fred_cluster.rs @@ -0,0 +1,76 @@ +use fred::types::RespVersion; +use hyper::server::conn::http1; +use hyper_util::rt::TokioIo; +use socketioxide::SocketIo; +use socketioxide_redis::drivers::fred::fred_client as fred; +use socketioxide_redis::RedisAdapterCtr; +use tokio::net::TcpListener; +use tracing::{info, Level}; +use tracing_subscriber::FmtSubscriber; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let subscriber = FmtSubscriber::builder() + .with_line_number(true) + .with_max_level(Level::TRACE) + .finish(); + tracing::subscriber::set_global_default(subscriber)?; + let server_config = fred::prelude::ServerConfig::new_clustered(vec![ + ("127.0.0.1", 7000), + ("127.0.0.1", 7001), + ("127.0.0.1", 7002), + ("127.0.0.1", 7003), + ("127.0.0.1", 7004), + ("127.0.0.1", 7005), + ]); + let mut config = fred::prelude::Config::default(); + config.server = server_config; + config.version = RespVersion::RESP3; + let client = fred::prelude::Builder::from_config(config).build_subscriber_client()?; + let adapter = RedisAdapterCtr::new_with_fred(client).await?; + #[allow(unused_mut)] + let mut builder = + SocketIo::builder().with_adapter::>(adapter); + + #[cfg(feature = "msgpack")] + { + builder = builder.with_parser(socketioxide::ParserConfig::msgpack()); + }; + + let (svc, io) = builder.build_svc(); + + io.ns("/", adapter_e2e::handler).await.unwrap(); + + #[cfg(feature = "v5")] + info!("Starting server with v5 protocol"); + #[cfg(feature = "v4")] + info!("Starting server with v4 protocol"); + let port: u16 = std::env::var("PORT") + .expect("a PORT env var should be set") + .parse() + .unwrap(); + + let listener = TcpListener::bind(("127.0.0.1", port)).await?; + + // We start a loop to continuously accept incoming connections + loop { + let (stream, _) = listener.accept().await?; + + // Use an adapter to access something implementing `tokio::io` traits as if they implement + // `hyper::rt` IO traits. + let io = TokioIo::new(stream); + let svc = svc.clone(); + + // Spawn a tokio task to serve multiple connections concurrently + tokio::task::spawn(async move { + // Finally, we bind the incoming connection to our `hello` service + if let Err(err) = http1::Builder::new() + .serve_connection(io, svc) + .with_upgrades() + .await + { + println!("Error serving connection: {:?}", err); + } + }); + } +} diff --git a/e2e/adapter/src/bins/redis.rs b/e2e/adapter/src/bins/redis.rs new file mode 100644 index 00000000..b45ba97e --- /dev/null +++ b/e2e/adapter/src/bins/redis.rs @@ -0,0 +1,64 @@ +use hyper::server::conn::http1; +use hyper_util::rt::TokioIo; +use socketioxide::SocketIo; +use socketioxide_redis::drivers::redis::redis_client as redis; +use socketioxide_redis::RedisAdapterCtr; +use tokio::net::TcpListener; +use tracing::{info, Level}; +use tracing_subscriber::FmtSubscriber; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let subscriber = FmtSubscriber::builder() + .with_line_number(true) + .with_max_level(Level::TRACE) + .finish(); + + tracing::subscriber::set_global_default(subscriber)?; + let client = redis::Client::open("redis://127.0.0.1:6379?protocol=resp3")?; + let adapter = RedisAdapterCtr::new_with_redis(&client).await?; + #[allow(unused_mut)] + let mut builder = + SocketIo::builder().with_adapter::>(adapter); + #[cfg(feature = "msgpack")] + { + builder = builder.with_parser(socketioxide::ParserConfig::msgpack()); + }; + + let (svc, io) = builder.build_svc(); + + io.ns("/", adapter_e2e::handler).await.unwrap(); + + #[cfg(feature = "v5")] + info!("Starting server with v5 protocol"); + #[cfg(feature = "v4")] + info!("Starting server with v4 protocol"); + let port: u16 = std::env::var("PORT") + .expect("a PORT env var should be set") + .parse() + .unwrap(); + + let listener = TcpListener::bind(("127.0.0.1", port)).await?; + + // We start a loop to continuously accept incoming connections + loop { + let (stream, _) = listener.accept().await?; + + // Use an adapter to access something implementing `tokio::io` traits as if they implement + // `hyper::rt` IO traits. + let io = TokioIo::new(stream); + let svc = svc.clone(); + + // Spawn a tokio task to serve multiple connections concurrently + tokio::task::spawn(async move { + // Finally, we bind the incoming connection to our `hello` service + if let Err(err) = http1::Builder::new() + .serve_connection(io, svc) + .with_upgrades() + .await + { + println!("Error serving connection: {:?}", err); + } + }); + } +} diff --git a/e2e/adapter/src/bins/redis_cluster.rs b/e2e/adapter/src/bins/redis_cluster.rs new file mode 100644 index 00000000..11d34e84 --- /dev/null +++ b/e2e/adapter/src/bins/redis_cluster.rs @@ -0,0 +1,71 @@ +use hyper::server::conn::http1; +use hyper_util::rt::TokioIo; +use socketioxide::SocketIo; +use socketioxide_redis::drivers::redis::redis_client as redis; +use socketioxide_redis::RedisAdapterCtr; +use tokio::net::TcpListener; +use tracing::{info, Level}; +use tracing_subscriber::FmtSubscriber; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let subscriber = FmtSubscriber::builder() + .with_line_number(true) + .with_max_level(Level::TRACE) + .finish(); + + tracing::subscriber::set_global_default(subscriber)?; + let builder = redis::cluster::ClusterClient::builder([ + "redis://127.0.0.1:7000?protocol=resp3", + "redis://127.0.0.1:7001?protocol=resp3", + "redis://127.0.0.1:7002?protocol=resp3", + "redis://127.0.0.1:7003?protocol=resp3", + "redis://127.0.0.1:7004?protocol=resp3", + "redis://127.0.0.1:7005?protocol=resp3", + ]); + let adapter = RedisAdapterCtr::new_with_cluster(builder).await?; + #[allow(unused_mut)] + let mut builder = + SocketIo::builder().with_adapter::>(adapter); + #[cfg(feature = "msgpack")] + { + builder = builder.with_parser(socketioxide::ParserConfig::msgpack()); + }; + + let (svc, io) = builder.build_svc(); + + io.ns("/", adapter_e2e::handler).await.unwrap(); + + #[cfg(feature = "v5")] + info!("Starting server with v5 protocol"); + #[cfg(feature = "v4")] + info!("Starting server with v4 protocol"); + let port: u16 = std::env::var("PORT") + .expect("a PORT env var should be set") + .parse() + .unwrap(); + + let listener = TcpListener::bind(("127.0.0.1", port)).await?; + + // We start a loop to continuously accept incoming connections + loop { + let (stream, _) = listener.accept().await?; + + // Use an adapter to access something implementing `tokio::io` traits as if they implement + // `hyper::rt` IO traits. + let io = TokioIo::new(stream); + let svc = svc.clone(); + + // Spawn a tokio task to serve multiple connections concurrently + tokio::task::spawn(async move { + // Finally, we bind the incoming connection to our `hello` service + if let Err(err) = http1::Builder::new() + .serve_connection(io, svc) + .with_upgrades() + .await + { + println!("Error serving connection: {:?}", err); + } + }); + } +} diff --git a/e2e/adapter/src/lib.rs b/e2e/adapter/src/lib.rs new file mode 100644 index 00000000..3741ab69 --- /dev/null +++ b/e2e/adapter/src/lib.rs @@ -0,0 +1,59 @@ +//! Test all adapter methods. +//! +//! There are 7 types of requests: +//! * Broadcast a packet to all the matching sockets. +//! * Broadcast a packet to all the matching sockets and wait for a stream of acks. +//! * Disconnect matching sockets. +//! * Get all the rooms. +//! * Add matching sockets to rooms. +//! * Remove matching sockets to rooms. +//! * Fetch all the remote sockets matching the options. + +use futures_util::StreamExt; +use socketioxide::{ + adapter::Adapter, + extract::{AckSender, SocketRef}, + SocketIo, +}; +pub async fn handler(s: SocketRef) { + s.join(["room1", "room2", "room4", "room5"]); + s.join(s.id); + s.on("broadcast", broadcast); + s.on("fetch_sockets", fetch_sockets); + s.on("broadcast_with_ack", broadcast_with_ack); + s.on("disconnect_socket", disconnect_socket); + s.on("rooms", rooms); +} + +async fn broadcast(io: SocketIo, s: SocketRef) { + io.emit("broadcast", &format!("hello from {}", s.id)) + .await + .unwrap(); +} +async fn broadcast_with_ack(io: SocketIo, ack: AckSender) { + let data: Vec = io + .emit_with_ack("broadcast_with_ack", &()) + .await + .unwrap() + .map(|(_, d)| d.unwrap()) + .collect() + .await; + ack.send(&data).unwrap(); +} +async fn fetch_sockets(io: SocketIo, ack: AckSender) { + let sockets: Vec<_> = io + .fetch_sockets() + .await + .unwrap() + .into_iter() + .map(|s| s.into_data()) + .collect(); + ack.send(&sockets).unwrap(); +} +async fn disconnect_socket(io: SocketIo) { + io.disconnect().await.unwrap(); +} +async fn rooms(io: SocketIo, ack: AckSender) { + let rooms = io.rooms().await.unwrap(); + ack.send(&rooms).unwrap(); +} diff --git a/e2e/adapter/tsconfig.json b/e2e/adapter/tsconfig.json new file mode 100644 index 00000000..7af9830d --- /dev/null +++ b/e2e/adapter/tsconfig.json @@ -0,0 +1,6 @@ +{ + "compilerOptions": { + "types": ["node"], + "esModuleInterop": true + } +} diff --git a/e2e/engineioxide/engineioxide.rs b/e2e/engineioxide/engineioxide.rs index 35a4700e..be4780dd 100644 --- a/e2e/engineioxide/engineioxide.rs +++ b/e2e/engineioxide/engineioxide.rs @@ -29,12 +29,12 @@ impl EngineIoHandler for MyHandler { println!("socket disconnect {}: {:?}", socket.id, reason); } - fn on_message(&self, msg: Str, socket: Arc>) { + fn on_message(self: &Arc, msg: Str, socket: Arc>) { println!("Ping pong message {:?}", msg); socket.emit(msg).ok(); } - fn on_binary(&self, data: Bytes, socket: Arc>) { + fn on_binary(self: &Arc, data: Bytes, socket: Arc>) { println!("Ping pong binary message {:?}", data); socket.emit_binary(data).ok(); } diff --git a/e2e/socketioxide/Cargo.toml b/e2e/socketioxide/Cargo.toml index c335bb2a..fcfe83a8 100644 --- a/e2e/socketioxide/Cargo.toml +++ b/e2e/socketioxide/Cargo.toml @@ -20,8 +20,8 @@ hyper = { workspace = true, features = ["server", "http1"] } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tracing-subscriber.workspace = true tracing.workspace = true -serde_json.workspace = true rmpv = { version = "1.3.0", features = ["with-serde"] } +serde_json.workspace = true [[bin]] name = "socketioxide-e2e" diff --git a/examples/Cargo.toml b/examples/Cargo.toml index c9fd5df3..40f18aff 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -11,6 +11,8 @@ rmpv = { version = "1.3.0", features = ["with-serde"] } tower = { version = "0.5.0", default-features = false } tracing = "0.1.37" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } -axum = "0.7.5" +axum = "0.8" hyper-util.version = "0.1.1" hyper = { version = "1.0.1", features = ["http1", "server"] } +socketioxide = { path = "../crates/socketioxide" } +serde_json = "1" diff --git a/examples/angular-todomvc/Cargo.toml b/examples/angular-todomvc/Cargo.toml index 9658078b..4c6b473c 100644 --- a/examples/angular-todomvc/Cargo.toml +++ b/examples/angular-todomvc/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" [dependencies] -socketioxide = { path = "../../crates/socketioxide", features = ["state"] } +socketioxide = { workspace = true, features = ["state"] } axum.workspace = true tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tower-http = { version = "0.5.0", features = ["cors", "fs"] } diff --git a/examples/angular-todomvc/src/main.rs b/examples/angular-todomvc/src/main.rs index 3f6b8f60..8e93061e 100644 --- a/examples/angular-todomvc/src/main.rs +++ b/examples/angular-todomvc/src/main.rs @@ -48,8 +48,12 @@ async fn main() -> Result<(), Box> { let mut todos = todos.lock().unwrap(); todos.clear(); todos.extend_from_slice(&new_todos); - - s.broadcast().emit("update-store", &new_todos).unwrap(); + async move { + s.broadcast() + .emit("update-store", &new_todos) + .await + .unwrap(); + } }, ); }); diff --git a/examples/axum-echo-tls/Cargo.toml b/examples/axum-echo-tls/Cargo.toml index 85cb9fde..9c3c9929 100644 --- a/examples/axum-echo-tls/Cargo.toml +++ b/examples/axum-echo-tls/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] -socketioxide = { path = "../../crates/socketioxide" } +socketioxide.workspace = true axum.workspace = true tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tracing-subscriber.workspace = true diff --git a/examples/axum-echo/Cargo.toml b/examples/axum-echo/Cargo.toml index 747b78f5..16382771 100644 --- a/examples/axum-echo/Cargo.toml +++ b/examples/axum-echo/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] -socketioxide = { path = "../../crates/socketioxide" } +socketioxide.workspace = true axum.workspace = true tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tracing-subscriber.workspace = true diff --git a/examples/background-task/Cargo.toml b/examples/background-task/Cargo.toml index 8c81d7e3..dbd490d6 100644 --- a/examples/background-task/Cargo.toml +++ b/examples/background-task/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -socketioxide = { path = "../../crates/socketioxide" } +socketioxide.workspace = true axum.workspace = true tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tower.workspace = true diff --git a/examples/background-task/src/main.rs b/examples/background-task/src/main.rs index a833e7ed..0c25393a 100644 --- a/examples/background-task/src/main.rs +++ b/examples/background-task/src/main.rs @@ -13,9 +13,9 @@ async fn background_task(io: SocketIo) { loop { tokio::time::sleep(std::time::Duration::from_secs(1)).await; info!("Background task"); - let cnt = io.of("/").unwrap().sockets().unwrap().len(); + let cnt = io.of("/").unwrap().sockets().len(); let msg = format!("{}s, {} socket connected", i, cnt); - io.emit("tic tac !", &msg).unwrap(); + io.emit("tic tac !", &msg).await.unwrap(); i += 1; } diff --git a/examples/basic-crud-application/Cargo.toml b/examples/basic-crud-application/Cargo.toml index 552a4120..c7b1639d 100644 --- a/examples/basic-crud-application/Cargo.toml +++ b/examples/basic-crud-application/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" [dependencies] -socketioxide = { path = "../../crates/socketioxide", features = ["state"] } +socketioxide = { workspace = true, features = ["state"] } uuid = { version = "1.6.1", features = ["v4", "serde"] } axum.workspace = true tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } diff --git a/examples/basic-crud-application/src/handlers/todo.rs b/examples/basic-crud-application/src/handlers/todo.rs index 85e3e6ba..68640705 100644 --- a/examples/basic-crud-application/src/handlers/todo.rs +++ b/examples/basic-crud-application/src/handlers/todo.rs @@ -44,7 +44,12 @@ impl Todos { } } -pub fn create(s: SocketRef, Data(data): Data, ack: AckSender, todos: State) { +pub async fn create( + s: SocketRef, + Data(data): Data, + ack: AckSender, + todos: State, +) { let id = Uuid::new_v4(); let todo = Todo { id, inner: data }; @@ -53,7 +58,7 @@ pub fn create(s: SocketRef, Data(data): Data, ack: AckSender, todos let res: Response<_> = id.into(); ack.send(&res).ok(); - s.broadcast().emit("todo:created", &todo).ok(); + s.broadcast().emit("todo:created", &todo).await.ok(); } pub async fn read(Data(id): Data, ack: AckSender, todos: State) { @@ -62,21 +67,26 @@ pub async fn read(Data(id): Data, ack: AckSender, todos: State) { } pub async fn update(s: SocketRef, Data(data): Data, ack: AckSender, todos: State) { - let res = todos - .get_mut(&data.id) - .ok_or(Error::NotFound) - .map(|mut todo| { + let res = match todos.get_mut(&data.id) { + Some(mut todo) => { todo.inner = data.inner.clone(); - s.broadcast().emit("todo:updated", &data).ok(); - }); + s.broadcast().emit("todo:updated", &data).await.ok(); + Ok(()) + } + None => Err(Error::NotFound), + }; ack.send(&res).ok(); } pub async fn delete(s: SocketRef, Data(id): Data, ack: AckSender, todos: State) { - let res = todos.remove(&id).ok_or(Error::NotFound).map(|_| { - s.broadcast().emit("todo:deleted", &id).ok(); - }); + let res = match todos.remove(&id) { + Some(_) => { + s.broadcast().emit("todo:deleted", &id).await.ok(); + Ok(()) + } + None => Err(Error::NotFound), + }; ack.send(&res).ok(); } diff --git a/examples/chat/Cargo.toml b/examples/chat/Cargo.toml index 40493ce9..65150f6d 100644 --- a/examples/chat/Cargo.toml +++ b/examples/chat/Cargo.toml @@ -6,10 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -socketioxide = { path = "../../crates/socketioxide", features = [ - "extensions", - "state", -] } +socketioxide = { workspace = true, features = ["extensions", "state"] } axum.workspace = true tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tower-http = { version = "0.5.0", features = ["cors", "fs"] } diff --git a/examples/chat/src/main.rs b/examples/chat/src/main.rs index f088be77..3fe072cc 100644 --- a/examples/chat/src/main.rs +++ b/examples/chat/src/main.rs @@ -5,11 +5,11 @@ use socketioxide::{ extract::{Data, Extension, SocketRef, State}, SocketIo, }; +use std::sync::Arc; use tower::ServiceBuilder; use tower_http::{cors::CorsLayer, services::ServeDir}; use tracing::info; use tracing_subscriber::FmtSubscriber; -use std::sync::Arc; #[derive(Deserialize, Serialize, Debug, Clone)] #[serde(transparent)] @@ -62,18 +62,18 @@ async fn main() -> Result<(), Box> { io.ns("/", |s: SocketRef| { s.on( "new message", - |s: SocketRef, Data::(msg), Extension::(username)| { + |s: SocketRef, Data::(msg), Extension::(username)| async move { let msg = &Res::Message { username, message: msg, }; - s.broadcast().emit("new message", msg).ok(); + s.broadcast().emit("new message", msg).await.ok(); }, ); s.on( "add user", - |s: SocketRef, Data::(username), user_cnt: State| { + |s: SocketRef, Data::(username), user_cnt: State| async move { if s.extensions.get::().is_some() { return; } @@ -85,33 +85,38 @@ async fn main() -> Result<(), Box> { num_users, username: Username(username), }; - s.broadcast().emit("user joined", res).ok(); + s.broadcast().emit("user joined", res).await.ok(); }, ); - s.on("typing", |s: SocketRef, Extension::(username)| { - s.broadcast() - .emit("typing", &Res::Username { username }) - .ok(); - }); + s.on( + "typing", + |s: SocketRef, Extension::(username)| async move { + s.broadcast() + .emit("typing", &Res::Username { username }) + .await + .ok(); + }, + ); s.on( "stop typing", - |s: SocketRef, Extension::(username)| { + |s: SocketRef, Extension::(username)| async move { s.broadcast() .emit("stop typing", &Res::Username { username }) + .await .ok(); }, ); s.on_disconnect( - |s: SocketRef, user_cnt: State, Extension::(username)| { + |s: SocketRef, user_cnt: State, Extension::(username)| async move { let num_users = user_cnt.remove_user(); let res = &Res::UserEvent { num_users, username, }; - s.broadcast().emit("user left", res).ok(); + s.broadcast().emit("user left", res).await.ok(); }, ); }); diff --git a/examples/hyper-echo/Cargo.toml b/examples/hyper-echo/Cargo.toml index 0823a241..9b97ba3e 100644 --- a/examples/hyper-echo/Cargo.toml +++ b/examples/hyper-echo/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -socketioxide = { path = "../../crates/socketioxide", features = ["tracing"] } +socketioxide = { workspace = true, features = ["tracing"] } hyper = { workspace = true, features = ["server", "http1"] } hyper-util = { workspace = true, features = ["tokio"] } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } diff --git a/examples/private-messaging/Cargo.toml b/examples/private-messaging/Cargo.toml index 8127e946..c5fab6e4 100644 --- a/examples/private-messaging/Cargo.toml +++ b/examples/private-messaging/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" [dependencies] -socketioxide = { path = "../../crates/socketioxide", features = [ +socketioxide = { workspace = true, features = [ "extensions", "state", "tracing", diff --git a/examples/private-messaging/src/handlers.rs b/examples/private-messaging/src/handlers.rs index 131965c1..3a2a795d 100644 --- a/examples/private-messaging/src/handlers.rs +++ b/examples/private-messaging/src/handlers.rs @@ -47,7 +47,7 @@ struct PrivateMessageReq { content: String, } -pub fn on_connection( +pub async fn on_connection( s: SocketRef, Extension::>(session): Extension>, State(sessions): State, @@ -67,14 +67,14 @@ pub fn on_connection( s.emit("users", &users).unwrap(); let res = UserConnectedRes::new(&session, vec![]); - s.broadcast().emit("user connected", &res).unwrap(); + s.broadcast().emit("user connected", &res).await.unwrap(); s.on( "private message", |s: SocketRef, Data(PrivateMessageReq { to, content }), State::(msgs), - Extension::>(session)| { + Extension::>(session)| async move { let message = Message { from: session.user_id, to, @@ -82,24 +82,27 @@ pub fn on_connection( }; s.within(to.to_string()) .emit("private message", &message) + .await .ok(); msgs.add(message); }, ); - s.on_disconnect(|s: SocketRef, Extension::>(session)| { - session.set_connected(false); - let res = UserDisconnectedRes { - user_id: &session.user_id, - username: &session.username, - }; - s.broadcast().emit("user disconnected", &res).ok(); - }); + s.on_disconnect( + |s: SocketRef, Extension::>(session)| async move { + session.set_connected(false); + let res = UserDisconnectedRes { + user_id: &session.user_id, + username: &session.username, + }; + s.broadcast().emit("user disconnected", &res).await.ok(); + }, + ); } /// Handles the connection of a new user. /// Be careful to not emit anything to the user before the authentication is done. -pub fn authenticate_middleware( +pub async fn authenticate_middleware( s: SocketRef, Data(auth): Data, State(sessions): State, @@ -116,7 +119,7 @@ pub fn authenticate_middleware( session }; - s.join(session.user_id.to_string())?; + s.join(session.user_id.to_string()); Ok(()) } diff --git a/examples/react-rooms-chat/Cargo.toml b/examples/react-rooms-chat/Cargo.toml index b97d497c..3514e469 100644 --- a/examples/react-rooms-chat/Cargo.toml +++ b/examples/react-rooms-chat/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -socketioxide = { path = "../../crates/socketioxide", features = ["state"] } +socketioxide = { workspace = true, features = ["state"] } tokio = { version = "1", features = ["full"] } tracing = "0.1" tracing-subscriber = "0.3" diff --git a/examples/react-rooms-chat/src/main.rs b/examples/react-rooms-chat/src/main.rs index df750161..a491773d 100644 --- a/examples/react-rooms-chat/src/main.rs +++ b/examples/react-rooms-chat/src/main.rs @@ -29,8 +29,8 @@ async fn on_connect(socket: SocketRef) { "join", |socket: SocketRef, Data::(room), store: State| async move { info!("Received join: {:?}", room); - let _ = socket.leave_all(); - let _ = socket.join(room.clone()); + socket.leave_all(); + socket.join(room.clone()); let messages = store.get(&room).await; let _ = socket.emit("messages", &Messages { messages }); }, @@ -49,14 +49,14 @@ async fn on_connect(socket: SocketRef) { store.insert(&data.room, response.clone()).await; - let _ = socket.within(data.room).emit("message", &response); + let _ = socket.within(data.room).emit("message", &response).await; }, ) } async fn handler(axum::extract::State(io): axum::extract::State) { info!("handler called"); - let _ = io.emit("hello", "world"); + let _ = io.emit("hello", "world").await; } #[tokio::main] diff --git a/examples/redis-whiteboard/Cargo.toml b/examples/redis-whiteboard/Cargo.toml new file mode 100644 index 00000000..93a77c29 --- /dev/null +++ b/examples/redis-whiteboard/Cargo.toml @@ -0,0 +1,32 @@ +[package] +name = "redis-whiteboard" +version = "0.1.0" +edition = "2021" + +[dependencies] +socketioxide-redis = { path = "../../crates/socketioxide-redis", features = [ + "redis", + "redis-cluster", + "fred", +] } +socketioxide = { workspace = true, features = ["tracing", "msgpack"] } +axum.workspace = true +tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } +tower-http = { version = "0.5.0", features = ["cors", "fs"] } +tower.workspace = true +tracing-subscriber.workspace = true +tracing.workspace = true +serde.workspace = true +rmpv.workspace = true + +[[bin]] +name = "redis" +path = "src/redis.rs" + +[[bin]] +name = "redis-cluster" +path = "src/redis_cluster.rs" + +[[bin]] +name = "fred" +path = "src/fred.rs" diff --git a/examples/redis-whiteboard/Readme.md b/examples/redis-whiteboard/Readme.md new file mode 100644 index 00000000..033bd4bf --- /dev/null +++ b/examples/redis-whiteboard/Readme.md @@ -0,0 +1,6 @@ +# Same example than whiteboard but with a redis adapter + +You can spawn as much as server as you want with different ports (env PORT) and then join with clients +connected on these different ports. + +The parser is set to msgpack in the example, but you can use any socket.io parser you want. diff --git a/examples/redis-whiteboard/dist/index.html b/examples/redis-whiteboard/dist/index.html new file mode 100644 index 00000000..e8a0f2aa --- /dev/null +++ b/examples/redis-whiteboard/dist/index.html @@ -0,0 +1,23 @@ + + + + + Socket.IO whiteboard + + + + + + +
+
+
+
+
+
+
+ + + + + diff --git a/examples/redis-whiteboard/dist/main.js b/examples/redis-whiteboard/dist/main.js new file mode 100644 index 00000000..e41c37d4 --- /dev/null +++ b/examples/redis-whiteboard/dist/main.js @@ -0,0 +1,124 @@ +"use strict"; + +(function () { + const params = new URLSearchParams(window.location.search); + var socket = io(); + var canvas = document.getElementsByClassName("whiteboard")[0]; + var colors = document.getElementsByClassName("color"); + var context = canvas.getContext("2d"); + + var current = { + color: "black", + }; + var drawing = false; + + canvas.addEventListener("mousedown", onMouseDown, false); + canvas.addEventListener("mouseup", onMouseUp, false); + canvas.addEventListener("mouseout", onMouseUp, false); + canvas.addEventListener("mousemove", throttle(onMouseMove, 10), false); + + //Touch support for mobile devices + canvas.addEventListener("touchstart", onMouseDown, false); + canvas.addEventListener("touchend", onMouseUp, false); + canvas.addEventListener("touchcancel", onMouseUp, false); + canvas.addEventListener("touchmove", throttle(onMouseMove, 10), false); + + for (var i = 0; i < colors.length; i++) { + colors[i].addEventListener("click", onColorUpdate, false); + } + + socket.on("drawing", onDrawingEvent); + + window.addEventListener("resize", onResize, false); + onResize(); + + function drawLine(x0, y0, x1, y1, color, emit) { + context.beginPath(); + context.moveTo(x0, y0); + context.lineTo(x1, y1); + context.strokeStyle = color; + context.lineWidth = 2; + context.stroke(); + context.closePath(); + + if (!emit) { + return; + } + var w = canvas.width; + var h = canvas.height; + + socket.emit("drawing", { + x0: x0 / w, + y0: y0 / h, + x1: x1 / w, + y1: y1 / h, + color: color, + }); + } + + function onMouseDown(e) { + drawing = true; + current.x = e.clientX || e.touches[0].clientX; + current.y = e.clientY || e.touches[0].clientY; + } + + function onMouseUp(e) { + if (!drawing) { + return; + } + drawing = false; + drawLine( + current.x, + current.y, + e.clientX || e.touches[0].clientX, + e.clientY || e.touches[0].clientY, + current.color, + true, + ); + } + + function onMouseMove(e) { + if (!drawing) { + return; + } + drawLine( + current.x, + current.y, + e.clientX || e.touches[0].clientX, + e.clientY || e.touches[0].clientY, + current.color, + true, + ); + current.x = e.clientX || e.touches[0].clientX; + current.y = e.clientY || e.touches[0].clientY; + } + + function onColorUpdate(e) { + current.color = e.target.className.split(" ")[1]; + } + + // limit the number of events per second + function throttle(callback, delay) { + var previousCall = new Date().getTime(); + return function () { + var time = new Date().getTime(); + + if (time - previousCall >= delay) { + previousCall = time; + callback.apply(null, arguments); + } + }; + } + + function onDrawingEvent(data) { + var w = canvas.width; + var h = canvas.height; + drawLine(data.x0 * w, data.y0 * h, data.x1 * w, data.y1 * h, data.color); + } + + // make the canvas fill its parent + function onResize() { + canvas.width = window.innerWidth; + canvas.height = window.innerHeight; + } +})(); diff --git a/examples/redis-whiteboard/dist/style.css b/examples/redis-whiteboard/dist/style.css new file mode 100644 index 00000000..437a29cf --- /dev/null +++ b/examples/redis-whiteboard/dist/style.css @@ -0,0 +1,44 @@ + +/** + * Fix user-agent + */ + +* { + box-sizing: border-box; +} + +html, body { + height: 100%; + margin: 0; + padding: 0; +} + +/** + * Canvas + */ + +.whiteboard { + height: 100%; + width: 100%; + position: absolute; + left: 0; + right: 0; + bottom: 0; + top: 0; +} + +.colors { + position: fixed; +} + +.color { + display: inline-block; + height: 48px; + width: 48px; +} + +.color.black { background-color: black; } +.color.red { background-color: red; } +.color.green { background-color: green; } +.color.blue { background-color: blue; } +.color.yellow { background-color: yellow; } diff --git a/examples/redis-whiteboard/src/fred.rs b/examples/redis-whiteboard/src/fred.rs new file mode 100644 index 00000000..7bf14d84 --- /dev/null +++ b/examples/redis-whiteboard/src/fred.rs @@ -0,0 +1,71 @@ +//! A simple whiteboard example using Redis as the adapter. +//! It uses the fred crate to connect to a Redis server. +use rmpv::Value; +use socketioxide::{ + adapter::Adapter, + extract::{Data, SocketRef}, + ParserConfig, SocketIo, +}; +use socketioxide_redis::drivers::fred::fred_client as fred; +use socketioxide_redis::{FredAdapter, RedisAdapterCtr}; +use std::str::FromStr; +use tower::ServiceBuilder; +use tower_http::{cors::CorsLayer, services::ServeDir}; +use tracing::info; +use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; + +use fred::{ + prelude::Config, + types::{Builder, RespVersion}, +}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::registry() + .with(fmt::layer()) + .with(EnvFilter::from_default_env()) + .init(); + + info!("connecting to redis"); + let mut config = Config::from_url("redis://127.0.0.1:6379")?; + // We have to manually set the version to RESP3. Fred defaults to RESP2. + config.version = RespVersion::RESP3; + let client = Builder::default() + .set_config(config) + .build_subscriber_client()?; + let adapter = RedisAdapterCtr::new_with_fred(client).await?; + info!("starting server"); + + let (layer, io) = SocketIo::builder() + .with_parser(ParserConfig::msgpack()) + .with_adapter::>(adapter) + .build_layer(); + + // It is heavily recommended to use generic fns instead of closures for handlers. + // This allows to be generic over the adapter you want to use. + async fn on_drawing(s: SocketRef
, Data(data): Data) { + s.broadcast().emit("drawing", &data).await.ok(); + } + fn on_connect(s: SocketRef) { + s.on("drawing", on_drawing); + } + io.ns("/", on_connect).await?; + + let app = axum::Router::new() + .nest_service("/", ServeDir::new("dist")) + .layer( + ServiceBuilder::new() + .layer(CorsLayer::permissive()) // Enable CORS policy + .layer(layer), + ); + + let port: u16 = std::env::var("PORT") + .map(|s| u16::from_str(&s).unwrap()) + .unwrap_or(3000); + let listener = tokio::net::TcpListener::bind(("0.0.0.0", port)) + .await + .unwrap(); + axum::serve(listener, app).await?; + + Ok(()) +} diff --git a/examples/redis-whiteboard/src/redis.rs b/examples/redis-whiteboard/src/redis.rs new file mode 100644 index 00000000..e52d2693 --- /dev/null +++ b/examples/redis-whiteboard/src/redis.rs @@ -0,0 +1,62 @@ +//! A simple whiteboard example using Redis as the adapter. +//! It uses the redis crate to connect to a Redis server. +use std::str::FromStr; + +use rmpv::Value; +use socketioxide::{ + adapter::Adapter, + extract::{Data, SocketRef}, + ParserConfig, SocketIo, +}; +use socketioxide_redis::{RedisAdapter, RedisAdapterCtr}; +use socketioxide_redis::drivers::redis::redis_client as redis; +use tower::ServiceBuilder; +use tower_http::{cors::CorsLayer, services::ServeDir}; +use tracing::info; +use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::registry() + .with(fmt::layer()) + .with(EnvFilter::from_default_env()) + .init(); + + info!("connecting to redis"); + let client = redis::Client::open("redis://127.0.0.1:6379?protocol=resp3")?; + let adapter = RedisAdapterCtr::new_with_redis(&client).await?; + info!("starting server"); + + let (layer, io) = SocketIo::builder() + .with_parser(ParserConfig::msgpack()) + .with_adapter::>(adapter) + .build_layer(); + + // It is heavily recommended to use generic fns instead of closures for handlers. + // This allows to be generic over the adapter you want to use. + async fn on_drawing(s: SocketRef, Data(data): Data) { + s.broadcast().emit("drawing", &data).await.ok(); + } + fn on_connect(s: SocketRef) { + s.on("drawing", on_drawing); + } + io.ns("/", on_connect).await?; + + let app = axum::Router::new() + .nest_service("/", ServeDir::new("dist")) + .layer( + ServiceBuilder::new() + .layer(CorsLayer::permissive()) // Enable CORS policy + .layer(layer), + ); + + let port: u16 = std::env::var("PORT") + .map(|s| u16::from_str(&s).unwrap()) + .unwrap_or(3000); + let listener = tokio::net::TcpListener::bind(("0.0.0.0", port)) + .await + .unwrap(); + axum::serve(listener, app).await?; + + Ok(()) +} diff --git a/examples/redis-whiteboard/src/redis_cluster.rs b/examples/redis-whiteboard/src/redis_cluster.rs new file mode 100644 index 00000000..a14b7063 --- /dev/null +++ b/examples/redis-whiteboard/src/redis_cluster.rs @@ -0,0 +1,65 @@ +//! A simple whiteboard example using Redis as the adapter. +//! It uses the redis crate to connect to a Redis server. +use std::str::FromStr; + +use rmpv::Value; +use socketioxide::{ + adapter::Adapter, + extract::{Data, SocketRef}, + ParserConfig, SocketIo, +}; +use socketioxide_redis::drivers::redis::redis_client as redis; +use socketioxide_redis::{ClusterAdapter, RedisAdapterCtr}; +use tower::ServiceBuilder; +use tower_http::{cors::CorsLayer, services::ServeDir}; +use tracing::info; +use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::registry() + .with(fmt::layer()) + .with(EnvFilter::from_default_env()) + .init(); + + info!("connecting to redis"); + // single node cluster. In a real world scenario, you would have multiple nodes. + let builder = redis::cluster::ClusterClient::builder(std::iter::once( + "redis://127.0.0.1:6379?protocol=resp3", + )); + let adapter = RedisAdapterCtr::new_with_cluster(builder).await?; + info!("starting server"); + + let (layer, io) = SocketIo::builder() + .with_parser(ParserConfig::msgpack()) + .with_adapter::>(adapter) + .build_layer(); + + // It is heavily recommended to use generic fns instead of closures for handlers. + // This allows to be generic over the adapter you want to use. + async fn on_drawing(s: SocketRef, Data(data): Data) { + s.broadcast().emit("drawing", &data).await.ok(); + } + fn on_connect(s: SocketRef) { + s.on("drawing", on_drawing); + } + io.ns("/", on_connect).await?; + + let app = axum::Router::new() + .nest_service("/", ServeDir::new("dist")) + .layer( + ServiceBuilder::new() + .layer(CorsLayer::permissive()) // Enable CORS policy + .layer(layer), + ); + + let port: u16 = std::env::var("PORT") + .map(|s| u16::from_str(&s).unwrap()) + .unwrap_or(3000); + let listener = tokio::net::TcpListener::bind(("0.0.0.0", port)) + .await + .unwrap(); + axum::serve(listener, app).await?; + + Ok(()) +} diff --git a/examples/salvo-echo/Cargo.toml b/examples/salvo-echo/Cargo.toml index 3b29aaef..f14af161 100644 --- a/examples/salvo-echo/Cargo.toml +++ b/examples/salvo-echo/Cargo.toml @@ -7,7 +7,7 @@ rust-version = "1.67" # required by salvo # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -socketioxide = { path = "../../crates/socketioxide", features = ["tracing"] } +socketioxide = { workspace = true, features = ["tracing"] } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tracing-subscriber.workspace = true tracing.workspace = true diff --git a/examples/viz-echo/Cargo.toml b/examples/viz-echo/Cargo.toml index 6e04aa4c..ae50f96e 100644 --- a/examples/viz-echo/Cargo.toml +++ b/examples/viz-echo/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" [dependencies] viz = "0.8.0" -socketioxide = { path = "../../crates/socketioxide" } +socketioxide.workspace = true tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tracing-subscriber.workspace = true tracing.workspace = true diff --git a/examples/webrtc-node-app/Cargo.toml b/examples/webrtc-node-app/Cargo.toml index 71dd89af..c19cf2bf 100644 --- a/examples/webrtc-node-app/Cargo.toml +++ b/examples/webrtc-node-app/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -socketioxide = { path = "../../crates/socketioxide", features = ["msgpack"] } +socketioxide = { workspace = true, features = ["msgpack"] } axum.workspace = true tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tower-http = { version = "0.5.0", features = ["cors", "fs"] } diff --git a/examples/webrtc-node-app/src/main.rs b/examples/webrtc-node-app/src/main.rs index f33e8e3a..cabded63 100644 --- a/examples/webrtc-node-app/src/main.rs +++ b/examples/webrtc-node-app/src/main.rs @@ -35,49 +35,73 @@ async fn main() -> Result<(), Box> { .build_layer(); io.ns("/", |s: SocketRef| { - s.on("join", |s: SocketRef, Data(room_id): Data| { - let socket_cnt = s.within(room_id.clone()).sockets().unwrap().len(); - if socket_cnt == 0 { - tracing::info!("creating room {room_id} and emitting room_created socket event"); - s.join(room_id.clone()).unwrap(); - s.emit("room_created", &room_id).unwrap(); - } else if socket_cnt == 1 { - tracing::info!("joining room {room_id} and emitting room_joined socket event"); - s.join(room_id.clone()).unwrap(); - s.emit("room_joined", &room_id).unwrap(); - } else { - tracing::info!("Can't join room {room_id}, emitting full_room socket event"); - s.emit("full_room", &room_id).ok(); - } - }); + s.on( + "join", + |s: SocketRef, Data(room_id): Data| async move { + let socket_cnt = s.within(room_id.clone()).sockets().len(); + if socket_cnt == 0 { + tracing::info!( + "creating room {room_id} and emitting room_created socket event" + ); + s.join(room_id.clone()); + s.emit("room_created", &room_id).unwrap(); + } else if socket_cnt == 1 { + tracing::info!("joining room {room_id} and emitting room_joined socket event"); + s.join(room_id.clone()); + s.emit("room_joined", &room_id).unwrap(); + } else { + tracing::info!("Can't join room {room_id}, emitting full_room socket event"); + s.emit("full_room", &room_id).ok(); + } + }, + ); - s.on("start_call", |s: SocketRef, Data(room_id): Data| { - tracing::info!("broadcasting start_call event to peers in room {room_id}"); - s.to(room_id.clone()).emit("start_call", &room_id).ok(); - }); - s.on("webrtc_offer", |s: SocketRef, Data(event): Data| { - tracing::info!( - "broadcasting webrtc_offer event to peers in room {}", - event.room_id - ); - s.to(event.room_id).emit("webrtc_offer", &event.sdp).ok(); - }); - s.on("webrtc_answer", |s: SocketRef, Data(event): Data| { - tracing::info!( - "broadcasting webrtc_answer event to peers in room {}", - event.room_id - ); - s.to(event.room_id).emit("webrtc_answer", &event.sdp).ok(); - }); + s.on( + "start_call", + |s: SocketRef, Data(room_id): Data| async move { + tracing::info!("broadcasting start_call event to peers in room {room_id}"); + s.to(room_id.clone()) + .emit("start_call", &room_id) + .await + .ok(); + }, + ); + s.on( + "webrtc_offer", + |s: SocketRef, Data(event): Data| async move { + tracing::info!( + "broadcasting webrtc_offer event to peers in room {}", + event.room_id + ); + s.to(event.room_id) + .emit("webrtc_offer", &event.sdp) + .await + .ok(); + }, + ); + s.on( + "webrtc_answer", + |s: SocketRef, Data(event): Data| async move { + tracing::info!( + "broadcasting webrtc_answer event to peers in room {}", + event.room_id + ); + s.to(event.room_id) + .emit("webrtc_answer", &event.sdp) + .await + .ok(); + }, + ); s.on( "webrtc_ice_candidate", - |s: SocketRef, Data(event): Data| { + |s: SocketRef, Data(event): Data| async move { tracing::info!( "broadcasting ice_candidate event to peers in room {}", event.room_id ); s.to(event.room_id.clone()) .emit("webrtc_ice_candidate", &event) + .await .ok(); }, ); diff --git a/examples/whiteboard/Cargo.toml b/examples/whiteboard/Cargo.toml index dc173511..a608210c 100644 --- a/examples/whiteboard/Cargo.toml +++ b/examples/whiteboard/Cargo.toml @@ -4,10 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] -socketioxide = { path = "../../crates/socketioxide", features = [ - "tracing", - "msgpack", -] } +socketioxide = { workspace = true, features = ["tracing", "msgpack"] } axum.workspace = true tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tower-http = { version = "0.5.0", features = ["cors", "fs"] } diff --git a/examples/whiteboard/src/main.rs b/examples/whiteboard/src/main.rs index 7ded248c..2d96993c 100644 --- a/examples/whiteboard/src/main.rs +++ b/examples/whiteboard/src/main.rs @@ -22,8 +22,8 @@ async fn main() -> Result<(), Box> { .build_layer(); io.ns("/", |s: SocketRef| { - s.on("drawing", |s: SocketRef, Data::(data)| { - s.broadcast().emit("drawing", &data).unwrap(); + s.on("drawing", |s: SocketRef, Data::(data)| async move { + s.broadcast().emit("drawing", &data).await.unwrap(); }); });