diff --git a/src/filter/http_context.rs b/src/filter/http_context.rs index 69a8c0c..e9204d1 100644 --- a/src/filter/http_context.rs +++ b/src/filter/http_context.rs @@ -142,7 +142,9 @@ impl Context for Filter { match op_res { Ok(operation) => { - if GrpcService::process_grpc_response(operation, resp_size).is_ok() { + if let Ok(result) = GrpcService::process_grpc_response(operation, resp_size) { + // add the response headers + self.response_headers_to_add.extend(result.response_headers); // call the next op match self.operation_dispatcher.borrow_mut().next() { Ok(some_op) => { diff --git a/src/service.rs b/src/service.rs index 575281a..404671f 100644 --- a/src/service.rs +++ b/src/service.rs @@ -54,7 +54,7 @@ impl GrpcService { pub fn process_grpc_response( operation: Rc, resp_size: usize, - ) -> Result<(), StatusCode> { + ) -> Result { let failure_mode = operation.get_failure_mode(); if let Some(res_body_bytes) = hostcalls::get_buffer(BufferType::GrpcReceiveBuffer, 0, resp_size).unwrap() @@ -92,6 +92,20 @@ impl GrpcService { } } +pub struct GrpcResult { + pub response_headers: Vec<(String, String)>, +} +impl GrpcResult { + pub fn default() -> Self { + Self { + response_headers: Vec::new(), + } + } + pub fn new(response_headers: Vec<(String, String)>) -> Self { + Self { response_headers } + } +} + pub type GrpcCallFn = fn( upstream_name: &str, service_name: &str, diff --git a/src/service/auth.rs b/src/service/auth.rs index a1e4e96..33af068 100644 --- a/src/service/auth.rs +++ b/src/service/auth.rs @@ -6,7 +6,7 @@ use crate::envoy::{ SocketAddress, StatusCode, }; use crate::service::grpc_message::{GrpcMessageResponse, GrpcMessageResult}; -use crate::service::GrpcService; +use crate::service::{GrpcResult, GrpcService}; use chrono::{DateTime, FixedOffset}; use log::{debug, warn}; use protobuf::well_known_types::Timestamp; @@ -125,7 +125,7 @@ impl AuthService { pub fn process_auth_grpc_response( auth_resp: GrpcMessageResponse, failure_mode: FailureMode, - ) -> Result<(), StatusCode> { + ) -> Result { if let GrpcMessageResponse::Auth(check_response) = auth_resp { // store dynamic metadata in filter state store_metadata(check_response.get_dynamic_metadata()); @@ -153,7 +153,7 @@ impl AuthService { ) .unwrap() }); - Ok(()) + Ok(GrpcResult::default()) } Some(CheckResponse_oneof_http_response::denied_response(denied_response)) => { debug!("process_auth_grpc_response: received DeniedHttpResponse"); diff --git a/src/service/rate_limit.rs b/src/service/rate_limit.rs index 2f97f31..4d8f242 100644 --- a/src/service/rate_limit.rs +++ b/src/service/rate_limit.rs @@ -3,11 +3,11 @@ use crate::envoy::{ RateLimitDescriptor, RateLimitRequest, RateLimitResponse, RateLimitResponse_Code, StatusCode, }; use crate::service::grpc_message::{GrpcMessageResponse, GrpcMessageResult}; -use crate::service::GrpcService; +use crate::service::{GrpcResult, GrpcService}; use log::warn; use protobuf::{Message, RepeatedField}; use proxy_wasm::hostcalls; -use proxy_wasm::types::{Bytes, MapType}; +use proxy_wasm::types::Bytes; pub const RATELIMIT_SERVICE_NAME: &str = "envoy.service.ratelimit.v3.RateLimitService"; pub const RATELIMIT_METHOD_NAME: &str = "ShouldRateLimit"; @@ -38,7 +38,7 @@ impl RateLimitService { pub fn process_ratelimit_grpc_response( rl_resp: GrpcMessageResponse, failure_mode: FailureMode, - ) -> Result<(), StatusCode> { + ) -> Result { match rl_resp { GrpcMessageResponse::RateLimit(RateLimitResponse { overall_code: RateLimitResponse_Code::UNKNOWN, @@ -65,16 +65,13 @@ impl RateLimitService { response_headers_to_add: additional_headers, .. }) => { - // TODO: This should not be sent to the upstream! - additional_headers.iter().for_each(|header| { - hostcalls::add_map_value( - MapType::HttpResponseHeaders, - header.get_key(), - header.get_value(), - ) - .unwrap() - }); - Ok(()) + let result = GrpcResult::new( + additional_headers + .iter() + .map(|header| (header.get_key().to_owned(), header.get_value().to_owned())) + .collect(), + ); + Ok(result) } _ => { warn!("not a valid GrpcMessageResponse::RateLimit(RateLimitResponse)!"); diff --git a/tests/rate_limited.rs b/tests/rate_limited.rs index 0c0e72b..b88daa6 100644 --- a/tests/rate_limited.rs +++ b/tests/rate_limited.rs @@ -353,6 +353,12 @@ fn it_passes_additional_headers() { ) .expect_get_buffer_bytes(Some(BufferType::GrpcReceiveBuffer)) .returning(Some(&grpc_response)) + .execute_and_expect(ReturnType::None) + .unwrap(); + + module + .call_proxy_on_response_headers(http_context, 0, false) + .expect_log(Some(LogLevel::Debug), Some("#2 on_http_response_headers")) .expect_add_header_map_value( Some(MapType::HttpResponseHeaders), Some("test"), @@ -363,12 +369,6 @@ fn it_passes_additional_headers() { Some("other"), Some("header value"), ) - .execute_and_expect(ReturnType::None) - .unwrap(); - - module - .call_proxy_on_response_headers(http_context, 0, false) - .expect_log(Some(LogLevel::Debug), Some("#2 on_http_response_headers")) .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); }