diff --git a/Cargo.lock b/Cargo.lock index 01785e71..d9dcf569 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2646,6 +2646,15 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "tokio-condvar" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7233b09174540ef9bf9fc8326bcad6ccebc631e7c9a1e2e48d956a133056f9d" +dependencies = [ + "tokio", +] + [[package]] name = "tokio-macros" version = "2.2.0" @@ -3177,6 +3186,7 @@ dependencies = [ "sonic-rs", "thiserror", "tokio", + "tokio-condvar", "tracing", "volo", ] diff --git a/volo-thrift/Cargo.toml b/volo-thrift/Cargo.toml index b1b5dd3f..2e5b2aae 100644 --- a/volo-thrift/Cargo.toml +++ b/volo-thrift/Cargo.toml @@ -48,6 +48,7 @@ tokio = { workspace = true, features = [ "parking_lot", ] } tracing.workspace = true +tokio-condvar = "0.1.0" [features] default = [] diff --git a/volo-thrift/src/client/mod.rs b/volo-thrift/src/client/mod.rs index a5873632..ae06bdc3 100644 --- a/volo-thrift/src/client/mod.rs +++ b/volo-thrift/src/client/mod.rs @@ -463,9 +463,10 @@ impl ClientBuilder +pub struct MessageService where - Resp: EntryMessage + Send + 'static, + Req: EntryMessage + Send + 'static + Sync, + Resp: EntryMessage + Send + 'static + Sync, MkT: MakeTransport, MkC: MakeCodec + Sync, { @@ -474,14 +475,14 @@ where #[cfg(feature = "multiplex")] inner: motore::utils::Either< pingpong::Client, - crate::transport::multiplex::Client, + crate::transport::multiplex::Client, >, read_biz_error: bool, } -impl Service for MessageService +impl Service for MessageService where - Req: EntryMessage + 'static + Send, + Req: Send + 'static + EntryMessage + Sync, Resp: Send + 'static + EntryMessage + Sync, MkT: MakeTransport, MkC: MakeCodec + Sync, @@ -531,8 +532,8 @@ where + Clone + Sync, Req: EntryMessage + Send + 'static + Sync + Clone, - Resp: EntryMessage + Send + 'static, - IL: Layer>, + Resp: EntryMessage + Send + 'static + Sync, + IL: Layer>, IL::Service: Service> + Sync + Clone + Send + 'static, >::Error: Send + Into, diff --git a/volo-thrift/src/codec/default/mod.rs b/volo-thrift/src/codec/default/mod.rs index 1bf98a78..e9e4def6 100644 --- a/volo-thrift/src/codec/default/mod.rs +++ b/volo-thrift/src/codec/default/mod.rs @@ -116,8 +116,7 @@ pub struct DefaultEncoder { impl Encoder for DefaultEncoder { - #[inline] - async fn encode( + async fn send( &mut self, cx: &mut Cx, msg: ThriftMessage, @@ -179,6 +178,52 @@ impl Encoder } // write_result } + + #[inline] + async fn encode( + &mut self, + cx: &mut Cx, + msg: ThriftMessage, + ) -> Result<(), ThriftException> { + cx.stats_mut().record_encode_start_at(); + + // first, we need to get the size of the message + let (real_size, malloc_size) = self.encoder.size(cx, &msg)?; + trace!( + "[VOLO] codec encode message real size: {}, malloc size: {}", + real_size, + malloc_size + ); + cx.stats_mut().set_write_size(real_size); + + // then we reserve the size of the message in the linked bytes + self.linked_bytes.reserve(malloc_size); + // after that, we encode the message into the linked bytes + self.encoder + .encode(cx, &mut self.linked_bytes, msg) + .map_err(|e| { + // record the error time + cx.stats_mut().record_encode_end_at(); + e + })?; + + cx.stats_mut().record_encode_end_at(); + Ok(()) + } + + async fn flush(&mut self) -> Result<(), ThriftException> { + let write_result: Result<(), ThriftException> = self + .linked_bytes + .write_all_vectored(&mut self.writer) + .await + .map_err(|e| e.into()); + write_result?; + self.writer.flush().await.map_err(Into::into) + } + + async fn reset(&mut self) { + self.linked_bytes.reset(); + } } pub struct DefaultDecoder { diff --git a/volo-thrift/src/codec/mod.rs b/volo-thrift/src/codec/mod.rs index 52f729c7..163df16b 100644 --- a/volo-thrift/src/codec/mod.rs +++ b/volo-thrift/src/codec/mod.rs @@ -26,11 +26,19 @@ pub trait Decoder: Send + 'static { /// /// Note: [`Encoder`] should be designed to be ready for reuse. pub trait Encoder: Send + 'static { + fn reset(&mut self) -> impl Future + Send; + fn send( + &mut self, + cx: &mut Cx, + msg: ThriftMessage, + ) -> impl Future> + Send; fn encode( &mut self, cx: &mut Cx, msg: ThriftMessage, ) -> impl Future> + Send; + + fn flush(&mut self) -> impl Future> + Send; } /// [`MakeCodec`] receives an [`AsyncRead`] and an [`AsyncWrite`] and returns a diff --git a/volo-thrift/src/transport/mod.rs b/volo-thrift/src/transport/mod.rs index f48bcafb..b881cb4f 100644 --- a/volo-thrift/src/transport/mod.rs +++ b/volo-thrift/src/transport/mod.rs @@ -1,5 +1,4 @@ pub(crate) mod incoming; -#[cfg(feature = "multiplex")] pub mod multiplex; pub mod pingpong; pub mod pool; diff --git a/volo-thrift/src/transport/multiplex/client.rs b/volo-thrift/src/transport/multiplex/client.rs index e874fca1..c17d614c 100644 --- a/volo-thrift/src/transport/multiplex/client.rs +++ b/volo-thrift/src/transport/multiplex/client.rs @@ -14,18 +14,18 @@ use crate::{ ClientError, EntryMessage, ThriftMessage, }; -pub struct MakeClientTransport +pub struct MakeClientTransport where MkT: MakeTransport, MkC: MakeCodec, { make_transport: MkT, make_codec: MkC, - _phantom: PhantomData Resp>, + _phantom: PhantomData<(fn() -> Resp, fn() -> Req)>, } -impl, Resp> Clone - for MakeClientTransport +impl, Req, Resp> Clone + for MakeClientTransport { fn clone(&self) -> Self { Self { @@ -36,7 +36,7 @@ impl, Resp> Cl } } -impl MakeClientTransport +impl MakeClientTransport where MkT: MakeTransport, MkC: MakeCodec, @@ -51,13 +51,14 @@ where } } -impl UnaryService
for MakeClientTransport +impl UnaryService
for MakeClientTransport where MkT: MakeTransport, MkC: MakeCodec + Sync, - Resp: EntryMessage + Send + 'static, + Resp: EntryMessage + Send + 'static + Sync, + Req: EntryMessage + Send + 'static + Sync, { - type Response = ThriftTransport; + type Response = ThriftTransport; type Error = io::Error; async fn call(&self, target: Address) -> Result { @@ -72,22 +73,24 @@ where } } -pub struct Client +pub struct Client where MkT: MakeTransport, MkC: MakeCodec + Sync, - Resp: EntryMessage + Send + 'static, + Resp: EntryMessage + Send + 'static + Sync, + Req: EntryMessage + Send + 'static + Sync, { #[allow(clippy::type_complexity)] - make_transport: PooledMakeTransport, Address>, + make_transport: PooledMakeTransport, Address>, _marker: PhantomData, } -impl Clone for Client +impl Clone for Client where MkT: MakeTransport, MkC: MakeCodec + Sync, - Resp: EntryMessage + Send + 'static, + Resp: EntryMessage + Send + 'static + Sync, + Req: EntryMessage + Send + 'static + Sync, { fn clone(&self) -> Self { Self { @@ -97,11 +100,12 @@ where } } -impl Client +impl Client where MkT: MakeTransport, MkC: MakeCodec + Sync, - Resp: EntryMessage + Send + 'static, + Resp: EntryMessage + Send + 'static + Sync, + Req: EntryMessage + Send + 'static + Sync, { pub fn new(make_transport: MkT, pool_cfg: Option, make_codec: MkC) -> Self { let make_transport = MakeClientTransport::new(make_transport, make_codec); @@ -113,9 +117,9 @@ where } } -impl Service> for Client +impl Service> for Client where - Req: Send + 'static + EntryMessage, + Req: Send + 'static + EntryMessage + Sync, Resp: EntryMessage + Send + 'static + Sync, MkT: MakeTransport, MkC: MakeCodec + Sync, diff --git a/volo-thrift/src/transport/multiplex/mod.rs b/volo-thrift/src/transport/multiplex/mod.rs index 22ce370d..3e222966 100644 --- a/volo-thrift/src/transport/multiplex/mod.rs +++ b/volo-thrift/src/transport/multiplex/mod.rs @@ -1,6 +1,7 @@ mod client; mod server; mod thrift_transport; +pub mod utils; pub use client::Client; pub use server::serve; diff --git a/volo-thrift/src/transport/multiplex/server.rs b/volo-thrift/src/transport/multiplex/server.rs index 71b6f601..4c27f708 100644 --- a/volo-thrift/src/transport/multiplex/server.rs +++ b/volo-thrift/src/transport/multiplex/server.rs @@ -56,7 +56,7 @@ pub async fn serve( if let Err(e) = metainfo::METAINFO .scope( RefCell::new(mi), - encoder.encode::(&mut cx, msg), + encoder.send::(&mut cx, msg), ) .await { diff --git a/volo-thrift/src/transport/multiplex/thrift_transport.rs b/volo-thrift/src/transport/multiplex/thrift_transport.rs index edb126c6..2b555158 100644 --- a/volo-thrift/src/transport/multiplex/thrift_transport.rs +++ b/volo-thrift/src/transport/multiplex/thrift_transport.rs @@ -1,5 +1,7 @@ use std::{ cell::RefCell, + collections::VecDeque, + marker::PhantomData, sync::{ atomic::{AtomicBool, AtomicUsize}, Arc, @@ -13,6 +15,7 @@ use tokio::{ io::{AsyncRead, AsyncWrite}, sync::{oneshot, Mutex}, }; +use tokio_condvar::Condvar; use volo::{ context::{Role, RpcInfo}, net::Address, @@ -21,7 +24,10 @@ use volo::{ use crate::{ codec::{Decoder, Encoder, MakeCodec}, context::{ClientContext, ThriftContext}, - transport::pool::{Poolable, Reservation}, + transport::{ + multiplex::utils::TxHashMap, + pool::{Poolable, Reservation}, + }, ClientError, EntryMessage, ThriftMessage, }; @@ -31,17 +37,14 @@ lazy_static::lazy_static! { } #[pin_project] -pub struct ThriftTransport { - write_half: Arc>>, - dirty: Arc, +pub struct ThriftTransport { + _phantom1: PhantomData E>, + #[allow(clippy::type_complexity)] tx_map: Arc< - Mutex< - rustc_hash::FxHashMap< - i32, - oneshot::Sender< - Result)>, ClientError>, - >, + TxHashMap< + oneshot::Sender< + Result)>, ClientError>, >, >, >, @@ -50,25 +53,120 @@ pub struct ThriftTransport { read_error: Arc, // read connection is closed read_closed: Arc, + // TODO make this to lockless + batch_queue: Arc>>>, + queue_cv: Arc, } -impl Clone for ThriftTransport { +impl Clone for ThriftTransport { fn clone(&self) -> Self { Self { - write_half: self.write_half.clone(), - dirty: self.dirty.clone(), tx_map: self.tx_map.clone(), write_error: self.write_error.clone(), read_error: self.read_error.clone(), read_closed: self.read_closed.clone(), + batch_queue: self.batch_queue.clone(), + _phantom1: PhantomData, + queue_cv: self.queue_cv.clone(), } } } -impl ThriftTransport +impl ThriftTransport where E: Encoder, + Req: EntryMessage + Send + 'static + Sync, + Resp: EntryMessage + Send + 'static + Sync, { + pub fn write_loop(&self, mut write_half: WriteHalf) { + let batch_queu = self.batch_queue.clone(); + let inner_tx_map = self.tx_map.clone(); + let inner_read_error: Arc = self.read_error.clone(); + let inner_read_closed = self.read_closed.clone(); + let inner_write_error = self.write_error.clone(); + let queue_cv = self.queue_cv.clone(); + tokio::spawn(async move { + let mut resolved = Vec::with_capacity(32); + let mut has_error; + loop { + { + resolved.clear(); + write_half.reset().await; + has_error = false; + let mut queue = batch_queu.lock().await; + while queue.is_empty() + && !inner_read_error.load(std::sync::atomic::Ordering::Relaxed) + && !inner_read_closed.load(std::sync::atomic::Ordering::Relaxed) + { + queue = queue_cv.wait(queue).await; + } + + if inner_read_error.load(std::sync::atomic::Ordering::Relaxed) + || inner_read_closed.load(std::sync::atomic::Ordering::Relaxed) + { + return; + } + + while !queue.is_empty() { + let current = queue.pop_front().unwrap(); + let seq = current.meta.seq_id; + resolved.push(seq); + let mut cx = ClientContext::new( + seq, + RpcInfo::with_role(Role::Client), + pilota::thrift::TMessageType::Call, + ); + let res = write_half.encode(&mut cx, current).await; + match res { + Ok(_) => {} + Err(err) => { + tracing::error!( + "[VOLO] multiplex connection encode error: {}", + err + ); + inner_write_error.store(true, std::sync::atomic::Ordering::Relaxed); + has_error = true; + while !queue.is_empty() { + let current = queue.pop_front().unwrap(); + resolved.push(current.meta.seq_id); + } + break; + } + } + } + if has_error { + for seq in resolved.iter() { + let _ = inner_tx_map.remove(seq).await.unwrap().send(Err( + ClientError::Application(ApplicationException::new( + ApplicationExceptionKind::UNKNOWN, + format!("write error "), + )), + )); + } + return; + } + let res = write_half.flush().await; + match res { + Ok(_) => {} + Err(err) => { + tracing::error!("[VOLO] multiplex connection flush error: {}", err,); + inner_write_error.store(true, std::sync::atomic::Ordering::Relaxed); + for seq in resolved.iter() { + let _ = inner_tx_map.remove(&seq).await.unwrap().send(Err( + ClientError::Application(ApplicationException::new( + ApplicationExceptionKind::UNKNOWN, + err.to_string(), + )), + )); + } + return; + } + } + } + } + }); + } + pub fn new< R: AsyncRead + Send + Sync + Unpin + 'static, W: AsyncWrite + Send + Sync + Unpin + 'static, @@ -92,12 +190,9 @@ where let write_half = WriteHalf { encoder, id }; #[allow(clippy::type_complexity)] let tx_map: Arc< - Mutex< - rustc_hash::FxHashMap< - i32, - oneshot::Sender< - Result)>, ClientError>, - >, + TxHashMap< + oneshot::Sender< + Result)>, ClientError>, >, >, > = Default::default(); @@ -108,6 +203,9 @@ where let inner_read_error = read_error.clone(); let read_closed = Arc::new(AtomicBool::new(false)); let inner_read_closed = read_closed.clone(); + let queue_cv = Arc::new(Condvar::new()); + let inner_queue_cv = queue_cv.clone(); + //// read loop tokio::spawn(async move { metainfo::METAINFO .scope(RefCell::new(Default::default()), async move { @@ -132,39 +230,43 @@ where e, target ); - let mut tx_map = inner_tx_map.lock().await; inner_read_error.store(true, std::sync::atomic::Ordering::Relaxed); - for (_, tx) in tx_map.drain() { - let _ = tx.send(Err(ClientError::Application( - ApplicationException::new( - ApplicationExceptionKind::UNKNOWN, - format!( - "multiplex connection error: {e}, target: {target}" + inner_queue_cv.notify_all(); + + inner_tx_map + .for_all_drain(|tx| { + let _ = tx.send(Err(ClientError::Application( + ApplicationException::new( + ApplicationExceptionKind::UNKNOWN, + format!( + "multiplex connection error: {e}, target: {target}" + ), ), - ), - ))); - } + ))); + }) + .await; return; } // we have checked the error above, so it's safe to unwrap here let res = res.unwrap(); if res.is_none() { // the connection is closed - let mut tx_map = inner_tx_map.lock().await; - if !tx_map.is_empty() { + if !inner_tx_map.is_empty().await { inner_read_error.store(true, std::sync::atomic::Ordering::Relaxed); - for (_, tx) in tx_map.drain() { - let _ = tx.send(Ok(None)); - } + inner_tx_map + .for_all_drain(|tx| { + let _ = tx.send(Ok(None)); + }) + .await; } inner_read_closed.store(true, std::sync::atomic::Ordering::Relaxed); + inner_queue_cv.notify_all(); return; } // now we get ThriftMessage let res = res.unwrap(); let seq_id = res.meta.seq_id; - let mut tx_map = inner_tx_map.lock().await; - if let Some(tx) = tx_map.remove(&seq_id) { + if let Some(tx) = inner_tx_map.remove(&seq_id).await { metainfo::METAINFO.with(|mi| { let mi = mi.take(); let _ = tx.send(Ok(Some((mi, cx, res)))); @@ -181,23 +283,27 @@ where }) .await; }); - Self { - write_half: Arc::new(Mutex::new(write_half)), - dirty: Arc::new(AtomicBool::new(false)), + let ret = Self { tx_map, write_error, read_error, read_closed, - } + batch_queue: Default::default(), + _phantom1: PhantomData, + queue_cv, + }; + ret.write_loop(write_half); + ret } } -impl ThriftTransport +impl ThriftTransport where E: Encoder, Resp: EntryMessage, + Req: EntryMessage, { - pub async fn send( + pub async fn send( &self, cx: &mut ClientContext, msg: ThriftMessage, @@ -216,38 +322,21 @@ where "multiplex connection closed".to_string(), ))); } - let (tx, rx) = oneshot::channel(); - let mut tx_map = self.tx_map.lock().await; - let seq_id = msg.meta.seq_id; - if !oneway { - tx_map.insert(seq_id, tx); - } - drop(tx_map); - let mut wh = self.write_half.lock().await; - // check connection dirty - if self.dirty.load(std::sync::atomic::Ordering::Relaxed) { - // connection is dirty, we should also set write error to indicate the connection should - // not be reused - self.write_error - .store(true, std::sync::atomic::Ordering::Relaxed); + if self.write_error.load(std::sync::atomic::Ordering::Relaxed) { return Err(ClientError::Application(ApplicationException::new( ApplicationExceptionKind::UNKNOWN, - "multiplex connection is dirty".to_string(), + "multiplex connection error".to_string(), ))); } - self.dirty.store(true, std::sync::atomic::Ordering::Relaxed); - let res = wh.send(cx, msg).await; - self.dirty - .store(false, std::sync::atomic::Ordering::Relaxed); - drop(wh); - if let Err(e) = res { - self.write_error - .store(true, std::sync::atomic::Ordering::Relaxed); - if !oneway { - let mut tx_map = self.tx_map.lock().await; - tx_map.remove(&seq_id); - } - return Err(e); + + let (tx, rx) = oneshot::channel(); + let seq_id = msg.meta.seq_id; + if !oneway { + self.tx_map.insert(seq_id, tx).await; + } + { + self.batch_queue.lock().await.push_back(msg); + self.queue_cv.notify_all(); } if oneway { return Ok(None); @@ -340,6 +429,7 @@ pub struct WriteHalf { id: usize, } +#[allow(dead_code)] impl WriteHalf where E: Encoder, @@ -354,12 +444,32 @@ where tracing::error!("[VOLO] transport[{}] encode error: {:?}", self.id, e); e })?; + Ok(()) + } + pub async fn reset(&mut self) { + self.encoder.reset().await; + } + + pub async fn encode( + &mut self, + cx: &mut impl ThriftContext, + msg: ThriftMessage, + ) -> Result<(), ClientError> { + self.encoder.encode(cx, msg).await.map_err(|mut e| { + e.append_msg(&format!(", rpcinfo: {:?}", cx.rpc_info())); + tracing::error!("[VOLO] transport[{}] encode error: {:?}", self.id, e); + e + })?; + Ok(()) + } + pub async fn flush(&mut self) -> Result<(), ClientError> { + self.encoder.flush().await?; Ok(()) } } -impl Poolable for ThriftTransport { +impl Poolable for ThriftTransport { fn reusable(&self) -> bool { !self.write_error.load(std::sync::atomic::Ordering::Relaxed) && !self.read_error.load(std::sync::atomic::Ordering::Relaxed) diff --git a/volo-thrift/src/transport/multiplex/utils.rs b/volo-thrift/src/transport/multiplex/utils.rs new file mode 100644 index 00000000..e708d7a8 --- /dev/null +++ b/volo-thrift/src/transport/multiplex/utils.rs @@ -0,0 +1,54 @@ +use std::array; + +use tokio::sync::Mutex; + +const SHARD_COUNT: usize = 64; + +pub struct TxHashMap { + sharded: [Mutex>; SHARD_COUNT], +} + +impl Default for TxHashMap { + fn default() -> Self { + TxHashMap { + sharded: array::from_fn(|_| Default::default()), + } + } +} + +impl TxHashMap +where + T: Sized, +{ + pub async fn remove(&self, key: &i32) -> Option { + self.sharded[(*key % (SHARD_COUNT as i32)) as usize] + .lock() + .await + .remove(key) + } + + pub async fn is_empty(&self) -> bool { + for s in self.sharded.iter() { + if !s.lock().await.is_empty() { + return false; + } + } + true + } + + pub async fn insert(&self, key: i32, value: T) -> Option { + self.sharded[(key % (SHARD_COUNT as i32)) as usize] + .lock() + .await + .insert(key, value) + } + + pub async fn for_all_drain(&self, mut f: impl FnMut(T) -> ()) { + for sharded in self.sharded.iter() { + let mut s = sharded.lock().await; + for data in s.drain() { + f(data.1) + } + } + } +} diff --git a/volo-thrift/src/transport/pingpong/server.rs b/volo-thrift/src/transport/pingpong/server.rs index b0530928..4bbb7ea9 100644 --- a/volo-thrift/src/transport/pingpong/server.rs +++ b/volo-thrift/src/transport/pingpong/server.rs @@ -98,7 +98,7 @@ pub async fn serve( }), ); if let Err(e) = async { - let result = encoder.encode(&mut cx, msg).await; + let result = encoder.send(&mut cx, msg).await; span_provider.leave_encode(&cx); result } @@ -140,7 +140,7 @@ pub async fn serve( thrift_exception_to_application_exception(e), ), ); - if let Err(e) = encoder.encode(&mut cx, msg).await { + if let Err(e) = encoder.send(&mut cx, msg).await { error!( "[VOLO] server send error error: {:?}, cx: {:?}, \ peer_addr: {:?}", diff --git a/volo-thrift/src/transport/pingpong/thrift_transport.rs b/volo-thrift/src/transport/pingpong/thrift_transport.rs index 76badb22..7a3ca035 100644 --- a/volo-thrift/src/transport/pingpong/thrift_transport.rs +++ b/volo-thrift/src/transport/pingpong/thrift_transport.rs @@ -131,9 +131,9 @@ where cx: &mut impl ThriftContext, msg: ThriftMessage, ) -> Result<(), ClientError> { - self.encoder.encode(cx, msg).await.map_err(|mut e| { + self.encoder.send(cx, msg).await.map_err(|mut e| { e.append_msg(&format!(", rpcinfo: {:?}", cx.rpc_info())); - tracing::error!("[VOLO] transport[{}] encode error: {:?}", self.id, e); + tracing::error!("[VOLO] transport[{}] send error: {:?}", self.id, e); e })?;