From f8224953b281bf157b3ddc18d88a7085852a6bdc Mon Sep 17 00:00:00 2001 From: Daniel Mesejo Date: Sun, 29 Sep 2024 14:19:55 +0200 Subject: [PATCH] feat: make register_csv accept a list of paths (#883) --- python/datafusion/context.py | 11 +++-- python/datafusion/tests/test_sql.py | 35 ++++++++++++++++ src/context.rs | 64 +++++++++++++++++++++++++---- 3 files changed, 99 insertions(+), 11 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 35a40ccd4..2c41faba6 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -714,7 +714,7 @@ def register_parquet( def register_csv( self, name: str, - path: str | pathlib.Path, + path: str | pathlib.Path | list[str | pathlib.Path], schema: pyarrow.Schema | None = None, has_header: bool = True, delimiter: str = ",", @@ -728,7 +728,7 @@ def register_csv( Args: name: Name of the table to register. - path: Path to the CSV file. + path: Path to the CSV file. It also accepts a list of Paths. schema: An optional schema representing the CSV file. If None, the CSV reader will try to infer it based on data in file. has_header: Whether the CSV file have a header. If schema inference @@ -741,9 +741,14 @@ def register_csv( selected for data input. file_compression_type: File compression type. """ + if isinstance(path, list): + path = [str(p) for p in path] + else: + path = str(path) + self.ctx.register_csv( name, - str(path), + path, schema, has_header, delimiter, diff --git a/python/datafusion/tests/test_sql.py b/python/datafusion/tests/test_sql.py index bd2ae58d7..e39a9f5c7 100644 --- a/python/datafusion/tests/test_sql.py +++ b/python/datafusion/tests/test_sql.py @@ -104,6 +104,41 @@ def test_register_csv(ctx, tmp_path): ctx.register_csv("csv4", path, file_compression_type="rar") +def test_register_csv_list(ctx, tmp_path): + path = tmp_path / "test.csv" + + int_values = [1, 2, 3, 4] + table = pa.Table.from_arrays( + [ + int_values, + ["a", "b", "c", "d"], + [1.1, 2.2, 3.3, 4.4], + ], + names=["int", "str", "float"], + ) + write_csv(table, path) + ctx.register_csv("csv", path) + + csv_df = ctx.table("csv") + expected_count = csv_df.count() * 2 + ctx.register_csv( + "double_csv", + path=[ + path, + path, + ], + ) + + double_csv_df = ctx.table("double_csv") + actual_count = double_csv_df.count() + assert actual_count == expected_count + + int_sum = ctx.sql("select sum(int) from double_csv").to_pydict()[ + "sum(double_csv.int)" + ][0] + assert int_sum == 2 * sum(int_values) + + def test_register_parquet(ctx, tmp_path): path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data()) ctx.register_parquet("t", path) diff --git a/src/context.rs b/src/context.rs index 79db2e65c..7ad12ceb0 100644 --- a/src/context.rs +++ b/src/context.rs @@ -46,7 +46,8 @@ use crate::utils::{get_tokio_runtime, wait_for_future}; use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; -use datafusion::common::ScalarValue; +use datafusion::catalog_common::TableReference; +use datafusion::common::{exec_err, ScalarValue}; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::{ @@ -54,9 +55,12 @@ use datafusion::datasource::listing::{ }; use datafusion::datasource::MemTable; use datafusion::datasource::TableProvider; -use datafusion::execution::context::{SQLOptions, SessionConfig, SessionContext, TaskContext}; +use datafusion::execution::context::{ + DataFilePaths, SQLOptions, SessionConfig, SessionContext, TaskContext, +}; use datafusion::execution::disk_manager::DiskManagerConfig; use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool}; +use datafusion::execution::options::ReadOptions; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::physical_plan::SendableRecordBatchStream; use datafusion::prelude::{ @@ -621,7 +625,7 @@ impl PySessionContext { pub fn register_csv( &mut self, name: &str, - path: PathBuf, + path: &Bound<'_, PyAny>, schema: Option>, has_header: bool, delimiter: &str, @@ -630,9 +634,6 @@ impl PySessionContext { file_compression_type: Option, py: Python, ) -> PyResult<()> { - let path = path - .to_str() - .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?; let delimiter = delimiter.as_bytes(); if delimiter.len() != 1 { return Err(PyValueError::new_err( @@ -648,8 +649,15 @@ impl PySessionContext { .file_compression_type(parse_file_compression_type(file_compression_type)?); options.schema = schema.as_ref().map(|x| &x.0); - let result = self.ctx.register_csv(name, path, options); - wait_for_future(py, result).map_err(DataFusionError::from)?; + if path.is_instance_of::() { + let paths = path.extract::>()?; + let result = self.register_csv_from_multiple_paths(name, paths, options); + wait_for_future(py, result).map_err(DataFusionError::from)?; + } else { + let path = path.extract::()?; + let result = self.ctx.register_csv(name, &path, options); + wait_for_future(py, result).map_err(DataFusionError::from)?; + } Ok(()) } @@ -981,6 +989,46 @@ impl PySessionContext { async fn _table(&self, name: &str) -> datafusion::common::Result { self.ctx.table(name).await } + + async fn register_csv_from_multiple_paths( + &self, + name: &str, + table_paths: Vec, + options: CsvReadOptions<'_>, + ) -> datafusion::common::Result<()> { + let table_paths = table_paths.to_urls()?; + let session_config = self.ctx.copied_config(); + let listing_options = + options.to_listing_options(&session_config, self.ctx.copied_table_options()); + + let option_extension = listing_options.file_extension.clone(); + + if table_paths.is_empty() { + return exec_err!("No table paths were provided"); + } + + // check if the file extension matches the expected extension + for path in &table_paths { + let file_path = path.as_str(); + if !file_path.ends_with(option_extension.clone().as_str()) && !path.is_collection() { + return exec_err!( + "File path '{file_path}' does not match the expected extension '{option_extension}'" + ); + } + } + + let resolved_schema = options + .get_resolved_schema(&session_config, self.ctx.state(), table_paths[0].clone()) + .await?; + + let config = ListingTableConfig::new_with_multi_paths(table_paths) + .with_listing_options(listing_options) + .with_schema(resolved_schema); + let table = ListingTable::try_new(config)?; + self.ctx + .register_table(TableReference::Bare { table: name.into() }, Arc::new(table))?; + Ok(()) + } } pub fn convert_table_partition_cols(