diff --git a/.github/workflows/github-ci.yml b/.github/workflows/github-ci.yml index b51e437f..24d80186 100644 --- a/.github/workflows/github-ci.yml +++ b/.github/workflows/github-ci.yml @@ -89,6 +89,7 @@ jobs: strategy: matrix: engineio-version: [v3, v4] + hyper-version: [hyper-v04, hyper-v1] steps: - uses: actions/checkout@v3 - uses: actions-rs/toolchain@v1 @@ -115,8 +116,8 @@ jobs: - name: Install deps & run tests run: | cd engine.io-protocol/test-suite && npm install && cd ../.. - cargo build --bin engineioxide-e2e --features ${{ matrix.engineio-version }} --release - cargo run --bin engineioxide-e2e --features ${{ matrix.engineio-version }} --release > server.txt & npm --prefix engine.io-protocol/test-suite test > client.txt + cargo build --bin engineioxide-${{ matrix.hyper-version }}-e2e --features ${{ matrix.engineio-version }} --release + cargo run --bin engineioxide-${{ matrix.hyper-version }}-e2e --features ${{ matrix.engineio-version }} --release > server.txt & npm --prefix engine.io-protocol/test-suite test > client.txt - name: Server output if: always() run: cat server.txt diff --git a/CHANGELOG.md b/CHANGELOG.md index 3801006f..b13bf969 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,17 @@ +# 0.7.0 +## socketioxide +* The `extensions` field on sockets has been moved to a separate optional feature flag named `extensions` +* All the `tracing` internal calls have been moved to a separate optional feature flag named `tracing` +* A compatibility layer is now available for hyper v1 under the feature flag `hyper-v1`. You can call `with_hyper_v1` on the `SocketIoLayer` or the `SocketIoService` to get a layer/service working with hyper v1. The default is still hyper v0. +* New example with hyper v1 standalone +* New example with [salvo](https://salvo.rs) (based on hyper v1) +* Socket.io packet encoding/decoding has been optimized, it is now between ~15% and ~50% faster than before + +## engineioxide +* All the `tracing` internal calls have been moved to a separate optional feature flag named `tracing` +* A compatibility layer is now available for hyper v1 under the feature flag `hyper-v1`. You can call `with_hyper_v1` on the `EngineIoLayer` or the `EngineIoService` to get a layer/service working with hyper v1. The default is still hyper v0. +* Sid generation is now done manually without external crates + # 0.6.0 ## socketioxide * New API for creating the socket.io layer/service. A cheaply clonable `SocketIo` struct is now returned with the layer/service and allows to access namespaces/rooms/sockets everywhere in the application. Moreover, it is now possible to add and remove namespaces dynamically through the `SocketIo` struct. diff --git a/Cargo.lock b/Cargo.lock index 37b1362c..e91cc6c4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,41 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "aead" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" +dependencies = [ + "crypto-common", + "generic-array", +] + +[[package]] +name = "aes" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac1f845298e95f983ff1944b728ae08b8cebab80d684f0a832ed0fc74dfa27e2" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + +[[package]] +name = "aes-gcm" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1" +dependencies = [ + "aead", + "aes", + "cipher", + "ctr", + "ghash", + "subtle", +] + [[package]] name = "aho-corasick" version = "1.1.2" @@ -26,6 +61,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "alloc-no-stdlib" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" + +[[package]] +name = "alloc-stdlib" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" +dependencies = [ + "alloc-no-stdlib", +] + [[package]] name = "anes" version = "0.1.6" @@ -67,8 +117,8 @@ dependencies = [ "bytes", "futures-util", "http", - "http-body", - "hyper", + "http-body 0.4.5", + "hyper 0.14.27", "itoa", "matchit", "memchr", @@ -97,13 +147,27 @@ dependencies = [ "bytes", "futures-util", "http", - "http-body", + "http-body 0.4.5", "mime", "rustversion", "tower-layer", "tower-service", ] +[[package]] +name = "axum-echo" +version = "0.6.0" +dependencies = [ + "axum", + "futures", + "serde", + "serde_json", + "socketioxide", + "tokio", + "tracing", + "tracing-subscriber", +] + [[package]] name = "backtrace" version = "0.3.69" @@ -119,6 +183,12 @@ dependencies = [ "rustc-demangle", ] +[[package]] +name = "base64" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3441f0f7b02788e948e47f457ca01f1d7e6d92c693bc132c22b087d3141c03ff" + [[package]] name = "base64" version = "0.21.5" @@ -146,6 +216,27 @@ dependencies = [ "generic-array", ] +[[package]] +name = "brotli" +version = "3.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "516074a47ef4bce09577a3b379392300159ce5b1ba2e501ff1c819950066100f" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor", +] + +[[package]] +name = "brotli-decompressor" +version = "2.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e2e4afe60d7dd600fdd3de8d0f08c2b7ec039712e3b6137ff98b7004e82de4f" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + [[package]] name = "bumpalo" version = "3.14.0" @@ -176,6 +267,7 @@ version = "1.0.83" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" dependencies = [ + "jobserver", "libc", ] @@ -212,6 +304,16 @@ dependencies = [ "half", ] +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + [[package]] name = "clap" version = "4.4.7" @@ -237,6 +339,23 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1" +[[package]] +name = "cookie" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3cd91cf61412820176e137621345ee43b3f4423e589e7ae4e50d601d93e35ef8" +dependencies = [ + "aes-gcm", + "base64 0.21.5", + "hmac", + "percent-encoding", + "rand 0.8.5", + "sha2", + "subtle", + "time", + "version_check", +] + [[package]] name = "cpufeatures" version = "0.2.11" @@ -246,6 +365,15 @@ dependencies = [ "libc", ] +[[package]] +name = "crc32fast" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d" +dependencies = [ + "cfg-if", +] + [[package]] name = "criterion" version = "0.5.1" @@ -315,6 +443,16 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "cruet" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "113a9e83d8f614be76de8df1f25bf9d0ea6e85ea573710a3d3f3abe1438ae49c" +dependencies = [ + "once_cell", + "regex", +] + [[package]] name = "crypto-common" version = "0.1.6" @@ -322,9 +460,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" dependencies = [ "generic-array", + "rand_core 0.6.4", "typenum", ] +[[package]] +name = "ctr" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" +dependencies = [ + "cipher", +] + [[package]] name = "dashmap" version = "5.5.3" @@ -344,6 +492,15 @@ version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2e66c9d817f1720209181c316d28635c050fa304f9c79e47a520882661b7308" +[[package]] +name = "deranged" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f32d04922c60427da6f9fef14d042d9edddef64cb9d4ce0d64d0685fbeb1fd3" +dependencies = [ + "powerfmt", +] + [[package]] name = "digest" version = "0.10.7" @@ -352,6 +509,7 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", + "subtle", ] [[package]] @@ -373,15 +531,12 @@ dependencies = [ name = "engineio-echo" version = "0.6.0" dependencies = [ - "axum", - "engineioxide", "futures", - "hyper", + "hyper 0.14.27", "serde", "serde_json", + "socketioxide", "tokio", - "tower", - "tower-http", "tracing", "tracing-subscriber", "warp", @@ -392,15 +547,18 @@ name = "engineioxide" version = "0.6.0" dependencies = [ "async-trait", - "base64", + "base64 0.21.5", "bytes", "futures", "http", - "http-body", - "hyper", + "http-body 0.4.5", + "http-body 1.0.0-rc.2", + "hyper 0.14.27", + "hyper 1.0.0-rc.4", + "hyper-util", "memchr", "pin-project", - "rand", + "rand 0.8.5", "serde", "serde_json", "thiserror", @@ -418,13 +576,41 @@ version = "0.6.0" dependencies = [ "engineioxide", "futures", - "hyper", + "hyper 0.14.27", + "hyper 1.0.0-rc.4", + "hyper-util", "serde_json", "tokio", "tracing", "tracing-subscriber", ] +[[package]] +name = "enumflags2" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5998b4f30320c9d93aed72f63af821bfdac50465b75428fce77b48ec482c3939" +dependencies = [ + "enumflags2_derive", +] + +[[package]] +name = "enumflags2_derive" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f95e2801cd355d4a1a3e3953ce6ee5ae9603a5c833455343a8bfe3f44d418246" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + [[package]] name = "errno" version = "0.3.5" @@ -435,6 +621,22 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "fastrand" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" + +[[package]] +name = "flate2" +version = "1.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46303f565772937ffe1d394a4fac6f411c6013172fadde9dcdb1e147a086940e" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + [[package]] name = "fnv" version = "1.0.7" @@ -549,6 +751,17 @@ dependencies = [ "version_check", ] +[[package]] +name = "getrandom" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.9.0+wasi-snapshot-preview1", +] + [[package]] name = "getrandom" version = "0.2.10" @@ -557,7 +770,17 @@ checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" dependencies = [ "cfg-if", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", +] + +[[package]] +name = "ghash" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d930750de5717d2dd0b8c0d42c076c0e884c81a73e6cab859bbd2339c71e3e40" +dependencies = [ + "opaque-debug", + "polyval", ] [[package]] @@ -578,7 +801,7 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap", + "indexmap 1.9.3", "slab", "tokio", "tokio-util", @@ -609,7 +832,7 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06683b93020a07e3dbcf5f8c0f6d40080d725bea7936fc01ad345c01b97dc270" dependencies = [ - "base64", + "base64 0.21.5", "bytes", "headers-core", "http", @@ -633,6 +856,15 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "http" version = "0.2.9" @@ -655,6 +887,29 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "http-body" +version = "1.0.0-rc.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "951dfc2e32ac02d67c90c0d65bd27009a635dc9b381a2cc7d284ab01e3a0150d" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.0-rc.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08ef12f041acdd397010e5fb6433270c147d3b8b2d0a840cd7fff8e531dca5c8" +dependencies = [ + "bytes", + "futures-util", + "http", + "http-body 1.0.0-rc.2", + "pin-project-lite", +] + [[package]] name = "http-range-header" version = "0.3.1" @@ -685,7 +940,7 @@ dependencies = [ "futures-util", "h2", "http", - "http-body", + "http-body 0.4.5", "httparse", "httpdate", "itoa", @@ -697,6 +952,76 @@ dependencies = [ "want", ] +[[package]] +name = "hyper" +version = "1.0.0-rc.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d280a71f348bcc670fc55b02b63c53a04ac0bf2daff2980795aeaf53edae10e6" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "h2", + "http", + "http-body 1.0.0-rc.2", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "tokio", + "tracing", + "want", +] + +[[package]] +name = "hyper-echo" +version = "0.6.0" +dependencies = [ + "futures", + "hyper 0.14.27", + "serde", + "serde_json", + "socketioxide", + "tokio", + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "hyper-util" +version = "0.0.0" +source = "git+https://github.com/hyperium/hyper-util.git#ced9f812460420017705fa7cae4dca7be9e23f4a" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "http", + "http-body 1.0.0-rc.2", + "hyper 1.0.0-rc.4", + "once_cell", + "pin-project-lite", + "socket2 0.5.5", + "tokio", + "tower", + "tower-service", + "tracing", +] + +[[package]] +name = "hyper-v1-echo" +version = "0.6.0" +dependencies = [ + "futures", + "hyper 1.0.0-rc.4", + "hyper-util", + "serde", + "serde_json", + "socketioxide", + "tokio", + "tracing", + "tracing-subscriber", +] + [[package]] name = "idna" version = "0.4.0" @@ -717,6 +1042,25 @@ dependencies = [ "hashbrown 0.12.3", ] +[[package]] +name = "indexmap" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8adf3ddd720272c6ea8bf59463c04e0f93d0bbf7c5439b691bca2987e0270897" +dependencies = [ + "equivalent", + "hashbrown 0.14.2", +] + +[[package]] +name = "inout" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" +dependencies = [ + "generic-array", +] + [[package]] name = "is-terminal" version = "0.4.9" @@ -743,6 +1087,15 @@ version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" +[[package]] +name = "jobserver" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c37f63953c4c63420ed5fd3d6d398c719489b9f872b9fa683262f8edd363c7d" +dependencies = [ + "libc", +] + [[package]] name = "js-sys" version = "0.3.64" @@ -822,6 +1175,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime-infer" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f65b181c4fc4a9bfe77dbf7dfa5f34f292dc22c3b4267c505d1fa149a07e3559" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "mime_guess" version = "2.0.4" @@ -848,7 +1211,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3dce281c5e46beae905d4de1870d8b1509a9142b62eedf18b443b011ca8343d0" dependencies = [ "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "windows-sys", ] @@ -870,6 +1233,15 @@ dependencies = [ "version_check", ] +[[package]] +name = "multimap" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1a5d38b9b352dbd913288736af36af41c48d61b1a8cd34bcecd727561b7d511" +dependencies = [ + "serde", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -920,6 +1292,12 @@ version = "11.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" +[[package]] +name = "opaque-debug" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" + [[package]] name = "overload" version = "0.1.1" @@ -987,6 +1365,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkg-config" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" + [[package]] name = "plotters" version = "0.3.5" @@ -1015,12 +1399,39 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "polyval" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d52cff9d1d4dee5fe6d03729099f4a310a41179e0a10dbf542039873f2e826fb" +dependencies = [ + "cfg-if", + "cpufeatures", + "opaque-debug", + "universal-hash", +] + +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "ppv-lite86" version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +[[package]] +name = "proc-macro-crate" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e8366a6159044a37876a2b9817124296703c586a5c92e2c53751fa06d8d43e8" +dependencies = [ + "toml_edit", +] + [[package]] name = "proc-macro2" version = "1.0.69" @@ -1039,6 +1450,19 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" +dependencies = [ + "getrandom 0.1.16", + "libc", + "rand_chacha 0.2.2", + "rand_core 0.5.1", + "rand_hc", +] + [[package]] name = "rand" version = "0.8.5" @@ -1046,8 +1470,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha", - "rand_core", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402" +dependencies = [ + "ppv-lite86", + "rand_core 0.5.1", ] [[package]] @@ -1057,7 +1491,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_core" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" +dependencies = [ + "getrandom 0.1.16", ] [[package]] @@ -1066,7 +1509,16 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", + "getrandom 0.2.10", +] + +[[package]] +name = "rand_hc" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" +dependencies = [ + "rand_core 0.5.1", ] [[package]] @@ -1167,7 +1619,7 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2" dependencies = [ - "base64", + "base64 0.21.5", ] [[package]] @@ -1182,6 +1634,112 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" +[[package]] +name = "salvo" +version = "0.58.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0206b25c533ed47049843951c74e884a005dcfe057b3af783059d9ae0e40d0f" +dependencies = [ + "salvo_core", +] + +[[package]] +name = "salvo-echo" +version = "0.6.0" +dependencies = [ + "futures", + "salvo", + "serde", + "serde_json", + "socketioxide", + "tokio", + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "salvo-utils" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39b83ebce32e15342c188b1a687556fbc830111e393181e1693ce80ffdd2eda0" +dependencies = [ + "futures-channel", + "futures-util", + "http", + "hyper 1.0.0-rc.4", + "once_cell", + "pin-project-lite", + "socket2 0.5.5", + "tokio", + "tower", + "tower-service", + "tracing", +] + +[[package]] +name = "salvo_core" +version = "0.58.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "661eabf006e4b5dde529a7a31dd875f82b9505d64ac82a517538729d9110941a" +dependencies = [ + "async-trait", + "base64 0.21.5", + "brotli", + "bytes", + "cookie", + "cruet", + "encoding_rs", + "enumflags2", + "flate2", + "form_urlencoded", + "futures-channel", + "futures-util", + "headers", + "http", + "http-body-util", + "hyper 1.0.0-rc.4", + "indexmap 2.0.2", + "mime", + "mime-infer", + "multer", + "multimap", + "once_cell", + "parking_lot", + "percent-encoding", + "pin-project", + "regex", + "salvo-utils", + "salvo_macros", + "serde", + "serde-xml-rs", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "tempfile", + "textnonce", + "thiserror", + "tokio", + "tokio-util", + "tower", + "tracing", + "url", + "zstd", +] + +[[package]] +name = "salvo_macros" +version = "0.58.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9d3886e7c3e07d17664800b7aa0015839759e231c29069dc101f7e351531bd1" +dependencies = [ + "cruet", + "proc-macro-crate", + "proc-macro2", + "quote", + "regex", + "syn", +] + [[package]] name = "same-file" version = "1.0.6" @@ -1212,6 +1770,18 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde-xml-rs" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb3aa78ecda1ebc9ec9847d5d3aba7d618823446a049ba2491940506da6e2782" +dependencies = [ + "log", + "serde", + "thiserror", + "xml-rs", +] + [[package]] name = "serde_derive" version = "1.0.190" @@ -1267,6 +1837,17 @@ dependencies = [ "digest", ] +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -1336,24 +1917,6 @@ dependencies = [ "tracing-subscriber", ] -[[package]] -name = "socketio-echo" -version = "0.6.0" -dependencies = [ - "axum", - "futures", - "hyper", - "serde", - "serde_json", - "socketioxide", - "tokio", - "tower", - "tower-http", - "tracing", - "tracing-subscriber", - "warp", -] - [[package]] name = "socketioxide" version = "0.6.0" @@ -1364,8 +1927,10 @@ dependencies = [ "engineioxide", "futures", "http", - "http-body", - "hyper", + "http-body 0.4.5", + "http-body 1.0.0-rc.2", + "hyper 0.14.27", + "hyper 1.0.0-rc.4", "serde", "serde_json", "thiserror", @@ -1381,7 +1946,7 @@ name = "socketioxide-e2e" version = "0.6.0" dependencies = [ "futures", - "hyper", + "hyper 0.14.27", "serde_json", "socketioxide", "tokio", @@ -1395,6 +1960,12 @@ version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +[[package]] +name = "subtle" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" + [[package]] name = "syn" version = "2.0.38" @@ -1412,6 +1983,29 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" +[[package]] +name = "tempfile" +version = "3.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ef1adac450ad7f4b3c28589471ade84f25f731a7a0fe30d71dfa9f60fd808e5" +dependencies = [ + "cfg-if", + "fastrand", + "redox_syscall", + "rustix", + "windows-sys", +] + +[[package]] +name = "textnonce" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7743f8d70cd784ed1dc33106a18998d77758d281dc40dc3e6d050cf0f5286683" +dependencies = [ + "base64 0.12.3", + "rand 0.7.3", +] + [[package]] name = "thiserror" version = "1.0.50" @@ -1442,6 +2036,35 @@ dependencies = [ "once_cell", ] +[[package]] +name = "time" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4a34ab300f2dee6e562c10a046fc05e358b29f9bf92277f30c3c8d82275f6f5" +dependencies = [ + "deranged", + "itoa", + "powerfmt", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" + +[[package]] +name = "time-macros" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ad70d68dba9e1f8aceda7aa6711965dfec1cac869f311a51bd08b3a2ccbce20" +dependencies = [ + "time-core", +] + [[package]] name = "tinytemplate" version = "1.2.1" @@ -1534,6 +2157,23 @@ dependencies = [ "tracing", ] +[[package]] +name = "toml_datetime" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3550f4e9685620ac18a50ed434eb3aec30db8ba93b0287467bca5826ea25baf1" + +[[package]] +name = "toml_edit" +version = "0.20.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70f427fce4d84c72b5b732388bf4a9f4531b53f74e2887e3ecb2481f68f66d81" +dependencies = [ + "indexmap 2.0.2", + "toml_datetime", + "winnow", +] + [[package]] name = "tower" version = "0.4.13" @@ -1545,6 +2185,7 @@ dependencies = [ "pin-project", "pin-project-lite", "tokio", + "tokio-util", "tower-layer", "tower-service", "tracing", @@ -1561,7 +2202,7 @@ dependencies = [ "futures-core", "futures-util", "http", - "http-body", + "http-body 0.4.5", "http-range-header", "pin-project-lite", "tower-layer", @@ -1660,7 +2301,7 @@ dependencies = [ "http", "httparse", "log", - "rand", + "rand 0.8.5", "sha1", "thiserror", "url", @@ -1709,6 +2350,16 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36" +[[package]] +name = "universal-hash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" +dependencies = [ + "crypto-common", + "subtle", +] + [[package]] name = "url" version = "2.4.1" @@ -1768,7 +2419,7 @@ dependencies = [ "futures-util", "headers", "http", - "hyper", + "hyper 0.14.27", "log", "mime", "mime_guess", @@ -1788,6 +2439,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "wasi" +version = "0.9.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -1954,3 +2611,46 @@ name = "windows_x86_64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" + +[[package]] +name = "winnow" +version = "0.5.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3b801d0e0a6726477cc207f60162da452f3a95adb368399bef20a946e06f65c" +dependencies = [ + "memchr", +] + +[[package]] +name = "xml-rs" +version = "0.8.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fcb9cbac069e033553e8bb871be2fbdffcab578eb25bd0f7c508cedc6dcd75a" + +[[package]] +name = "zstd" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bffb3309596d527cfcba7dfc6ed6052f1d39dfbd7c867aa2e865e4a449c10110" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43747c7422e2924c11144d5229878b98180ef8b06cca4ab5af37afc8a8d8ea3e" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.9+zstd.1.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e16efa8a874a0481a574084d34cc26fdb3b99627480f785888deb6386506656" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/e2e/engineioxide/Cargo.toml b/e2e/engineioxide/Cargo.toml index bf850543..f4f56a5e 100644 --- a/e2e/engineioxide/Cargo.toml +++ b/e2e/engineioxide/Cargo.toml @@ -6,7 +6,14 @@ edition = "2021" [dependencies] engineioxide = { path = "../../engineioxide", default-features = false, features = [ "tracing", + "hyper-v1", ] } +hyper-v1 = { package = "hyper", version = "1.0.0-rc.4", features = [ + "server", + "http1", + "http2", +] } +hyper-util = { git = "https://github.com/hyperium/hyper-util.git" } hyper = { version = "0.14.26" } tokio = { version = "1.13.0", features = ["full"] } tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } @@ -15,8 +22,12 @@ serde_json = "1.0.95" futures = "0.3.27" [[bin]] -name = "engineioxide-e2e" -path = "engineioxide.rs" +name = "engineioxide-hyper-v1-e2e" +path = "engineioxide-hyper-v1.rs" + +[[bin]] +name = "engineioxide-hyper-v04-e2e" +path = "engineioxide-hyper-v04.rs" [features] v3 = ["engineioxide/v3"] diff --git a/e2e/engineioxide/engineioxide.rs b/e2e/engineioxide/engineioxide-hyper-v04.rs similarity index 100% rename from e2e/engineioxide/engineioxide.rs rename to e2e/engineioxide/engineioxide-hyper-v04.rs index edf04718..b9ebcc25 100644 --- a/e2e/engineioxide/engineioxide.rs +++ b/e2e/engineioxide/engineioxide-hyper-v04.rs @@ -43,6 +43,7 @@ async fn main() -> Result<(), Box> { .with_line_number(true) .with_max_level(Level::DEBUG) .finish(); + tracing::subscriber::set_global_default(subscriber)?; let config = EngineIoConfig::builder() .ping_interval(Duration::from_millis(300)) @@ -54,7 +55,6 @@ async fn main() -> Result<(), Box> { let svc = EngineIoService::with_config(MyHandler, config); let server = Server::bind(addr).serve(svc.into_make_service()); - tracing::subscriber::set_global_default(subscriber)?; #[cfg(feature = "v3")] tracing::info!("Starting server with v3 protocol"); diff --git a/e2e/engineioxide/engineioxide-hyper-v1.rs b/e2e/engineioxide/engineioxide-hyper-v1.rs new file mode 100644 index 00000000..eb40a21a --- /dev/null +++ b/e2e/engineioxide/engineioxide-hyper-v1.rs @@ -0,0 +1,86 @@ +//! This a end to end test server used with this [test suite](https://github.com/socketio/engine.io-protocol) + +use std::{sync::Arc, time::Duration}; + +use engineioxide::{ + config::EngineIoConfig, + handler::EngineIoHandler, + service::EngineIoService, + socket::{DisconnectReason, Socket}, +}; +use hyper_util::rt::TokioIo; +use hyper_v1::server::conn::http1; +use tokio::net::TcpListener; +use tracing::Level; +use tracing_subscriber::FmtSubscriber; + +#[derive(Debug, Clone)] +struct MyHandler; + +#[engineioxide::async_trait] +impl EngineIoHandler for MyHandler { + type Data = (); + + fn on_connect(&self, socket: Arc>) { + println!("socket connect {}", socket.id); + } + fn on_disconnect(&self, socket: Arc>, reason: DisconnectReason) { + println!("socket disconnect {}: {:?}", socket.id, reason); + } + + fn on_message(&self, msg: String, socket: Arc>) { + println!("Ping pong message {:?}", msg); + socket.emit(msg).ok(); + } + + fn on_binary(&self, data: Vec, socket: Arc>) { + println!("Ping pong binary message {:?}", data); + socket.emit_binary(data).ok(); + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let subscriber = FmtSubscriber::builder() + .with_line_number(true) + .with_max_level(Level::DEBUG) + .finish(); + tracing::subscriber::set_global_default(subscriber)?; + + let config = EngineIoConfig::builder() + .ping_interval(Duration::from_millis(300)) + .ping_timeout(Duration::from_millis(200)) + .max_payload(1e6 as u64) + .build(); + + let svc = EngineIoService::with_config(MyHandler, config).with_hyper_v1(); + + let listener = TcpListener::bind("127.0.0.1:3000").await?; + + #[cfg(feature = "v3")] + tracing::info!("Starting server with v3 protocol"); + #[cfg(feature = "v4")] + tracing::info!("Starting server with v4 protocol"); + + // We start a loop to continuously accept incoming connections + loop { + let (stream, _) = listener.accept().await?; + + // Use an adapter to access something implementing `tokio::io` traits as if they implement + // `hyper::rt` IO traits. + let io = TokioIo::new(stream); + let svc = svc.clone(); + + // Spawn a tokio task to serve multiple connections concurrently + tokio::task::spawn(async move { + // Finally, we bind the incoming connection to our `hello` service + if let Err(err) = http1::Builder::new() + .serve_connection(io, svc) + .with_upgrades() + .await + { + println!("Error serving connection: {:?}", err); + } + }); + } +} diff --git a/engineioxide/Cargo.toml b/engineioxide/Cargo.toml index 4f2b6f6e..007b3843 100644 --- a/engineioxide/Cargo.toml +++ b/engineioxide/Cargo.toml @@ -40,6 +40,15 @@ tracing = { version = "0.1.37", optional = true } memchr = { version = "2.5.0", optional = true } unicode-segmentation = { version = "1.10.1", optional = true } +# Hyper v1.0 +hyper-v1 = { package = "hyper", version = "1.0.0-rc.4", optional = true, features = [ + "server", + "http1", + "http2", +] } +http-body-v1 = { package = "http-body", version = "1.0.0-rc.2", optional = true } +hyper-util = { git = "https://github.com/hyperium/hyper-util.git", optional = true } + [dev-dependencies] tokio = { version = "1.26.0", features = ["macros", "parking_lot"] } tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } @@ -58,3 +67,4 @@ v4 = [] v3 = ["memchr", "unicode-segmentation"] test-utils = [] tracing = ["dep:tracing"] +hyper-v1 = ["dep:hyper-v1", "dep:http-body-v1", "dep:hyper-util"] diff --git a/engineioxide/src/body.rs b/engineioxide/src/body.rs deleted file mode 100644 index edf8764e..00000000 --- a/engineioxide/src/body.rs +++ /dev/null @@ -1,101 +0,0 @@ -use bytes::Bytes; -use http::HeaderMap; -use http_body::{Body, Full, SizeHint}; -use pin_project::pin_project; -use std::pin::Pin; -use std::task::{Context, Poll}; -#[pin_project] -pub struct ResponseBody { - #[pin] - inner: ResponseBodyInner, -} - -impl ResponseBody { - pub fn empty_response() -> Self { - Self { - inner: ResponseBodyInner::EmptyResponse, - } - } - - pub fn custom_response(body: Full) -> Self { - Self { - inner: ResponseBodyInner::CustomBody { body }, - } - } - - pub fn new(body: B) -> Self { - Self { - inner: ResponseBodyInner::Body { body }, - } - } -} - -impl Default for ResponseBody { - fn default() -> Self { - Self::empty_response() - } -} - -#[pin_project(project = BodyProj)] -enum ResponseBodyInner { - EmptyResponse, - CustomBody { - #[pin] - body: Full, - }, - Body { - #[pin] - body: B, - }, -} - -impl Body for ResponseBody -where - B: Body, - B::Error: std::error::Error + 'static, -{ - type Data = Bytes; - type Error = B::Error; - - fn poll_data( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - match self.project().inner.project() { - BodyProj::EmptyResponse => Poll::Ready(None), - BodyProj::Body { body } => body.poll_data(cx), - BodyProj::CustomBody { body } => body.poll_data(cx).map_err(|err| match err {}), - } - } - - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - match self.project().inner.project() { - BodyProj::EmptyResponse => Poll::Ready(Ok(None)), - BodyProj::Body { body } => body.poll_trailers(cx), - BodyProj::CustomBody { body } => body.poll_trailers(cx).map_err(|err| match err {}), - } - } - - fn is_end_stream(&self) -> bool { - match &self.inner { - ResponseBodyInner::EmptyResponse => true, - ResponseBodyInner::Body { body } => body.is_end_stream(), - ResponseBodyInner::CustomBody { body } => body.is_end_stream(), - } - } - - fn size_hint(&self) -> SizeHint { - match &self.inner { - ResponseBodyInner::EmptyResponse => { - let mut hint = SizeHint::default(); - hint.set_upper(0); - hint - } - ResponseBodyInner::Body { body } => body.size_hint(), - ResponseBodyInner::CustomBody { body } => body.size_hint(), - } - } -} diff --git a/engineioxide/src/body/mod.rs b/engineioxide/src/body/mod.rs new file mode 100644 index 00000000..78d60be2 --- /dev/null +++ b/engineioxide/src/body/mod.rs @@ -0,0 +1,4 @@ +#[cfg(feature = "hyper-v1")] +pub mod request; + +pub mod response; diff --git a/engineioxide/src/body/request.rs b/engineioxide/src/body/request.rs new file mode 100644 index 00000000..760459b0 --- /dev/null +++ b/engineioxide/src/body/request.rs @@ -0,0 +1,101 @@ +//! Custom Request Body compat struct implementation to map [`http_body_v1::Body`] to a [`http_body::Body`] +//! Only enabled with the feature flag `hyper-v1` +//! +//! Eavily inspired from : https://github.com/davidpdrsn/tower-hyper-http-body-compat + +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use http::HeaderMap; +use pin_project::pin_project; + +/// Wraps a body to implement the [`http_body::Body`] on it +#[pin_project] +pub struct IncomingBody { + #[pin] + body: B, + trailers: Option, +} +impl IncomingBody { + pub fn new(body: B) -> IncomingBody { + IncomingBody { + body, + trailers: None, + } + } +} + +impl http_body::Body for IncomingBody +where + B: http_body_v1::Body, +{ + type Data = B::Data; + type Error = B::Error; + + fn poll_data( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let this = self.project(); + match futures::ready!(this.body.poll_frame(cx)) { + Some(Ok(frame)) => { + let frame = match frame.into_data() { + Ok(data) => return Poll::Ready(Some(Ok(data))), + Err(frame) => frame, + }; + + match frame.into_trailers() { + Ok(trailers) => { + *this.trailers = Some(trailers); + } + Err(_frame) => {} + } + + Poll::Ready(None) + } + Some(Err(err)) => Poll::Ready(Some(Err(err))), + None => Poll::Ready(None), + } + } + + fn poll_trailers( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + loop { + let this = self.as_mut().project(); + + if let Some(trailers) = this.trailers.take() { + break Poll::Ready(Ok(Some(trailers))); + } + + match futures::ready!(this.body.poll_frame(cx)) { + Some(Ok(frame)) => match frame.into_trailers() { + Ok(trailers) => break Poll::Ready(Ok(Some(trailers))), + // we might get a trailers frame on next poll + // so loop and try again + Err(_frame) => {} + }, + Some(Err(err)) => break Poll::Ready(Err(err)), + None => break Poll::Ready(Ok(None)), + } + } + } + + fn size_hint(&self) -> http_body::SizeHint { + let size_hint = self.body.size_hint(); + let mut out = http_body::SizeHint::new(); + out.set_lower(size_hint.lower()); + if let Some(upper) = size_hint.upper() { + out.set_upper(upper); + } + out + } + + #[inline] + fn is_end_stream(&self) -> bool { + self.body.is_end_stream() + } +} diff --git a/engineioxide/src/body/response.rs b/engineioxide/src/body/response.rs new file mode 100644 index 00000000..72ec7f3a --- /dev/null +++ b/engineioxide/src/body/response.rs @@ -0,0 +1,229 @@ +//! Response Body wrapper in order to return a custom body or the body from the inner service + +use bytes::Bytes; +use http::HeaderMap; +use http_body::{Body, Full, SizeHint}; +use pin_project::pin_project; +use std::convert::Infallible; +use std::pin::Pin; +use std::task::{Context, Poll}; + +#[pin_project(project = BodyProj)] +pub enum ResponseBody { + EmptyResponse, + CustomBody { + #[pin] + body: Full, + }, + Body { + #[pin] + body: B, + }, +} +impl Default for ResponseBody { + fn default() -> Self { + Self::empty_response() + } +} +impl ResponseBody { + pub fn empty_response() -> Self { + ResponseBody::EmptyResponse + } + + pub fn custom_response(body: Full) -> Self { + ResponseBody::CustomBody { body } + } + + pub fn new(body: B) -> Self { + ResponseBody::Body { body } + } +} +impl Body for ResponseBody +where + B: Body, + B::Error: std::error::Error + 'static, +{ + type Data = Bytes; + type Error = B::Error; + + fn poll_data( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + match self.project() { + BodyProj::EmptyResponse => Poll::Ready(None), + BodyProj::Body { body } => body.poll_data(cx), + BodyProj::CustomBody { body } => body.poll_data(cx).map_err(|err| match err {}), + } + } + + fn poll_trailers( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + match self.project() { + BodyProj::EmptyResponse => Poll::Ready(Ok(None)), + BodyProj::Body { body } => body.poll_trailers(cx), + BodyProj::CustomBody { body } => body.poll_trailers(cx).map_err(|err| match err {}), + } + } + + fn is_end_stream(&self) -> bool { + match self { + ResponseBody::EmptyResponse => true, + ResponseBody::Body { body } => body.is_end_stream(), + ResponseBody::CustomBody { body } => body.is_end_stream(), + } + } + + fn size_hint(&self) -> SizeHint { + match self { + ResponseBody::EmptyResponse => { + let mut hint = SizeHint::default(); + hint.set_upper(0); + hint + } + ResponseBody::Body { body } => body.size_hint(), + ResponseBody::CustomBody { body } => body.size_hint(), + } + } +} + +/// Implementation heavily inspired from https://github.com/davidpdrsn/tower-hyper-http-body-compat +#[cfg(feature = "hyper-v1")] +impl http_body_v1::Body for ResponseBody +where + B: http_body_v1::Body, + B::Error: std::error::Error + 'static, +{ + type Data = B::Data; + + type Error = B::Error; + + fn is_end_stream(&self) -> bool { + match &self { + ResponseBody::EmptyResponse => true, + ResponseBody::Body { body } => body.is_end_stream(), + ResponseBody::CustomBody { body } => body.is_end_stream(), + } + } + + fn size_hint(&self) -> http_body_v1::SizeHint { + match &self { + ResponseBody::EmptyResponse => { + let mut hint = http_body_v1::SizeHint::default(); + hint.set_upper(0); + hint + } + ResponseBody::Body { body } => body.size_hint(), + ResponseBody::CustomBody { body } => { + let size_hint = body.size_hint(); + let mut out = http_body_v1::SizeHint::new(); + out.set_lower(size_hint.lower()); + if let Some(upper) = size_hint.upper() { + out.set_upper(upper); + } + out + } + } + } + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + match self.project() { + BodyProj::EmptyResponse => Poll::Ready(None), + BodyProj::Body { body } => body.poll_frame(cx), + BodyProj::CustomBody { mut body } => { + match body.as_mut().poll_data(cx) { + Poll::Ready(Some(Ok(buf))) => { + return Poll::Ready(Some(Ok(http_body_v1::Frame::data(buf)))) + } + Poll::Ready(Some(Err(_))) => unreachable!("unreachable error!"), + Poll::Ready(None) => {} + Poll::Pending => return Poll::Pending, + } + + match body.as_mut().poll_trailers(cx) { + Poll::Ready(Ok(Some(trailers))) => { + Poll::Ready(Some(Ok(http_body_v1::Frame::trailers(trailers)))) + } + Poll::Ready(Ok(None)) => Poll::Ready(None), + Poll::Ready(Err(_)) => unreachable!("unreachable error!"), + Poll::Pending => Poll::Pending, + } + } + } + } +} + +/// A body that is always empty and that implements [`http_body::Body`] and [`http_body_v1::Body`]. +pub struct Empty { + _marker: std::marker::PhantomData D>, +} + +impl http_body::Body for Empty { + type Data = D; + + type Error = Infallible; + + #[inline(always)] + fn poll_data( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll>> { + Poll::Ready(None) + } + + #[inline(always)] + fn poll_trailers( + self: Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll, Self::Error>> { + Poll::Ready(Ok(None)) + } + + #[inline(always)] + fn is_end_stream(&self) -> bool { + true + } +} + +#[cfg(feature = "hyper-v1")] +impl http_body_v1::Body for Empty { + type Data = D; + type Error = Infallible; + + #[inline(always)] + fn poll_frame( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + Poll::Ready(None) + } + + #[inline(always)] + fn is_end_stream(&self) -> bool { + true + } + + #[inline(always)] + fn size_hint(&self) -> http_body_v1::SizeHint { + http_body_v1::SizeHint::with_exact(0) + } +} + +impl std::fmt::Debug for Empty { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Empty").finish() + } +} + +impl Default for Empty { + fn default() -> Self { + Self { + _marker: std::marker::PhantomData, + } + } +} diff --git a/engineioxide/src/config.rs b/engineioxide/src/config.rs index a1fd657e..5ae40580 100644 --- a/engineioxide/src/config.rs +++ b/engineioxide/src/config.rs @@ -1,6 +1,6 @@ use std::time::Duration; -pub use crate::transport::TransportType; +pub use crate::service::TransportType; #[derive(Debug, Clone)] pub struct EngineIoConfig { diff --git a/engineioxide/src/engine.rs b/engineioxide/src/engine.rs index 2a63a5de..38b2f0c8 100644 --- a/engineioxide/src/engine.rs +++ b/engineioxide/src/engine.rs @@ -6,8 +6,8 @@ use std::{ use crate::{ config::EngineIoConfig, handler::EngineIoHandler, + service::TransportType, socket::{DisconnectReason, Socket, SocketReq}, - transport::TransportType, }; use crate::{service::ProtocolVersion, sid::Sid}; diff --git a/engineioxide/src/errors.rs b/engineioxide/src/errors.rs index 7ede6972..becdd2a3 100644 --- a/engineioxide/src/errors.rs +++ b/engineioxide/src/errors.rs @@ -2,8 +2,9 @@ use http::{Response, StatusCode}; use tokio::sync::mpsc; use tokio_tungstenite::tungstenite; +use crate::body::response::ResponseBody; +use crate::packet::Packet; use crate::sid::Sid; -use crate::{body::ResponseBody, packet::Packet}; #[derive(thiserror::Error, Debug)] pub enum Error { @@ -35,16 +36,10 @@ pub enum Error { #[error("http error response: {0:?}")] HttpErrorResponse(StatusCode), - #[error("transport unknown")] - UnknownTransport, #[error("unknown session id")] UnknownSessionID(Sid), - #[error("bad handshake method")] - BadHandshakeMethod, #[error("transport mismatch")] TransportMismatch, - #[error("unsupported protocol version")] - UnsupportedProtocolVersion, #[error("payload too large")] PayloadTooLarge, @@ -77,21 +72,15 @@ impl From for Response> { .status(413) .body(ResponseBody::empty_response()) .unwrap(), - Error::UnknownTransport => { - conn_err_resp("{\"code\":\"0\",\"message\":\"Transport unknown\"}") - } + Error::UnknownSessionID(_) => { conn_err_resp("{\"code\":\"1\",\"message\":\"Session ID unknown\"}") } - Error::BadHandshakeMethod => { - conn_err_resp("{\"code\":\"2\",\"message\":\"Bad handshake method\"}") - } + Error::TransportMismatch => { conn_err_resp("{\"code\":\"3\",\"message\":\"Bad request\"}") } - Error::UnsupportedProtocolVersion => { - conn_err_resp("{\"code\":\"5\",\"message\":\"Unsupported protocol version\"}") - } + _e => { #[cfg(feature = "tracing")] tracing::debug!("uncaught error {_e:?}"); diff --git a/engineioxide/src/futures.rs b/engineioxide/src/futures.rs deleted file mode 100644 index 545c7985..00000000 --- a/engineioxide/src/futures.rs +++ /dev/null @@ -1,118 +0,0 @@ -use crate::body::ResponseBody; -use crate::errors::Error; -use bytes::Bytes; -use futures::ready; -use http::header::{CONNECTION, CONTENT_LENGTH, CONTENT_TYPE, SEC_WEBSOCKET_ACCEPT, UPGRADE}; -use http::{HeaderValue, Response, StatusCode}; -use http_body::{Body, Full}; -use pin_project::pin_project; -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; -use tokio_tungstenite::tungstenite::handshake::derive_accept_key; - -pub(crate) type BoxFuture = - Pin>, Error>> + Send>>; - -/// Create a response for http request -pub fn http_response( - code: StatusCode, - data: D, - is_binary: bool, -) -> Result>, http::Error> -where - D: Into, -{ - let body: Bytes = data.into(); - let res = Response::builder() - .status(code) - .header(CONTENT_LENGTH, body.len()); - if is_binary { - res.header(CONTENT_TYPE, "application/octet-stream") - } else { - res.header(CONTENT_TYPE, "text/plain; charset=UTF-8") - } - .body(ResponseBody::custom_response(Full::new(body))) -} - -/// Create a response for websocket upgrade -pub fn ws_response(ws_key: &HeaderValue) -> Result>, http::Error> { - let derived = derive_accept_key(ws_key.as_bytes()); - let sec = derived.parse::().unwrap(); - Response::builder() - .status(StatusCode::SWITCHING_PROTOCOLS) - .header(UPGRADE, HeaderValue::from_static("websocket")) - .header(CONNECTION, HeaderValue::from_static("Upgrade")) - .header(SEC_WEBSOCKET_ACCEPT, sec) - .body(ResponseBody::empty_response()) -} - -#[pin_project] -pub struct ResponseFuture { - #[pin] - inner: ResponseFutureInner, -} - -impl ResponseFuture { - pub fn empty_response(code: u16) -> Self { - Self { - inner: ResponseFutureInner::EmptyResponse { code }, - } - } - pub fn ready(res: Result>, Error>) -> Self { - Self { - inner: ResponseFutureInner::ReadyResponse { res: Some(res) }, - } - } - pub fn new(future: F) -> Self { - Self { - inner: ResponseFutureInner::Future { future }, - } - } - pub fn async_response(future: BoxFuture) -> Self { - Self { - inner: ResponseFutureInner::AsyncResponse { future }, - } - } -} -#[pin_project(project = ResFutProj)] -enum ResponseFutureInner { - EmptyResponse { - code: u16, - }, - ReadyResponse { - res: Option>, Error>>, - }, - AsyncResponse { - future: BoxFuture, - }, - Future { - #[pin] - future: F, - }, -} - -impl Future for ResponseFuture -where - ResBody: Body, - F: Future, E>>, -{ - type Output = Result>, E>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let res = match self.project().inner.project() { - ResFutProj::Future { future } => ready!(future.poll(cx))?.map(ResponseBody::new), - - ResFutProj::EmptyResponse { code } => Response::builder() - .status(*code) - .body(ResponseBody::empty_response()) - .unwrap(), - ResFutProj::AsyncResponse { future } => ready!(future - .as_mut() - .poll(cx) - .map(|r| r.unwrap_or_else(|e| e.into()))), - ResFutProj::ReadyResponse { res } => res.take().unwrap().unwrap_or_else(|e| e.into()), - }; - Poll::Ready(Ok(res)) - } -} diff --git a/engineioxide/src/layer.rs b/engineioxide/src/layer.rs index a5e16d12..5a146416 100644 --- a/engineioxide/src/layer.rs +++ b/engineioxide/src/layer.rs @@ -1,6 +1,10 @@ use tower::Layer; -use crate::{config::EngineIoConfig, handler::EngineIoHandler, service::EngineIoService}; +use crate::{ + config::EngineIoConfig, + handler::EngineIoHandler, + service::{self, EngineIoService}, +}; #[derive(Debug, Clone)] pub struct EngineIoLayer { @@ -18,6 +22,12 @@ impl EngineIoLayer { pub fn from_config(handler: H, config: EngineIoConfig) -> Self { Self { config, handler } } + + #[cfg(feature = "hyper-v1")] + #[inline(always)] + pub fn with_hyper_v1(self) -> EngineIoHyperLayer { + EngineIoHyperLayer(self) + } } impl Layer for EngineIoLayer { @@ -27,3 +37,17 @@ impl Layer for EngineIoLayer { EngineIoService::with_config_inner(inner, self.handler.clone(), self.config.clone()) } } + +#[cfg(feature = "hyper-v1")] +#[derive(Debug, Clone)] +pub struct EngineIoHyperLayer(EngineIoLayer); + +#[cfg(feature = "hyper-v1")] +impl Layer for EngineIoHyperLayer { + type Service = service::hyper_v1::EngineIoHyperService; + + fn layer(&self, inner: S) -> Self::Service { + EngineIoService::with_config_inner(inner, self.0.handler.clone(), self.0.config.clone()) + .with_hyper_v1() + } +} diff --git a/engineioxide/src/lib.rs b/engineioxide/src/lib.rs index fef5a123..34a0bd26 100644 --- a/engineioxide/src/lib.rs +++ b/engineioxide/src/lib.rs @@ -1,7 +1,9 @@ pub use async_trait::async_trait; +pub use service::{ProtocolVersion, TransportType}; /// A Packet type to use when sending data to the client pub use socket::{DisconnectReason, Socket, SocketReq}; + #[cfg(not(any(feature = "v3", feature = "v4")))] compile_error!("At least one protocol version must be enabled"); @@ -13,11 +15,8 @@ pub mod service; pub mod sid; pub mod socket; -pub use service::ProtocolVersion; - mod body; mod engine; -mod futures; mod packet; mod peekable; mod transport; diff --git a/engineioxide/src/packet.rs b/engineioxide/src/packet.rs index 9000905f..ac23cbd7 100644 --- a/engineioxide/src/packet.rs +++ b/engineioxide/src/packet.rs @@ -3,7 +3,7 @@ use serde::{de::Error, Deserialize, Serialize}; use crate::config::EngineIoConfig; use crate::sid::Sid; -use crate::transport::TransportType; +use crate::TransportType; /// A Packet type to use when receiving and sending data from the client #[derive(Debug, PartialEq, PartialOrd)] diff --git a/engineioxide/src/service.rs b/engineioxide/src/service.rs deleted file mode 100644 index abadcd34..00000000 --- a/engineioxide/src/service.rs +++ /dev/null @@ -1,443 +0,0 @@ -use crate::{ - body::ResponseBody, - config::EngineIoConfig, - engine::EngineIo, - errors::Error, - futures::ResponseFuture, - handler::EngineIoHandler, - sid::Sid, - transport::{polling, ws, TransportType}, -}; -use bytes::Bytes; -use futures::future::{ready, Ready}; -use http::{Method, Request}; -use http_body::{Body, Empty}; -use hyper::{service::Service, Response}; -use std::{ - convert::Infallible, - fmt::Debug, - str::FromStr, - sync::Arc, - task::{Context, Poll}, -}; - -/// A [`Service`] that handles engine.io requests as a middleware. -/// If the request is not an engine.io request, it forwards it to the inner service. -/// If it is an engine.io request it will forward it to the appropriate [`transport`](crate::transport). -/// -/// By default, it uses a [`NotFoundService`] as the inner service so it can be used as a standalone [`Service`]. -pub struct EngineIoService { - inner: S, - engine: Arc>, -} - -impl EngineIoService { - /// Create a new [`EngineIoService`] with a [`NotFoundService`] as the inner service. - /// If the request is not an `EngineIo` request, it will always return a 404 response. - pub fn new(handler: H) -> Self { - EngineIoService::with_config(handler, EngineIoConfig::default()) - } - /// Create a new [`EngineIoService`] with a custom config - pub fn with_config(handler: H, config: EngineIoConfig) -> Self { - EngineIoService::with_config_inner(NotFoundService, handler, config) - } -} - -impl EngineIoService { - /// Create a new [`EngineIoService`] with a custom inner service. - pub fn with_inner(inner: S, handler: H) -> Self { - EngineIoService::with_config_inner(inner, handler, EngineIoConfig::default()) - } - - /// Create a new [`EngineIoService`] with a custom inner service and a custom config. - pub fn with_config_inner(inner: S, handler: H, config: EngineIoConfig) -> Self { - EngineIoService { - inner, - engine: Arc::new(EngineIo::new(handler, config)), - } - } - - /// Convert this [`EngineIoService`] into a [`MakeEngineIoService`]. - /// This is useful when using [`EngineIoService`] without layers. - pub fn into_make_service(self) -> MakeEngineIoService { - MakeEngineIoService::new(self) - } -} - -impl Clone for EngineIoService { - fn clone(&self) -> Self { - EngineIoService { - inner: self.inner.clone(), - engine: self.engine.clone(), - } - } -} - -/// The service implementation for [`EngineIoService`]. -impl Service> for EngineIoService -where - ResBody: Body + Send + 'static, - ReqBody: Body + Send + Unpin + 'static + Debug, - ::Error: Debug, - ::Data: Send, - S: Service, Response = Response>, - H: EngineIoHandler, -{ - type Response = Response>; - type Error = S::Error; - type Future = ResponseFuture; - - fn poll_ready(&mut self, cx: &mut Context) -> Poll> { - self.inner.poll_ready(cx) - } - - /// Handle the request. - /// Each request is parsed to a [`RequestInfo`] - /// If the request is an `EngineIo` request, it is handled by the corresponding [`transport`](crate::transport). - /// Otherwise, it is forwarded to the inner service. - fn call(&mut self, req: Request) -> Self::Future { - if req.uri().path().starts_with(&self.engine.config.req_path) { - let engine = self.engine.clone(); - match RequestInfo::parse(&req, &self.engine.config) { - Ok(RequestInfo { - protocol, - sid: None, - transport: TransportType::Polling, - method: Method::GET, - #[cfg(feature = "v3")] - b64, - }) => ResponseFuture::ready(polling::open_req( - engine, - protocol, - req, - #[cfg(feature = "v3")] - !b64, - )), - Ok(RequestInfo { - protocol, - sid: Some(sid), - transport: TransportType::Polling, - method: Method::GET, - .. - }) => ResponseFuture::async_response(Box::pin(polling::polling_req( - engine, protocol, sid, - ))), - Ok(RequestInfo { - protocol, - sid: Some(sid), - transport: TransportType::Polling, - method: Method::POST, - .. - }) => ResponseFuture::async_response(Box::pin(polling::post_req( - engine, protocol, sid, req, - ))), - Ok(RequestInfo { - protocol, - sid, - transport: TransportType::Websocket, - method: Method::GET, - .. - }) => ResponseFuture::ready(ws::new_req(engine, protocol, sid, req)), - Err(e) => { - #[cfg(feature = "tracing")] - tracing::debug!("error parsing request: {:?}", e); - ResponseFuture::ready(Ok(e.into())) - } - _req => { - #[cfg(feature = "tracing")] - tracing::debug!("invalid request: {:?}", _req); - ResponseFuture::empty_response(400) - } - } - } else { - ResponseFuture::new(self.inner.call(req)) - } - } -} - -impl Debug for EngineIoService { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("EngineIoService").finish() - } -} - -/// A MakeService that always returns a clone of the [`EngineIoService`] it was created with. -pub struct MakeEngineIoService { - svc: EngineIoService, -} - -impl MakeEngineIoService { - /// Create a new [`MakeEngineIoService`] with a custom inner service. - pub fn new(svc: EngineIoService) -> Self { - MakeEngineIoService { svc } - } -} - -impl Service for MakeEngineIoService { - type Response = EngineIoService; - - type Error = Infallible; - - type Future = Ready>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, _req: T) -> Self::Future { - ready(Ok(self.svc.clone())) - } -} - -/// A [`Service`] that always returns a 404 response and that is compatible with [`EngineIoService`]. -#[derive(Debug, Clone)] -pub struct NotFoundService; -impl Service> for NotFoundService -where - ReqBody: Body + Send + 'static + Debug, - ::Error: Debug, - ::Data: Send, -{ - type Response = Response>>; - type Error = Infallible; - type Future = Ready>>, Infallible>>; - - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, _: Request) -> Self::Future { - ready(Ok(Response::builder() - .status(404) - .body(ResponseBody::empty_response()) - .unwrap())) - } -} - -/// The protocol version used by the client. -#[derive(Debug, Copy, Clone, PartialEq)] -pub enum ProtocolVersion { - V3 = 3, - V4 = 4, -} - -impl FromStr for ProtocolVersion { - type Err = Error; - - #[cfg(all(feature = "v3", feature = "v4"))] - fn from_str(s: &str) -> Result { - match s { - "3" => Ok(ProtocolVersion::V3), - "4" => Ok(ProtocolVersion::V4), - _ => Err(Error::UnsupportedProtocolVersion), - } - } - - #[cfg(all(feature = "v4", not(feature = "v3")))] - fn from_str(s: &str) -> Result { - match s { - "4" => Ok(ProtocolVersion::V4), - _ => Err(Error::UnsupportedProtocolVersion), - } - } - - #[cfg(all(feature = "v3", not(feature = "v4")))] - fn from_str(s: &str) -> Result { - match s { - "3" => Ok(ProtocolVersion::V3), - _ => Err(Error::UnsupportedProtocolVersion), - } - } -} - -/// The request information extracted from the request URI. -#[derive(Debug)] -struct RequestInfo { - /// The protocol version used by the client. - protocol: ProtocolVersion, - /// The socket id if present in the request. - sid: Option, - /// The transport type used by the client. - transport: TransportType, - /// The request method. - method: Method, - /// If the client asked for base64 encoding only. - #[cfg(feature = "v3")] - b64: bool, -} - -impl RequestInfo { - /// Parse the request URI to extract the [`TransportType`](crate::service::TransportType) and the socket id. - fn parse(req: &Request, config: &EngineIoConfig) -> Result { - let query = req.uri().query().ok_or(Error::UnknownTransport)?; - - let protocol: ProtocolVersion = query - .split('&') - .find(|s| s.starts_with("EIO=")) - .and_then(|s| s.split('=').nth(1)) - .ok_or(Error::UnsupportedProtocolVersion) - .and_then(|t| t.parse())?; - - let sid = query - .split('&') - .find(|s| s.starts_with("sid=")) - .and_then(|s| s.split('=').nth(1).map(|s1| s1.parse().ok())) - .flatten(); - - let transport: TransportType = query - .split('&') - .find(|s| s.starts_with("transport=")) - .and_then(|s| s.split('=').nth(1)) - .ok_or(Error::UnknownTransport) - .and_then(|t| t.parse())?; - - if !config.allowed_transport(transport) { - return Err(Error::TransportMismatch); - } - - #[cfg(feature = "v3")] - let b64: bool = query - .split('&') - .find(|s| s.starts_with("b64=")) - .map(|_| true) - .unwrap_or_default(); - - let method = req.method().clone(); - if !matches!(method, Method::GET) && sid.is_none() { - Err(Error::BadHandshakeMethod) - } else { - Ok(RequestInfo { - protocol, - sid, - transport, - method, - #[cfg(feature = "v3")] - b64, - }) - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - fn build_request(path: &str) -> Request<()> { - Request::get(path).body(()).unwrap() - } - - #[test] - #[cfg(feature = "v4")] - fn request_info_polling() { - let req = build_request("http://localhost:3000/socket.io/?EIO=4&transport=polling"); - let info = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap(); - assert_eq!(info.sid, None); - assert_eq!(info.transport, TransportType::Polling); - assert_eq!(info.protocol, ProtocolVersion::V4); - assert_eq!(info.method, Method::GET); - } - - #[test] - #[cfg(feature = "v4")] - fn request_info_websocket() { - let req = build_request("http://localhost:3000/socket.io/?EIO=4&transport=websocket"); - let info = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap(); - assert_eq!(info.sid, None); - assert_eq!(info.transport, TransportType::Websocket); - assert_eq!(info.protocol, ProtocolVersion::V4); - assert_eq!(info.method, Method::GET); - } - - #[test] - #[cfg(feature = "v3")] - fn request_info_polling_with_sid() { - let req = build_request( - "http://localhost:3000/socket.io/?EIO=3&transport=polling&sid=AAAAAAAAAAAAAAHs", - ); - let info = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap(); - assert_eq!(info.sid, Some("AAAAAAAAAAAAAAHs".parse().unwrap())); - assert_eq!(info.transport, TransportType::Polling); - assert_eq!(info.protocol, ProtocolVersion::V3); - assert_eq!(info.method, Method::GET); - } - - #[test] - #[cfg(feature = "v4")] - fn request_info_websocket_with_sid() { - let req = build_request( - "http://localhost:3000/socket.io/?EIO=4&transport=websocket&sid=AAAAAAAAAAAAAAHs", - ); - let info = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap(); - assert_eq!(info.sid, Some("AAAAAAAAAAAAAAHs".parse().unwrap())); - assert_eq!(info.transport, TransportType::Websocket); - assert_eq!(info.protocol, ProtocolVersion::V4); - assert_eq!(info.method, Method::GET); - } - - #[test] - #[cfg(feature = "v3")] - fn request_info_polling_with_bin_by_default() { - let req = build_request("http://localhost:3000/socket.io/?EIO=3&transport=polling"); - let req = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap(); - assert!(!req.b64); - } - - #[test] - #[cfg(feature = "v3")] - fn request_info_polling_withb64() { - assert!(cfg!(feature = "v3")); - - let req = build_request("http://localhost:3000/socket.io/?EIO=3&transport=polling&b64=1"); - let req = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap(); - assert!(req.b64); - } - - #[test] - #[cfg(feature = "v4")] - fn transport_unknown_err() { - let req = build_request("http://localhost:3000/socket.io/?EIO=4&transport=grpc"); - let err = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap_err(); - assert!(matches!(err, Error::UnknownTransport)); - } - #[test] - fn unsupported_protocol_version() { - let req = build_request("http://localhost:3000/socket.io/?EIO=2&transport=polling"); - let err = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap_err(); - assert!(matches!(err, Error::UnsupportedProtocolVersion)); - } - #[test] - #[cfg(feature = "v4")] - fn bad_handshake_method() { - let req = Request::post("http://localhost:3000/socket.io/?EIO=4&transport=polling") - .body(()) - .unwrap(); - let err = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap_err(); - assert!(matches!(err, Error::BadHandshakeMethod)); - } - - #[test] - #[cfg(feature = "v4")] - fn unsupported_transport() { - let req = build_request("http://localhost:3000/socket.io/?EIO=4&transport=polling"); - let err = RequestInfo::parse( - &req, - &EngineIoConfig::builder() - .transports([TransportType::Websocket]) - .build(), - ) - .unwrap_err(); - - assert!(matches!(err, Error::TransportMismatch)); - - let req = build_request("http://localhost:3000/socket.io/?EIO=4&transport=websocket"); - let err = RequestInfo::parse( - &req, - &EngineIoConfig::builder() - .transports([TransportType::Polling]) - .build(), - ) - .unwrap_err(); - - assert!(matches!(err, Error::TransportMismatch)) - } -} diff --git a/engineioxide/src/service/futures.rs b/engineioxide/src/service/futures.rs new file mode 100644 index 00000000..2b8c0060 --- /dev/null +++ b/engineioxide/src/service/futures.rs @@ -0,0 +1,67 @@ +use crate::body::response::ResponseBody; +use crate::errors::Error; +use futures::ready; +use http::Response; +use pin_project::pin_project; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pub(crate) type BoxFuture = + Pin>, Error>> + Send>>; + +#[pin_project(project = ResFutProj)] +pub enum ResponseFuture { + EmptyResponse { + code: u16, + }, + ReadyResponse { + res: Option>, Error>>, + }, + AsyncResponse { + future: BoxFuture, + }, + Future { + #[pin] + future: F, + }, +} + +impl ResponseFuture { + pub fn empty_response(code: u16) -> Self { + ResponseFuture::EmptyResponse { code } + } + pub fn ready(res: Result>, Error>) -> Self { + ResponseFuture::ReadyResponse { res: Some(res) } + } + pub fn new(future: F) -> Self { + ResponseFuture::Future { future } + } + pub fn async_response(future: BoxFuture) -> Self { + ResponseFuture::AsyncResponse { future } + } +} + +impl Future for ResponseFuture +where + F: Future, E>>, +{ + type Output = Result>, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let res = match self.project() { + ResFutProj::Future { future } => ready!(future.poll(cx))?.map(ResponseBody::new), + + ResFutProj::EmptyResponse { code } => Response::builder() + .status(*code) + .body(ResponseBody::empty_response()) + .unwrap(), + ResFutProj::AsyncResponse { future } => ready!(future + .as_mut() + .poll(cx) + .map(|r| r.unwrap_or_else(|e| e.into()))), + ResFutProj::ReadyResponse { res } => res.take().unwrap().unwrap_or_else(|e| e.into()), + }; + Poll::Ready(Ok(res)) + } +} diff --git a/engineioxide/src/service/hyper_v1.rs b/engineioxide/src/service/hyper_v1.rs new file mode 100644 index 00000000..38757488 --- /dev/null +++ b/engineioxide/src/service/hyper_v1.rs @@ -0,0 +1,117 @@ +//! Implement Services for hyper 1.0 +//! Only enabled with feature flag `hyper-v1` +use crate::{ + body::{ + request::IncomingBody, + response::{Empty, ResponseBody}, + }, + handler::EngineIoHandler, +}; +use bytes::Bytes; +use futures::future::{self, Ready}; +use http::Request; +use hyper::Response; +use hyper_v1::body::Incoming; +use std::{ + convert::Infallible, + task::{Context, Poll}, +}; + +use super::{futures::ResponseFuture, parser::dispatch_req, EngineIoService, NotFoundService}; + +/// A wrapper of [`EngineIoService`] that handles engine.io requests as a middleware from hyper-v1. +pub struct EngineIoHyperService(EngineIoService); +impl EngineIoHyperService +where + H: EngineIoHandler, +{ + pub(crate) fn new(svc: EngineIoService) -> Self { + EngineIoHyperService(svc) + } +} + +/// Tower Service implementation with an [`Incoming`] body and a [`http_body_v1::Body`] Body for hyper 1.0 +impl tower::Service> for EngineIoHyperService +where + ResBody: http_body_v1::Body + Send + 'static, + ReqBody: http_body_v1::Body + Send + Unpin + 'static + std::fmt::Debug, + ReqBody::Error: std::fmt::Debug, + ReqBody::Data: Send, + S: tower::Service, Response = Response>, + H: EngineIoHandler, +{ + type Response = Response>; + type Error = S::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.0.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + if req.uri().path().starts_with(&self.0.engine.config.req_path) { + let req = req.map(IncomingBody::new); + dispatch_req( + req, + self.0.engine.clone(), + #[cfg(feature = "hyper-v1")] + true, // hyper-v1 enabled + ) + } else { + ResponseFuture::new(self.0.inner.call(req)) + } + } +} + +/// Hyper 1.0 Service implementation with an [`Incoming`] body and a [`http_body_v1::Body`] Body +impl hyper_v1::service::Service> for EngineIoHyperService +where + ResBody: http_body_v1::Body + Send + 'static, + S: hyper_v1::service::Service, Response = Response>, + H: EngineIoHandler, +{ + type Response = Response>; + type Error = S::Error; + type Future = ResponseFuture; + + fn call(&self, req: Request) -> Self::Future { + if req.uri().path().starts_with(&self.0.engine.config.req_path) { + let req = req.map(IncomingBody::new); + dispatch_req( + req, + self.0.engine.clone(), + true, // hyper-v1 enabled + ) + } else { + ResponseFuture::new(self.0.inner.call(req)) + } + } +} + +impl Clone for EngineIoHyperService { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl std::fmt::Debug for EngineIoHyperService { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("EngineIoHyperService") + .field(&self.0) + .finish() + } +} + +/// Implement a custom [`hyper_v1::service::Service`] for the [`NotFoundService`] +impl hyper_v1::service::Service> for NotFoundService { + type Response = Response>>; + type Error = Infallible; + type Future = Ready>>, Infallible>>; + + fn call(&self, _: Request) -> Self::Future { + future::ready(Ok(Response::builder() + .status(404) + .body(ResponseBody::empty_response()) + .unwrap())) + } +} diff --git a/engineioxide/src/service/mod.rs b/engineioxide/src/service/mod.rs new file mode 100644 index 00000000..fd8c07a0 --- /dev/null +++ b/engineioxide/src/service/mod.rs @@ -0,0 +1,171 @@ +use std::{ + convert::Infallible, + sync::Arc, + task::{Context, Poll}, +}; + +use ::futures::future::{self, Ready}; +use bytes::Bytes; +use http::{Request, Response}; +use http_body::Body; +use tower::Service; + +use crate::{ + body::response::{Empty, ResponseBody}, + config::EngineIoConfig, + engine::EngineIo, + handler::EngineIoHandler, +}; + +#[cfg(feature = "hyper-v1")] +pub mod hyper_v1; + +mod futures; +mod parser; + +pub use self::parser::{ProtocolVersion, TransportType}; +use self::{futures::ResponseFuture, parser::dispatch_req}; + +/// A [`Service`] that handles engine.io requests as a middleware. +/// If the request is not an engine.io request, it forwards it to the inner service. +/// If it is an engine.io request it will forward it to the appropriate [`transport`](crate::transport). +/// +/// By default, it uses a [`NotFoundService`] as the inner service so it can be used as a standalone [`Service`]. +pub struct EngineIoService { + inner: S, + engine: Arc>, +} + +impl EngineIoService { + /// Create a new [`EngineIoService`] with a [`NotFoundService`] as the inner service. + /// If the request is not an `EngineIo` request, it will always return a 404 response. + pub fn new(handler: H) -> Self { + EngineIoService::with_config(handler, EngineIoConfig::default()) + } + /// Create a new [`EngineIoService`] with a custom config + pub fn with_config(handler: H, config: EngineIoConfig) -> Self { + EngineIoService::with_config_inner(NotFoundService, handler, config) + } +} + +impl EngineIoService { + #[cfg(feature = "hyper-v1")] + #[inline(always)] + pub fn with_hyper_v1(self) -> hyper_v1::EngineIoHyperService { + hyper_v1::EngineIoHyperService::new(self) + } + /// Create a new [`EngineIoService`] with a custom inner service. + pub fn with_inner(inner: S, handler: H) -> Self { + EngineIoService::with_config_inner(inner, handler, EngineIoConfig::default()) + } + + /// Create a new [`EngineIoService`] with a custom inner service and a custom config. + pub fn with_config_inner(inner: S, handler: H, config: EngineIoConfig) -> Self { + EngineIoService { + inner, + engine: Arc::new(EngineIo::new(handler, config)), + } + } + + /// Convert this [`EngineIoService`] into a [`MakeEngineIoService`]. + /// This is useful when using [`EngineIoService`] without layers. + pub fn into_make_service(self) -> MakeEngineIoService { + MakeEngineIoService::new(self) + } +} + +impl Clone for EngineIoService { + fn clone(&self) -> Self { + EngineIoService { + inner: self.inner.clone(), + engine: self.engine.clone(), + } + } +} +impl std::fmt::Debug for EngineIoService { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("EngineIoService").finish() + } +} + +/// The service implementation for [`EngineIoService`]. +impl Service> for EngineIoService +where + ResBody: Body + Send + 'static, + ReqBody: Body + Send + Unpin + 'static + std::fmt::Debug, + ReqBody::Error: std::fmt::Debug, + ReqBody::Data: Send, + S: tower::Service, Response = Response>, + H: EngineIoHandler, +{ + type Response = Response>; + type Error = S::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + if req.uri().path().starts_with(&self.engine.config.req_path) { + dispatch_req( + req, + self.engine.clone(), + #[cfg(feature = "hyper-v1")] + false, // hyper-v1 disabled + ) + } else { + ResponseFuture::new(self.inner.call(req)) + } + } +} + +/// A MakeService that always returns a clone of the [`EngineIoService`] it was created with. +pub struct MakeEngineIoService { + svc: EngineIoService, +} + +impl MakeEngineIoService { + /// Create a new [`MakeEngineIoService`] with a custom inner service. + pub fn new(svc: EngineIoService) -> Self { + MakeEngineIoService { svc } + } +} + +impl Service for MakeEngineIoService { + type Response = EngineIoService; + + type Error = Infallible; + + type Future = Ready>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _req: T) -> Self::Future { + future::ready(Ok(self.svc.clone())) + } +} + +/// A [`Service`] that always returns a 404 response and that is compatible with [`EngineIoService`] +/// *and* a [`hyper_v1::EngineIoHyperService`]. +#[derive(Debug, Clone)] +pub struct NotFoundService; + +impl Service> for NotFoundService { + type Response = Response>>; + type Error = Infallible; + type Future = Ready>>, Infallible>>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _: Request) -> Self::Future { + future::ready(Ok(Response::builder() + .status(404) + .body(ResponseBody::empty_response()) + .unwrap())) + } +} diff --git a/engineioxide/src/service/parser.rs b/engineioxide/src/service/parser.rs new file mode 100644 index 00000000..6fb84cba --- /dev/null +++ b/engineioxide/src/service/parser.rs @@ -0,0 +1,390 @@ +//! A Parser module to parse any `EngineIo` query + +use std::{str::FromStr, sync::Arc}; + +use futures::Future; +use http::{Method, Request, Response}; + +use crate::{ + body::response::ResponseBody, + config::EngineIoConfig, + engine::EngineIo, + handler::EngineIoHandler, + service::futures::ResponseFuture, + sid::Sid, + transport::{polling, ws}, +}; + +/// Dispatch a request according to the [`RequestInfo`] to the appropriate [`transport`](crate::transport). +pub fn dispatch_req( + req: Request, + engine: Arc>, + #[cfg(feature = "hyper-v1")] hyper_v1: bool, +) -> ResponseFuture +where + ReqBody: http_body::Body + Send + Unpin + 'static, + ReqBody::Data: Send, + ReqBody::Error: std::fmt::Debug, + ResBody: Send + 'static, + H: EngineIoHandler, + F: Future, +{ + match RequestInfo::parse(&req, &engine.config) { + Ok(RequestInfo { + protocol, + sid: None, + transport: TransportType::Polling, + method: Method::GET, + #[cfg(feature = "v3")] + b64, + }) => ResponseFuture::ready(polling::open_req( + engine, + protocol, + req, + #[cfg(feature = "v3")] + !b64, + )), + Ok(RequestInfo { + protocol, + sid: Some(sid), + transport: TransportType::Polling, + method: Method::GET, + .. + }) => ResponseFuture::async_response(Box::pin(polling::polling_req(engine, protocol, sid))), + Ok(RequestInfo { + protocol, + sid: Some(sid), + transport: TransportType::Polling, + method: Method::POST, + .. + }) => { + ResponseFuture::async_response(Box::pin(polling::post_req(engine, protocol, sid, req))) + } + Ok(RequestInfo { + protocol, + sid, + transport: TransportType::Websocket, + method: Method::GET, + .. + }) => ResponseFuture::ready(ws::new_req( + engine, + protocol, + sid, + req, + #[cfg(feature = "hyper-v1")] + hyper_v1, + )), + Err(e) => { + #[cfg(feature = "tracing")] + tracing::debug!("error parsing request: {:?}", e); + ResponseFuture::ready(Ok(e.into())) + } + _req => { + #[cfg(feature = "tracing")] + tracing::debug!("invalid request: {:?}", _req); + ResponseFuture::empty_response(400) + } + } +} + +#[derive(thiserror::Error, Debug)] +pub enum ParseError { + #[error("transport unknown")] + UnknownTransport, + #[error("bad handshake method")] + BadHandshakeMethod, + #[error("transport mismatch")] + TransportMismatch, + #[error("unsupported protocol version")] + UnsupportedProtocolVersion, +} + +/// Convert an error into an http response +/// If it is a known error, return the appropriate http status code +/// Otherwise, return a 500 +impl From for Response> { + fn from(err: ParseError) -> Self { + use ParseError::*; + let conn_err_resp = |message: &'static str| { + Response::builder() + .status(400) + .header("Content-Type", "application/json") + .body(ResponseBody::custom_response(message.into())) + .unwrap() + }; + match err { + UnknownTransport => conn_err_resp("{\"code\":\"0\",\"message\":\"Transport unknown\"}"), + BadHandshakeMethod => { + conn_err_resp("{\"code\":\"2\",\"message\":\"Bad handshake method\"}") + } + TransportMismatch => conn_err_resp("{\"code\":\"3\",\"message\":\"Bad request\"}"), + UnsupportedProtocolVersion => { + conn_err_resp("{\"code\":\"5\",\"message\":\"Unsupported protocol version\"}") + } + } + } +} + +/// The engine.io protocol +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum ProtocolVersion { + V3 = 3, + V4 = 4, +} + +impl FromStr for ProtocolVersion { + type Err = ParseError; + + #[cfg(all(feature = "v3", feature = "v4"))] + fn from_str(s: &str) -> Result { + match s { + "3" => Ok(ProtocolVersion::V3), + "4" => Ok(ProtocolVersion::V4), + _ => Err(ParseError::UnsupportedProtocolVersion), + } + } + + #[cfg(all(feature = "v4", not(feature = "v3")))] + fn from_str(s: &str) -> Result { + match s { + "4" => Ok(ProtocolVersion::V4), + _ => Err(ParseError::UnsupportedProtocolVersion), + } + } + + #[cfg(all(feature = "v3", not(feature = "v4")))] + fn from_str(s: &str) -> Result { + match s { + "3" => Ok(ProtocolVersion::V3), + _ => Err(ParseError::UnsupportedProtocolVersion), + } + } +} + +/// The type of the [`transport`](crate::transport) used by the client. +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum TransportType { + Polling = 0x01, + Websocket = 0x02, +} + +impl FromStr for TransportType { + type Err = ParseError; + + fn from_str(s: &str) -> Result { + match s { + "websocket" => Ok(TransportType::Websocket), + "polling" => Ok(TransportType::Polling), + _ => Err(ParseError::UnknownTransport), + } + } +} +impl From for &'static str { + fn from(t: TransportType) -> Self { + match t { + TransportType::Polling => "polling", + TransportType::Websocket => "websocket", + } + } +} +impl From for String { + fn from(t: TransportType) -> Self { + match t { + TransportType::Polling => "polling".into(), + TransportType::Websocket => "websocket".into(), + } + } +} + +/// The request information extracted from the request URI. +#[derive(Debug)] +pub struct RequestInfo { + /// The protocol version used by the client. + pub protocol: ProtocolVersion, + /// The socket id if present in the request. + pub sid: Option, + /// The transport type used by the client. + pub transport: TransportType, + /// The request method. + pub method: Method, + /// If the client asked for base64 encoding only. + #[cfg(feature = "v3")] + pub b64: bool, +} + +impl RequestInfo { + /// Parse the request URI to extract the [`TransportType`](crate::service::TransportType) and the socket id. + fn parse(req: &Request, config: &EngineIoConfig) -> Result { + use ParseError::*; + let query = req.uri().query().ok_or(UnknownTransport)?; + + let protocol: ProtocolVersion = query + .split('&') + .find(|s| s.starts_with("EIO=")) + .and_then(|s| s.split('=').nth(1)) + .ok_or(UnsupportedProtocolVersion) + .and_then(|t| t.parse())?; + + let sid = query + .split('&') + .find(|s| s.starts_with("sid=")) + .and_then(|s| s.split('=').nth(1).map(|s1| s1.parse().ok())) + .flatten(); + + let transport: TransportType = query + .split('&') + .find(|s| s.starts_with("transport=")) + .and_then(|s| s.split('=').nth(1)) + .ok_or(UnknownTransport) + .and_then(|t| t.parse())?; + + if !config.allowed_transport(transport) { + return Err(TransportMismatch); + } + + #[cfg(feature = "v3")] + let b64: bool = query + .split('&') + .find(|s| s.starts_with("b64=")) + .map(|_| true) + .unwrap_or_default(); + + let method = req.method().clone(); + if !matches!(method, Method::GET) && sid.is_none() { + Err(BadHandshakeMethod) + } else { + Ok(RequestInfo { + protocol, + sid, + transport, + method, + #[cfg(feature = "v3")] + b64, + }) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn build_request(path: &str) -> Request<()> { + Request::get(path).body(()).unwrap() + } + + #[test] + #[cfg(feature = "v4")] + fn request_info_polling() { + let req = build_request("http://localhost:3000/socket.io/?EIO=4&transport=polling"); + let info = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap(); + assert_eq!(info.sid, None); + assert_eq!(info.transport, TransportType::Polling); + assert_eq!(info.protocol, ProtocolVersion::V4); + assert_eq!(info.method, Method::GET); + } + + #[test] + #[cfg(feature = "v4")] + fn request_info_websocket() { + let req = build_request("http://localhost:3000/socket.io/?EIO=4&transport=websocket"); + let info = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap(); + assert_eq!(info.sid, None); + assert_eq!(info.transport, TransportType::Websocket); + assert_eq!(info.protocol, ProtocolVersion::V4); + assert_eq!(info.method, Method::GET); + } + + #[test] + #[cfg(feature = "v3")] + fn request_info_polling_with_sid() { + let req = build_request( + "http://localhost:3000/socket.io/?EIO=3&transport=polling&sid=AAAAAAAAAAAAAAHs", + ); + let info = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap(); + assert_eq!(info.sid, Some("AAAAAAAAAAAAAAHs".parse().unwrap())); + assert_eq!(info.transport, TransportType::Polling); + assert_eq!(info.protocol, ProtocolVersion::V3); + assert_eq!(info.method, Method::GET); + } + + #[test] + #[cfg(feature = "v4")] + fn request_info_websocket_with_sid() { + let req = build_request( + "http://localhost:3000/socket.io/?EIO=4&transport=websocket&sid=AAAAAAAAAAAAAAHs", + ); + let info = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap(); + assert_eq!(info.sid, Some("AAAAAAAAAAAAAAHs".parse().unwrap())); + assert_eq!(info.transport, TransportType::Websocket); + assert_eq!(info.protocol, ProtocolVersion::V4); + assert_eq!(info.method, Method::GET); + } + + #[test] + #[cfg(feature = "v3")] + fn request_info_polling_with_bin_by_default() { + let req = build_request("http://localhost:3000/socket.io/?EIO=3&transport=polling"); + let req = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap(); + assert!(!req.b64); + } + + #[test] + #[cfg(feature = "v3")] + fn request_info_polling_withb64() { + assert!(cfg!(feature = "v3")); + + let req = build_request("http://localhost:3000/socket.io/?EIO=3&transport=polling&b64=1"); + let req = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap(); + assert!(req.b64); + } + + #[test] + #[cfg(feature = "v4")] + fn transport_unknown_err() { + let req = build_request("http://localhost:3000/socket.io/?EIO=4&transport=grpc"); + let err = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap_err(); + assert!(matches!(err, ParseError::UnknownTransport)); + } + #[test] + fn unsupported_protocol_version() { + let req = build_request("http://localhost:3000/socket.io/?EIO=2&transport=polling"); + let err = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap_err(); + assert!(matches!(err, ParseError::UnsupportedProtocolVersion)); + } + #[test] + #[cfg(feature = "v4")] + fn bad_handshake_method() { + let req = Request::post("http://localhost:3000/socket.io/?EIO=4&transport=polling") + .body(()) + .unwrap(); + let err = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap_err(); + assert!(matches!(err, ParseError::BadHandshakeMethod)); + } + + #[test] + #[cfg(feature = "v4")] + fn unsupported_transport() { + let req = build_request("http://localhost:3000/socket.io/?EIO=4&transport=polling"); + let err = RequestInfo::parse( + &req, + &EngineIoConfig::builder() + .transports([TransportType::Websocket]) + .build(), + ) + .unwrap_err(); + + assert!(matches!(err, ParseError::TransportMismatch)); + + let req = build_request("http://localhost:3000/socket.io/?EIO=4&transport=websocket"); + let err = RequestInfo::parse( + &req, + &EngineIoConfig::builder() + .transports([TransportType::Polling]) + .build(), + ) + .unwrap_err(); + + assert!(matches!(err, ParseError::TransportMismatch)) + } +} diff --git a/engineioxide/src/socket.rs b/engineioxide/src/socket.rs index 0fcf07fe..a977f8f5 100644 --- a/engineioxide/src/socket.rs +++ b/engineioxide/src/socket.rs @@ -21,7 +21,7 @@ use crate::{ config::EngineIoConfig, errors::Error, packet::Packet, peekable::PeekableReceiver, service::ProtocolVersion, }; -use crate::{sid::Sid, transport::TransportType}; +use crate::{service::TransportType, sid::Sid}; /// Http Request data used to create a socket #[derive(Debug)] diff --git a/engineioxide/src/transport/mod.rs b/engineioxide/src/transport/mod.rs index 06e17b4c..94d843a5 100644 --- a/engineioxide/src/transport/mod.rs +++ b/engineioxide/src/transport/mod.rs @@ -1,43 +1,4 @@ //! All transports modules available in engineioxide -use std::str::FromStr; - -use crate::errors::Error; - pub mod polling; pub mod ws; - -/// The type of the [`transport`](crate::transport) used by the client. -#[derive(Debug, Copy, Clone, PartialEq)] -pub enum TransportType { - Polling = 0x01, - Websocket = 0x02, -} - -impl FromStr for TransportType { - type Err = Error; - - fn from_str(s: &str) -> Result { - match s { - "websocket" => Ok(TransportType::Websocket), - "polling" => Ok(TransportType::Polling), - _ => Err(Error::UnknownTransport), - } - } -} -impl From for &'static str { - fn from(t: TransportType) -> Self { - match t { - TransportType::Polling => "polling", - TransportType::Websocket => "websocket", - } - } -} -impl From for String { - fn from(t: TransportType) -> Self { - match t { - TransportType::Polling => "polling".into(), - TransportType::Websocket => "websocket".into(), - } - } -} diff --git a/engineioxide/src/transport/polling/mod.rs b/engineioxide/src/transport/polling/mod.rs index 47444b1e..0ccb4fce 100644 --- a/engineioxide/src/transport/polling/mod.rs +++ b/engineioxide/src/transport/polling/mod.rs @@ -1,27 +1,47 @@ //! The polling transport module handles polling, post and init requests use std::sync::Arc; +use bytes::Bytes; use futures::StreamExt; use http::{Request, Response, StatusCode}; -use http_body::Body; +use http_body::{Body, Full}; use crate::{ - body::ResponseBody, + body::response::ResponseBody, engine::EngineIo, errors::Error, - futures::http_response, handler::EngineIoHandler, packet::{OpenPacket, Packet}, - service::ProtocolVersion, + service::{ProtocolVersion, TransportType}, sid::Sid, transport::polling::payload::Payload, DisconnectReason, SocketReq, }; -use super::TransportType; - mod payload; +/// Create a response for http request +fn http_response( + code: StatusCode, + data: D, + is_binary: bool, +) -> Result>, http::Error> +where + D: Into, +{ + use http::header::*; + let body: Bytes = data.into(); + let res = Response::builder() + .status(code) + .header(CONTENT_LENGTH, body.len()); + if is_binary { + res.header(CONTENT_TYPE, "application/octet-stream") + } else { + res.header(CONTENT_TYPE, "text/plain; charset=UTF-8") + } + .body(ResponseBody::custom_response(Full::new(body))) +} + pub fn open_req( engine: Arc>, protocol: ProtocolVersion, diff --git a/engineioxide/src/transport/ws.rs b/engineioxide/src/transport/ws.rs index dd57c0a7..3aa58e7c 100644 --- a/engineioxide/src/transport/ws.rs +++ b/engineioxide/src/transport/ws.rs @@ -7,40 +7,60 @@ use std::sync::Arc; use futures::{ + future::Either, stream::{SplitSink, SplitStream}, SinkExt, StreamExt, TryStreamExt, }; -use http::{Request, Response, StatusCode}; -use hyper::upgrade::Upgraded; -use tokio::task::JoinHandle; +use http::{HeaderValue, Request, Response, StatusCode}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + task::JoinHandle, +}; use tokio_tungstenite::{ - tungstenite::{protocol::Role, Message}, + tungstenite::{handshake::derive_accept_key, protocol::Role, Message}, WebSocketStream, }; use crate::{ - body::ResponseBody, + body::response::ResponseBody, config::EngineIoConfig, engine::EngineIo, errors::Error, - futures::ws_response, handler::EngineIoHandler, packet::{OpenPacket, Packet}, service::ProtocolVersion, + service::TransportType, sid::Sid, - transport::TransportType, DisconnectReason, Socket, SocketReq, }; +/// Create a response for websocket upgrade +fn ws_response(ws_key: &HeaderValue) -> Result>, http::Error> { + let derived = derive_accept_key(ws_key.as_bytes()); + let sec = derived.parse::().unwrap(); + Response::builder() + .status(StatusCode::SWITCHING_PROTOCOLS) + .header(http::header::UPGRADE, HeaderValue::from_static("websocket")) + .header( + http::header::CONNECTION, + HeaderValue::from_static("Upgrade"), + ) + .header(http::header::SEC_WEBSOCKET_ACCEPT, sec) + .body(ResponseBody::empty_response()) +} + /// Upgrade a websocket request to create a websocket connection. /// /// If a sid is provided in the query it means that is is upgraded from an existing HTTP polling request. In this case /// the http polling request is closed and the SID is kept for the websocket +/// +/// It can be used with hyper-v1 by setting the `hyper_v1` parameter to true pub fn new_req( engine: Arc>, protocol: ProtocolVersion, sid: Option, req: Request, + #[cfg(feature = "hyper-v1")] hyper_v1: bool, ) -> Result>, Error> { let (parts, _) = req.into_parts(); let ws_key = parts @@ -52,20 +72,43 @@ pub fn new_req( let req = Request::from_parts(parts, ()); tokio::spawn(async move { - match hyper::upgrade::on(req).await { - Ok(conn) => match on_init(engine, conn, protocol, sid, req_data).await { - Ok(_) => { - #[cfg(feature = "tracing")] - tracing::debug!("ws closed") - } - Err(_e) => { - #[cfg(feature = "tracing")] - tracing::debug!("ws closed with error: {:?}", _e) - } - }, + #[cfg(feature = "hyper-v1")] + let res = if hyper_v1 { + // Wraps the hyper-v1 upgrade so it implement `AsyncRead` and `AsyncWrite` + Either::Left( + hyper_v1::upgrade::on(req) + .await + .map(hyper_util::rt::TokioIo::new), + ) + } else { + Either::Right(hyper::upgrade::on(req).await) + }; + #[cfg(not(feature = "hyper-v1"))] + let res = Either::Right(hyper::upgrade::on(req).await); + + let res = match res { + Either::Left(Ok(conn)) => on_init(engine, conn, protocol, sid, req_data).await, + Either::Right(Ok(conn)) => on_init(engine, conn, protocol, sid, req_data).await, + Either::Left(Err(_e)) => { + #[cfg(feature = "tracing")] + tracing::debug!("ws upgrade error: {}", _e); + return; + } + Either::Right(Err(_e)) => { + #[cfg(feature = "tracing")] + tracing::debug!("ws upgrade error: {}", _e); + return; + } + }; + + match res { + Ok(_) => { + #[cfg(feature = "tracing")] + tracing::debug!("ws closed") + } Err(_e) => { #[cfg(feature = "tracing")] - tracing::debug!("ws upgrade error: {}", _e) + tracing::debug!("ws closed with error: {:?}", _e) } } }); @@ -78,13 +121,16 @@ pub fn new_req( /// Sends an open packet if it is not an upgrade from a polling request /// /// Read packets from the websocket and handle them, it will block until the connection is closed -async fn on_init( +async fn on_init( engine: Arc>, - conn: Upgraded, + conn: S, protocol: ProtocolVersion, sid: Option, req_data: SocketReq, -) -> Result<(), Error> { +) -> Result<(), Error> +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ let ws_init = move || WebSocketStream::from_raw_socket(conn, Role::Server, None); let (socket, ws) = if let Some(sid) = sid { match engine.get_socket(sid) { @@ -92,7 +138,7 @@ async fn on_init( Some(socket) if socket.is_ws() => return Err(Error::UpgradeError), Some(socket) => { let mut ws = ws_init().await; - upgrade_handshake::(protocol, &socket, &mut ws).await?; + upgrade_handshake::(protocol, &socket, &mut ws).await?; (socket, ws) } } @@ -114,7 +160,7 @@ async fn on_init( (socket, ws) }; let (tx, rx) = ws.split(); - let rx_handle = forward_to_socket::(socket.clone(), tx); + let rx_handle = forward_to_socket::(socket.clone(), tx); engine.handler.on_connect(socket.clone()); @@ -132,11 +178,14 @@ async fn on_init( } /// Forwards all packets received from a websocket to a EngineIo [`Socket`] -async fn forward_to_handler( +async fn forward_to_handler( engine: &Arc>, - mut rx: SplitStream>, + mut rx: SplitStream>, socket: &Arc>, -) -> Result<(), Error> { +) -> Result<(), Error> +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ while let Some(msg) = rx.try_next().await? { match msg { Message::Text(msg) => match Packet::try_from(msg)? { @@ -170,10 +219,13 @@ async fn forward_to_handler( /// Forwards all packets waiting to be sent to the websocket /// /// The websocket stream is flushed only when the internal channel is drained -fn forward_to_socket( +fn forward_to_socket( socket: Arc>, - mut tx: SplitSink, Message>, -) -> JoinHandle<()> { + mut tx: SplitSink, Message>, +) -> JoinHandle<()> +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ // Pipe between websocket and internal socket channel tokio::spawn(async move { let mut internal_rx = socket.internal_rx.try_lock().unwrap(); @@ -220,11 +272,14 @@ fn forward_to_socket( }) } /// Send a Engine.IO [`OpenPacket`] to initiate a websocket connection -async fn init_handshake( +async fn init_handshake( sid: Sid, - ws: &mut WebSocketStream, + ws: &mut WebSocketStream, config: &EngineIoConfig, -) -> Result<(), Error> { +) -> Result<(), Error> +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ let packet = Packet::Open(OpenPacket::new(TransportType::Websocket, sid, config)); ws.send(Message::Text(packet.try_into()?)).await?; Ok(()) @@ -254,11 +309,14 @@ async fn init_handshake( ///│ ----- WebSocket frames ----- │ /// ``` #[cfg_attr(feature = "tracing", tracing::instrument(skip(socket, ws), fields(sid = socket.id.to_string())))] -async fn upgrade_handshake( +async fn upgrade_handshake( protocol: ProtocolVersion, socket: &Arc>, - ws: &mut WebSocketStream, -) -> Result<(), Error> { + ws: &mut WebSocketStream, +) -> Result<(), Error> +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ #[cfg(feature = "tracing")] tracing::debug!("websocket connection upgrade"); diff --git a/examples/axum-echo/Cargo.toml b/examples/axum-echo/Cargo.toml new file mode 100644 index 00000000..90a074bd --- /dev/null +++ b/examples/axum-echo/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "axum-echo" +version = "0.6.0" +edition = "2021" + +[dependencies] +socketioxide = { path = "../../socketioxide" } +axum = { version = "0.6.20" } +tokio = { version = "1.33.0", features = ["full"] } +tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } +tracing = "0.1.37" +serde = "1.0.188" +serde_json = "1.0.107" +futures = "0.3.28" + +[[bin]] +name = "axum-echo" +path = "axum_echo.rs" diff --git a/examples/socketio-echo/src/axum_echo.rs b/examples/axum-echo/axum_echo.rs similarity index 100% rename from examples/socketio-echo/src/axum_echo.rs rename to examples/axum-echo/axum_echo.rs diff --git a/examples/engineio-echo/Cargo.toml b/examples/engineio-echo/Cargo.toml deleted file mode 100644 index b69321aa..00000000 --- a/examples/engineio-echo/Cargo.toml +++ /dev/null @@ -1,32 +0,0 @@ -[package] -name = "engineio-echo" -version = "0.6.0" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -engineioxide = { path = "../../engineioxide" } -axum = { version = "0.6.20" } -warp = { version = "0.3.6" } -hyper = { version = "0.14.27" } -tokio = { version = "1.33.0", features = ["full"] } -tower = { version = "0.4.13" } -tower-http = { version = "0.4.4", features = ["cors"] } -tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } -tracing = "0.1.37" -serde = "1.0.188" -serde_json = "1.0.107" -futures = "0.3.28" - -[[example]] -name = "engineio-axum-echo" -path = "src/axum_echo.rs" - -[[example]] -name = "engineio-hyper-echo" -path = "src/hyper_echo.rs" - -[[example]] -name = "engineio-warp-echo" -path = "src/warp_echo.rs" diff --git a/examples/engineio-echo/src/axum_echo.rs b/examples/engineio-echo/src/axum_echo.rs deleted file mode 100644 index 4f0ae0b4..00000000 --- a/examples/engineio-echo/src/axum_echo.rs +++ /dev/null @@ -1,52 +0,0 @@ -use std::sync::Arc; - -use axum::routing::get; -use axum::Server; -use engineioxide::{ - handler::EngineIoHandler, - layer::EngineIoLayer, - socket::{DisconnectReason, Socket}, -}; -use tracing::info; -use tracing_subscriber::FmtSubscriber; - -#[derive(Debug, Clone)] -struct MyHandler; - -#[engineioxide::async_trait] -impl EngineIoHandler for MyHandler { - type Data = (); - - fn on_connect(&self, socket: Arc>) { - println!("socket connect {}", socket.id); - } - fn on_disconnect(&self, socket: Arc>, reason: DisconnectReason) { - println!("socket disconnect {}: {:?}", socket.id, reason); - } - - fn on_message(&self, msg: String, socket: Arc>) { - println!("Ping pong message {:?}", msg); - socket.emit(msg).ok(); - } - - fn on_binary(&self, data: Vec, socket: Arc>) { - println!("Ping pong binary message {:?}", data); - socket.emit_binary(data).ok(); - } -} - -#[tokio::main] -async fn main() -> Result<(), Box> { - tracing::subscriber::set_global_default(FmtSubscriber::default())?; - - info!("Starting server"); - let app = axum::Router::new() - .route("/", get(|| async { "Hello, World!" })) - .layer(EngineIoLayer::new(MyHandler)); - - Server::bind(&"127.0.0.1:3000".parse().unwrap()) - .serve(app.into_make_service()) - .await?; - - Ok(()) -} diff --git a/examples/engineio-echo/src/hyper_echo.rs b/examples/engineio-echo/src/hyper_echo.rs deleted file mode 100644 index c93c5ef7..00000000 --- a/examples/engineio-echo/src/hyper_echo.rs +++ /dev/null @@ -1,52 +0,0 @@ -use std::sync::Arc; - -use engineioxide::{ - handler::EngineIoHandler, - service::EngineIoService, - socket::{DisconnectReason, Socket}, -}; -use hyper::Server; -use tracing::info; -use tracing_subscriber::FmtSubscriber; - -#[derive(Debug, Clone)] -struct MyHandler; - -#[engineioxide::async_trait] -impl EngineIoHandler for MyHandler { - type Data = (); - - fn on_connect(&self, socket: Arc>) { - println!("socket connect {}", socket.id); - } - fn on_disconnect(&self, socket: Arc>, reason: DisconnectReason) { - println!("socket disconnect {}: {:?}", socket.id, reason); - } - - fn on_message(&self, msg: String, socket: Arc>) { - println!("Ping pong message {:?}", msg); - socket.emit(msg).ok(); - } - - fn on_binary(&self, data: Vec, socket: Arc>) { - println!("Ping pong binary message {:?}", data); - socket.emit_binary(data).ok(); - } -} - -#[tokio::main] -async fn main() -> Result<(), Box> { - tracing::subscriber::set_global_default(FmtSubscriber::default())?; - - // We'll bind to 127.0.0.1:3000 - let addr = &"127.0.0.1:3000".parse().unwrap(); - let svc = EngineIoService::new(MyHandler); - - let server = Server::bind(addr).serve(svc.into_make_service()); - - info!("Starting server"); - - server.await?; - - Ok(()) -} diff --git a/examples/engineio-echo/src/warp_echo.rs b/examples/engineio-echo/src/warp_echo.rs deleted file mode 100644 index c97aa984..00000000 --- a/examples/engineio-echo/src/warp_echo.rs +++ /dev/null @@ -1,56 +0,0 @@ -use std::sync::Arc; - -use engineioxide::{ - handler::EngineIoHandler, - service::EngineIoService, - socket::{DisconnectReason, Socket}, -}; -use hyper::Server; -use tracing::info; -use tracing_subscriber::FmtSubscriber; -use warp::Filter; - -#[derive(Debug, Clone)] -struct MyHandler; - -#[engineioxide::async_trait] -impl EngineIoHandler for MyHandler { - type Data = (); - - fn on_connect(&self, socket: Arc>) { - println!("socket connect {}", socket.id); - } - fn on_disconnect(&self, socket: Arc>, reason: DisconnectReason) { - println!("socket disconnect {}: {:?}", socket.id, reason); - } - - fn on_message(&self, msg: String, socket: Arc>) { - println!("Ping pong message {:?}", msg); - socket.emit(msg).ok(); - } - - fn on_binary(&self, data: Vec, socket: Arc>) { - println!("Ping pong binary message {:?}", data); - socket.emit_binary(data).ok(); - } -} - -#[tokio::main] -async fn main() -> Result<(), Box> { - tracing::subscriber::set_global_default(FmtSubscriber::default())?; - - let filter = warp::any().map(|| "Hello From Warp!"); - let warp_svc = warp::service(filter); - - // We'll bind to 127.0.0.1:3000 - let addr = &"127.0.0.1:3000".parse().unwrap(); - let svc = EngineIoService::with_inner(warp_svc, MyHandler); - - let server = Server::bind(addr).serve(svc.into_make_service()); - - info!("Starting server"); - - server.await?; - - Ok(()) -} diff --git a/examples/hyper-echo/Cargo.toml b/examples/hyper-echo/Cargo.toml new file mode 100644 index 00000000..6da5794a --- /dev/null +++ b/examples/hyper-echo/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "hyper-echo" +version = "0.6.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +socketioxide = { path = "../../socketioxide" } +hyper = { version = "0.14.27" } +tokio = { version = "1.33.0", features = ["full"] } +tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } +tracing = "0.1.37" +serde = "1.0.188" +serde_json = "1.0.107" +futures = "0.3.28" + +[[bin]] +name = "hyper-echo" +path = "hyper_echo.rs" diff --git a/examples/socketio-echo/src/hyper_echo.rs b/examples/hyper-echo/hyper_echo.rs similarity index 100% rename from examples/socketio-echo/src/hyper_echo.rs rename to examples/hyper-echo/hyper_echo.rs diff --git a/examples/hyper-v1-echo/Cargo.toml b/examples/hyper-v1-echo/Cargo.toml new file mode 100644 index 00000000..2cd98b72 --- /dev/null +++ b/examples/hyper-v1-echo/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "hyper-v1-echo" +version = "0.6.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +socketioxide = { path = "../../socketioxide", features = [ + "hyper-v1", + "tracing", +] } +hyper-v1 = { package = "hyper", version = "1.0.0-rc.4", optional = true, features = [ + "server", + "http1", + "http2", +] } +hyper-util = { git = "https://github.com/hyperium/hyper-util.git" } +tokio = { version = "1.33.0", features = ["full"] } +tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } +tracing = "0.1.37" +serde = "1.0.188" +serde_json = "1.0.107" +futures = "0.3.28" + +[[bin]] +name = "hyper-v1-echo" +path = "hyper_v1_echo.rs" diff --git a/examples/hyper-v1-echo/hyper_v1_echo.rs b/examples/hyper-v1-echo/hyper_v1_echo.rs new file mode 100644 index 00000000..524db746 --- /dev/null +++ b/examples/hyper-v1-echo/hyper_v1_echo.rs @@ -0,0 +1,72 @@ +use std::net::SocketAddr; + +use hyper_util::rt::TokioIo; +use hyper_v1::server::conn::http1; +use serde_json::Value; +use socketioxide::SocketIo; +use tokio::net::TcpListener; +use tracing::{info, Level}; +use tracing_subscriber::FmtSubscriber; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let subscriber = FmtSubscriber::builder() + .with_line_number(true) + .with_max_level(Level::TRACE) + .finish(); + tracing::subscriber::set_global_default(subscriber)?; + + let (svc, io) = SocketIo::new_svc(); + + io.ns("/", |socket, auth: Value| async move { + info!("Socket.IO connected: {:?} {:?}", socket.ns(), socket.id); + socket.emit("auth", auth).ok(); + + socket.on("message", |socket, data: Value, bin, _| async move { + info!("Received event: {:?} {:?}", data, bin); + socket.bin(bin).emit("message-back", data).ok(); + }); + + socket.on("message-with-ack", |_, data: Value, bin, ack| async move { + info!("Received event: {:?} {:?}", data, bin); + ack.bin(bin).send(data).ok(); + }); + + socket.on_disconnect(|socket, reason| async move { + info!("Socket.IO disconnected: {} {}", socket.id, reason); + }); + }); + + io.ns("/custom", |socket, auth: Value| async move { + info!("Socket.IO connected on: {:?} {:?}", socket.ns(), socket.id); + socket.emit("auth", auth).ok(); + }); + + let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); + let listener = TcpListener::bind(addr).await?; + + // Convert the `SocketIoService` so it works with hyper 1.0 + let svc = svc.with_hyper_v1(); + + // We start a loop to continuously accept incoming connections + loop { + let (stream, _) = listener.accept().await?; + + // Use an adapter to access something implementing `tokio::io` traits as if they implement + // `hyper::rt` IO traits. + let io = TokioIo::new(stream); + let svc = svc.clone(); + + // Spawn a tokio task to serve multiple connections concurrently + tokio::task::spawn(async move { + // Finally, we bind the incoming connection to our `hello` service + if let Err(err) = http1::Builder::new() + .serve_connection(io, svc) + .with_upgrades() + .await + { + println!("Error serving connection: {:?}", err); + } + }); + } +} diff --git a/examples/salvo-echo/Cargo.toml b/examples/salvo-echo/Cargo.toml new file mode 100644 index 00000000..fae806cd --- /dev/null +++ b/examples/salvo-echo/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "salvo-echo" +version = "0.6.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +socketioxide = { path = "../../socketioxide", features = [ + "hyper-v1", + "tracing", +] } +salvo = { version = "0.58.2", features = ["tower-compat"] } +tokio = { version = "1.33.0", features = ["full"] } +tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } +tracing = "0.1.37" +serde = "1.0.188" +serde_json = "1.0.107" +futures = "0.3.28" + +[[bin]] +name = "salvo-echo" +path = "salvo_echo.rs" diff --git a/examples/salvo-echo/salvo_echo.rs b/examples/salvo-echo/salvo_echo.rs new file mode 100644 index 00000000..350fdc77 --- /dev/null +++ b/examples/salvo-echo/salvo_echo.rs @@ -0,0 +1,53 @@ +use salvo::prelude::*; +use serde_json::Value; +use socketioxide::SocketIo; + +use tracing::{info, Level}; +use tracing_subscriber::FmtSubscriber; + +#[handler] +async fn hello() -> &'static str { + "Hello World" +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let subscriber = FmtSubscriber::builder() + .with_line_number(true) + .with_max_level(Level::TRACE) + .finish(); + tracing::subscriber::set_global_default(subscriber)?; + + let (layer, io) = SocketIo::new_layer(); + + io.ns("/", |socket, auth: Value| async move { + info!("Socket.IO connected: {:?} {:?}", socket.ns(), socket.id); + socket.emit("auth", auth).ok(); + + socket.on("message", |socket, data: Value, bin, _| async move { + info!("Received event: {:?} {:?}", data, bin); + socket.bin(bin).emit("message-back", data).ok(); + }); + + socket.on("message-with-ack", |_, data: Value, bin, ack| async move { + info!("Received event: {:?} {:?}", data, bin); + ack.bin(bin).send(data).ok(); + }); + + socket.on_disconnect(|socket, reason| async move { + info!("Socket.IO disconnected: {} {}", socket.id, reason); + }); + }); + + io.ns("/custom", |socket, auth: Value| async move { + info!("Socket.IO connected on: {:?} {:?}", socket.ns(), socket.id); + socket.emit("auth", auth).ok(); + }); + + let layer = layer.with_hyper_v1().compat(); + let router = Router::with_path("/socket.io").hoop(layer).goal(hello); + let acceptor = TcpListener::new("127.0.0.1:3000").bind().await; + Server::new(acceptor).serve(router).await; + + Ok(()) +} diff --git a/examples/socketio-echo/Cargo.toml b/examples/warp-echo/Cargo.toml similarity index 57% rename from examples/socketio-echo/Cargo.toml rename to examples/warp-echo/Cargo.toml index f5758356..a57f6b3d 100644 --- a/examples/socketio-echo/Cargo.toml +++ b/examples/warp-echo/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "socketio-echo" +name = "engineio-echo" version = "0.6.0" edition = "2021" @@ -7,26 +7,15 @@ edition = "2021" [dependencies] socketioxide = { path = "../../socketioxide" } -axum = { version = "0.6.20" } warp = { version = "0.3.6" } hyper = { version = "0.14.27" } tokio = { version = "1.33.0", features = ["full"] } -tower = { version = "0.4.13" } -tower-http = { version = "0.4.4", features = ["cors"] } tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } tracing = "0.1.37" serde = "1.0.188" serde_json = "1.0.107" futures = "0.3.28" -[[example]] -name = "socketio-axum-echo" -path = "src/axum_echo.rs" - -[[example]] -name = "socketio-hyper-echo" -path = "src/hyper_echo.rs" - -[[example]] -name = "socketio-warp-echo" -path = "src/warp_echo.rs" +[[bin]] +name = "warp-echo" +path = "warp_echo.rs" diff --git a/examples/socketio-echo/src/warp_echo.rs b/examples/warp-echo/warp_echo.rs similarity index 100% rename from examples/socketio-echo/src/warp_echo.rs rename to examples/warp-echo/warp_echo.rs diff --git a/socketioxide/Cargo.toml b/socketioxide/Cargo.toml index 11d8a174..e31b3e4e 100644 --- a/socketioxide/Cargo.toml +++ b/socketioxide/Cargo.toml @@ -36,6 +36,14 @@ dashmap = { version = "5.4.0", optional = true } # Tracing tracing = { version = "0.1.37", optional = true } +# Hyper v0.1 +http-body-v1 = { package = "http-body", version = "1.0.0-rc.2", optional = true } +hyper-v1 = { package = "hyper", version = "1.0.0-rc.4", optional = true, features = [ + "server", + "http1", + "http2", +] } + [features] default = ["v5"] v5 = ["engineioxide/v4"] @@ -43,6 +51,7 @@ v4 = ["engineioxide/v3"] test-utils = [] tracing = ["dep:tracing", "engineioxide/tracing"] extensions = ["dep:dashmap"] +hyper-v1 = ["engineioxide/hyper-v1", "dep:http-body-v1", "dep:hyper-v1"] [dev-dependencies] engineioxide = { path = "../engineioxide", version = "0.6.0", features = [ diff --git a/socketioxide/src/client.rs b/socketioxide/src/client.rs index 969c0610..b8a7a46d 100644 --- a/socketioxide/src/client.rs +++ b/socketioxide/src/client.rs @@ -66,7 +66,7 @@ impl Client { tracing::debug!("auth: {:?}", auth); let sid = esocket.id; - if let Some(ns) = self.get_ns(&ns_path) { + if let Some(ns) = self.get_ns(ns_path) { ns.connect(sid, esocket.clone(), auth, self.config.clone())?; // cancel the connect timeout task for v5 @@ -196,7 +196,7 @@ impl EngineIoHandler for Client { if protocol == ProtocolVersion::V4 { #[cfg(feature = "tracing")] tracing::debug!("connecting to default namespace for v4"); - self.sock_connect(None, "/".into(), &socket).unwrap(); + self.sock_connect(None, "/", &socket).unwrap(); } #[cfg(feature = "v5")] diff --git a/socketioxide/src/hyper_v1.rs b/socketioxide/src/hyper_v1.rs new file mode 100644 index 00000000..beccdb52 --- /dev/null +++ b/socketioxide/src/hyper_v1.rs @@ -0,0 +1,72 @@ +use std::{ + sync::Arc, + task::{Context, Poll}, +}; + +use crate::{adapter::Adapter, client::Client}; +use engineioxide::service::hyper_v1::EngineIoHyperService; +use http::{Request, Response}; +use http_body_v1::Body; +use hyper_v1::body::Incoming; +use hyper_v1::service::Service as HyperSvc; +use tower::Service as TowerSvc; + +/// [`Service`](tower::Service) implementation for `hyper 1.0` +/// It can be created with `with_hyper_v1` fn on [`SocketIoService`](crate::service::SocketIoService) +/// or [`SocketIoLayer`](crate::layer::SocketIoLayer) +pub struct SocketIoHyperService(EngineIoHyperService>, S>); + +impl SocketIoHyperService { + pub(crate) fn new(svc: EngineIoHyperService>, S>) -> Self { + Self(svc) + } +} + +/// Tower Service implementation with a [`http_body_v1::Body`] Body +impl TowerSvc> for SocketIoHyperService +where + ResBody: Body + Send + 'static, + ReqBody: Body + Send + 'static + std::fmt::Debug + Unpin, + ReqBody::Error: std::fmt::Debug, + ReqBody::Data: Send, + S: TowerSvc, Response = Response> + Clone, +{ + type Response = + >, S> as TowerSvc>>::Response; + type Error = >, S> as TowerSvc>>::Error; + type Future = >, S> as TowerSvc>>::Future; + + #[inline(always)] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.0.poll_ready(cx) + } + #[inline(always)] + fn call(&mut self, req: Request) -> Self::Future { + self.0.call(req) + } +} + +/// Hyper 1.0 Service implementation with an [`Incoming`] body and a [`http_body_v1::Body`] Body +impl HyperSvc> for SocketIoHyperService +where + ResBody: http_body_v1::Body + Send + 'static, + S: hyper_v1::service::Service, Response = Response>, + S: Clone, + A: Adapter, +{ + type Response = + >, S> as HyperSvc>>::Response; + type Error = >, S> as HyperSvc>>::Error; + type Future = >, S> as HyperSvc>>::Future; + + #[inline(always)] + fn call(&self, req: Request) -> Self::Future { + self.0.call(req) + } +} + +impl Clone for SocketIoHyperService { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} diff --git a/socketioxide/src/layer.rs b/socketioxide/src/layer.rs index 7def32d6..4701db00 100644 --- a/socketioxide/src/layer.rs +++ b/socketioxide/src/layer.rs @@ -25,6 +25,15 @@ impl SocketIoLayer { }; (layer, client) } + + /// Convert this [`Layer`] into a [`SocketIoHyperLayer`] to use with hyper v1 and its dependent frameworks. + /// + /// This is only available when the `hyper-v1` feature is enabled. + #[cfg(feature = "hyper-v1")] + #[inline(always)] + pub fn with_hyper_v1(self) -> SocketIoHyperLayer { + SocketIoHyperLayer(self) + } } impl Layer for SocketIoLayer { @@ -34,3 +43,22 @@ impl Layer for SocketIoLayer { SocketIoService::with_client(inner, self.client.clone()) } } + +/// A [`Layer`] for [`SocketIoService`] that works with hyper v1 and its dependent frameworks. +#[cfg(feature = "hyper-v1")] +pub struct SocketIoHyperLayer(SocketIoLayer); + +#[cfg(feature = "hyper-v1")] +impl Clone for SocketIoHyperLayer { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} +#[cfg(feature = "hyper-v1")] +impl Layer for SocketIoHyperLayer { + type Service = crate::hyper_v1::SocketIoHyperService; + + fn layer(&self, inner: S) -> Self::Service { + SocketIoService::with_client(inner, self.0.client.clone()).with_hyper_v1() + } +} diff --git a/socketioxide/src/lib.rs b/socketioxide/src/lib.rs index eab877c2..212b76f0 100644 --- a/socketioxide/src/lib.rs +++ b/socketioxide/src/lib.rs @@ -73,6 +73,8 @@ pub mod adapter; #[cfg(feature = "extensions")] pub mod extensions; +#[cfg(feature = "hyper-v1")] +pub mod hyper_v1; pub mod layer; pub mod service; diff --git a/socketioxide/src/service.rs b/socketioxide/src/service.rs index 478b9850..c9f7b2b7 100644 --- a/socketioxide/src/service.rs +++ b/socketioxide/src/service.rs @@ -59,6 +59,16 @@ impl SocketIoService { let svc = EngineIoService::with_config_inner(inner, client, engine_config); Self { engine_svc: svc } } + + /// Convert this [`Service`] into a [`SocketIoHyperService`](crate::hyper_v1::SocketIoHyperService) + /// to use with hyper v1 and its dependent frameworks. + /// + /// This is only available when the `hyper-v1` feature is enabled. + #[inline(always)] + #[cfg(feature = "hyper-v1")] + pub fn with_hyper_v1(self) -> crate::hyper_v1::SocketIoHyperService { + crate::hyper_v1::SocketIoHyperService::new(self.engine_svc.with_hyper_v1()) + } } impl Clone for SocketIoService {