Skip to content

Commit

Permalink
refactor: replace more GAT-based async trait with RPITIT
Browse files Browse the repository at this point in the history
Signed-off-by: Bugen Zhao <[email protected]>
  • Loading branch information
BugenZhao committed Sep 13, 2023
1 parent cbbee64 commit cc81f14
Show file tree
Hide file tree
Showing 9 changed files with 243 additions and 337 deletions.
11 changes: 3 additions & 8 deletions src/batch/src/exchange_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::fmt::Debug;
use std::future::Future;

use risingwave_common::array::DataChunk;
use risingwave_common::error::Result;

use crate::execution::grpc_exchange::GrpcExchangeSource;
use crate::execution::local_exchange::LocalExchangeSource;
Expand All @@ -24,11 +25,7 @@ use crate::task::TaskId;

/// Each `ExchangeSource` maps to one task, it takes the execution result from task chunk by chunk.
pub trait ExchangeSource: Send + Debug {
type TakeDataFuture<'a>: Future<Output = risingwave_common::error::Result<Option<DataChunk>>>
+ 'a
where
Self: 'a;
fn take_data(&mut self) -> Self::TakeDataFuture<'_>;
fn take_data(&mut self) -> impl Future<Output = Result<Option<DataChunk>>> + '_;

/// Get upstream task id.
fn get_task_id(&self) -> TaskId;
Expand All @@ -42,9 +39,7 @@ pub enum ExchangeSourceImpl {
}

impl ExchangeSourceImpl {
pub(crate) async fn take_data(
&mut self,
) -> risingwave_common::error::Result<Option<DataChunk>> {
pub(crate) async fn take_data(&mut self) -> Result<Option<DataChunk>> {
match self {
ExchangeSourceImpl::Grpc(grpc) => grpc.take_data().await,
ExchangeSourceImpl::Local(local) => local.take_data().await,
Expand Down
35 changes: 15 additions & 20 deletions src/batch/src/execution/grpc_exchange.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

use std::fmt::{Debug, Formatter};
use std::future::Future;

use futures::StreamExt;
use risingwave_common::array::DataChunk;
Expand Down Expand Up @@ -73,26 +72,22 @@ impl Debug for GrpcExchangeSource {
}

impl ExchangeSource for GrpcExchangeSource {
type TakeDataFuture<'a> = impl Future<Output = Result<Option<DataChunk>>> + 'a;

fn take_data(&mut self) -> Self::TakeDataFuture<'_> {
async {
let res = match self.stream.next().await {
None => {
return Ok(None);
}
Some(r) => r,
};
let task_data = res?;
let data = DataChunk::from_protobuf(task_data.get_record_batch()?)?.compact();
trace!(
"Receiver taskOutput = {:?}, data = {:?}",
self.task_output_id,
data
);
async fn take_data(&mut self) -> Result<Option<DataChunk>> {
let res = match self.stream.next().await {
None => {
return Ok(None);
}
Some(r) => r,
};
let task_data = res?;
let data = DataChunk::from_protobuf(task_data.get_record_batch()?)?.compact();
trace!(
"Receiver taskOutput = {:?}, data = {:?}",
self.task_output_id,
data
);

Ok(Some(data))
}
Ok(Some(data))
}

fn get_task_id(&self) -> TaskId {
Expand Down
31 changes: 13 additions & 18 deletions src/batch/src/execution/local_exchange.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

use std::fmt::{Debug, Formatter};
use std::future::Future;

use risingwave_common::array::DataChunk;
use risingwave_common::error::Result;
Expand Down Expand Up @@ -52,23 +51,19 @@ impl Debug for LocalExchangeSource {
}

impl ExchangeSource for LocalExchangeSource {
type TakeDataFuture<'a> = impl Future<Output = Result<Option<DataChunk>>> + 'a;

fn take_data(&mut self) -> Self::TakeDataFuture<'_> {
async {
let ret = self.task_output.direct_take_data().await?;
if let Some(data) = ret {
let data = data.compact();
trace!(
"Receiver task: {:?}, source task output: {:?}, data: {:?}",
self.task_id,
self.task_output.id(),
data
);
Ok(Some(data))
} else {
Ok(None)
}
async fn take_data(&mut self) -> Result<Option<DataChunk>> {
let ret = self.task_output.direct_take_data().await?;
if let Some(data) = ret {
let data = data.compact();
trace!(
"Receiver task: {:?}, source task output: {:?}, data: {:?}",
self.task_id,
self.task_output.id(),
data
);
Ok(Some(data))
} else {
Ok(None)
}
}

Expand Down
15 changes: 5 additions & 10 deletions src/batch/src/executor/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

use std::collections::VecDeque;
use std::future::Future;

use assert_matches::assert_matches;
use futures::StreamExt;
Expand Down Expand Up @@ -246,15 +245,11 @@ impl FakeExchangeSource {
}

impl ExchangeSource for FakeExchangeSource {
type TakeDataFuture<'a> = impl Future<Output = Result<Option<DataChunk>>> + 'a;

fn take_data(&mut self) -> Self::TakeDataFuture<'_> {
async {
if let Some(chunk) = self.chunks.pop() {
Ok(chunk)
} else {
Ok(None)
}
async fn take_data(&mut self) -> Result<Option<DataChunk>> {
if let Some(chunk) = self.chunks.pop() {
Ok(chunk)
} else {
Ok(None)
}
}

Expand Down
1 change: 1 addition & 0 deletions src/batch/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#![feature(result_option_inspect)]
#![feature(assert_matches)]
#![feature(lazy_cell)]
#![feature(return_position_impl_trait_in_trait)]

mod error;
pub mod exchange_source;
Expand Down
58 changes: 23 additions & 35 deletions src/connector/src/source/external.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,9 @@ impl MySqlOffset {
}

pub trait ExternalTableReader {
type CdcOffsetFuture<'a>: Future<Output = ConnectorResult<CdcOffset>> + Send + 'a
where
Self: 'a;

fn get_normalized_table_name(&self, table_name: &SchemaTableName) -> String;

fn current_cdc_offset(&self) -> Self::CdcOffsetFuture<'_>;
fn current_cdc_offset(&self) -> impl Future<Output = ConnectorResult<CdcOffset>> + Send + '_;

fn parse_binlog_offset(&self, offset: &str) -> ConnectorResult<CdcOffset>;

Expand Down Expand Up @@ -248,32 +244,28 @@ pub struct ExternalTableConfig {
}

impl ExternalTableReader for MySqlExternalTableReader {
type CdcOffsetFuture<'a> = impl Future<Output = ConnectorResult<CdcOffset>> + 'a;

fn get_normalized_table_name(&self, table_name: &SchemaTableName) -> String {
format!("`{}`", table_name.table_name)
}

fn current_cdc_offset(&self) -> Self::CdcOffsetFuture<'_> {
async move {
let mut conn = self
.pool
.get_conn()
.await
.map_err(|e| ConnectorError::Connection(anyhow!(e)))?;

let sql = "SHOW MASTER STATUS".to_string();
let mut rs = conn.query::<mysql_async::Row, _>(sql).await?;
let row = rs
.iter_mut()
.exactly_one()
.map_err(|e| ConnectorError::Internal(anyhow!("read binlog error: {}", e)))?;

Ok(CdcOffset::MySql(MySqlOffset {
filename: row.take("File").unwrap(),
position: row.take("Position").unwrap(),
}))
}
async fn current_cdc_offset(&self) -> ConnectorResult<CdcOffset> {
let mut conn = self
.pool
.get_conn()
.await
.map_err(|e| ConnectorError::Connection(anyhow!(e)))?;

let sql = "SHOW MASTER STATUS".to_string();
let mut rs = conn.query::<mysql_async::Row, _>(sql).await?;
let row = rs
.iter_mut()
.exactly_one()
.map_err(|e| ConnectorError::Internal(anyhow!("read binlog error: {}", e)))?;

Ok(CdcOffset::MySql(MySqlOffset {
filename: row.take("File").unwrap(),
position: row.take("Position").unwrap(),
}))
}

fn parse_binlog_offset(&self, offset: &str) -> ConnectorResult<CdcOffset> {
Expand Down Expand Up @@ -478,21 +470,17 @@ impl MySqlExternalTableReader {
}

impl ExternalTableReader for ExternalTableReaderImpl {
type CdcOffsetFuture<'a> = impl Future<Output = ConnectorResult<CdcOffset>> + 'a;

fn get_normalized_table_name(&self, table_name: &SchemaTableName) -> String {
match self {
ExternalTableReaderImpl::MySql(mysql) => mysql.get_normalized_table_name(table_name),
ExternalTableReaderImpl::Mock(mock) => mock.get_normalized_table_name(table_name),
}
}

fn current_cdc_offset(&self) -> Self::CdcOffsetFuture<'_> {
async move {
match self {
ExternalTableReaderImpl::MySql(mysql) => mysql.current_cdc_offset().await,
ExternalTableReaderImpl::Mock(mock) => mock.current_cdc_offset().await,
}
async fn current_cdc_offset(&self) -> ConnectorResult<CdcOffset> {
match self {
ExternalTableReaderImpl::MySql(mysql) => mysql.current_cdc_offset().await,
ExternalTableReaderImpl::Mock(mock) => mock.current_cdc_offset().await,
}
}

Expand Down
24 changes: 10 additions & 14 deletions src/connector/src/source/mock_external_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::future::Future;
use std::sync::atomic::AtomicUsize;

use futures::stream::BoxStream;
Expand Down Expand Up @@ -91,24 +90,21 @@ impl MockExternalTableReader {
}

impl ExternalTableReader for MockExternalTableReader {
type CdcOffsetFuture<'a> = impl Future<Output = ConnectorResult<CdcOffset>> + 'a;

fn get_normalized_table_name(&self, _table_name: &SchemaTableName) -> String {
"`mock_table`".to_string()
}

fn current_cdc_offset(&self) -> Self::CdcOffsetFuture<'_> {
async fn current_cdc_offset(&self) -> ConnectorResult<CdcOffset> {
static IDX: AtomicUsize = AtomicUsize::new(0);
async move {
let idx = IDX.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if idx < self.binlog_watermarks.len() {
Ok(CdcOffset::MySql(self.binlog_watermarks[idx].clone()))
} else {
Ok(CdcOffset::MySql(MySqlOffset {
filename: "1.binlog".to_string(),
position: u64::MAX,
}))
}

let idx = IDX.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if idx < self.binlog_watermarks.len() {
Ok(CdcOffset::MySql(self.binlog_watermarks[idx].clone()))
} else {
Ok(CdcOffset::MySql(MySqlOffset {
filename: "1.binlog".to_string(),
position: u64::MAX,
}))
}
}

Expand Down
Loading

0 comments on commit cc81f14

Please sign in to comment.