diff --git a/src/filter/http_context.rs b/src/filter/http_context.rs index cbc7eea9..d890c4f0 100644 --- a/src/filter/http_context.rs +++ b/src/filter/http_context.rs @@ -1,9 +1,7 @@ -use crate::attribute::store_metadata; -use crate::configuration::{ExtensionType, FailureMode, FilterConfig}; -use crate::envoy::{CheckResponse_oneof_http_response, RateLimitResponse, RateLimitResponse_Code}; +use crate::configuration::{FailureMode, FilterConfig}; use crate::operation_dispatcher::OperationDispatcher; use crate::policy::Policy; -use crate::service::grpc_message::GrpcMessageResponse; +use crate::service::GrpcService; use log::{debug, warn}; use proxy_wasm::traits::{Context, HttpContext}; use proxy_wasm::types::Action; @@ -59,108 +57,6 @@ impl Filter { Action::Continue } } - - fn handle_error_on_grpc_response(&self, failure_mode: &FailureMode) { - match failure_mode { - FailureMode::Deny => { - self.send_http_response(500, vec![], Some(b"Internal Server Error.\n")) - } - FailureMode::Allow => self.resume_http_request(), - } - } - - fn process_ratelimit_grpc_response( - &mut self, - rl_resp: GrpcMessageResponse, - failure_mode: &FailureMode, - ) { - match rl_resp { - GrpcMessageResponse::RateLimit(RateLimitResponse { - overall_code: RateLimitResponse_Code::UNKNOWN, - .. - }) => { - self.handle_error_on_grpc_response(failure_mode); - } - GrpcMessageResponse::RateLimit(RateLimitResponse { - overall_code: RateLimitResponse_Code::OVER_LIMIT, - response_headers_to_add: rl_headers, - .. - }) => { - let mut response_headers = vec![]; - for header in &rl_headers { - response_headers.push((header.get_key(), header.get_value())); - } - self.send_http_response(429, response_headers, Some(b"Too Many Requests\n")); - } - GrpcMessageResponse::RateLimit(RateLimitResponse { - overall_code: RateLimitResponse_Code::OK, - response_headers_to_add: additional_headers, - .. - }) => { - for header in additional_headers { - self.response_headers_to_add - .push((header.key, header.value)); - } - } - _ => {} - } - self.operation_dispatcher.borrow_mut().next(); - } - - fn process_auth_grpc_response( - &self, - auth_resp: GrpcMessageResponse, - failure_mode: &FailureMode, - ) { - if let GrpcMessageResponse::Auth(check_response) = auth_resp { - // store dynamic metadata in filter state - store_metadata(check_response.get_dynamic_metadata()); - - match check_response.http_response { - Some(CheckResponse_oneof_http_response::ok_response(ok_response)) => { - debug!( - "#{} process_auth_grpc_response: received OkHttpResponse", - self.context_id - ); - - ok_response - .get_response_headers_to_add() - .iter() - .for_each(|header| { - self.add_http_response_header( - header.get_header().get_key(), - header.get_header().get_value(), - ) - }); - } - Some(CheckResponse_oneof_http_response::denied_response(denied_response)) => { - debug!( - "#{} process_auth_grpc_response: received DeniedHttpResponse", - self.context_id - ); - - let mut response_headers = vec![]; - denied_response.get_headers().iter().for_each(|header| { - response_headers.push(( - header.get_header().get_key(), - header.get_header().get_value(), - )) - }); - self.send_http_response( - denied_response.get_status().code as u32, - response_headers, - Some(denied_response.get_body().as_ref()), - ); - return; - } - None => { - self.handle_error_on_grpc_response(failure_mode); - return; - } - } - } - self.operation_dispatcher.borrow_mut().next(); - } } impl HttpContext for Filter { @@ -209,30 +105,8 @@ impl Context for Filter { let some_op = self.operation_dispatcher.borrow().get_operation(token_id); if let Some(operation) = some_op { - let failure_mode = &operation.get_failure_mode(); - let res_body_bytes = match self.get_grpc_call_response_body(0, resp_size) { - Some(bytes) => bytes, - None => { - warn!("grpc response body is empty!"); - self.handle_error_on_grpc_response(failure_mode); - return; - } - }; - let res = - match GrpcMessageResponse::new(operation.get_extension_type(), &res_body_bytes) { - Ok(res) => res, - Err(e) => { - warn!( - "failed to parse grpc response body into GrpcMessageResponse message: {e}" - ); - self.handle_error_on_grpc_response(failure_mode); - return; - } - }; - match operation.get_extension_type() { - ExtensionType::Auth => self.process_auth_grpc_response(res, failure_mode), - ExtensionType::RateLimit => self.process_ratelimit_grpc_response(res, failure_mode), - } + GrpcService::process_grpc_response(operation, resp_size); + self.operation_dispatcher.borrow_mut().next(); if let Some(_op) = self.operation_dispatcher.borrow_mut().next() { } else { @@ -240,7 +114,7 @@ impl Context for Filter { } } else { warn!("No Operation found with token_id: {token_id}"); - self.handle_error_on_grpc_response(&FailureMode::Deny); // TODO(didierofrivia): Decide on what's the default failure mode + GrpcService::handle_error_on_grpc_response(&FailureMode::Deny); // TODO(didierofrivia): Decide on what's the default failure mode } } } diff --git a/src/service.rs b/src/service.rs index 9e988a1a..c16281e7 100644 --- a/src/service.rs +++ b/src/service.rs @@ -3,12 +3,15 @@ pub(crate) mod grpc_message; pub(crate) mod rate_limit; use crate::configuration::{Action, Extension, ExtensionType, FailureMode}; -use crate::service::auth::{AUTH_METHOD_NAME, AUTH_SERVICE_NAME}; -use crate::service::grpc_message::GrpcMessageRequest; -use crate::service::rate_limit::{RATELIMIT_METHOD_NAME, RATELIMIT_SERVICE_NAME}; +use crate::operation_dispatcher::Operation; +use crate::service::auth::{AuthService, AUTH_METHOD_NAME, AUTH_SERVICE_NAME}; +use crate::service::grpc_message::{GrpcMessageRequest, GrpcMessageResponse}; +use crate::service::rate_limit::{RateLimitService, RATELIMIT_METHOD_NAME, RATELIMIT_SERVICE_NAME}; use crate::service::TracingHeader::{Baggage, Traceparent, Tracestate}; +use log::warn; use protobuf::Message; -use proxy_wasm::types::{Bytes, MapType, Status}; +use proxy_wasm::hostcalls; +use proxy_wasm::types::{BufferType, Bytes, MapType, Status}; use std::cell::OnceCell; use std::rc::Rc; use std::time::Duration; @@ -50,6 +53,43 @@ impl GrpcService { pub fn failure_mode(&self) -> &FailureMode { &self.extension.failure_mode } + + pub fn process_grpc_response(operation: Rc, resp_size: usize) { + let failure_mode = operation.get_failure_mode(); + let res_body_bytes = + match hostcalls::get_buffer(BufferType::GrpcReceiveBuffer, 0, resp_size).unwrap() { + Some(bytes) => bytes, + None => { + warn!("grpc response body is empty!"); + GrpcService::handle_error_on_grpc_response(failure_mode); + return; + } + }; + let res = match GrpcMessageResponse::new(operation.get_extension_type(), &res_body_bytes) { + Ok(res) => res, + Err(e) => { + warn!("failed to parse grpc response body into GrpcMessageResponse message: {e}"); + GrpcService::handle_error_on_grpc_response(failure_mode); + return; + } + }; + match operation.get_extension_type() { + ExtensionType::Auth => AuthService::process_auth_grpc_response(res, failure_mode), + ExtensionType::RateLimit => { + RateLimitService::process_ratelimit_grpc_response(res, failure_mode) + } + } + } + + pub fn handle_error_on_grpc_response(failure_mode: &FailureMode) { + match failure_mode { + FailureMode::Deny => { + hostcalls::send_http_response(500, vec![], Some(b"Internal Server Error.\n")) + .unwrap() + } + FailureMode::Allow => hostcalls::resume_http_request().unwrap(), + } + } } pub type GrpcCallFn = fn( diff --git a/src/service/auth.rs b/src/service/auth.rs index b2261ac1..7b2812de 100644 --- a/src/service/auth.rs +++ b/src/service/auth.rs @@ -1,10 +1,14 @@ -use crate::attribute::get_attribute; +use crate::attribute::{get_attribute, store_metadata}; +use crate::configuration::FailureMode; use crate::envoy::{ Address, AttributeContext, AttributeContext_HttpRequest, AttributeContext_Peer, - AttributeContext_Request, CheckRequest, Metadata, SocketAddress, + AttributeContext_Request, CheckRequest, CheckResponse_oneof_http_response, Metadata, + SocketAddress, }; use crate::service::grpc_message::{GrpcMessageResponse, GrpcMessageResult}; +use crate::service::GrpcService; use chrono::{DateTime, FixedOffset, Timelike}; +use log::debug; use protobuf::well_known_types::Timestamp; use protobuf::Message; use proxy_wasm::hostcalls; @@ -87,4 +91,49 @@ impl AuthService { peer.set_address(address); peer } + + pub fn process_auth_grpc_response(auth_resp: GrpcMessageResponse, failure_mode: &FailureMode) { + if let GrpcMessageResponse::Auth(check_response) = auth_resp { + // store dynamic metadata in filter state + store_metadata(check_response.get_dynamic_metadata()); + + match check_response.http_response { + Some(CheckResponse_oneof_http_response::ok_response(ok_response)) => { + debug!("process_auth_grpc_response: received OkHttpResponse"); + + ok_response + .get_response_headers_to_add() + .iter() + .for_each(|header| { + hostcalls::add_map_value( + MapType::HttpResponseHeaders, + header.get_header().get_key(), + header.get_header().get_value(), + ) + .unwrap() + }); + } + Some(CheckResponse_oneof_http_response::denied_response(denied_response)) => { + debug!("process_auth_grpc_response: received DeniedHttpResponse",); + + let mut response_headers = vec![]; + denied_response.get_headers().iter().for_each(|header| { + response_headers.push(( + header.get_header().get_key(), + header.get_header().get_value(), + )) + }); + hostcalls::send_http_response( + denied_response.get_status().code as u32, + response_headers, + Some(denied_response.get_body().as_ref()), + ) + .unwrap(); + } + None => { + GrpcService::handle_error_on_grpc_response(failure_mode); + } + } + } + } } diff --git a/src/service/rate_limit.rs b/src/service/rate_limit.rs index 4a81884a..106538f9 100644 --- a/src/service/rate_limit.rs +++ b/src/service/rate_limit.rs @@ -1,7 +1,12 @@ -use crate::envoy::{RateLimitDescriptor, RateLimitRequest}; +use crate::configuration::FailureMode; +use crate::envoy::{ + RateLimitDescriptor, RateLimitRequest, RateLimitResponse, RateLimitResponse_Code, +}; use crate::service::grpc_message::{GrpcMessageResponse, GrpcMessageResult}; +use crate::service::GrpcService; use protobuf::{Message, RepeatedField}; -use proxy_wasm::types::Bytes; +use proxy_wasm::hostcalls; +use proxy_wasm::types::{Bytes, MapType}; pub const RATELIMIT_SERVICE_NAME: &str = "envoy.service.ratelimit.v3.RateLimitService"; pub const RATELIMIT_METHOD_NAME: &str = "ShouldRateLimit"; @@ -28,6 +33,47 @@ impl RateLimitService { Err(e) => Err(e), } } + + pub fn process_ratelimit_grpc_response( + rl_resp: GrpcMessageResponse, + failure_mode: &FailureMode, + ) { + match rl_resp { + GrpcMessageResponse::RateLimit(RateLimitResponse { + overall_code: RateLimitResponse_Code::UNKNOWN, + .. + }) => { + GrpcService::handle_error_on_grpc_response(failure_mode); + } + GrpcMessageResponse::RateLimit(RateLimitResponse { + overall_code: RateLimitResponse_Code::OVER_LIMIT, + response_headers_to_add: rl_headers, + .. + }) => { + let mut response_headers = vec![]; + for header in &rl_headers { + response_headers.push((header.get_key(), header.get_value())); + } + hostcalls::send_http_response(429, response_headers, Some(b"Too Many Requests\n")) + .unwrap(); + } + GrpcMessageResponse::RateLimit(RateLimitResponse { + overall_code: RateLimitResponse_Code::OK, + response_headers_to_add: additional_headers, + .. + }) => { + additional_headers.iter().for_each(|header| { + hostcalls::add_map_value( + MapType::HttpResponseHeaders, + header.get_key(), + header.get_value(), + ) + .unwrap() + }); + } + _ => {} + } + } } #[cfg(test)] diff --git a/tests/rate_limited.rs b/tests/rate_limited.rs index 21ba4988..16e905e4 100644 --- a/tests/rate_limited.rs +++ b/tests/rate_limited.rs @@ -377,12 +377,6 @@ fn it_passes_additional_headers() { ) .expect_get_buffer_bytes(Some(BufferType::GrpcReceiveBuffer)) .returning(Some(&grpc_response)) - .execute_and_expect(ReturnType::None) - .unwrap(); - - module - .call_proxy_on_response_headers(http_context, 0, false) - .expect_log(Some(LogLevel::Debug), Some("#2 on_http_response_headers")) .expect_add_header_map_value( Some(MapType::HttpResponseHeaders), Some("test"), @@ -393,6 +387,12 @@ fn it_passes_additional_headers() { Some("other"), Some("header value"), ) + .execute_and_expect(ReturnType::None) + .unwrap(); + + module + .call_proxy_on_response_headers(http_context, 0, false) + .expect_log(Some(LogLevel::Debug), Some("#2 on_http_response_headers")) .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); }