diff --git a/CHANGELOG.md b/CHANGELOG.md index de5fa2d..cb78157 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,12 +5,18 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [0.0.9] - 2023-09-23 +## [0.0.10] - 2023-10-04 ### Added - - `dispatcher_internal_queue_size` parameter to builder, allowing to customize size of internal queue between dispatcher and connection. + - `internal_simultaneous_requests_threshold` parameter to builder, which allow to customize maximum number of simultaneously created requests, which connection can effectively handle. + +### Changed + - Rewritten internal logic of connection to Tarantool, which improved performance separated reading and writing to socket into separate tasks. ### Fixed - - Increased size of internal queue between dispatcher and connection, which should significantly increase performance (previously it was degrading rapidly with a lot of parallel requests). + - Increased size of internal channel between dispatcher and connection, which should significantly increase performance (previously it was degrading rapidly with a lot of concurrent requests). + + +## [0.0.9] - 2023-09-23 (broken, yanked) ## [0.0.8] - 2023-09-05 diff --git a/Cargo.toml b/Cargo.toml index a83e41b..147be1c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "tarantool-rs" description = "Asyncronous tokio-based client for Tarantool" -version = "0.0.9" +version = "0.0.10" edition = "2021" authors = ["Andrey Kononov flowneee3@gmail.com"] license = "MIT" @@ -26,6 +26,7 @@ serde = { version = "1", features = ["derive"] } sha-1 = "0.10" thiserror = "1" tokio = { version = "1", features = ["rt", "net", "io-util", "macros", "time"] } +tokio-stream = "0.1" tokio-util = { version = "0.7", default-features = false, features = ["codec"] } tracing = { version = "0.1", features = ["log"] } @@ -40,6 +41,7 @@ serde_json = "1" tokio = { version = "1", features = ["full"] } tracing-test = { version = "0.2", features = ["no-env-filter"] } tarantool-test-container = { path = "tarantool-test-container" } +rusty_tarantool = "*" [[example]] name = "cli_client" @@ -60,3 +62,7 @@ harness = false [[bench]] name = "compare" harness = false + +[[bench]] +name = "simple_loop" +harness = false diff --git a/benches/bench.rs b/benches/bench.rs index 931b62a..3d870e4 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -26,7 +26,7 @@ pub fn bench_tarantool_rs(c: &mut Criterion) { // Bench logic // NOTE: on my PC converting to join add slight overhead (1-2 microseconds for 1 future input) // NOTE: on my PC 50 input load tarantool to 50% on single core - for parallel in [1, 2, 5, 10, 50].into_iter() { + for parallel in [1, 50, 250, 1000].into_iter() { group.bench_with_input(BenchmarkId::new("ping", parallel), ¶llel, |b, p| { b.to_async(&tokio_rt).iter(|| async { let make_fut = |_| conn.ping(); diff --git a/benches/simple_loop.rs b/benches/simple_loop.rs new file mode 100644 index 0000000..9611412 --- /dev/null +++ b/benches/simple_loop.rs @@ -0,0 +1,48 @@ +use std::time::{Duration, Instant}; + +use futures::{stream::repeat_with, StreamExt}; +use tarantool_rs::{Connection, ExecutorExt}; + +type TarantoolTestContainer = tarantool_test_container::TarantoolTestContainer< + tarantool_test_container::TarantoolDefaultArgs, +>; + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + let container = TarantoolTestContainer::default(); + + let conn = Connection::builder() + .internal_simultaneous_requests_threshold(1000) + .build(format!("127.0.0.1:{}", container.connect_port())) + .await?; + // let conn = rusty_tarantool::tarantool::ClientConfig::new( + // format!("127.0.0.1:{}", container.connect_port()), + // "guest", + // "", + // ) + // .build(); + // conn.ping().await?; + + let mut counter = 0u64; + let mut last_measured_counter = 0; + let mut last_measured_ts = Instant::now(); + + let interval_secs = 2; + let interval = Duration::from_secs(interval_secs); + + let mut stream = repeat_with(|| conn.ping()).buffer_unordered(1000); + while let _ = stream.next().await { + counter += 1; + if last_measured_ts.elapsed() > interval { + last_measured_ts = Instant::now(); + let counter_diff = counter - last_measured_counter; + last_measured_counter = counter; + println!( + "Iterations over last {interval_secs} seconds: {counter_diff}, per second: {}", + counter_diff / interval_secs + ); + } + } + + Ok(()) +} diff --git a/src/builder.rs b/src/builder.rs index aa75ee5..41b75fc 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -67,7 +67,7 @@ pub struct ConnectionBuilder { connect_timeout: Option, reconnect_interval: Option, sql_statement_cache_capacity: usize, - dispatcher_internal_queue_size: usize, + internal_simultaneous_requests_threshold: usize, } impl Default for ConnectionBuilder { @@ -81,7 +81,7 @@ impl Default for ConnectionBuilder { connect_timeout: None, reconnect_interval: Some(ReconnectInterval::default()), sql_statement_cache_capacity: DEFAULT_SQL_STATEMENT_CACHE_CAPACITY, - dispatcher_internal_queue_size: DEFAULT_DISPATCHER_INTERNAL_QUEUE_SIZE, + internal_simultaneous_requests_threshold: DEFAULT_DISPATCHER_INTERNAL_QUEUE_SIZE, } } } @@ -92,17 +92,18 @@ impl ConnectionBuilder { where A: ToSocketAddrs + Display + Clone + Send + Sync + 'static, { - let (dispatcher, disaptcher_sender) = Dispatcher::new( + let (dispatcher_fut, disaptcher_sender) = Dispatcher::prepare( addr, self.user.as_deref(), self.password.as_deref(), self.timeout, self.reconnect_interval.clone(), + self.internal_simultaneous_requests_threshold, ) .await?; // TODO: support setting custom executor - tokio::spawn(dispatcher.run()); + tokio::spawn(dispatcher_fut); let conn = Connection::new( disaptcher_sender, self.timeout, @@ -194,16 +195,20 @@ impl ConnectionBuilder { self } - /// Sets size of the internal queue between connection and dispatcher. + /// Prepare `Connection` to process `value` number of simultaneously created requests. /// - /// This queue contains all requests, made from [`Connection`]s/[`Stream`]s/etc. - /// Increasing its size can help if you have a lot of requests, made concurrently - /// and frequently, however this will increase memory consumption slightly. + /// It is not hard limit, but making more simultaneous requests than this value + /// will result in degradation in performance, so try to increase this value, + /// if you unsatisfied with performance. + /// + /// Internally connection have multiple bounded channels, and this parameter mostly + /// affect size of this channels. Increasing this value can help if you have a lot of simultaneously + /// created requests, however this will increase memory consumption. /// /// By default set to 500, which should be reasonable compromise between memory - /// (about 50 KB) and performance. - pub fn dispatcher_internal_queue_size(&mut self, size: usize) -> &mut Self { - self.dispatcher_internal_queue_size = size; + /// (about 100 KB) and performance. + pub fn internal_simultaneous_requests_threshold(&mut self, value: usize) -> &mut Self { + self.internal_simultaneous_requests_threshold = value; self } } diff --git a/src/errors.rs b/src/errors.rs index 5d8d2fa..93096d2 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -7,7 +7,7 @@ use rmp::{ encode::{RmpWriteErr, ValueWriteError}, }; use rmpv::Value; -use tokio::time::error::Elapsed; +use tokio::{task::JoinError, time::error::Elapsed}; /// Error returned by Tarantool in response to a request. #[derive(Clone, Debug, thiserror::Error)] @@ -66,8 +66,8 @@ pub enum Error { SpaceMissingPrimaryIndex, /// Underlying TCP connection closed. - #[error("TCP connection error")] - ConnectionError(#[from] Arc), + #[error("TCP connection IO error")] + Io(#[from] Arc), /// Underlying TCP connection was closed. #[error("TCP connection closed")] ConnectionClosed, @@ -79,7 +79,7 @@ pub enum Error { impl From for Error { fn from(v: tokio::io::Error) -> Self { - Self::ConnectionError(Arc::new(v)) + Self::Io(Arc::new(v)) } } @@ -87,7 +87,6 @@ impl From for Error { fn from(value: CodecDecodeError) -> Self { match value { CodecDecodeError::Io(x) => x.into(), - CodecDecodeError::Closed => Self::ConnectionClosed, CodecDecodeError::Decode(x) => x.into(), } } @@ -102,6 +101,17 @@ impl From for Error { } } +impl From for Error { + fn from(value: ConnectionError) -> Self { + match value { + ConnectionError::Io(x) => x.into(), + ConnectionError::ConnectionClosed => Self::ConnectionClosed, + ConnectionError::Decode(x) => x.into(), + err @ ConnectionError::JoinError(_) => Self::Other(err.into()), + } + } +} + impl From for Error { fn from(_: Elapsed) -> Self { Self::Timeout @@ -296,28 +306,53 @@ impl DecodingErrorLocation { } /// Helper type to return errors from decoder. -#[derive(Clone)] +#[derive(Debug, thiserror::Error)] pub(crate) enum CodecDecodeError { + #[error(transparent)] + Io(#[from] tokio::io::Error), + #[error(transparent)] + Decode(#[from] DecodingError), +} + +/// Helper type to return errors from encoder. +#[derive(Debug, thiserror::Error)] +pub(crate) enum CodecEncodeError { + #[error(transparent)] + Io(#[from] tokio::io::Error), + #[error(transparent)] + Encode(#[from] EncodingError), +} + +/// Error type, returned to client from connection. +#[derive(Clone, Debug, thiserror::Error)] +pub(crate) enum ConnectionError { + #[error(transparent)] Io(Arc), - Closed, - Decode(DecodingError), + #[error("Connection closed")] + ConnectionClosed, + #[error(transparent)] + Decode(#[from] DecodingError), + #[error("Tokio JoinHandle error: {0:?}")] + JoinError(#[source] Arc), } -impl From for CodecDecodeError { - fn from(v: tokio::io::Error) -> Self { - Self::Io(Arc::new(v)) +impl From for ConnectionError { + fn from(value: tokio::io::Error) -> Self { + Self::Io(Arc::new(value)) } } -/// Helper type to return errors from encoder. -#[derive(Debug)] -pub(crate) enum CodecEncodeError { - Io(tokio::io::Error), - Encode(EncodingError), +impl From for ConnectionError { + fn from(value: CodecDecodeError) -> Self { + match value { + CodecDecodeError::Io(x) => x.into(), + CodecDecodeError::Decode(x) => x.into(), + } + } } -impl From for CodecEncodeError { - fn from(v: tokio::io::Error) -> Self { - Self::Io(v) +impl From for ConnectionError { + fn from(value: JoinError) -> Self { + Self::JoinError(Arc::new(value)) } } diff --git a/src/transport/connection.rs b/src/transport/connection.rs index c6d6785..6d91898 100644 --- a/src/transport/connection.rs +++ b/src/transport/connection.rs @@ -1,41 +1,161 @@ -use std::{ - collections::HashMap, - fmt::Display, - sync::atomic::{AtomicU32, Ordering}, - time::Duration, +use std::{collections::HashMap, fmt::Display, time::Duration}; + +use futures::{ + future::{Fuse, FusedFuture}, + FutureExt, SinkExt, StreamExt, TryStreamExt, }; -use futures::{SinkExt, TryStreamExt}; use tokio::{ - io::AsyncReadExt, - net::{TcpStream, ToSocketAddrs}, - sync::oneshot, + io::{AsyncReadExt, AsyncWriteExt}, + net::{ + tcp::{OwnedReadHalf, OwnedWriteHalf}, + TcpStream, ToSocketAddrs, + }, + pin, + sync::{ + mpsc::{self}, + oneshot, + }, + task::JoinHandle, }; -use tokio_util::codec::Framed; +use tokio_stream::wrappers::ReceiverStream; +use tokio_util::codec::{FramedRead, FramedWrite}; use tracing::{debug, trace, warn}; -use super::dispatcher::DispatcherResponse; +use super::dispatcher::{DispatcherRequest, DispatcherResponse}; use crate::{ codec::{ request::{Auth, EncodedRequest}, response::{Response, ResponseBody}, ClientCodec, Greeting, }, - errors::{CodecDecodeError, CodecEncodeError, Error}, + errors::{CodecEncodeError, ConnectionError, Error}, }; -pub(crate) struct Connection { - stream: Framed, +struct ConnectionData { in_flights: HashMap>, - next_sync: AtomicU32, + next_sync: u32, +} + +impl Default for ConnectionData { + fn default() -> Self { + Self { + in_flights: HashMap::with_capacity(5), + next_sync: 0, + } + } +} + +impl ConnectionData { + #[inline] + fn next_sync(&mut self) -> u32 { + let next = self.next_sync; + self.next_sync += 1; + next + } + + /// Prepare request for sending to server. + /// + /// Set `sync` value and attempt to store this message in in-flight storage. + /// + /// `Err` means that message was not prepared and should not be sent. + /// This function also take care of reporting error through `tx`. + #[inline] + fn try_prepare_request( + &mut self, + request: &mut EncodedRequest, + tx: oneshot::Sender, + ) -> Result<(), ()> { + let sync = self.next_sync(); + *request.sync_mut() = sync; + trace!( + "Sending request with sync {}, stream_id {:?}", + request.sync, + request.stream_id + ); + // TODO: replace with try_insert when stabilized + // If sync already assigned to another request, return an error + // for current request + if let Some(old) = self.in_flights.insert(request.sync, tx) { + let new = self + .in_flights + .insert(request.sync, old) + .expect("Shouldn't panic, value was just inserted"); + if new.send(Err(Error::DuplicatedSync(request.sync))).is_err() { + warn!( + "Failed to pass error to sync {}, receiver dropped", + request.sync + ); + } + return Err(()); + } + Ok(()) + } + + /// Send result of processing request (by sync) to client. + #[inline] + fn respond_to_client(&mut self, sync: u32, result: Result) { + if let Some(tx) = self.in_flights.remove(&sync) { + if tx.send(result).is_err() { + warn!("Failed to pass response sync {}, receiver dropped", sync); + } + } else { + warn!("Unknown sync {}", sync); + } + } + + /// Send error to all in-flight requests and drop them. + #[inline] + fn send_error_to_all_in_flights(&mut self, err: ConnectionError) { + for (_, tx) in self.in_flights.drain() { + let _ = tx.send(Err(err.clone().into())); + } + } } // TODO: cancel +async fn writer_task( + mut rx: mpsc::Receiver, + mut stream: FramedWrite, +) -> Result<(), (u32, CodecEncodeError, Vec)> { + let mut result = Ok(()); + while let Some(x) = rx.recv().await { + let sync = x.sync; + if let Err(err) = stream.send(x).await { + // Close internal queue and extract all remaining requests + rx.close(); + let mut remaining_requests = Vec::new(); + while let Ok(next) = rx.try_recv() { + remaining_requests.push(next); + } + + result = Err((sync, err, remaining_requests)); + break; + } + } + + if let Err(err) = stream.into_inner().shutdown().await { + warn!("Failed to shutdown TCP stream cleanly: {err}"); + } + + result +} + +type WriterTaskJoinHandle = JoinHandle)>>; + +pub(crate) struct Connection { + read_stream: FramedRead, + writer_tx: mpsc::Sender, + writer_task_handle: WriterTaskJoinHandle, + data: ConnectionData, +} + impl Connection { async fn new_inner( addr: A, user: Option<&str>, password: Option<&str>, + internal_simultaneous_requests_threshold: usize, ) -> Result where A: ToSocketAddrs + Display, @@ -50,16 +170,37 @@ impl Connection { debug!("Server: {}", greeting.server); trace!("Salt: {:?}", greeting.salt); - let mut this = Self { - stream: Framed::new(tcp, ClientCodec::default()), - in_flights: HashMap::with_capacity(5), - next_sync: AtomicU32::new(0), - }; + let (read_tcp_stream, write_tcp_stream) = tcp.into_split(); + let mut read_stream = FramedRead::new(read_tcp_stream, ClientCodec::default()); + let mut write_stream = FramedWrite::new(write_tcp_stream, ClientCodec::default()); + + let mut conn_data = ConnectionData::default(); if let Some(user) = user { - this.auth(user, password, &greeting.salt).await?; + Self::auth( + &mut read_stream, + &mut write_stream, + conn_data.next_sync(), + user, + password, + &greeting.salt, + ) + .await?; } + // TODO: review size of this queue + // Make this queue slightly larger than queue between Client and Dispatcher + let (writer_tx, writer_rx) = + mpsc::channel(internal_simultaneous_requests_threshold / 100 * 105); + let writer_task_handle = tokio::spawn(writer_task(writer_rx, write_stream)); + + let this = Self { + read_stream, + writer_tx, + writer_task_handle, + data: conn_data, + }; + Ok(this) } @@ -68,118 +209,154 @@ impl Connection { user: Option<&str>, password: Option<&str>, timeout: Option, + internal_simultaneous_requests_threshold: usize, ) -> Result where A: ToSocketAddrs + Display, { match timeout { - Some(dur) => tokio::time::timeout(dur, Self::new_inner(addr, user, password)) + Some(dur) => tokio::time::timeout( + dur, + Self::new_inner( + addr, + user, + password, + internal_simultaneous_requests_threshold, + ), + ) + .await + .map_err(|_| Error::ConnectTimeout) + .and_then(|x| x), + None => { + Self::new_inner( + addr, + user, + password, + internal_simultaneous_requests_threshold, + ) .await - .map_err(|_| Error::ConnectTimeout) - .and_then(|x| x), - None => Self::new_inner(addr, user, password).await, + } } } - async fn auth(&mut self, user: &str, password: Option<&str>, salt: &[u8]) -> Result<(), Error> { + async fn auth( + read_stream: &mut FramedRead, + write_stream: &mut FramedWrite, + sync: u32, + user: &str, + password: Option<&str>, + salt: &[u8], + ) -> Result<(), Error> { let mut request = EncodedRequest::new(Auth::new(user, password, salt), None).unwrap(); - *request.sync_mut() = self.next_sync(); + *request.sync_mut() = sync; trace!("Sending auth request"); - self.stream.send(request).await?; + write_stream.send(request).await?; - let resp = self.get_next_stream_value().await?; + let resp = Self::get_next_stream_value(read_stream).await?; match resp.body { ResponseBody::Ok(_x) => Ok(()), ResponseBody::Error(err) => Err(Error::Auth(err)), } } - // TODO: maybe other Ordering?? - fn next_sync(&self) -> u32 { - self.next_sync.fetch_add(1, Ordering::SeqCst) - } - - pub(super) async fn send_request( - &mut self, - mut request: EncodedRequest, - tx: oneshot::Sender, - ) -> Result<(), tokio::io::Error> { - let sync = self.next_sync(); - *request.sync_mut() = sync; - trace!( - "Sending request with sync {}, stream_id {:?}", - request.sync, - request.stream_id - ); - // TODO: replace with try_insert when stabilized - // If sync already assigned to another request, return an error - // for current request - if let Some(old) = self.in_flights.insert(request.sync, tx) { - let new = self - .in_flights - .insert(request.sync, old) - .expect("Shouldn't panic, value was just inserted"); - if new.send(Err(Error::DuplicatedSync(request.sync))).is_err() { - warn!( - "Failed to pass error to sync {}, receiver dropped", - request.sync - ); - } - return Ok(()); - } - match self.stream.send(request).await { - Ok(x) => Ok(x), - Err(CodecEncodeError::Encode(err)) => { - if self - .in_flights - .remove(&sync) - .expect("Shouldn't panic, value was just inserted") - .send(Err(err.into())) - .is_err() - { - warn!("Failed to pass error to sync {}, receiver dropped", sync); - } - Ok(()) - } - Err(CodecEncodeError::Io(err)) => Err(err), - } - } - - fn pass_response(&mut self, response: Response) { - let sync = response.sync; - if let Some(tx) = self.in_flights.remove(&sync) { - if tx.send(Ok(response)).is_err() { - warn!("Failed to pass response sync {}, receiver dropped", sync); - } - } else { - warn!("Unknown sync {}", sync); - } - } - - async fn get_next_stream_value(&mut self) -> Result { - match self.stream.try_next().await { + #[inline] + async fn get_next_stream_value( + read_stream: &mut FramedRead, + ) -> Result { + match read_stream.try_next().await { Ok(Some(x)) => Ok(x), - Ok(None) => Err(CodecDecodeError::Closed), - Err(e) => Err(e), + Ok(None) => Err(ConnectionError::ConnectionClosed), + Err(e) => Err(e.into()), } } - pub(super) async fn handle_next_response(&mut self) -> Result<(), CodecDecodeError> { - let resp = self.get_next_stream_value().await?; + #[inline] + fn handle_response(connection_data: &mut ConnectionData, response: Response) { trace!( "Received response for sync {}, schema version {}", - resp.sync, - resp.schema_version + response.sync, + response.schema_version ); - self.pass_response(resp); - Ok(()) + connection_data.respond_to_client(response.sync, Ok(response)); } - /// Send error to all in flight requests and drop current transport. - pub(super) fn finish_with_error(&mut self, err: CodecDecodeError) { - for (_, tx) in self.in_flights.drain() { - let _ = tx.send(Err(err.clone().into())); - } + /// Run connection until it breaks of `rx` is closed. + /// + /// `Ok` means `rx` was closed and connection should not be restarted. + /// `Err` means connection was dropped due to some error. + pub(crate) async fn run( + self, + client_rx: &mut ReceiverStream, + ) -> Result<(), ()> { + let Self { + mut read_stream, + writer_tx, + mut writer_task_handle, + mut data, + } = self; + + let send_to_writer_future = Fuse::terminated(); + pin!(send_to_writer_future); + + let err = loop { + tokio::select! { + // Read value from TCP stream + next = Connection::get_next_stream_value(&mut read_stream) => { + match next { + Ok(x) => Connection::handle_response(&mut data, x), + Err(err) => break err, + } + } + + // Read value from internal queue if nothing being sent to writer + next = client_rx.next(), if send_to_writer_future.is_terminated() => { + if let Some((mut request, tx)) = next { + // If failed to prepare request or client already + // dropped oneshot - just go to next + if tx.is_closed() || data + .try_prepare_request(&mut request, tx) + .is_err() + { + continue; + } + + send_to_writer_future.set(writer_tx.send(request).fuse()); + } else { + // TODO: actually don't quit until all in-flights processed + debug!("All senders dropped"); + return Ok(()); + } + } + + // Await sending request to writer. + // NOTE: For some reason checking Fuse for termination makes code _slightly_ faster + _send_res = &mut send_to_writer_future, if !send_to_writer_future.is_terminated() => { + // TODO: somehow return EncodedRequest from Err variant, so it can be retried + + // Do nothing, since on success there is nothing to do, + // and on error we can only response to client with ConnectionClosed, + // which will happen anyway in next branch on next (or so) iteration. + } + + // Wait for writer task to finish + writer_result = &mut writer_task_handle => { + match writer_result { + Err(err) => { + break err.into(); + }, + // TODO: actually do something with remaining requests + Ok(Err((sync, err, _remaining_requests))) => { + data.respond_to_client(sync, Err(err.into())); + }, + _ => {} + } + break ConnectionError::ConnectionClosed + } + } + }; + + data.send_error_to_all_in_flights(err); + Err(()) } } diff --git a/src/transport/dispatcher.rs b/src/transport/dispatcher.rs index be0b1c2..9d231d4 100644 --- a/src/transport/dispatcher.rs +++ b/src/transport/dispatcher.rs @@ -6,6 +6,7 @@ use tokio::{ net::ToSocketAddrs, sync::{mpsc, oneshot}, }; +use tokio_stream::wrappers::ReceiverStream; use tracing::{debug, error}; use super::connection::Connection; @@ -43,20 +44,21 @@ type ConnectDynFuture = dyn Future> + Send; /// /// Currently no-op, in future it should handle reconnects, schema reloading, pooling. pub(crate) struct Dispatcher { - rx: mpsc::Receiver, - conn: Connection, + rx: ReceiverStream, + conn: Option, conn_factory: Box Pin> + Send + Sync>, reconnect_interval: Option, } impl Dispatcher { - pub(crate) async fn new( + pub(crate) async fn prepare( addr: A, user: Option<&str>, password: Option<&str>, connect_timeout: Option, reconnect_interval: Option, - ) -> Result<(Self, DispatcherSender), Error> + internal_simultaneous_requests_threshold: usize, + ) -> Result<(impl Future, DispatcherSender), Error> where A: ToSocketAddrs + Display + Clone + Send + Sync + 'static, { @@ -68,22 +70,29 @@ impl Dispatcher { let password = password.clone(); let connect_timeout = connect_timeout; Box::pin(async move { - Connection::new(addr, user.as_deref(), password.as_deref(), connect_timeout).await + Connection::new( + addr, + user.as_deref(), + password.as_deref(), + connect_timeout, + internal_simultaneous_requests_threshold, + ) + .await }) as Pin> }); let conn = conn_factory().await?; - // TODO: test whether increased size can help with performance - let (tx, rx) = mpsc::channel(1); + let (tx, rx) = mpsc::channel(internal_simultaneous_requests_threshold); Ok(( Self { - rx, - conn, + rx: ReceiverStream::new(rx), + conn: Some(conn), conn_factory, reconnect_interval, - }, + } + .run(), DispatcherSender { tx }, )) } @@ -96,7 +105,7 @@ impl Dispatcher { loop { match (self.conn_factory)().await { Ok(conn) => { - self.conn = conn; + self.conn = Some(conn); return; } Err(err) => { @@ -112,39 +121,14 @@ impl Dispatcher { pub(crate) async fn run(mut self) { debug!("Starting dispatcher"); loop { - if self.run_conn().await { - return; - } - self.reconnect().await; - } - } - - pub(crate) async fn run_conn(&mut self) -> bool { - let err = loop { - tokio::select! { - next = self.conn.handle_next_response() => { - if let Err(e) = next { - break e; - } - } - next = self.rx.recv() => { - if let Some((request, tx)) = next { - // Check whether tx is closed in case someone cancelled request - // while it was in queue - if !tx.is_closed() { - if let Err(err) = self.conn.send_request(request, tx).await { - break err.into(); - } - } - } else { - debug!("All senders dropped"); - return true - } + if let Some(conn) = self.conn.take() { + if conn.run(&mut self.rx).await.is_ok() { + return; } + } else { + self.reconnect().await; } - }; - self.conn.finish_with_error(err); - false + } } }