diff --git a/volo-thrift/src/context.rs b/volo-thrift/src/context.rs index efc260e3..2786ef53 100644 --- a/volo-thrift/src/context.rs +++ b/volo-thrift/src/context.rs @@ -292,7 +292,7 @@ impl std::ops::DerefMut for ServerContext { } pub trait ThriftContext: volo::context::Context + Send + 'static { - fn encode_conn_reset(&self) -> Option; + fn encode_conn_reset(&self) -> bool; fn set_conn_reset_by_ttheader(&mut self, reset: bool); fn handle_decoded_msg_ident(&mut self, ident: &TMessageIdentifier); fn seq_id(&self) -> i32; @@ -307,8 +307,8 @@ pub trait ThriftContext: volo::context::Context + Send + 'stati impl ThriftContext for ClientContext { #[inline] - fn encode_conn_reset(&self) -> Option { - None + fn encode_conn_reset(&self) -> bool { + false } #[inline] @@ -342,12 +342,14 @@ impl ThriftContext for ClientContext { impl ThriftContext for ServerContext { #[inline] - fn encode_conn_reset(&self) -> Option { - Some(self.transport.is_conn_reset()) + fn encode_conn_reset(&self) -> bool { + self.transport.is_conn_reset() } #[inline] - fn set_conn_reset_by_ttheader(&mut self, _reset: bool) {} + fn set_conn_reset_by_ttheader(&mut self, reset: bool) { + self.transport.set_conn_reset(reset) + } #[inline] fn handle_decoded_msg_ident(&mut self, ident: &TMessageIdentifier) { diff --git a/volo-thrift/src/transport/multiplex/server.rs b/volo-thrift/src/transport/multiplex/server.rs index 71b6f601..5476b23b 100644 --- a/volo-thrift/src/transport/multiplex/server.rs +++ b/volo-thrift/src/transport/multiplex/server.rs @@ -12,7 +12,7 @@ use volo::{context::Context, net::Address, volo_unreachable}; use crate::{ codec::{Decoder, Encoder}, - context::ServerContext, + context::{ServerContext, ThriftContext as _}, protocol::TMessageType, server_error_to_application_exception, thrift_exception_to_application_exception, DummyMessage, EntryMessage, ServerError, ThriftMessage, @@ -40,7 +40,8 @@ pub async fn serve( // mpsc channel used to send responses to the loop let (send_tx, mut send_rx) = mpsc::channel(CHANNEL_SIZE); - let (error_send_tx, mut error_send_rx) = mpsc::channel(1); + let (error_send_tx, mut error_send_rx) = + mpsc::channel::<(ServerContext, ThriftMessage)>(1); tokio::spawn({ let peer_addr = peer_addr.clone(); @@ -70,6 +71,9 @@ pub async fn serve( return; } stat_tracer.iter().for_each(|f| f(&cx)); + if cx.encode_conn_reset() { + return; + } } None => { // log it @@ -85,6 +89,7 @@ pub async fn serve( error_msg = error_send_rx.recv() => { match error_msg { Some((mut cx, msg)) => { + cx.set_conn_reset_by_ttheader(true); if let Err(e) = encoder .encode::(&mut cx, msg) .await @@ -185,11 +190,11 @@ pub async fn serve( metainfo::METAINFO .scope(RefCell::new(mi), async move { cx.stats.record_process_start_at(); - let resp = svc.call(&mut cx, req).await; + let resp = svc.call(&mut cx, req).await.map_err(Into::into); cx.stats.record_process_end_at(); if exit_mark.load(Ordering::Relaxed) { - cx.transport.set_conn_reset(true); + cx.set_conn_reset_by_ttheader(true); } let req_msg_type = cx.req_msg_type.expect("`req_msg_type` should be set."); @@ -201,7 +206,7 @@ pub async fn serve( let msg = ThriftMessage::mk_server_resp( &cx, resp.map_err(|e| { - server_error_to_application_exception(e.into()) + server_error_to_application_exception(e) }), ); let mi = metainfo::METAINFO.with(|m| m.take()); diff --git a/volo-thrift/src/transport/pingpong/server.rs b/volo-thrift/src/transport/pingpong/server.rs index 886bfac7..cdc526a0 100644 --- a/volo-thrift/src/transport/pingpong/server.rs +++ b/volo-thrift/src/transport/pingpong/server.rs @@ -12,7 +12,7 @@ use volo::{net::Address, volo_unreachable}; use crate::{ codec::{Decoder, Encoder}, - context::{ServerContext, SERVER_CONTEXT_CACHE}, + context::{ServerContext, ThriftContext, SERVER_CONTEXT_CACHE}, protocol::TMessageType, server_error_to_application_exception, thrift_exception_to_application_exception, tracing::SpanProvider, @@ -81,11 +81,11 @@ pub async fn serve( match msg { Ok(Some(ThriftMessage { data: Ok(req), .. })) => { cx.stats.record_process_start_at(); - let resp = service.call(&mut cx, req).await; + let resp = service.call(&mut cx, req).await.map_err(Into::into); cx.stats.record_process_end_at(); if exit_mark.load(Ordering::Relaxed) { - cx.transport.set_conn_reset(true); + cx.set_conn_reset_by_ttheader(true); } let req_msg_type = @@ -98,9 +98,7 @@ pub async fn serve( }); let msg = ThriftMessage::mk_server_resp( &cx, - resp.map_err(|e| { - server_error_to_application_exception(e.into()) - }), + resp.map_err(|e| server_error_to_application_exception(e)), ); if let Err(e) = async { let result = encoder.encode(&mut cx, msg).await; @@ -119,6 +117,9 @@ pub async fn serve( return Err(()); } } + if cx.transport.is_conn_reset() { + return Err(()); + } } Ok(Some(ThriftMessage { data: Err(_), .. })) => { volo_unreachable!(); @@ -138,6 +139,7 @@ pub async fn serve( e, cx, peer_addr ); cx.msg_type = Some(TMessageType::Exception); + cx.set_conn_reset_by_ttheader(true); if !matches!(e, ThriftException::Transport(_)) { let msg = ThriftMessage::mk_server_resp( &cx,