diff --git a/tests/compression/Cargo.toml b/tests/compression/Cargo.toml index c795d0868..ddf2510da 100644 --- a/tests/compression/Cargo.toml +++ b/tests/compression/Cargo.toml @@ -12,11 +12,12 @@ futures-core = "0.3" http = "0.2" http-body = "0.4" hyper = "0.14.3" +paste = "1.0.12" pin-project = "1.0" prost = "0.11" tokio = {version = "1.0", features = ["macros", "rt-multi-thread", "net"]} tokio-stream = "0.1" -tonic = {path = "../../tonic", features = ["gzip"]} +tonic = {path = "../../tonic", features = ["gzip", "zstd"]} tower = {version = "0.4", features = []} tower-http = {version = "0.4", features = ["map-response-body", "map-request-body"]} diff --git a/tests/compression/src/bidirectional_stream.rs b/tests/compression/src/bidirectional_stream.rs index 4abb4d10d..bc27f62ac 100644 --- a/tests/compression/src/bidirectional_stream.rs +++ b/tests/compression/src/bidirectional_stream.rs @@ -1,20 +1,45 @@ use super::*; +use http_body::Body; use tonic::codec::CompressionEncoding; -#[tokio::test(flavor = "multi_thread")] -async fn client_enabled_server_enabled() { +util::parametrized_tests! { + client_enabled_server_enabled, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn client_enabled_server_enabled(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); let svc = test_server::TestServer::new(Svc::default()) - .accept_compressed(CompressionEncoding::Gzip) - .send_compressed(CompressionEncoding::Gzip); + .accept_compressed(encoding) + .send_compressed(encoding); let request_bytes_counter = Arc::new(AtomicUsize::new(0)); let response_bytes_counter = Arc::new(AtomicUsize::new(0)); - fn assert_right_encoding(req: http::Request) -> http::Request { - assert_eq!(req.headers().get("grpc-encoding").unwrap(), "gzip"); - req + #[derive(Clone)] + pub struct AssertRightEncoding { + encoding: CompressionEncoding, + } + + #[allow(dead_code)] + impl AssertRightEncoding { + pub fn new(encoding: CompressionEncoding) -> Self { + Self { encoding } + } + + pub fn call(self, req: http::Request) -> http::Request { + let expected = match self.encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", self.encoding), + }; + assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected); + + req + } } tokio::spawn({ @@ -24,7 +49,9 @@ async fn client_enabled_server_enabled() { Server::builder() .layer( ServiceBuilder::new() - .map_request(assert_right_encoding) + .map_request(move |req| { + AssertRightEncoding::new(encoding).clone().call(req) + }) .layer(measure_request_body_size_layer( request_bytes_counter.clone(), )) @@ -44,8 +71,8 @@ async fn client_enabled_server_enabled() { }); let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .send_compressed(CompressionEncoding::Gzip) - .accept_compressed(CompressionEncoding::Gzip); + .send_compressed(encoding) + .accept_compressed(encoding); let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]); @@ -56,7 +83,12 @@ async fn client_enabled_server_enabled() { .await .unwrap(); - assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); + let expected = match encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", encoding), + }; + assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected); let mut stream: Streaming = res.into_inner(); diff --git a/tests/compression/src/client_stream.rs b/tests/compression/src/client_stream.rs index 99d72751b..e1d862479 100644 --- a/tests/compression/src/client_stream.rs +++ b/tests/compression/src/client_stream.rs @@ -1,19 +1,42 @@ use super::*; -use http_body::Body as _; +use http_body::Body; use tonic::codec::CompressionEncoding; -#[tokio::test(flavor = "multi_thread")] -async fn client_enabled_server_enabled() { +util::parametrized_tests! { + client_enabled_server_enabled, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn client_enabled_server_enabled(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); - let svc = - test_server::TestServer::new(Svc::default()).accept_compressed(CompressionEncoding::Gzip); + let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding); let request_bytes_counter = Arc::new(AtomicUsize::new(0)); - fn assert_right_encoding(req: http::Request) -> http::Request { - assert_eq!(req.headers().get("grpc-encoding").unwrap(), "gzip"); - req + #[derive(Clone)] + pub struct AssertRightEncoding { + encoding: CompressionEncoding, + } + + #[allow(dead_code)] + impl AssertRightEncoding { + pub fn new(encoding: CompressionEncoding) -> Self { + Self { encoding } + } + + pub fn call(self, req: http::Request) -> http::Request { + let expected = match self.encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", self.encoding), + }; + assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected); + + req + } } tokio::spawn({ @@ -22,7 +45,9 @@ async fn client_enabled_server_enabled() { Server::builder() .layer( ServiceBuilder::new() - .map_request(assert_right_encoding) + .map_request(move |req| { + AssertRightEncoding::new(encoding).clone().call(req) + }) .layer(measure_request_body_size_layer( request_bytes_counter.clone(), )) @@ -35,8 +60,8 @@ async fn client_enabled_server_enabled() { } }); - let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .send_compressed(CompressionEncoding::Gzip); + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding); let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]); @@ -48,12 +73,17 @@ async fn client_enabled_server_enabled() { assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); } -#[tokio::test(flavor = "multi_thread")] -async fn client_disabled_server_enabled() { +util::parametrized_tests! { + client_disabled_server_enabled, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn client_disabled_server_enabled(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); - let svc = - test_server::TestServer::new(Svc::default()).accept_compressed(CompressionEncoding::Gzip); + let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding); let request_bytes_counter = Arc::new(AtomicUsize::new(0)); @@ -93,8 +123,14 @@ async fn client_disabled_server_enabled() { assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); } -#[tokio::test(flavor = "multi_thread")] -async fn client_enabled_server_disabled() { +util::parametrized_tests! { + client_enabled_server_disabled, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn client_enabled_server_disabled(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); let svc = test_server::TestServer::new(Svc::default()); @@ -107,8 +143,8 @@ async fn client_enabled_server_disabled() { .unwrap(); }); - let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .send_compressed(CompressionEncoding::Gzip); + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding); let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(); let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]); @@ -117,18 +153,31 @@ async fn client_enabled_server_disabled() { let status = client.compress_input_client_stream(req).await.unwrap_err(); assert_eq!(status.code(), tonic::Code::Unimplemented); + let expected = match encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", encoding), + }; assert_eq!( status.message(), - "Content is compressed with `gzip` which isn't supported" + format!( + "Content is compressed with `{}` which isn't supported", + expected + ) ); } -#[tokio::test(flavor = "multi_thread")] -async fn compressing_response_from_client_stream() { +util::parametrized_tests! { + compressing_response_from_client_stream, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn compressing_response_from_client_stream(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); - let svc = - test_server::TestServer::new(Svc::default()).send_compressed(CompressionEncoding::Gzip); + let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding); let response_bytes_counter = Arc::new(AtomicUsize::new(0)); @@ -153,14 +202,19 @@ async fn compressing_response_from_client_stream() { } }); - let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .accept_compressed(CompressionEncoding::Gzip); + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding); let stream = tokio_stream::iter(vec![]); let req = Request::new(Box::pin(stream)); let res = client.compress_output_client_stream(req).await.unwrap(); - assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); + let expected = match encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", encoding), + }; + assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected); let bytes_sent = response_bytes_counter.load(SeqCst); assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); } diff --git a/tests/compression/src/compressing_request.rs b/tests/compression/src/compressing_request.rs index 7cdfd7cec..615cf3d2f 100644 --- a/tests/compression/src/compressing_request.rs +++ b/tests/compression/src/compressing_request.rs @@ -1,18 +1,103 @@ use super::*; -use http_body::Body as _; +use http_body::Body; use tonic::codec::CompressionEncoding; -#[tokio::test(flavor = "multi_thread")] -async fn client_enabled_server_enabled() { +util::parametrized_tests! { + client_enabled_server_enabled, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn client_enabled_server_enabled(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); - let svc = - test_server::TestServer::new(Svc::default()).accept_compressed(CompressionEncoding::Gzip); + let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding); + + let request_bytes_counter = Arc::new(AtomicUsize::new(0)); + + #[derive(Clone)] + pub struct AssertRightEncoding { + encoding: CompressionEncoding, + } + + #[allow(dead_code)] + impl AssertRightEncoding { + pub fn new(encoding: CompressionEncoding) -> Self { + Self { encoding } + } + + pub fn call(self, req: http::Request) -> http::Request { + let expected = match self.encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", self.encoding), + }; + assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected); + + req + } + } + + tokio::spawn({ + let request_bytes_counter = request_bytes_counter.clone(); + async move { + Server::builder() + .layer( + ServiceBuilder::new() + .layer( + ServiceBuilder::new() + .map_request(move |req| { + AssertRightEncoding::new(encoding).clone().call(req) + }) + .layer(measure_request_body_size_layer(request_bytes_counter)) + .into_inner(), + ) + .into_inner(), + ) + .add_service(svc) + .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>(server)])) + .await + .unwrap(); + } + }); + + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding); + + for _ in 0..3 { + client + .compress_input_unary(SomeData { + data: [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec(), + }) + .await + .unwrap(); + let bytes_sent = request_bytes_counter.load(SeqCst); + assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); + } +} + +util::parametrized_tests! { + client_enabled_server_enabled_multi_encoding, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn client_enabled_server_enabled_multi_encoding(encoding: CompressionEncoding) { + let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); + + let svc = test_server::TestServer::new(Svc::default()) + .accept_compressed(CompressionEncoding::Gzip) + .accept_compressed(CompressionEncoding::Zstd); let request_bytes_counter = Arc::new(AtomicUsize::new(0)); fn assert_right_encoding(req: http::Request) -> http::Request { - assert_eq!(req.headers().get("grpc-encoding").unwrap(), "gzip"); + let supported_encodings = ["gzip", "zstd"]; + let req_encoding = req.headers().get("grpc-encoding").unwrap(); + assert!(supported_encodings.iter().any(|e| e == req_encoding)); + req } @@ -37,8 +122,8 @@ async fn client_enabled_server_enabled() { } }); - let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .send_compressed(CompressionEncoding::Gzip); + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding); for _ in 0..3 { client @@ -52,8 +137,14 @@ async fn client_enabled_server_enabled() { } } -#[tokio::test(flavor = "multi_thread")] -async fn client_enabled_server_disabled() { +parametrized_tests! { + client_enabled_server_disabled, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn client_enabled_server_disabled(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); let svc = test_server::TestServer::new(Svc::default()); @@ -66,8 +157,8 @@ async fn client_enabled_server_disabled() { .unwrap(); }); - let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .send_compressed(CompressionEncoding::Gzip); + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding); let status = client .compress_input_unary(SomeData { @@ -77,9 +168,17 @@ async fn client_enabled_server_disabled() { .unwrap_err(); assert_eq!(status.code(), tonic::Code::Unimplemented); + let expected = match encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", encoding), + }; assert_eq!( status.message(), - "Content is compressed with `gzip` which isn't supported" + format!( + "Content is compressed with `{}` which isn't supported", + expected + ) ); assert_eq!( @@ -87,13 +186,17 @@ async fn client_enabled_server_disabled() { "identity" ); } +parametrized_tests! { + client_mark_compressed_without_header_server_enabled, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} -#[tokio::test(flavor = "multi_thread")] -async fn client_mark_compressed_without_header_server_enabled() { +#[allow(dead_code)] +async fn client_mark_compressed_without_header_server_enabled(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); - let svc = - test_server::TestServer::new(Svc::default()).accept_compressed(CompressionEncoding::Gzip); + let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding); tokio::spawn({ async move { diff --git a/tests/compression/src/compressing_response.rs b/tests/compression/src/compressing_response.rs index cc5d4f4cd..0cc8eb516 100644 --- a/tests/compression/src/compressing_response.rs +++ b/tests/compression/src/compressing_response.rs @@ -1,12 +1,21 @@ use super::*; use tonic::codec::CompressionEncoding; -#[tokio::test(flavor = "multi_thread")] -async fn client_enabled_server_enabled() { +util::parametrized_tests! { + client_enabled_server_enabled, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn client_enabled_server_enabled(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); #[derive(Clone, Copy)] - struct AssertCorrectAcceptEncoding(S); + struct AssertCorrectAcceptEncoding { + service: S, + encoding: CompressionEncoding, + } impl Service> for AssertCorrectAcceptEncoding where @@ -20,20 +29,28 @@ async fn client_enabled_server_enabled() { &mut self, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - self.0.poll_ready(cx) + self.service.poll_ready(cx) } fn call(&mut self, req: http::Request) -> Self::Future { + let expected = match self.encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", self.encoding), + }; assert_eq!( - req.headers().get("grpc-accept-encoding").unwrap(), - "gzip,identity" + req.headers() + .get("grpc-accept-encoding") + .unwrap() + .to_str() + .unwrap(), + format!("{},identity", expected) ); - self.0.call(req) + self.service.call(req) } } - let svc = - test_server::TestServer::new(Svc::default()).send_compressed(CompressionEncoding::Gzip); + let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding); let response_bytes_counter = Arc::new(AtomicUsize::new(0)); @@ -43,7 +60,10 @@ async fn client_enabled_server_enabled() { Server::builder() .layer( ServiceBuilder::new() - .layer(layer_fn(AssertCorrectAcceptEncoding)) + .layer(layer_fn(|service| AssertCorrectAcceptEncoding { + service, + encoding, + })) .layer(MapResponseBodyLayer::new(move |body| { util::CountBytesBody { inner: body, @@ -59,19 +79,72 @@ async fn client_enabled_server_enabled() { } }); - let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .accept_compressed(CompressionEncoding::Gzip); + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding); + + let expected = match encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", encoding), + }; for _ in 0..3 { let res = client.compress_output_unary(()).await.unwrap(); - assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); + assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected); let bytes_sent = response_bytes_counter.load(SeqCst); assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE); } } +util::parametrized_tests! { + client_enabled_server_disabled, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn client_enabled_server_disabled(encoding: CompressionEncoding) { + let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); + + let svc = test_server::TestServer::new(Svc::default()); + + let response_bytes_counter = Arc::new(AtomicUsize::new(0)); + + tokio::spawn({ + let response_bytes_counter = response_bytes_counter.clone(); + async move { + Server::builder() + // no compression enable on the server so responses should not be compressed + .layer( + ServiceBuilder::new() + .layer(MapResponseBodyLayer::new(move |body| { + util::CountBytesBody { + inner: body, + counter: response_bytes_counter.clone(), + } + })) + .into_inner(), + ) + .add_service(svc) + .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>(server)])) + .await + .unwrap(); + } + }); + + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding); + + let res = client.compress_output_unary(()).await.unwrap(); + + assert!(res.metadata().get("grpc-encoding").is_none()); + + let bytes_sent = response_bytes_counter.load(SeqCst); + assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); +} + #[tokio::test(flavor = "multi_thread")] -async fn client_enabled_server_disabled() { +async fn client_enabled_server_disabled_multi_encoding() { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); let svc = test_server::TestServer::new(Svc::default()); @@ -101,7 +174,8 @@ async fn client_enabled_server_disabled() { }); let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .accept_compressed(CompressionEncoding::Gzip); + .accept_compressed(CompressionEncoding::Gzip) + .accept_compressed(CompressionEncoding::Zstd); let res = client.compress_output_unary(()).await.unwrap(); @@ -111,8 +185,14 @@ async fn client_enabled_server_disabled() { assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); } -#[tokio::test(flavor = "multi_thread")] -async fn client_disabled() { +util::parametrized_tests! { + client_disabled, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn client_disabled(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); #[derive(Clone, Copy)] @@ -139,8 +219,7 @@ async fn client_disabled() { } } - let svc = - test_server::TestServer::new(Svc::default()).send_compressed(CompressionEncoding::Gzip); + let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding); let response_bytes_counter = Arc::new(AtomicUsize::new(0)); @@ -176,12 +255,17 @@ async fn client_disabled() { assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); } -#[tokio::test(flavor = "multi_thread")] -async fn server_replying_with_unsupported_encoding() { +util::parametrized_tests! { + server_replying_with_unsupported_encoding, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn server_replying_with_unsupported_encoding(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); - let svc = - test_server::TestServer::new(Svc::default()).send_compressed(CompressionEncoding::Gzip); + let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding); fn add_weird_content_encoding(mut response: http::Response) -> http::Response { response @@ -203,8 +287,8 @@ async fn server_replying_with_unsupported_encoding() { .unwrap(); }); - let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .accept_compressed(CompressionEncoding::Gzip); + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding); let status: Status = client.compress_output_unary(()).await.unwrap_err(); assert_eq!(status.code(), tonic::Code::Unimplemented); @@ -214,14 +298,20 @@ async fn server_replying_with_unsupported_encoding() { ); } -#[tokio::test(flavor = "multi_thread")] -async fn disabling_compression_on_single_response() { +util::parametrized_tests! { + disabling_compression_on_single_response, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn disabling_compression_on_single_response(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); let svc = test_server::TestServer::new(Svc { disable_compressing_on_response: true, }) - .send_compressed(CompressionEncoding::Gzip); + .send_compressed(encoding); let response_bytes_counter = Arc::new(AtomicUsize::new(0)); @@ -246,23 +336,38 @@ async fn disabling_compression_on_single_response() { } }); - let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .accept_compressed(CompressionEncoding::Gzip); + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding); let res = client.compress_output_unary(()).await.unwrap(); - assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); + + let expected = match encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", encoding), + }; + assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected); + let bytes_sent = response_bytes_counter.load(SeqCst); assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); } -#[tokio::test(flavor = "multi_thread")] -async fn disabling_compression_on_response_but_keeping_compression_on_stream() { +util::parametrized_tests! { + disabling_compression_on_response_but_keeping_compression_on_stream, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn disabling_compression_on_response_but_keeping_compression_on_stream( + encoding: CompressionEncoding, +) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); let svc = test_server::TestServer::new(Svc { disable_compressing_on_response: true, }) - .send_compressed(CompressionEncoding::Gzip); + .send_compressed(encoding); let response_bytes_counter = Arc::new(AtomicUsize::new(0)); @@ -287,12 +392,17 @@ async fn disabling_compression_on_response_but_keeping_compression_on_stream() { } }); - let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .accept_compressed(CompressionEncoding::Gzip); + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding); let res = client.compress_output_server_stream(()).await.unwrap(); - assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); + let expected = match encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", encoding), + }; + assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected); let mut stream: Streaming = res.into_inner(); @@ -311,14 +421,20 @@ async fn disabling_compression_on_response_but_keeping_compression_on_stream() { assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE); } -#[tokio::test(flavor = "multi_thread")] -async fn disabling_compression_on_response_from_client_stream() { +util::parametrized_tests! { + disabling_compression_on_response_from_client_stream, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn disabling_compression_on_response_from_client_stream(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); let svc = test_server::TestServer::new(Svc { disable_compressing_on_response: true, }) - .send_compressed(CompressionEncoding::Gzip); + .send_compressed(encoding); let response_bytes_counter = Arc::new(AtomicUsize::new(0)); @@ -343,14 +459,20 @@ async fn disabling_compression_on_response_from_client_stream() { } }); - let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .accept_compressed(CompressionEncoding::Gzip); + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding); let stream = tokio_stream::iter(vec![]); let req = Request::new(Box::pin(stream)); let res = client.compress_output_client_stream(req).await.unwrap(); - assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); + + let expected = match encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", encoding), + }; + assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected); let bytes_sent = response_bytes_counter.load(SeqCst); assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE); } diff --git a/tests/compression/src/server_stream.rs b/tests/compression/src/server_stream.rs index 453c055c8..ada653bc7 100644 --- a/tests/compression/src/server_stream.rs +++ b/tests/compression/src/server_stream.rs @@ -2,12 +2,17 @@ use super::*; use tonic::codec::CompressionEncoding; use tonic::Streaming; -#[tokio::test(flavor = "multi_thread")] -async fn client_enabled_server_enabled() { +util::parametrized_tests! { + client_enabled_server_enabled, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn client_enabled_server_enabled(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); - let svc = - test_server::TestServer::new(Svc::default()).send_compressed(CompressionEncoding::Gzip); + let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding); let response_bytes_counter = Arc::new(AtomicUsize::new(0)); @@ -32,12 +37,17 @@ async fn client_enabled_server_enabled() { } }); - let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .accept_compressed(CompressionEncoding::Gzip); + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding); let res = client.compress_output_server_stream(()).await.unwrap(); - assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip"); + let expected = match encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", encoding), + }; + assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected); let mut stream: Streaming = res.into_inner(); @@ -56,12 +66,17 @@ async fn client_enabled_server_enabled() { assert!(response_bytes_counter.load(SeqCst) < UNCOMPRESSED_MIN_BODY_SIZE); } -#[tokio::test(flavor = "multi_thread")] -async fn client_disabled_server_enabled() { +util::parametrized_tests! { + client_disabled_server_enabled, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn client_disabled_server_enabled(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); - let svc = - test_server::TestServer::new(Svc::default()).send_compressed(CompressionEncoding::Gzip); + let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding); let response_bytes_counter = Arc::new(AtomicUsize::new(0)); @@ -102,8 +117,14 @@ async fn client_disabled_server_enabled() { assert!(response_bytes_counter.load(SeqCst) > UNCOMPRESSED_MIN_BODY_SIZE); } -#[tokio::test(flavor = "multi_thread")] -async fn client_enabled_server_disabled() { +util::parametrized_tests! { + client_enabled_server_disabled, + zstd: CompressionEncoding::Zstd, + gzip: CompressionEncoding::Gzip, +} + +#[allow(dead_code)] +async fn client_enabled_server_disabled(encoding: CompressionEncoding) { let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10); let svc = test_server::TestServer::new(Svc::default()); @@ -131,8 +152,8 @@ async fn client_enabled_server_disabled() { } }); - let mut client = test_client::TestClient::new(mock_io_channel(client).await) - .accept_compressed(CompressionEncoding::Gzip); + let mut client = + test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding); let res = client.compress_output_server_stream(()).await.unwrap(); diff --git a/tests/compression/src/util.rs b/tests/compression/src/util.rs index 57c3f7dbf..64b960be8 100644 --- a/tests/compression/src/util.rs +++ b/tests/compression/src/util.rs @@ -12,9 +12,26 @@ use std::{ task::{Context, Poll}, }; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tonic::codec::CompressionEncoding; use tonic::transport::{server::Connected, Channel}; use tower_http::map_request_body::MapRequestBodyLayer; +macro_rules! parametrized_tests { + ($fn_name:ident, $($test_name:ident: $input:expr),+ $(,)?) => { + paste::paste! { + $( + #[tokio::test(flavor = "multi_thread")] + async fn [<$fn_name _ $test_name>]() { + let input = $input; + $fn_name(input).await; + } + )+ + } + } +} + +pub(crate) use parametrized_tests; + /// A body that tracks how many bytes passes through it #[pin_project] pub struct CountBytesBody { @@ -100,3 +117,26 @@ pub async fn mock_io_channel(client: tokio::io::DuplexStream) -> Channel { .await .unwrap() } + +#[derive(Clone)] +pub struct AssertRightEncoding { + encoding: CompressionEncoding, +} + +#[allow(dead_code)] +impl AssertRightEncoding { + pub fn new(encoding: CompressionEncoding) -> Self { + Self { encoding } + } + + pub fn call(self, req: http::Request) -> http::Request { + let expected = match self.encoding { + CompressionEncoding::Gzip => "gzip", + CompressionEncoding::Zstd => "zstd", + _ => panic!("unexpected encoding {:?}", self.encoding), + }; + assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected); + + req + } +} diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index ef3222fce..dc1786eb4 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -25,6 +25,7 @@ version = "0.9.2" [features] codegen = ["dep:async-trait"] gzip = ["dep:flate2"] +zstd = ["dep:zstd"] default = ["transport", "codegen", "prost"] prost = ["dep:prost"] tls = ["dep:rustls-pemfile", "transport", "dep:tokio-rustls", "dep:async-stream", "tokio/rt"] @@ -83,6 +84,7 @@ webpki-roots = { version = "0.25.0", optional = true } # compression flate2 = {version = "1.0", optional = true} +zstd = { version = "0.12.3", optional = true } [dev-dependencies] bencher = "0.1.5" diff --git a/tonic/src/codec/compression.rs b/tonic/src/codec/compression.rs index 7063bd865..bf94ca3fd 100644 --- a/tonic/src/codec/compression.rs +++ b/tonic/src/codec/compression.rs @@ -4,6 +4,8 @@ use bytes::{Buf, BytesMut}; #[cfg(feature = "gzip")] use flate2::read::{GzDecoder, GzEncoder}; use std::fmt; +#[cfg(feature = "zstd")] +use zstd::stream::read::{Decoder, Encoder}; pub(crate) const ENCODING_HEADER: &str = "grpc-encoding"; pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding"; @@ -13,6 +15,8 @@ pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding"; pub struct EnabledCompressionEncodings { #[cfg(feature = "gzip")] pub(crate) gzip: bool, + #[cfg(feature = "zstd")] + pub(crate) zstd: bool, } impl EnabledCompressionEncodings { @@ -21,6 +25,8 @@ impl EnabledCompressionEncodings { match encoding { #[cfg(feature = "gzip")] CompressionEncoding::Gzip => self.gzip, + #[cfg(feature = "zstd")] + CompressionEncoding::Zstd => self.zstd, } } @@ -29,14 +35,17 @@ impl EnabledCompressionEncodings { match encoding { #[cfg(feature = "gzip")] CompressionEncoding::Gzip => self.gzip = true, + #[cfg(feature = "zstd")] + CompressionEncoding::Zstd => self.zstd = true, } } pub(crate) fn into_accept_encoding_header_value(self) -> Option { - if self.is_gzip_enabled() { - Some(http::HeaderValue::from_static("gzip,identity")) - } else { - None + match (self.is_gzip_enabled(), self.is_zstd_enabled()) { + (true, false) => Some(http::HeaderValue::from_static("gzip,identity")), + (false, true) => Some(http::HeaderValue::from_static("zstd,identity")), + (true, true) => Some(http::HeaderValue::from_static("gzip,zstd,identity")), + (false, false) => None, } } @@ -49,6 +58,16 @@ impl EnabledCompressionEncodings { const fn is_gzip_enabled(&self) -> bool { false } + + #[cfg(feature = "zstd")] + const fn is_zstd_enabled(&self) -> bool { + self.zstd + } + + #[cfg(not(feature = "zstd"))] + const fn is_zstd_enabled(&self) -> bool { + false + } } /// The compression encodings Tonic supports. @@ -59,6 +78,10 @@ pub enum CompressionEncoding { #[cfg(feature = "gzip")] #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))] Gzip, + #[allow(missing_docs)] + #[cfg(feature = "zstd")] + #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))] + Zstd, } impl CompressionEncoding { @@ -67,7 +90,7 @@ impl CompressionEncoding { map: &http::HeaderMap, enabled_encodings: EnabledCompressionEncodings, ) -> Option { - if !enabled_encodings.is_gzip_enabled() { + if !enabled_encodings.is_gzip_enabled() && !enabled_encodings.is_zstd_enabled() { return None; } @@ -77,6 +100,8 @@ impl CompressionEncoding { split_by_comma(header_value_str).find_map(|value| match value { #[cfg(feature = "gzip")] "gzip" => Some(CompressionEncoding::Gzip), + #[cfg(feature = "zstd")] + "zstd" => Some(CompressionEncoding::Zstd), _ => None, }) } @@ -103,6 +128,10 @@ impl CompressionEncoding { "gzip" if enabled_encodings.is_enabled(CompressionEncoding::Gzip) => { Ok(Some(CompressionEncoding::Gzip)) } + #[cfg(feature = "zstd")] + "zstd" if enabled_encodings.is_enabled(CompressionEncoding::Zstd) => { + Ok(Some(CompressionEncoding::Zstd)) + } "identity" => Ok(None), other => { let mut status = Status::unimplemented(format!( @@ -123,17 +152,26 @@ impl CompressionEncoding { } } - pub(crate) fn into_header_value(self) -> http::HeaderValue { + #[allow(missing_docs)] + pub(crate) fn as_str(&self) -> &'static str { match self { #[cfg(feature = "gzip")] - CompressionEncoding::Gzip => http::HeaderValue::from_static("gzip"), + CompressionEncoding::Gzip => "gzip", + #[cfg(feature = "zstd")] + CompressionEncoding::Zstd => "zstd", } } + pub(crate) fn into_header_value(self) -> http::HeaderValue { + http::HeaderValue::from_static(self.as_str()) + } + pub(crate) fn encodings() -> &'static [Self] { &[ #[cfg(feature = "gzip")] CompressionEncoding::Gzip, + #[cfg(feature = "zstd")] + CompressionEncoding::Zstd, ] } } @@ -144,6 +182,8 @@ impl fmt::Display for CompressionEncoding { match *self { #[cfg(feature = "gzip")] CompressionEncoding::Gzip => write!(f, "gzip"), + #[cfg(feature = "zstd")] + CompressionEncoding::Zstd => write!(f, "zstd"), } } } @@ -162,6 +202,7 @@ pub(crate) fn compress( ) -> Result<(), std::io::Error> { let capacity = ((len / BUFFER_SIZE) + 1) * BUFFER_SIZE; out_buf.reserve(capacity); + let mut out_writer = bytes::BufMut::writer(out_buf); match encoding { #[cfg(feature = "gzip")] @@ -171,10 +212,17 @@ pub(crate) fn compress( // FIXME: support customizing the compression level flate2::Compression::new(6), ); - let mut out_writer = bytes::BufMut::writer(out_buf); - std::io::copy(&mut gzip_encoder, &mut out_writer)?; } + #[cfg(feature = "zstd")] + CompressionEncoding::Zstd => { + let mut zstd_encoder = Encoder::new( + &decompressed_buf[0..len], + // FIXME: support customizing the compression level + zstd::DEFAULT_COMPRESSION_LEVEL, + )?; + std::io::copy(&mut zstd_encoder, &mut out_writer)?; + } } decompressed_buf.advance(len); @@ -193,15 +241,19 @@ pub(crate) fn decompress( let estimate_decompressed_len = len * 2; let capacity = ((estimate_decompressed_len / BUFFER_SIZE) + 1) * BUFFER_SIZE; out_buf.reserve(capacity); + let mut out_writer = bytes::BufMut::writer(out_buf); match encoding { #[cfg(feature = "gzip")] CompressionEncoding::Gzip => { let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]); - let mut out_writer = bytes::BufMut::writer(out_buf); - std::io::copy(&mut gzip_decoder, &mut out_writer)?; } + #[cfg(feature = "zstd")] + CompressionEncoding::Zstd => { + let mut zstd_decoder = Decoder::new(&compressed_buf[0..len])?; + std::io::copy(&mut zstd_decoder, &mut out_writer)?; + } } compressed_buf.advance(len);