Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(batch): support file scan a directory of parquet files #17811

Merged
merged 1 commit into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion proto/batch_plan.proto
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
string s3_region = 4;
string s3_access_key = 5;
string s3_secret_key = 6;
string file_location = 7;
repeated string file_location = 7;

Check failure on line 86 in proto/batch_plan.proto

View workflow job for this annotation

GitHub Actions / Check breaking changes in Protobuf files

Field "7" with name "file_location" on message "FileScanNode" changed cardinality from "optional with implicit presence" to "repeated".
}

message ProjectNode {
Expand Down
65 changes: 33 additions & 32 deletions src/batch/src/executor/s3_file_scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub enum FileFormat {
/// S3 file scan executor. Currently only support parquet file format.
pub struct S3FileScanExecutor {
file_format: FileFormat,
location: String,
file_location: Vec<String>,
s3_region: String,
s3_access_key: String,
s3_secret_key: String,
Expand All @@ -61,7 +61,7 @@ impl Executor for S3FileScanExecutor {
impl S3FileScanExecutor {
pub fn new(
file_format: FileFormat,
location: String,
file_location: Vec<String>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont see an example of how to set multiple file_locations

s3_region: String,
s3_access_key: String,
s3_secret_key: String,
Expand All @@ -71,7 +71,7 @@ impl S3FileScanExecutor {
) -> Self {
Self {
file_format,
location,
file_location,
s3_region,
s3_access_key,
s3_secret_key,
Expand All @@ -84,35 +84,36 @@ impl S3FileScanExecutor {
#[try_stream(ok = DataChunk, error = BatchError)]
async fn do_execute(self: Box<Self>) {
assert_eq!(self.file_format, FileFormat::Parquet);

let mut batch_stream_builder = create_parquet_stream_builder(
self.s3_region.clone(),
self.s3_access_key.clone(),
self.s3_secret_key.clone(),
self.location.clone(),
)
.await?;

let arrow_schema = batch_stream_builder.schema();
assert_eq!(arrow_schema.fields.len(), self.schema.fields.len());
for (field, arrow_field) in self.schema.fields.iter().zip(arrow_schema.fields.iter()) {
assert_eq!(*field.name, *arrow_field.name());
}

batch_stream_builder = batch_stream_builder.with_projection(ProjectionMask::all());

batch_stream_builder = batch_stream_builder.with_batch_size(self.batch_size);

let record_batch_stream = batch_stream_builder
.build()
.map_err(|e| anyhow!(e).context("fail to build arrow stream builder"))?;

#[for_await]
for record_batch in record_batch_stream {
let record_batch = record_batch.map_err(BatchError::Parquet)?;
let chunk = IcebergArrowConvert.chunk_from_record_batch(&record_batch)?;
debug_assert_eq!(chunk.data_types(), self.schema.data_types());
yield chunk;
for file in self.file_location {
let mut batch_stream_builder = create_parquet_stream_builder(
self.s3_region.clone(),
self.s3_access_key.clone(),
self.s3_secret_key.clone(),
file,
)
.await?;

let arrow_schema = batch_stream_builder.schema();
assert_eq!(arrow_schema.fields.len(), self.schema.fields.len());
for (field, arrow_field) in self.schema.fields.iter().zip(arrow_schema.fields.iter()) {
assert_eq!(*field.name, *arrow_field.name());
}

batch_stream_builder = batch_stream_builder.with_projection(ProjectionMask::all());

batch_stream_builder = batch_stream_builder.with_batch_size(self.batch_size);

let record_batch_stream = batch_stream_builder
.build()
.map_err(|e| anyhow!(e).context("fail to build arrow stream builder"))?;

#[for_await]
for record_batch in record_batch_stream {
let record_batch = record_batch.map_err(BatchError::Parquet)?;
let chunk = IcebergArrowConvert.chunk_from_record_batch(&record_batch)?;
debug_assert_eq!(chunk.data_types(), self.schema.data_types());
yield chunk;
}
}
}
}
Expand Down
47 changes: 47 additions & 0 deletions src/connector/src/source/iceberg/parquet_file_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,14 @@ use futures::TryFutureExt;
use iceberg::io::{
FileIOBuilder, FileMetadata, FileRead, S3_ACCESS_KEY_ID, S3_REGION, S3_SECRET_ACCESS_KEY,
};
use iceberg::{Error, ErrorKind};
use opendal::layers::RetryLayer;
use opendal::services::S3;
use opendal::Operator;
use parquet::arrow::async_reader::{AsyncFileReader, MetadataLoader};
use parquet::arrow::ParquetRecordBatchStreamBuilder;
use parquet::file::metadata::ParquetMetaData;
use url::Url;

pub struct ParquetFileReader<R: FileRead> {
meta: FileMetadata,
Expand Down Expand Up @@ -83,3 +88,45 @@ pub async fn create_parquet_stream_builder(
.await
.map_err(|e| anyhow!(e))
}

pub async fn list_s3_directory(
s3_region: String,
s3_access_key: String,
s3_secret_key: String,
dir: String,
) -> Result<Vec<String>, anyhow::Error> {
let url = Url::parse(&dir)?;
let bucket = url.host_str().ok_or_else(|| {
Error::new(
ErrorKind::DataInvalid,
format!("Invalid s3 url: {}, missing bucket", dir),
)
})?;

let prefix = format!("s3://{}/", bucket);
if dir.starts_with(&prefix) {
let mut builder = S3::default();
builder
.region(&s3_region)
.access_key_id(&s3_access_key)
.secret_access_key(&s3_secret_key)
.bucket(bucket);
let op = Operator::new(builder)?
.layer(RetryLayer::default())
.finish();

op.list(&dir[prefix.len()..])
.await
.map_err(|e| anyhow!(e))
.map(|list| {
list.into_iter()
.map(|entry| prefix.to_string() + entry.path())
.collect()
})
} else {
Err(Error::new(
ErrorKind::DataInvalid,
format!("Invalid s3 url: {}, should start with {}", dir, prefix),
))?
Comment on lines +127 to +130
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why it's not an assertion? Because the code says the bucket was extracted from dir

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we need to ensure the URL prefix is s3:// instead of something else like http://. The bucket is the URL's host and can't guarantee the protocol.

}
}
55 changes: 49 additions & 6 deletions src/frontend/src/expr/table_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ use std::sync::{Arc, LazyLock};
use itertools::Itertools;
use risingwave_common::array::arrow::IcebergArrowConvert;
use risingwave_common::types::{DataType, ScalarImpl, StructType};
use risingwave_connector::source::iceberg::create_parquet_stream_builder;
use risingwave_connector::source::iceberg::{create_parquet_stream_builder, list_s3_directory};
pub use risingwave_pb::expr::table_function::PbType as TableFunctionType;
use risingwave_pb::expr::PbTableFunction;
use tokio::runtime::Runtime;

use super::{infer_type, Expr, ExprImpl, ExprRewriter, RwResult};
use super::{infer_type, Expr, ExprImpl, ExprRewriter, Literal, RwResult};
use crate::catalog::function_catalog::{FunctionCatalog, FunctionKind};
use crate::error::ErrorCode::BindError;

Expand Down Expand Up @@ -68,7 +68,7 @@ impl TableFunction {

/// A special table function which would be transformed into `LogicalFileScan` by `TableFunctionToFileScanRule` in the optimizer.
/// select * from `file_scan`('parquet', 's3', region, ak, sk, location)
pub fn new_file_scan(args: Vec<ExprImpl>) -> RwResult<Self> {
pub fn new_file_scan(mut args: Vec<ExprImpl>) -> RwResult<Self> {
let return_type = {
// arguments:
// file format e.g. parquet
Expand Down Expand Up @@ -149,13 +149,43 @@ impl TableFunction {
.expect("failed to build file-scan runtime")
});

tokio::task::block_in_place(|| {
let files = if eval_args[5].ends_with('/') {
let files = tokio::task::block_in_place(|| {
RUNTIME.block_on(async {
let files = list_s3_directory(
eval_args[2].clone(),
eval_args[3].clone(),
eval_args[4].clone(),
eval_args[5].clone(),
)
.await?;

Ok::<Vec<String>, anyhow::Error>(files)
})
})?;

if files.is_empty() {
return Err(BindError(
"file_scan function only accepts non-empty directory".to_string(),
)
.into());
}

Some(files)
} else {
None
};

let schema = tokio::task::block_in_place(|| {
RUNTIME.block_on(async {
let parquet_stream_builder = create_parquet_stream_builder(
eval_args[2].clone(),
eval_args[3].clone(),
eval_args[4].clone(),
eval_args[5].clone(),
match files.as_ref() {
Some(files) => files[0].clone(),
None => eval_args[5].clone(),
},
)
.await?;

Expand All @@ -171,7 +201,20 @@ impl TableFunction {
StructType::new(rw_types),
))
})
})?
})?;

if let Some(files) = files {
// if the file location is a directory, we need to remove the last argument and add all files in the directory as arguments
Copy link
Member

@fuyufjh fuyufjh Jul 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit dirty to list files in binder, but acceptable to me

args.remove(5);
for file in files {
args.push(ExprImpl::Literal(Box::new(Literal::new(
Some(ScalarImpl::Utf8(file.into())),
DataType::Varchar,
))));
}
}

schema
}
};

Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/optimizer/plan_node/generic/file_scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub struct FileScan {
pub s3_region: String,
pub s3_access_key: String,
pub s3_secret_key: String,
pub file_location: String,
pub file_location: Vec<String>,

#[educe(PartialEq(ignore))]
#[educe(Hash(ignore))]
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/optimizer/plan_node/logical_file_scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl LogicalFileScan {
s3_region: String,
s3_access_key: String,
s3_secret_key: String,
file_location: String,
file_location: Vec<String>,
) -> Self {
assert!("parquet".eq_ignore_ascii_case(&file_format));
assert!("s3".eq_ignore_ascii_case(&storage_type));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ impl Rule for TableFunctionToFileScanRule {

let schema = Schema::new(fields);

assert!(logical_table_function.table_function().args.len() >= 6);
let mut eval_args = vec![];
for arg in &logical_table_function.table_function().args {
assert_eq!(arg.return_type(), DataType::Varchar);
Expand All @@ -56,14 +57,13 @@ impl Rule for TableFunctionToFileScanRule {
}
}
}
assert!(eval_args.len() == 6);
assert!("parquet".eq_ignore_ascii_case(&eval_args[0]));
assert!("s3".eq_ignore_ascii_case(&eval_args[1]));
let s3_region = eval_args[2].clone();
let s3_access_key = eval_args[3].clone();
let s3_secret_key = eval_args[4].clone();
let file_location = eval_args[5].clone();

// The rest of the arguments are file locations
let file_location = eval_args[5..].iter().cloned().collect_vec();
Comment on lines +65 to +66
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tabVersion We set file_location here.

Some(
LogicalFileScan::new(
logical_table_function.ctx(),
Expand Down
30 changes: 21 additions & 9 deletions src/frontend/src/scheduler/distributed/stage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -405,15 +405,27 @@ impl StageRunner {
expr_context.clone(),
));
}
} else if let Some(_file_scan_info) = self.stage.file_scan_info.as_ref() {
let task_id = PbTaskId {
query_id: self.stage.query_id.id.clone(),
stage_id: self.stage.id,
task_id: 0_u64,
};
let plan_fragment = self.create_plan_fragment(0_u64, Some(PartitionInfo::File));
let worker = self.choose_worker(&plan_fragment, 0_u32, self.stage.dml_table_id)?;
futures.push(self.schedule_task(task_id, plan_fragment, worker, expr_context.clone()));
} else if let Some(file_scan_info) = self.stage.file_scan_info.as_ref() {
let chunk_size = (file_scan_info.file_location.len() as f32
/ self.stage.parallelism.unwrap() as f32)
.ceil() as usize;
for (id, files) in file_scan_info.file_location.chunks(chunk_size).enumerate() {
let task_id = PbTaskId {
query_id: self.stage.query_id.id.clone(),
stage_id: self.stage.id,
task_id: id as u64,
};
let plan_fragment =
self.create_plan_fragment(id as u64, Some(PartitionInfo::File(files.to_vec())));
let worker =
self.choose_worker(&plan_fragment, id as u32, self.stage.dml_table_id)?;
futures.push(self.schedule_task(
task_id,
plan_fragment,
worker,
expr_context.clone(),
));
}
} else {
for id in 0..self.stage.parallelism.unwrap() {
let task_id = PbTaskId {
Expand Down
Loading
Loading