Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add zstd compression support #1532

Merged
merged 12 commits into from
Nov 15, 2023
3 changes: 2 additions & 1 deletion tests/compression/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ bytes = "1"
http = "0.2"
http-body = "0.4"
hyper = "0.14.3"
paste = "1.0.12"
pin-project = "1.0"
prost = "0.12"
tokio = {version = "1.0", features = ["macros", "rt-multi-thread", "net"]}
tokio-stream = "0.1"
tonic = {path = "../../tonic", features = ["gzip"]}
tonic = {path = "../../tonic", features = ["gzip", "zstd"]}
tower = {version = "0.4", features = []}
tower-http = {version = "0.4", features = ["map-response-body", "map-request-body"]}

Expand Down
54 changes: 43 additions & 11 deletions tests/compression/src/bidirectional_stream.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,45 @@
use super::*;
use http_body::Body;
use tonic::codec::CompressionEncoding;

#[tokio::test(flavor = "multi_thread")]
async fn client_enabled_server_enabled() {
util::parametrized_tests! {
client_enabled_server_enabled,
zstd: CompressionEncoding::Zstd,
gzip: CompressionEncoding::Gzip,
}

#[allow(dead_code)]
async fn client_enabled_server_enabled(encoding: CompressionEncoding) {
let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);

let svc = test_server::TestServer::new(Svc::default())
.accept_compressed(CompressionEncoding::Gzip)
.send_compressed(CompressionEncoding::Gzip);
.accept_compressed(encoding)
.send_compressed(encoding);

let request_bytes_counter = Arc::new(AtomicUsize::new(0));
let response_bytes_counter = Arc::new(AtomicUsize::new(0));

fn assert_right_encoding<B>(req: http::Request<B>) -> http::Request<B> {
assert_eq!(req.headers().get("grpc-encoding").unwrap(), "gzip");
req
#[derive(Clone)]
pub struct AssertRightEncoding {
encoding: CompressionEncoding,
}

#[allow(dead_code)]
impl AssertRightEncoding {
pub fn new(encoding: CompressionEncoding) -> Self {
Self { encoding }
}

pub fn call<B: Body>(self, req: http::Request<B>) -> http::Request<B> {
let expected = match self.encoding {
CompressionEncoding::Gzip => "gzip",
CompressionEncoding::Zstd => "zstd",
_ => panic!("unexpected encoding {:?}", self.encoding),
};
assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected);

req
}
}

tokio::spawn({
Expand All @@ -24,7 +49,9 @@ async fn client_enabled_server_enabled() {
Server::builder()
.layer(
ServiceBuilder::new()
.map_request(assert_right_encoding)
.map_request(move |req| {
AssertRightEncoding::new(encoding).clone().call(req)
})
.layer(measure_request_body_size_layer(
request_bytes_counter.clone(),
))
Expand All @@ -44,8 +71,8 @@ async fn client_enabled_server_enabled() {
});

let mut client = test_client::TestClient::new(mock_io_channel(client).await)
.send_compressed(CompressionEncoding::Gzip)
.accept_compressed(CompressionEncoding::Gzip);
.send_compressed(encoding)
.accept_compressed(encoding);

let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec();
let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]);
Expand All @@ -56,7 +83,12 @@ async fn client_enabled_server_enabled() {
.await
.unwrap();

assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip");
let expected = match encoding {
CompressionEncoding::Gzip => "gzip",
CompressionEncoding::Zstd => "zstd",
_ => panic!("unexpected encoding {:?}", encoding),
};
assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected);

let mut stream: Streaming<SomeData> = res.into_inner();

Expand Down
108 changes: 81 additions & 27 deletions tests/compression/src/client_stream.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,42 @@
use super::*;
use http_body::Body as _;
use http_body::Body;
use tonic::codec::CompressionEncoding;

#[tokio::test(flavor = "multi_thread")]
async fn client_enabled_server_enabled() {
util::parametrized_tests! {
client_enabled_server_enabled,
zstd: CompressionEncoding::Zstd,
gzip: CompressionEncoding::Gzip,
}

#[allow(dead_code)]
async fn client_enabled_server_enabled(encoding: CompressionEncoding) {
let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);

let svc =
test_server::TestServer::new(Svc::default()).accept_compressed(CompressionEncoding::Gzip);
let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding);

let request_bytes_counter = Arc::new(AtomicUsize::new(0));

fn assert_right_encoding<B>(req: http::Request<B>) -> http::Request<B> {
assert_eq!(req.headers().get("grpc-encoding").unwrap(), "gzip");
req
#[derive(Clone)]
pub struct AssertRightEncoding {
encoding: CompressionEncoding,
}

#[allow(dead_code)]
impl AssertRightEncoding {
pub fn new(encoding: CompressionEncoding) -> Self {
Self { encoding }
}

pub fn call<B: Body>(self, req: http::Request<B>) -> http::Request<B> {
let expected = match self.encoding {
CompressionEncoding::Gzip => "gzip",
CompressionEncoding::Zstd => "zstd",
_ => panic!("unexpected encoding {:?}", self.encoding),
};
assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected);

req
}
}

tokio::spawn({
Expand All @@ -22,7 +45,9 @@ async fn client_enabled_server_enabled() {
Server::builder()
.layer(
ServiceBuilder::new()
.map_request(assert_right_encoding)
.map_request(move |req| {
AssertRightEncoding::new(encoding).clone().call(req)
})
.layer(measure_request_body_size_layer(
request_bytes_counter.clone(),
))
Expand All @@ -35,8 +60,8 @@ async fn client_enabled_server_enabled() {
}
});

let mut client = test_client::TestClient::new(mock_io_channel(client).await)
.send_compressed(CompressionEncoding::Gzip);
let mut client =
test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding);

let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec();
let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]);
Expand All @@ -48,12 +73,17 @@ async fn client_enabled_server_enabled() {
assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
}

#[tokio::test(flavor = "multi_thread")]
async fn client_disabled_server_enabled() {
util::parametrized_tests! {
client_disabled_server_enabled,
zstd: CompressionEncoding::Zstd,
gzip: CompressionEncoding::Gzip,
}

#[allow(dead_code)]
async fn client_disabled_server_enabled(encoding: CompressionEncoding) {
let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);

let svc =
test_server::TestServer::new(Svc::default()).accept_compressed(CompressionEncoding::Gzip);
let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding);

let request_bytes_counter = Arc::new(AtomicUsize::new(0));

Expand Down Expand Up @@ -93,8 +123,14 @@ async fn client_disabled_server_enabled() {
assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE);
}

