From 725d0435165a5b4cd669d70620685ef8e47682a7 Mon Sep 17 00:00:00 2001 From: dd di cesare Date: Mon, 7 Oct 2024 14:45:44 +0200 Subject: [PATCH] [refactor] Calling directly `hostcalls.add_map_value` * Instead of passing the Filter instance of `response_headers_to_add` * Matching the Auth Service Signed-off-by: dd di cesare --- src/filter/http_context.rs | 6 +----- src/service.rs | 14 ++++---------- src/service/rate_limit.rs | 14 +++++++++----- tests/rate_limited.rs | 12 ++++++------ 4 files changed, 20 insertions(+), 26 deletions(-) diff --git a/src/filter/http_context.rs b/src/filter/http_context.rs index 0850d598..d890c4f0 100644 --- a/src/filter/http_context.rs +++ b/src/filter/http_context.rs @@ -105,11 +105,7 @@ impl Context for Filter { let some_op = self.operation_dispatcher.borrow().get_operation(token_id); if let Some(operation) = some_op { - GrpcService::process_grpc_response( - operation, - resp_size, - &mut self.response_headers_to_add, - ); + GrpcService::process_grpc_response(operation, resp_size); self.operation_dispatcher.borrow_mut().next(); if let Some(_op) = self.operation_dispatcher.borrow_mut().next() { diff --git a/src/service.rs b/src/service.rs index 0f12118b..c16281e7 100644 --- a/src/service.rs +++ b/src/service.rs @@ -54,11 +54,7 @@ impl GrpcService { &self.extension.failure_mode } - pub fn process_grpc_response( - operation: Rc, - resp_size: usize, - response_headers_to_add: &mut Vec<(String, String)>, - ) { + pub fn process_grpc_response(operation: Rc, resp_size: usize) { let failure_mode = operation.get_failure_mode(); let res_body_bytes = match hostcalls::get_buffer(BufferType::GrpcReceiveBuffer, 0, resp_size).unwrap() { @@ -79,11 +75,9 @@ impl GrpcService { }; match operation.get_extension_type() { ExtensionType::Auth => AuthService::process_auth_grpc_response(res, failure_mode), - ExtensionType::RateLimit => RateLimitService::process_ratelimit_grpc_response( - res, - failure_mode, - response_headers_to_add, - ), + ExtensionType::RateLimit => { + RateLimitService::process_ratelimit_grpc_response(res, failure_mode) + } } } diff --git a/src/service/rate_limit.rs b/src/service/rate_limit.rs index 589ae77b..106538f9 100644 --- a/src/service/rate_limit.rs +++ b/src/service/rate_limit.rs @@ -6,7 +6,7 @@ use crate::service::grpc_message::{GrpcMessageResponse, GrpcMessageResult}; use crate::service::GrpcService; use protobuf::{Message, RepeatedField}; use proxy_wasm::hostcalls; -use proxy_wasm::types::Bytes; +use proxy_wasm::types::{Bytes, MapType}; pub const RATELIMIT_SERVICE_NAME: &str = "envoy.service.ratelimit.v3.RateLimitService"; pub const RATELIMIT_METHOD_NAME: &str = "ShouldRateLimit"; @@ -37,7 +37,6 @@ impl RateLimitService { pub fn process_ratelimit_grpc_response( rl_resp: GrpcMessageResponse, failure_mode: &FailureMode, - response_headers_to_add: &mut Vec<(String, String)>, ) { match rl_resp { GrpcMessageResponse::RateLimit(RateLimitResponse { @@ -63,9 +62,14 @@ impl RateLimitService { response_headers_to_add: additional_headers, .. }) => { - for header in additional_headers { - response_headers_to_add.push((header.key, header.value)); - } + additional_headers.iter().for_each(|header| { + hostcalls::add_map_value( + MapType::HttpResponseHeaders, + header.get_key(), + header.get_value(), + ) + .unwrap() + }); } _ => {} } diff --git a/tests/rate_limited.rs b/tests/rate_limited.rs index 21ba4988..16e905e4 100644 --- a/tests/rate_limited.rs +++ b/tests/rate_limited.rs @@ -377,12 +377,6 @@ fn it_passes_additional_headers() { ) .expect_get_buffer_bytes(Some(BufferType::GrpcReceiveBuffer)) .returning(Some(&grpc_response)) - .execute_and_expect(ReturnType::None) - .unwrap(); - - module - .call_proxy_on_response_headers(http_context, 0, false) - .expect_log(Some(LogLevel::Debug), Some("#2 on_http_response_headers")) .expect_add_header_map_value( Some(MapType::HttpResponseHeaders), Some("test"), @@ -393,6 +387,12 @@ fn it_passes_additional_headers() { Some("other"), Some("header value"), ) + .execute_and_expect(ReturnType::None) + .unwrap(); + + module + .call_proxy_on_response_headers(http_context, 0, false) + .expect_log(Some(LogLevel::Debug), Some("#2 on_http_response_headers")) .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); }