Skip to content

Commit

Permalink
feat: reads using global ctx
Browse files Browse the repository at this point in the history
  • Loading branch information
ion-elgreco committed Dec 28, 2024
1 parent 79c22d6 commit f7af294
Show file tree
Hide file tree
Showing 7 changed files with 305 additions and 2 deletions.
6 changes: 6 additions & 0 deletions python/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@

from .dataframe import DataFrame

from .io import read_parquet, read_avro, read_csv, read_json

from .expr import (
Expr,
WindowFrame,
Expand Down Expand Up @@ -89,6 +91,10 @@
"functions",
"object_store",
"substrait",
"read_parquet",
"read_avro",
"read_csv",
"read_json",
]


Expand Down
181 changes: 181 additions & 0 deletions python/datafusion/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""IO read functions using global context."""

import pathlib

from datafusion.dataframe import DataFrame
from datafusion.expr import Expr
import pyarrow
from ._internal import SessionContext as SessionContextInternal


def read_parquet(
path: str | pathlib.Path,
table_partition_cols: list[tuple[str, str]] | None = None,
parquet_pruning: bool = True,
file_extension: str = ".parquet",
skip_metadata: bool = True,
schema: pyarrow.Schema | None = None,
file_sort_order: list[list[Expr]] | None = None,
) -> DataFrame:
"""Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`.
Args:
path: Path to the Parquet file.
table_partition_cols: Partition columns.
parquet_pruning: Whether the parquet reader should use the predicate
to prune row groups.
file_extension: File extension; only files with this extension are
selected for data input.
skip_metadata: Whether the parquet reader should skip any metadata
that may be in the file schema. This can help avoid schema
conflicts due to metadata.
schema: An optional schema representing the parquet files. If None,
the parquet reader will try to infer it based on data in the
file.
file_sort_order: Sort order for the file.
Returns:
DataFrame representation of the read Parquet files
"""
if table_partition_cols is None:
table_partition_cols = []
return DataFrame(
SessionContextInternal._global_ctx().read_parquet(
str(path),
table_partition_cols,
parquet_pruning,
file_extension,
skip_metadata,
schema,
file_sort_order,
)
)


def read_json(
path: str | pathlib.Path,
schema: pyarrow.Schema | None = None,
schema_infer_max_records: int = 1000,
file_extension: str = ".json",
table_partition_cols: list[tuple[str, str]] | None = None,
file_compression_type: str | None = None,
) -> DataFrame:
"""Read a line-delimited JSON data source.
Args:
path: Path to the JSON file.
schema: The data source schema.
schema_infer_max_records: Maximum number of rows to read from JSON
files for schema inference if needed.
file_extension: File extension; only files with this extension are
selected for data input.
table_partition_cols: Partition columns.
file_compression_type: File compression type.
Returns:
DataFrame representation of the read JSON files.
"""
if table_partition_cols is None:
table_partition_cols = []
return DataFrame(
SessionContextInternal._global_ctx().read_json(
str(path),
schema,
schema_infer_max_records,
file_extension,
table_partition_cols,
file_compression_type,
)
)


def read_csv(
path: str | pathlib.Path | list[str] | list[pathlib.Path],
schema: pyarrow.Schema | None = None,
has_header: bool = True,
delimiter: str = ",",
schema_infer_max_records: int = 1000,
file_extension: str = ".csv",
table_partition_cols: list[tuple[str, str]] | None = None,
file_compression_type: str | None = None,
) -> DataFrame:
"""Read a CSV data source.
Args:
path: Path to the CSV file
schema: An optional schema representing the CSV files. If None, the
CSV reader will try to infer it based on data in file.
has_header: Whether the CSV file have a header. If schema inference
is run on a file with no headers, default column names are
created.
delimiter: An optional column delimiter.
schema_infer_max_records: Maximum number of rows to read from CSV
files for schema inference if needed.
file_extension: File extension; only files with this extension are
selected for data input.
table_partition_cols: Partition columns.
file_compression_type: File compression type.
Returns:
DataFrame representation of the read CSV files
"""
if table_partition_cols is None:
table_partition_cols = []

path = [str(p) for p in path] if isinstance(path, list) else str(path)

return DataFrame(
SessionContextInternal._global_ctx().read_csv(
path,
schema,
has_header,
delimiter,
schema_infer_max_records,
file_extension,
table_partition_cols,
file_compression_type,
)
)


def read_avro(
path: str | pathlib.Path,
schema: pyarrow.Schema | None = None,
file_partition_cols: list[tuple[str, str]] | None = None,
file_extension: str = ".avro",
) -> DataFrame:
"""Create a :py:class:`DataFrame` for reading Avro data source.
Args:
path: Path to the Avro file.
schema: The data source schema.
file_partition_cols: Partition columns.
file_extension: File extension to select.
Returns:
DataFrame representation of the read Avro file
"""
if file_partition_cols is None:
file_partition_cols = []
return DataFrame(
SessionContextInternal._global_ctx().read_avro(
str(path), schema, file_partition_cols, file_extension
)
)
1 change: 1 addition & 0 deletions python/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import pyarrow.dataset as ds
import pytest


from datafusion import (
DataFrame,
RuntimeConfig,
Expand Down
97 changes: 97 additions & 0 deletions python/tests/test_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import pathlib

from datafusion import column
import pyarrow as pa


from datafusion.io import read_avro, read_csv, read_json, read_parquet


def test_read_json_global_ctx(ctx):
path = os.path.dirname(os.path.abspath(__file__))

# Default
test_data_path = os.path.join(path, "data_test_context", "data.json")
df = read_json(test_data_path)
result = df.collect()

assert result[0].column(0) == pa.array(["a", "b", "c"])
assert result[0].column(1) == pa.array([1, 2, 3])

# Schema
schema = pa.schema(
[
pa.field("A", pa.string(), nullable=True),
]
)
df = read_json(test_data_path, schema=schema)
result = df.collect()

assert result[0].column(0) == pa.array(["a", "b", "c"])
assert result[0].schema == schema

# File extension
test_data_path = os.path.join(path, "data_test_context", "data.json")
df = read_json(test_data_path, file_extension=".json")
result = df.collect()

assert result[0].column(0) == pa.array(["a", "b", "c"])
assert result[0].column(1) == pa.array([1, 2, 3])


def test_read_parquet_global():
parquet_df = read_parquet(path="parquet/data/alltypes_plain.parquet")
parquet_df.show()
assert parquet_df is not None

path = pathlib.Path.cwd() / "parquet/data/alltypes_plain.parquet"
parquet_df = read_parquet(path=path)
assert parquet_df is not None


def test_read_csv():
csv_df = read_csv(path="testing/data/csv/aggregate_test_100.csv")
csv_df.select(column("c1")).show()


def test_read_csv_list():
csv_df = read_csv(path=["testing/data/csv/aggregate_test_100.csv"])
expected = csv_df.count() * 2

double_csv_df = read_csv(
path=[
"testing/data/csv/aggregate_test_100.csv",
"testing/data/csv/aggregate_test_100.csv",
]
)
actual = double_csv_df.count()

double_csv_df.select(column("c1")).show()
assert actual == expected


def test_read_avro():
avro_df = read_avro(path="testing/data/avro/alltypes_plain.avro")
avro_df.show()
assert avro_df is not None

path = pathlib.Path.cwd() / "testing/data/avro/alltypes_plain.avro"
avro_df = read_avro(path=path)
assert avro_df is not None
2 changes: 2 additions & 0 deletions python/tests/test_wrapper_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def missing_exports(internal_obj, wrapped_obj) -> None:
return

for attr in dir(internal_obj):
if attr in ["_global_ctx"]:
continue
assert attr in dir(wrapped_obj)

internal_attr = getattr(internal_obj, attr)
Expand Down
12 changes: 10 additions & 2 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use crate::store::StorageContexts;
use crate::udaf::PyAggregateUDF;
use crate::udf::PyScalarUDF;
use crate::udwf::PyWindowUDF;
use crate::utils::{get_tokio_runtime, wait_for_future};
use crate::utils::{get_global_ctx, get_tokio_runtime, wait_for_future};
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion::arrow::record_batch::RecordBatch;
Expand All @@ -68,7 +68,7 @@ use datafusion::prelude::{
AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
};
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple};
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
use tokio::task::JoinHandle;

/// Configuration options for a SessionContext
Expand Down Expand Up @@ -299,6 +299,14 @@ impl PySessionContext {
})
}

#[classmethod]
#[pyo3(signature = ())]
fn _global_ctx(_cls: &Bound<'_, PyType>) -> PyResult<Self> {
Ok(Self {
ctx: get_global_ctx().clone(),
})
}

/// Register an object store with the given name
#[pyo3(signature = (scheme, store, host=None))]
pub fn register_object_store(
Expand Down
8 changes: 8 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

use crate::errors::DataFusionError;
use crate::TokioRuntime;
use datafusion::execution::context::SessionContext;
use datafusion::logical_expr::Volatility;
use pyo3::prelude::*;
use std::future::Future;
Expand All @@ -35,6 +36,13 @@ pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime {
RUNTIME.get_or_init(|| TokioRuntime(tokio::runtime::Runtime::new().unwrap()))
}

/// Utility to get the Global Datafussion CTX
#[inline]
pub(crate) fn get_global_ctx() -> &'static SessionContext {
static CTX: OnceLock<SessionContext> = OnceLock::new();
CTX.get_or_init(|| SessionContext::new())
}

/// Utility to collect rust futures with GIL released
pub fn wait_for_future<F>(py: Python, f: F) -> F::Output
where
Expand Down

0 comments on commit f7af294

Please sign in to comment.