Skip to content

Commit

Permalink
feat(batch): support batch read for file source (#15358)
Browse files Browse the repository at this point in the history
  • Loading branch information
wcy-fdu authored Mar 6, 2024
1 parent 0dad818 commit 961ad85
Show file tree
Hide file tree
Showing 11 changed files with 300 additions and 23 deletions.
22 changes: 22 additions & 0 deletions ci/workflows/main-cron.yml
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,28 @@ steps:
timeout_in_minutes: 25
retry: *auto-retry

- label: "S3_v2 source batch read on AWS (json parser)"
key: "s3-v2-source-batch-read-check-aws-json-parser"
command: "ci/scripts/s3-source-test.sh -p ci-release -s 'fs_source_batch.py json'"
if: |
!(build.pull_request.labels includes "ci/main-cron/skip-ci") && build.env("CI_STEPS") == null
|| build.pull_request.labels includes "ci/run-s3-source-tests"
|| build.env("CI_STEPS") =~ /(^|,)s3-source-tests?(,|$$)/
depends_on: build
plugins:
- seek-oss/aws-sm#v2.3.1:
env:
S3_SOURCE_TEST_CONF: ci_s3_source_test_aws
- docker-compose#v5.1.0:
run: rw-build-env
config: ci/docker-compose.yml
mount-buildkite-agent: true
environment:
- S3_SOURCE_TEST_CONF
- ./ci/plugins/upload-failure-logs
timeout_in_minutes: 25
retry: *auto-retry

- label: "S3_v2 source check on AWS (csv parser)"
key: "s3-v2-source-check-aws-csv-parser"
command: "ci/scripts/s3-source-test.sh -p ci-release -s 'fs_source_v2.py csv_without_header'"
Expand Down
155 changes: 155 additions & 0 deletions e2e_test/s3/fs_source_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import os
import sys
import csv
import json
import random
import psycopg2

from time import sleep
from io import StringIO
from minio import Minio
from functools import partial

def gen_data(file_num, item_num_per_file):
assert item_num_per_file % 2 == 0, \
f'item_num_per_file should be even to ensure sum(mark) == 0: {item_num_per_file}'
return [
[{
'id': file_id * item_num_per_file + item_id,
'name': f'{file_id}_{item_id}',
'sex': item_id % 2,
'mark': (-1) ** (item_id % 2),
} for item_id in range(item_num_per_file)]
for file_id in range(file_num)
]

def format_json(data):
return [
'\n'.join([json.dumps(item) for item in file])
for file in data
]

def format_csv(data, with_header):
csv_files = []

for file_data in data:
ostream = StringIO()
writer = csv.DictWriter(ostream, fieldnames=file_data[0].keys())
if with_header:
writer.writeheader()
for item_data in file_data:
writer.writerow(item_data)
csv_files.append(ostream.getvalue())
return csv_files

def do_test(config, file_num, item_num_per_file, prefix, fmt):
conn = psycopg2.connect(
host="localhost",
port="4566",
user="root",
database="dev"
)

# Open a cursor to execute SQL statements
cur = conn.cursor()

def _source():
return f's3_test_{fmt}'

def _encode():
if fmt == 'json':
return 'JSON'
else:
return f"CSV (delimiter = ',', without_header = {str('without' in fmt).lower()})"

# Execute a SELECT statement
cur.execute(f'''CREATE SOURCE {_source()}(
id int,
name TEXT,
sex int,
mark int,
) WITH (
connector = 's3_v2',
match_pattern = '{prefix}*.{fmt}',
s3.region_name = '{config['S3_REGION']}',
s3.bucket_name = '{config['S3_BUCKET']}',
s3.credentials.access = '{config['S3_ACCESS_KEY']}',
s3.credentials.secret = '{config['S3_SECRET_KEY']}',
s3.endpoint_url = 'https://{config['S3_ENDPOINT']}'
) FORMAT PLAIN ENCODE {_encode()};''')

total_rows = file_num * item_num_per_file
MAX_RETRIES = 40
for retry_no in range(MAX_RETRIES):
cur.execute(f'select count(*) from {_source()}')
result = cur.fetchone()
if result[0] == total_rows:
break
print(f"[retry {retry_no}] Now got {result[0]} rows in source, {total_rows} expected, wait 30s")
sleep(30)

stmt = f'select count(*), sum(id), sum(sex), sum(mark) from {_source()}'
print(f'Execute {stmt}')
cur.execute(stmt)
result = cur.fetchone()

print('Got:', result)

def _assert_eq(field, got, expect):
assert got == expect, f'{field} assertion failed: got {got}, expect {expect}.'

_assert_eq('count(*)', result[0], total_rows)
_assert_eq('sum(id)', result[1], (total_rows - 1) * total_rows / 2)
_assert_eq('sum(sex)', result[2], total_rows / 2)
_assert_eq('sum(mark)', result[3], 0)

print('Test pass')

cur.execute(f'drop source {_source()}')
cur.close()
conn.close()


if __name__ == "__main__":
FILE_NUM = 4001
ITEM_NUM_PER_FILE = 2
data = gen_data(FILE_NUM, ITEM_NUM_PER_FILE)

fmt = sys.argv[1]
FORMATTER = {
'json': format_json,
'csv_with_header': partial(format_csv, with_header=True),
'csv_without_header': partial(format_csv, with_header=False),
}
assert fmt in FORMATTER, f"Unsupported format: {fmt}"
formatted_files = FORMATTER[fmt](data)

config = json.loads(os.environ["S3_SOURCE_TEST_CONF"])
client = Minio(
config["S3_ENDPOINT"],
access_key=config["S3_ACCESS_KEY"],
secret_key=config["S3_SECRET_KEY"],
secure=True,
)
run_id = str(random.randint(1000, 9999))
_local = lambda idx: f'data_{idx}.{fmt}'
_s3 = lambda idx: f"{run_id}_data_{idx}.{fmt}"

# put s3 files
for idx, file_str in enumerate(formatted_files):
with open(_local(idx), "w") as f:
f.write(file_str)
os.fsync(f.fileno())

client.fput_object(
config["S3_BUCKET"],
_s3(idx),
_local(idx)
)

# do test
do_test(config, FILE_NUM, ITEM_NUM_PER_FILE, run_id, fmt)

# clean up s3 files
for idx, _ in enumerate(formatted_files):
client.remove_object(config["S3_BUCKET"], _s3(idx))
2 changes: 1 addition & 1 deletion proto/batch_plan.proto
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ message SourceNode {
uint32 source_id = 1;
repeated plan_common.ColumnCatalog columns = 2;
map<string, string> with_properties = 3;
bytes split = 4;
repeated bytes split = 4;
catalog.StreamSourceInfo info = 5;
}

Expand Down
18 changes: 12 additions & 6 deletions src/batch/src/executor/source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use std::sync::Arc;

use futures::StreamExt;
use futures_async_stream::try_stream;
use itertools::Itertools;
use risingwave_common::array::{DataChunk, Op, StreamChunk};
use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, Schema, TableId};
use risingwave_common::types::DataType;
Expand All @@ -43,7 +44,7 @@ pub struct SourceExecutor {
column_ids: Vec<ColumnId>,
metrics: Arc<SourceMetrics>,
source_id: TableId,
split: SplitImpl,
split_list: Vec<SplitImpl>,

schema: Schema,
identity: String,
Expand Down Expand Up @@ -89,7 +90,11 @@ impl BoxedExecutorBuilder for SourceExecutor {
.map(|column| ColumnId::from(column.get_column_desc().unwrap().column_id))
.collect();

let split = SplitImpl::restore_from_bytes(&source_node.split)?;
let split_list = source_node
.split
.iter()
.map(|split| SplitImpl::restore_from_bytes(split).unwrap())
.collect_vec();

let fields = source_node
.columns
Expand All @@ -105,8 +110,9 @@ impl BoxedExecutorBuilder for SourceExecutor {

if let ConnectorProperties::Iceberg(iceberg_properties) = config {
let iceberg_properties: IcebergProperties = *iceberg_properties;
if let SplitImpl::Iceberg(split) = split {
let split: IcebergSplit = split;
assert_eq!(split_list.len(), 1);
if let SplitImpl::Iceberg(split) = &split_list[0] {
let split: IcebergSplit = split.clone();
Ok(Box::new(IcebergScanExecutor::new(
iceberg_properties.to_iceberg_config(),
Some(split.snapshot_id),
Expand Down Expand Up @@ -135,7 +141,7 @@ impl BoxedExecutorBuilder for SourceExecutor {
column_ids,
metrics: source.context().source_metrics(),
source_id: TableId::new(source_node.source_id),
split,
split_list,
schema,
identity: source.plan_node().get_identity().clone(),
source_ctrl_opts,
Expand Down Expand Up @@ -173,7 +179,7 @@ impl SourceExecutor {
));
let stream = self
.source
.to_stream(Some(vec![self.split]), self.column_ids, source_ctx)
.to_stream(Some(self.split_list), self.column_ids, source_ctx)
.await?;

#[for_await]
Expand Down
2 changes: 2 additions & 0 deletions src/connector/src/source/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,8 @@ impl ConnectorProperties {

pub fn support_multiple_splits(&self) -> bool {
matches!(self, ConnectorProperties::Kafka(_))
|| matches!(self, ConnectorProperties::OpendalS3(_))
|| matches!(self, ConnectorProperties::Gcs(_))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ impl<Src: OpendalSource> OpendalReader<Src> {
offset += len;
batch_size += len;
batch.push(msg);

if batch.len() >= max_chunk_size {
source_ctx
.metrics
Expand Down
30 changes: 28 additions & 2 deletions src/connector/src/source/reader/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use crate::source::filesystem::opendal_source::opendal_enumerator::OpendalEnumer
use crate::source::filesystem::opendal_source::{
OpendalGcs, OpendalPosixFs, OpendalS3, OpendalSource,
};
use crate::source::filesystem::FsPageItem;
use crate::source::filesystem::{FsPageItem, OpendalFsSplit};
use crate::source::{
create_split_reader, BoxChunkSourceStream, BoxTryStream, Column, ConnectorProperties,
ConnectorState, SourceColumnDesc, SourceContext, SplitReader,
Expand Down Expand Up @@ -149,7 +149,6 @@ impl SourceReader {
vec![reader]
} else {
let to_reader_splits = splits.into_iter().map(|split| vec![split]);

try_join_all(to_reader_splits.into_iter().map(|splits| {
tracing::debug!(?splits, ?prop, "spawning connector split reader");
let props = prop.clone();
Expand Down Expand Up @@ -194,3 +193,30 @@ async fn build_opendal_fs_list_stream<Src: OpendalSource>(lister: OpendalEnumera
}
}
}

#[try_stream(boxed, ok = OpendalFsSplit<Src>, error = crate::error::ConnectorError)]
pub async fn build_opendal_fs_list_for_batch<Src: OpendalSource>(lister: OpendalEnumerator<Src>) {
let matcher = lister.get_matcher();
let mut object_metadata_iter = lister.list().await?;

while let Some(list_res) = object_metadata_iter.next().await {
match list_res {
Ok(res) => {
if matcher
.as_ref()
.map(|m| m.matches(&res.name))
.unwrap_or(true)
{
let split = OpendalFsSplit::new(res.name, 0, res.size as usize);
yield split
} else {
continue;
}
}
Err(err) => {
tracing::error!(error = %err.as_report(), "list object fail");
return Err(err);
}
}
}
}
4 changes: 0 additions & 4 deletions src/frontend/src/optimizer/plan_node/logical_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use std::rc::Rc;

use fixedbitset::FixedBitSet;
use pretty_xmlish::{Pretty, XmlNode};
use risingwave_common::bail_not_implemented;
use risingwave_common::catalog::{
ColumnCatalog, ColumnDesc, Field, Schema, KAFKA_TIMESTAMP_COLUMN_NAME,
};
Expand Down Expand Up @@ -490,9 +489,6 @@ impl PredicatePushdown for LogicalSource {

impl ToBatch for LogicalSource {
fn to_batch(&self) -> Result<PlanRef> {
if self.core.is_new_fs_connector() {
bail_not_implemented!("New fs connector for batch");
}
let mut plan: PlanRef = BatchSource::new(self.core.clone()).into();

if let Some(exprs) = &self.output_exprs {
Expand Down
18 changes: 15 additions & 3 deletions src/frontend/src/scheduler/distributed/stage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,14 +377,22 @@ impl StageRunner {
));
}
} else if let Some(source_info) = self.stage.source_info.as_ref() {
for (id, split) in source_info.split_info().unwrap().iter().enumerate() {
let chunk_size = (source_info.split_info().unwrap().len() as f32
/ self.stage.parallelism.unwrap() as f32)
.ceil() as usize;
for (id, split) in source_info
.split_info()
.unwrap()
.chunks(chunk_size)
.enumerate()
{
let task_id = TaskIdPb {
query_id: self.stage.query_id.id.clone(),
stage_id: self.stage.id,
task_id: id as u32,
};
let plan_fragment = self
.create_plan_fragment(id as u32, Some(PartitionInfo::Source(split.clone())));
.create_plan_fragment(id as u32, Some(PartitionInfo::Source(split.to_vec())));
let worker =
self.choose_worker(&plan_fragment, id as u32, self.stage.dml_table_id)?;
futures.push(self.schedule_task(
Expand Down Expand Up @@ -981,11 +989,15 @@ impl StageRunner {
let NodeBody::Source(mut source_node) = node_body else {
unreachable!();
};

let partition = partition
.expect("no partition info for seq scan")
.into_source()
.expect("PartitionInfo should be SourcePartitionInfo");
source_node.split = partition.encode_to_bytes().into();
source_node.split = partition
.into_iter()
.map(|split| split.encode_to_bytes().into())
.collect_vec();
PlanNodePb {
children: vec![],
identity,
Expand Down
Loading

0 comments on commit 961ad85

Please sign in to comment.