From 6974545d218c501a6555bbf70b0f5941bd7c3597 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Sat, 28 Dec 2024 12:29:02 +0100 Subject: [PATCH] chore: set validation and typehint --- python/datafusion/context.py | 13 ++++++++++++- src/context.rs | 4 ++-- src/dataframe.rs | 21 +-------------------- src/utils.rs | 21 +++++++++++++++++++++ 4 files changed, 36 insertions(+), 23 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index a07b5d175..3fa133346 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -63,6 +63,15 @@ def __arrow_c_array__( # noqa: D105 ) -> tuple[object, object]: ... +class TableProviderExportable(Protocol): + """Type hint for object that has __datafusion_table_provider__ PyCapsule. + + https://datafusion.apache.org/python/user-guide/io/table_provider.html + """ + + def __datafusion_table_provider__(self) -> object: ... # noqa: D105 + + class SessionConfig: """Session configuration options.""" @@ -685,7 +694,9 @@ def deregister_table(self, name: str) -> None: """Remove a table from the session.""" self.ctx.deregister_table(name) - def register_table_provider(self, name: str, provider: Any) -> None: + def register_table_provider( + self, name: str, provider: TableProviderExportable + ) -> None: """Register a table provider. This table provider must have a method called ``__datafusion_table_provider__`` diff --git a/src/context.rs b/src/context.rs index 8675e97df..0512285a7 100644 --- a/src/context.rs +++ b/src/context.rs @@ -43,7 +43,7 @@ use crate::store::StorageContexts; use crate::udaf::PyAggregateUDF; use crate::udf::PyScalarUDF; use crate::udwf::PyWindowUDF; -use crate::utils::{get_tokio_runtime, wait_for_future}; +use crate::utils::{get_tokio_runtime, validate_pycapsule, wait_for_future}; use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; @@ -576,7 +576,7 @@ impl PySessionContext { if provider.hasattr("__datafusion_table_provider__")? { let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?; let capsule = capsule.downcast::()?; - // validate_pycapsule(capsule, "arrow_array_stream")?; + validate_pycapsule(capsule, "datafusion_table_provider")?; let provider = unsafe { capsule.reference::() }; let provider: ForeignTableProvider = provider.into(); diff --git a/src/dataframe.rs b/src/dataframe.rs index e7d6ca6d6..fcb46a756 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -44,7 +44,7 @@ use crate::expr::sort_expr::to_sort_expressions; use crate::physical_plan::PyExecutionPlan; use crate::record_batch::PyRecordBatchStream; use crate::sql::logical::PyLogicalPlan; -use crate::utils::{get_tokio_runtime, wait_for_future}; +use crate::utils::{get_tokio_runtime, validate_pycapsule, wait_for_future}; use crate::{ errors::DataFusionError, expr::{sort_expr::PySortExpr, PyExpr}, @@ -724,22 +724,3 @@ fn record_batch_into_schema( RecordBatch::try_new(schema, data_arrays) } - -fn validate_pycapsule(capsule: &Bound, name: &str) -> PyResult<()> { - let capsule_name = capsule.name()?; - if capsule_name.is_none() { - return Err(PyValueError::new_err( - "Expected schema PyCapsule to have name set.", - )); - } - - let capsule_name = capsule_name.unwrap().to_str()?; - if capsule_name != name { - return Err(PyValueError::new_err(format!( - "Expected name '{}' in PyCapsule, instead got '{}'", - name, capsule_name - ))); - } - - Ok(()) -} diff --git a/src/utils.rs b/src/utils.rs index 7fb23cafe..795589752 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -18,7 +18,9 @@ use crate::errors::DataFusionError; use crate::TokioRuntime; use datafusion::logical_expr::Volatility; +use pyo3::exceptions::PyValueError; use pyo3::prelude::*; +use pyo3::types::PyCapsule; use std::future::Future; use std::sync::OnceLock; use tokio::runtime::Runtime; @@ -58,3 +60,22 @@ pub(crate) fn parse_volatility(value: &str) -> Result, name: &str) -> PyResult<()> { + let capsule_name = capsule.name()?; + if capsule_name.is_none() { + return Err(PyValueError::new_err( + "Expected schema PyCapsule to have name set.", + )); + } + + let capsule_name = capsule_name.unwrap().to_str()?; + if capsule_name != name { + return Err(PyValueError::new_err(format!( + "Expected name '{}' in PyCapsule, instead got '{}'", + name, capsule_name + ))); + } + + Ok(()) +}