diff --git a/ballista/scheduler/src/config.rs b/ballista/scheduler/src/config.rs index 9ddb8b6e5..10c6df1db 100644 --- a/ballista/scheduler/src/config.rs +++ b/ballista/scheduler/src/config.rs @@ -38,6 +38,8 @@ pub struct SchedulerConfig { pub namespace: String, /// The external hostname of the scheduler pub external_host: String, + /// The bind host for the scheduler's gRPC service + pub bind_host: String, /// The bind port for the scheduler's gRPC service pub bind_port: u16, /// The task scheduling policy for the scheduler @@ -87,6 +89,7 @@ impl std::fmt::Debug for SchedulerConfig { .field("namespace", &self.namespace) .field("external_host", &self.external_host) .field("bind_port", &self.bind_port) + .field("bind_host", &self.bind_host) .field("scheduling_policy", &self.scheduling_policy) .field("event_loop_buffer_size", &self.event_loop_buffer_size) .field("task_distribution", &self.task_distribution) @@ -137,6 +140,7 @@ impl Default for SchedulerConfig { namespace: String::default(), external_host: "localhost".into(), bind_port: 50050, + bind_host: "127.0.0.1".into(), scheduling_policy: Default::default(), event_loop_buffer_size: 10000, task_distribution: Default::default(), @@ -326,6 +330,7 @@ impl TryFrom for SchedulerConfig { namespace: opt.namespace, external_host: opt.external_host, bind_port: opt.bind_port, + bind_host: opt.bind_host, scheduling_policy: opt.scheduler_policy, event_loop_buffer_size: opt.event_loop_buffer_size, task_distribution, diff --git a/python/Cargo.toml b/python/Cargo.toml index b03f1e997..747f330a9 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -31,8 +31,10 @@ publish = false [dependencies] async-trait = "0.1.77" -ballista = { path = "../ballista/client", version = "0.12.0", features = ["standalone"] } +ballista = { path = "../ballista/client", version = "0.12.0" } ballista-core = { path = "../ballista/core", version = "0.12.0" } +ballista-executor = { path = "../ballista/executor", version = "0.12.0" } +ballista-scheduler = { path = "../ballista/scheduler", version = "0.12.0" } datafusion = { version = "42", features = ["pyarrow", "avro"] } datafusion-proto = { version = "42" } datafusion-python = { version = "42" } diff --git a/python/README.md b/python/README.md index 01b0a7f90..d8ba03f3d 100644 --- a/python/README.md +++ b/python/README.md @@ -26,6 +26,12 @@ part of the default Cargo workspace so that it doesn't cause overhead for mainta ## Creating a SessionContext +> [!IMPORTANT] +> Current approach is to support datafusion python API, there are know limitations of current approach, +> with some cases producing errors. +> We trying to come up with the best approach to support datafusion python interface. +> More details could be found at [#1142](https://github.com/apache/datafusion-ballista/issues/1142) + Creates a new context and connects to a Ballista scheduler process. ```python @@ -33,22 +39,50 @@ from ballista import BallistaBuilder >>> ctx = BallistaBuilder().standalone() ``` -## Example SQL Usage +### Example SQL Usage ```python ->>> ctx.sql("create external table t stored as parquet location '/mnt/bigdata/tpch/sf10-parquet/lineitem.parquet'") +>>> ctx.sql("create external table t stored as parquet location './testdata/test.parquet'") >>> df = ctx.sql("select * from t limit 5") >>> pyarrow_batches = df.collect() ``` -## Example DataFrame Usage +### Example DataFrame Usage ```python ->>> df = ctx.read_parquet('/mnt/bigdata/tpch/sf10-parquet/lineitem.parquet').limit(5) +>>> df = ctx.read_parquet('./testdata/test.parquet').limit(5) >>> pyarrow_batches = df.collect() ``` -## Creating Virtual Environment +## Scheduler and Executor + +Scheduler and executors can be configured and started from python code. + +To start scheduler: + +```python +from ballista import BallistaScheduler + +scheduler = BallistaScheduler() + +scheduler.start() +scheduler.wait_for_termination() +``` + +For executor: + +```python +from ballista import BallistaExecutor + +executor = BallistaExecutor() + +executor.start() +executor.wait_for_termination() +``` + +## Development Process + +### Creating Virtual Environment ```shell python3 -m venv venv @@ -56,7 +90,7 @@ source venv/bin/activate pip3 install -r requirements.txt ``` -## Building +### Building ```shell maturin develop @@ -64,7 +98,7 @@ maturin develop Note that you can also run `maturin develop --release` to get a release build locally. -## Testing +### Testing ```shell python3 -m pytest diff --git a/python/ballista/__init__.py b/python/ballista/__init__.py index a143f17e9..4e80422b7 100644 --- a/python/ballista/__init__.py +++ b/python/ballista/__init__.py @@ -26,11 +26,11 @@ import pyarrow as pa from .ballista_internal import ( - BallistaBuilder, + BallistaBuilder, BallistaScheduler, BallistaExecutor ) __version__ = importlib_metadata.version(__name__) __all__ = [ - "BallistaBuilder", + "BallistaBuilder", "BallistaScheduler", "BallistaExecutor" ] \ No newline at end of file diff --git a/python/examples/client_remote.py b/python/examples/client_remote.py new file mode 100644 index 000000000..fd85858ac --- /dev/null +++ b/python/examples/client_remote.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# %% +from ballista import BallistaBuilder +from datafusion.context import SessionContext + +ctx: SessionContext = BallistaBuilder().remote("df://127.0.0.1:50050") + +# Select 1 to verify its working +ctx.sql("SELECT 1").show() + +# %% diff --git a/python/examples/example.py b/python/examples/client_standalone.py similarity index 79% rename from python/examples/example.py rename to python/examples/client_standalone.py index 61a9abbd2..dfe3c372f 100644 --- a/python/examples/example.py +++ b/python/examples/client_standalone.py @@ -15,18 +15,23 @@ # specific language governing permissions and limitations # under the License. +# %% + from ballista import BallistaBuilder from datafusion.context import SessionContext -# Ballista will initiate with an empty config -# set config variables with `config` ctx: SessionContext = BallistaBuilder()\ + .config("datafusion.catalog.information_schema","true")\ .config("ballista.job.name", "example ballista")\ - .config("ballista.shuffle.partitions", "16")\ .standalone() -#ctx_remote: SessionContext = ballista.remote("remote_ip", 50050) -# Select 1 to verify its working ctx.sql("SELECT 1").show() -#ctx_remote.sql("SELECT 2").show() \ No newline at end of file + +# %% +ctx.sql("SHOW TABLES").show() +# %% +ctx.sql("select name, value from information_schema.df_settings where name like 'ballista.job.name'").show() + + +# %% diff --git a/python/examples/executor.py b/python/examples/executor.py new file mode 100644 index 000000000..bb032f634 --- /dev/null +++ b/python/examples/executor.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# %% +from ballista import BallistaExecutor +# %% +executor = BallistaExecutor() +# %% +executor.start() +# %% +executor +# %% +executor.wait_for_termination() +# %% +# %% +executor.close() +# %% diff --git a/python/examples/readme_remote.py b/python/examples/readme_remote.py new file mode 100644 index 000000000..7e1c82d83 --- /dev/null +++ b/python/examples/readme_remote.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# %% + +from ballista import BallistaBuilder +from datafusion.context import SessionContext + +ctx: SessionContext = BallistaBuilder()\ + .config("ballista.job.name", "Readme Example Remote")\ + .config("datafusion.execution.target_partitions", "4")\ + .remote("df://127.0.0.1:50050") + +ctx.sql("create external table t stored as parquet location '../testdata/test.parquet'") + +# %% +df = ctx.sql("select * from t limit 5") +pyarrow_batches = df.collect() +pyarrow_batches[0].to_pandas() +# %% +df = ctx.read_parquet('../testdata/test.parquet').limit(5) +pyarrow_batches = df.collect() +pyarrow_batches[0].to_pandas() +# %% \ No newline at end of file diff --git a/python/examples/readme_standalone.py b/python/examples/readme_standalone.py new file mode 100644 index 000000000..15404e02d --- /dev/null +++ b/python/examples/readme_standalone.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# %% + +from ballista import BallistaBuilder +from datafusion.context import SessionContext + +ctx: SessionContext = BallistaBuilder()\ + .config("ballista.job.name", "Readme Example")\ + .config("datafusion.execution.target_partitions", "4")\ + .standalone() + +ctx.sql("create external table t stored as parquet location '../testdata/test.parquet'") + +# %% +df = ctx.sql("select * from t limit 5") +pyarrow_batches = df.collect() +pyarrow_batches[0].to_pandas() +# %% +df = ctx.read_parquet('../testdata/test.parquet').limit(5) +pyarrow_batches = df.collect() +pyarrow_batches[0].to_pandas() +# %% \ No newline at end of file diff --git a/python/examples/scheduler.py b/python/examples/scheduler.py new file mode 100644 index 000000000..1c40ce1ee --- /dev/null +++ b/python/examples/scheduler.py @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# %% +from ballista import BallistaScheduler +# %% +scheduler = BallistaScheduler() +# %% +scheduler +# %% +scheduler.start() +# %% +scheduler.wait_for_termination() +# %% +scheduler.close() \ No newline at end of file diff --git a/python/pyproject.toml b/python/pyproject.toml index cce88fd3b..d9b6d2bd9 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -43,7 +43,7 @@ classifier = [ "Programming Language :: Rust", ] dependencies = [ - "pyarrow>=11.0.0", + "pyarrow>=11.0.0", "cloudpickle" ] [project.urls] diff --git a/python/requirements.txt b/python/requirements.txt index a03a8f8d2..bfc0e03cf 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -1,3 +1,6 @@ -datafusion==35.0.0 +datafusion==42.0.0 pyarrow -pytest \ No newline at end of file +pytest +maturin==1.5.1 +cloudpickle +pandas \ No newline at end of file diff --git a/python/src/cluster.rs b/python/src/cluster.rs new file mode 100644 index 000000000..aa4260ce2 --- /dev/null +++ b/python/src/cluster.rs @@ -0,0 +1,264 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::future::IntoFuture; +use std::sync::Arc; + +use crate::codec::{PyLogicalCodec, PyPhysicalCodec}; +use crate::utils::to_pyerr; +use crate::utils::{spawn_feature, wait_for_future}; +use ballista_executor::executor_process::{ + start_executor_process, ExecutorProcessConfig, +}; +use ballista_scheduler::cluster::BallistaCluster; +use ballista_scheduler::config::SchedulerConfig; +use ballista_scheduler::scheduler_process::start_server; +use pyo3::exceptions::PyException; +use pyo3::{pyclass, pymethods, PyResult, Python}; +use tokio::task::JoinHandle; + +#[pyclass(name = "BallistaScheduler", module = "ballista", subclass)] +pub struct PyScheduler { + config: SchedulerConfig, + handle: Option>, +} + +#[pymethods] +impl PyScheduler { + #[pyo3(signature = (bind_host=None, bind_port=None))] + #[new] + pub fn new(py: Python, bind_host: Option, bind_port: Option) -> Self { + let mut config = SchedulerConfig::default(); + + if let Some(bind_port) = bind_port { + config.bind_port = bind_port; + } + + if let Some(host) = bind_host { + config.bind_host = host; + } + + config.override_logical_codec = + Some(Arc::new(PyLogicalCodec::try_new(py).unwrap())); + config.override_physical_codec = + Some(Arc::new(PyPhysicalCodec::try_new(py).unwrap())); + + Self { + config, + handle: None, + } + } + + pub fn start(&mut self, py: Python) -> PyResult<()> { + if self.handle.is_some() { + return Err(PyException::new_err("Scheduler already started")); + } + let cluster = wait_for_future(py, BallistaCluster::new_from_config(&self.config)) + .map_err(to_pyerr)?; + + let config = self.config.clone(); + let address = format!("{}:{}", config.bind_host, config.bind_port); + let address = address.parse()?; + let handle = spawn_feature(py, async move { + start_server(cluster, address, Arc::new(config)) + .await + .unwrap(); + }); + self.handle = Some(handle); + + Ok(()) + } + + pub fn wait_for_termination(&mut self, py: Python) -> PyResult<()> { + if self.handle.is_none() { + return Err(PyException::new_err("Scheduler not started")); + } + let mut handle = None; + std::mem::swap(&mut self.handle, &mut handle); + + match handle { + Some(handle) => wait_for_future(py, handle.into_future()) + .map_err(|e| PyException::new_err(e.to_string())), + None => Ok(()), + } + } + + pub fn close(&mut self) -> PyResult<()> { + let mut handle = None; + std::mem::swap(&mut self.handle, &mut handle); + + if let Some(handle) = handle { + handle.abort() + } + + Ok(()) + } + + #[classattr] + pub fn version() -> &'static str { + ballista_core::BALLISTA_VERSION + } + + pub fn __str__(&self) -> String { + match self.handle { + Some(_) => format!( + "listening address={}:{}", + self.config.bind_host, self.config.bind_port, + ), + None => format!( + "configured address={}:{}", + self.config.bind_host, self.config.bind_port, + ), + } + } + + pub fn __repr__(&self) -> String { + format!( + "BallistaScheduler(config={:?}, listening= {})", + self.config, + self.handle.is_some() + ) + } +} + +#[pyclass(name = "BallistaExecutor", module = "ballista", subclass)] +pub struct PyExecutor { + config: Arc, + handle: Option>, +} + +#[pymethods] +impl PyExecutor { + #[pyo3(signature = (bind_port=None, bind_host =None, scheduler_host = None, scheduler_port = None, concurrent_tasks = None))] + #[new] + pub fn new( + py: Python, + bind_port: Option, + bind_host: Option, + scheduler_host: Option, + scheduler_port: Option, + concurrent_tasks: Option, + ) -> PyResult { + let mut config = ExecutorProcessConfig::default(); + if let Some(port) = bind_port { + config.port = port; + } + + if let Some(host) = bind_host { + config.bind_host = host; + } + + if let Some(port) = scheduler_port { + config.scheduler_port = port; + } + + if let Some(host) = scheduler_host { + config.scheduler_host = host; + } + + if let Some(concurrent_tasks) = concurrent_tasks { + config.concurrent_tasks = concurrent_tasks as usize + } + + config.override_logical_codec = Some(Arc::new(PyLogicalCodec::try_new(py)?)); + config.override_physical_codec = Some(Arc::new(PyPhysicalCodec::try_new(py)?)); + + let config = Arc::new(config); + Ok(Self { + config, + handle: None, + }) + } + + pub fn start(&mut self, py: Python) -> PyResult<()> { + if self.handle.is_some() { + return Err(PyException::new_err("Executor already started")); + } + + let config = self.config.clone(); + + let handle = + spawn_feature( + py, + async move { start_executor_process(config).await.unwrap() }, + ); + self.handle = Some(handle); + + Ok(()) + } + + pub fn wait_for_termination(&mut self, py: Python) -> PyResult<()> { + if self.handle.is_none() { + return Err(PyException::new_err("Executor not started")); + } + let mut handle = None; + std::mem::swap(&mut self.handle, &mut handle); + + match handle { + Some(handle) => wait_for_future(py, handle.into_future()) + .map_err(|e| PyException::new_err(e.to_string())) + .map(|_| ()), + None => Ok(()), + } + } + + pub fn close(&mut self) -> PyResult<()> { + let mut handle = None; + std::mem::swap(&mut self.handle, &mut handle); + + if let Some(handle) = handle { + handle.abort() + } + + Ok(()) + } + + #[classattr] + pub fn version() -> &'static str { + ballista_core::BALLISTA_VERSION + } + + pub fn __str__(&self) -> String { + match self.handle { + Some(_) => format!( + "listening address={}:{}, scheduler={}:{}", + self.config.bind_host, + self.config.port, + self.config.scheduler_host, + self.config.scheduler_port + ), + None => format!( + "configured address={}:{}, scheduler={}:{}", + self.config.bind_host, + self.config.port, + self.config.scheduler_host, + self.config.scheduler_port + ), + } + } + + pub fn __repr__(&self) -> String { + format!( + "BallistaExecutor(address={}:{}, scheduler={}:{}, listening={})", + self.config.bind_host, + self.config.port, + self.config.scheduler_host, + self.config.scheduler_port, + self.handle.is_some() + ) + } +} diff --git a/python/src/codec.rs b/python/src/codec.rs new file mode 100644 index 000000000..c6b0b7e50 --- /dev/null +++ b/python/src/codec.rs @@ -0,0 +1,253 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use ballista_core::serde::{ + BallistaLogicalExtensionCodec, BallistaPhysicalExtensionCodec, +}; +use datafusion::logical_expr::ScalarUDF; +use datafusion_proto::logical_plan::LogicalExtensionCodec; +use datafusion_proto::physical_plan::PhysicalExtensionCodec; +use pyo3::types::{PyAnyMethods, PyBytes, PyBytesMethods}; +use pyo3::{PyObject, PyResult, Python}; +use std::fmt::Debug; +use std::sync::Arc; + +static MODULE: &str = "cloudpickle"; +static FUN_LOADS: &str = "loads"; +static FUN_DUMPS: &str = "dumps"; + +/// Serde protocol for UD(a)F +#[derive(Debug)] +struct CloudPickle { + loads: PyObject, + dumps: PyObject, +} + +impl CloudPickle { + pub fn try_new(py: Python<'_>) -> PyResult { + let module = py.import_bound(MODULE)?; + let loads = module.getattr(FUN_LOADS)?.unbind(); + let dumps = module.getattr(FUN_DUMPS)?.unbind(); + + Ok(Self { loads, dumps }) + } + + pub fn pickle(&self, py: Python<'_>, py_any: &PyObject) -> PyResult> { + let b: PyObject = self.dumps.call1(py, (py_any,))?.extract(py)?; + let blob = b.downcast_bound::(py)?.clone(); + + Ok(blob.as_bytes().to_owned()) + } + + pub fn unpickle(&self, py: Python<'_>, blob: &[u8]) -> PyResult { + let t: PyObject = self.loads.call1(py, (blob,))?.extract(py)?; + + Ok(t) + } +} + +pub struct PyLogicalCodec { + inner: BallistaLogicalExtensionCodec, + cloudpickle: CloudPickle, +} + +impl PyLogicalCodec { + pub fn try_new(py: Python<'_>) -> PyResult { + Ok(Self { + inner: BallistaLogicalExtensionCodec::default(), + cloudpickle: CloudPickle::try_new(py)?, + }) + } +} + +impl Debug for PyLogicalCodec { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PyLogicalCodec").finish() + } +} + +impl LogicalExtensionCodec for PyLogicalCodec { + fn try_decode( + &self, + buf: &[u8], + inputs: &[datafusion::logical_expr::LogicalPlan], + ctx: &datafusion::prelude::SessionContext, + ) -> datafusion::error::Result { + self.inner.try_decode(buf, inputs, ctx) + } + + fn try_encode( + &self, + node: &datafusion::logical_expr::Extension, + buf: &mut Vec, + ) -> datafusion::error::Result<()> { + self.inner.try_encode(node, buf) + } + + fn try_decode_table_provider( + &self, + buf: &[u8], + table_ref: &datafusion::sql::TableReference, + schema: datafusion::arrow::datatypes::SchemaRef, + ctx: &datafusion::prelude::SessionContext, + ) -> datafusion::error::Result> + { + self.inner + .try_decode_table_provider(buf, table_ref, schema, ctx) + } + + fn try_encode_table_provider( + &self, + table_ref: &datafusion::sql::TableReference, + node: std::sync::Arc, + buf: &mut Vec, + ) -> datafusion::error::Result<()> { + self.inner.try_encode_table_provider(table_ref, node, buf) + } + + fn try_decode_file_format( + &self, + buf: &[u8], + ctx: &datafusion::prelude::SessionContext, + ) -> datafusion::error::Result< + std::sync::Arc, + > { + self.inner.try_decode_file_format(buf, ctx) + } + + fn try_encode_file_format( + &self, + buf: &mut Vec, + node: std::sync::Arc, + ) -> datafusion::error::Result<()> { + self.inner.try_encode_file_format(buf, node) + } + + fn try_decode_udf( + &self, + name: &str, + buf: &[u8], + ) -> datafusion::error::Result> + { + // use cloud pickle to decode udf + self.inner.try_decode_udf(name, buf) + } + + fn try_encode_udf( + &self, + node: &datafusion::logical_expr::ScalarUDF, + buf: &mut Vec, + ) -> datafusion::error::Result<()> { + // use cloud pickle to decode udf + self.inner.try_encode_udf(node, buf) + } + + fn try_decode_udaf( + &self, + name: &str, + buf: &[u8], + ) -> datafusion::error::Result> + { + self.inner.try_decode_udaf(name, buf) + } + + fn try_encode_udaf( + &self, + node: &datafusion::logical_expr::AggregateUDF, + buf: &mut Vec, + ) -> datafusion::error::Result<()> { + self.inner.try_encode_udaf(node, buf) + } + + fn try_decode_udwf( + &self, + name: &str, + buf: &[u8], + ) -> datafusion::error::Result> + { + self.inner.try_decode_udwf(name, buf) + } + + fn try_encode_udwf( + &self, + node: &datafusion::logical_expr::WindowUDF, + buf: &mut Vec, + ) -> datafusion::error::Result<()> { + self.inner.try_encode_udwf(node, buf) + } +} + +pub struct PyPhysicalCodec { + inner: BallistaPhysicalExtensionCodec, + cloudpickle: CloudPickle, +} + +impl Debug for PyPhysicalCodec { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PyPhysicalCodec").finish() + } +} + +impl PyPhysicalCodec { + pub fn try_new(py: Python<'_>) -> PyResult { + Ok(Self { + inner: BallistaPhysicalExtensionCodec::default(), + cloudpickle: CloudPickle::try_new(py)?, + }) + } +} + +impl PhysicalExtensionCodec for PyPhysicalCodec { + fn try_decode( + &self, + buf: &[u8], + inputs: &[std::sync::Arc], + registry: &dyn datafusion::execution::FunctionRegistry, + ) -> datafusion::error::Result< + std::sync::Arc, + > { + self.inner.try_decode(buf, inputs, registry) + } + + fn try_encode( + &self, + node: std::sync::Arc, + buf: &mut Vec, + ) -> datafusion::error::Result<()> { + self.inner.try_encode(node, buf) + } + + fn try_decode_udf( + &self, + name: &str, + _buf: &[u8], + ) -> datafusion::common::Result> { + // use cloudpickle here + datafusion::common::not_impl_err!( + "PhysicalExtensionCodec is not provided for scalar function {name}" + ) + } + + fn try_encode_udf( + &self, + _node: &ScalarUDF, + _buf: &mut Vec, + ) -> datafusion::common::Result<()> { + // use cloudpickle here + Ok(()) + } +} diff --git a/python/src/lib.rs b/python/src/lib.rs index 41b4b6d31..13a6c38b9 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -15,32 +15,36 @@ // specific language governing permissions and limitations // under the License. +use crate::utils::wait_for_future; use ballista::prelude::*; +use cluster::{PyExecutor, PyScheduler}; use datafusion::execution::SessionStateBuilder; use datafusion::prelude::*; use datafusion_python::context::PySessionContext; -use datafusion_python::utils::wait_for_future; - -use std::collections::HashMap; - use pyo3::prelude::*; + +mod cluster; +#[allow(dead_code)] +mod codec; mod utils; -use utils::to_pyerr; + +pub(crate) struct TokioRuntime(tokio::runtime::Runtime); #[pymodule] fn ballista_internal(_py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { pyo3_log::init(); - // BallistaBuilder struct + m.add_class::()?; - // DataFusion struct m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + Ok(()) } -// Ballista Builder will take a HasMap/Dict Cionfg #[pyclass(name = "BallistaBuilder", module = "ballista", subclass)] pub struct PyBallistaBuilder { - conf: HashMap, + session_config: SessionConfig, } #[pymethods] @@ -48,56 +52,47 @@ impl PyBallistaBuilder { #[new] pub fn new() -> Self { Self { - conf: HashMap::new(), + session_config: SessionConfig::new_with_ballista(), } } pub fn config( mut slf: PyRefMut<'_, Self>, - k: &str, - v: &str, + key: &str, + value: &str, py: Python, ) -> PyResult { - slf.conf.insert(k.into(), v.into()); + let _ = slf.session_config.options_mut().set(key, value); Ok(slf.into_py(py)) } /// Construct the standalone instance from the SessionContext pub fn standalone(&self, py: Python) -> PyResult { - // Build the config - let config: SessionConfig = SessionConfig::from_string_hash_map(&self.conf)?; - // Build the state let state = SessionStateBuilder::new() - .with_config(config) + .with_config(self.session_config.clone()) .with_default_features() .build(); - // Build the context - let standalone_session = SessionContext::standalone_with_state(state); - // SessionContext is an async function - let ctx = wait_for_future(py, standalone_session)?; + let ctx = wait_for_future(py, SessionContext::standalone_with_state(state))?; - // Convert the SessionContext into a Python SessionContext Ok(ctx.into()) } /// Construct the remote instance from the SessionContext pub fn remote(&self, url: &str, py: Python) -> PyResult { - // Build the config - let config: SessionConfig = SessionConfig::from_string_hash_map(&self.conf)?; - // Build the state let state = SessionStateBuilder::new() - .with_config(config) + .with_config(self.session_config.clone()) .with_default_features() .build(); - // Build the context - let remote_session = SessionContext::remote_with_state(url, state); - // SessionContext is an async function - let ctx = wait_for_future(py, remote_session)?; + let ctx = wait_for_future(py, SessionContext::remote_with_state(url, state))?; - // Convert the SessionContext into a Python SessionContext Ok(ctx.into()) } + + #[classattr] + pub fn version() -> &'static str { + ballista_core::BALLISTA_VERSION + } } diff --git a/python/src/utils.rs b/python/src/utils.rs index 10278537e..f069475ea 100644 --- a/python/src/utils.rs +++ b/python/src/utils.rs @@ -15,10 +15,48 @@ // specific language governing permissions and limitations // under the License. +use std::future::Future; +use std::sync::OnceLock; +use tokio::task::JoinHandle; + use ballista_core::error::BallistaError; use pyo3::exceptions::PyException; -use pyo3::PyErr; +use pyo3::{PyErr, Python}; +use tokio::runtime::Runtime; + +use crate::TokioRuntime; pub(crate) fn to_pyerr(err: BallistaError) -> PyErr { PyException::new_err(err.to_string()) } + +#[inline] +pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime { + // NOTE: Other pyo3 python libraries have had issues with using tokio + // behind a forking app-server like `gunicorn` + // If we run into that problem, in the future we can look to `delta-rs` + // which adds a check in that disallows calls from a forked process + // https://github.com/delta-io/delta-rs/blob/87010461cfe01563d91a4b9cd6fa468e2ad5f283/python/src/utils.rs#L10-L31 + static RUNTIME: OnceLock = OnceLock::new(); + RUNTIME.get_or_init(|| TokioRuntime(tokio::runtime::Runtime::new().unwrap())) +} + +/// Utility to collect rust futures with GIL released +pub(crate) fn wait_for_future(py: Python, f: F) -> F::Output +where + F: Future + Send, + F::Output: Send, +{ + let runtime: &Runtime = &get_tokio_runtime().0; + py.allow_threads(|| runtime.block_on(f)) +} + +pub(crate) fn spawn_feature(py: Python, f: F) -> JoinHandle +where + F: Future + Send + 'static, + F::Output: Send, +{ + let runtime: &Runtime = &get_tokio_runtime().0; + // do we need py.allow_threads ? + py.allow_threads(|| runtime.spawn(f)) +}