#[tokio::test(flavor = "multi_thread")]
async fn client_enabled_server_disabled() {
util::parametrized_tests! {
client_enabled_server_disabled,
zstd: CompressionEncoding::Zstd,
gzip: CompressionEncoding::Gzip,
}

#[allow(dead_code)]
async fn client_enabled_server_disabled(encoding: CompressionEncoding) {
let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);

let svc = test_server::TestServer::new(Svc::default());
Expand All @@ -107,8 +143,8 @@ async fn client_enabled_server_disabled() {
.unwrap();
});

let mut client = test_client::TestClient::new(mock_io_channel(client).await)
.send_compressed(CompressionEncoding::Gzip);
let mut client =
test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding);

let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec();
let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]);
Expand All @@ -117,18 +153,31 @@ async fn client_enabled_server_disabled() {
let status = client.compress_input_client_stream(req).await.unwrap_err();

assert_eq!(status.code(), tonic::Code::Unimplemented);
let expected = match encoding {
CompressionEncoding::Gzip => "gzip",
CompressionEncoding::Zstd => "zstd",
_ => panic!("unexpected encoding {:?}", encoding),
};
assert_eq!(
status.message(),
"Content is compressed with `gzip` which isn't supported"
format!(
"Content is compressed with `{}` which isn't supported",
expected
)
);
}

#[tokio::test(flavor = "multi_thread")]
async fn compressing_response_from_client_stream() {
util::parametrized_tests! {
compressing_response_from_client_stream,
zstd: CompressionEncoding::Zstd,
gzip: CompressionEncoding::Gzip,
}

#[allow(dead_code)]
async fn compressing_response_from_client_stream(encoding: CompressionEncoding) {
let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);

let svc =
test_server::TestServer::new(Svc::default()).send_compressed(CompressionEncoding::Gzip);
let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding);

let response_bytes_counter = Arc::new(AtomicUsize::new(0));

Expand All @@ -153,13 +202,18 @@ async fn compressing_response_from_client_stream() {
}
});

let mut client = test_client::TestClient::new(mock_io_channel(client).await)
.accept_compressed(CompressionEncoding::Gzip);
let mut client =
test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding);

let req = Request::new(Box::pin(tokio_stream::empty()));

let res = client.compress_output_client_stream(req).await.unwrap();
assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip");
let expected = match encoding {
CompressionEncoding::Gzip => "gzip",
CompressionEncoding::Zstd => "zstd",
_ => panic!("unexpected encoding {:?}", encoding),
};
assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected);
let bytes_sent = response_bytes_counter.load(SeqCst);
assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
}
Loading
Loading