diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 000000000..ce92d5a52 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,312 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: Python Release Build +on: + pull_request: + branches: ["main"] + push: + tags: ["*-rc*"] + branches: ["branch-*"] + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install ruff + # Update output format to enable automatic inline annotations. +# - name: Run Ruff +# run: | +# ruff check --output-format=github python/ +# ruff format --check python/ + + generate-license: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + override: true + - name: Generate license file + run: python ./dev/create_license.py + - uses: actions/upload-artifact@v4 + with: + name: python-wheel-license + path: LICENSE.txt + + build-python-mac-win: + needs: [generate-license] + name: Mac/Win + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + python-version: ["3.10"] + os: [macos-latest, windows-latest] + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - uses: actions-rs/toolchain@v1 + with: + toolchain: stable + + - name: Upgrade pip + run: python -m pip install --upgrade pip + + - name: Install maturin + run: pip install maturin==1.5.1 + + - run: rm LICENSE.txt + - name: Download LICENSE.txt + uses: actions/download-artifact@v4 + with: + name: python-wheel-license + path: . + + - name: Install Protoc + uses: arduino/setup-protoc@v3 + with: + version: "27.4" + repo-token: ${{ secrets.GITHUB_TOKEN }} + + - name: Cargo build + run: cd python && cargo build + + - name: Build Python package + run: cd python && maturin build --release --strip + + - name: List Windows wheels + if: matrix.os == 'windows-latest' + run: dir python\target\wheels\ + # since the runner is dynamic shellcheck (from actionlint) can't infer this is powershell + # so we specify it explicitly + shell: powershell + + - name: List Mac wheels + if: matrix.os != 'windows-latest' + run: cd python/target/wheels/ + + - name: Archive wheels + uses: actions/upload-artifact@v4 + with: + name: dist-${{ matrix.os }} + path: python/target/wheels/* + + build-macos-x86_64: + needs: [generate-license] + name: Mac x86_64 + runs-on: macos-13 + strategy: + fail-fast: false + matrix: + python-version: ["3.10"] + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - uses: actions-rs/toolchain@v1 + with: + toolchain: stable + + - name: Upgrade pip + run: python -m pip install --upgrade pip + + - name: Install maturin + run: pip install maturin==1.5.1 + + - run: rm LICENSE.txt + - name: Download LICENSE.txt + uses: actions/download-artifact@v4 + with: + name: python-wheel-license + path: . + + - name: Install Protoc + uses: arduino/setup-protoc@v3 + with: + version: "27.4" + repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: Cargo build + run: cd python && cargo build + - name: Build Python package + run: cd python && maturin build --release --strip + - name: List Mac wheels + run: cd python/target/wheels/ + + - name: Archive wheels + uses: actions/upload-artifact@v4 + with: + name: dist-macos-aarch64 + path: python/target/wheels/* + + build-manylinux-x86_64: + needs: [generate-license] + name: Manylinux x86_64 + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - run: rm LICENSE.txt + + - name: Download LICENSE.txt + uses: actions/download-artifact@v4 + with: + name: python-wheel-license + path: . + + - run: cat LICENSE.txt + - name: Install Protoc + uses: arduino/setup-protoc@v3 + with: + version: "27.4" + repo-token: ${{ secrets.GITHUB_TOKEN }} + + - name: Upgrade pip + run: python -m pip install --upgrade pip + + - name: Install maturin + run: pip install maturin==1.5.1 + - name: Cargo Build + run: cd python && cargo build + + - name: Build Python package + run: cd python && maturin build --release --strip + - name: Archive wheels + uses: actions/upload-artifact@v4 + with: + name: dist-manylinux-x86_64 + path: python/target/wheels/* + + build-manylinux-aarch64: + needs: [generate-license] + name: Manylinux arm64 + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - run: rm LICENSE.txt + - name: Download LICENSE.txt + uses: actions/download-artifact@v4 + with: + name: python-wheel-license + path: . + + - run: cat LICENSE.txt + + - name: Install Protoc + uses: arduino/setup-protoc@v3 + with: + version: "27.4" + repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: Upgrade pip + run: python -m pip install --upgrade pip + + - name: Install maturin + run: pip install maturin==1.5.1 + - name: Cargo Build + run: cd python && cargo build + + - name: Build Python package + run: cd python && maturin build --release --strip + - name: Archive wheels + uses: actions/upload-artifact@v4 + with: + name: dist-manylinux-aarch64 + path: python/target/wheels/* + + build-sdist: + needs: [generate-license] + name: Source distribution + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - run: rm LICENSE.txt + - name: Download LICENSE.txt + uses: actions/download-artifact@v4 + with: + name: python-wheel-license + path: . + + - run: cat LICENSE.txt + + - name: Install Protoc + uses: arduino/setup-protoc@v3 + with: + version: "27.4" + repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: Upgrade pip + run: python -m pip install --upgrade pip + + - name: Install maturin + run: pip install maturin==1.5.1 + - name: Cargo Build + run: cd python && cargo build + - name: Build Python package + run: cd python && maturin build --release --sdist --out dist --strip + + - name: Assert sdist build does not generate wheels + run: | + if [ "$(ls -A target/wheels)" ]; then + echo "Error: Sdist build generated wheels" + exit 1 + else + echo "Directory is clean" + fi + shell: bash + + merge-build-artifacts: + runs-on: ubuntu-latest + needs: + - build-python-mac-win + - build-macos-x86_64 + - build-manylinux-x86_64 + - build-manylinux-aarch64 + - build-sdist + steps: + - name: Merge Build Artifacts + uses: actions/upload-artifact/merge@v4 + with: + name: dist + pattern: dist-* + + # NOTE: PyPI publish needs to be done manually for now after release passed the vote + # release: + # name: Publish in PyPI + # needs: [build-manylinux, build-python-mac-win] + # runs-on: ubuntu-latest + # steps: + # - uses: actions/download-artifact@v4 + # - name: Publish to PyPI + # uses: pypa/gh-action-pypi-publish@master + # with: + # user: __token__ + # password: ${{ secrets.pypi_password }} diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index cbda2632b..cdd5d30b6 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -117,10 +117,8 @@ jobs: export PATH=$PATH:$HOME/d/protoc/bin export ARROW_TEST_DATA=$(pwd)/testing/data export PARQUET_TEST_DATA=$(pwd)/parquet-testing/data - cargo test - cd examples - cargo run --example standalone_sql --features=ballista/standalone - cd ../python + cargo test --features=testcontainers + cd python cargo test env: CARGO_HOME: "/github/home/.cargo" diff --git a/Cargo.toml b/Cargo.toml index 9aa9d7071..a9e9556fc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ resolver = "2" [workspace.dependencies] arrow = { version = "53", features = ["ipc_compression"] } arrow-flight = { version = "53", features = ["flight-sql-experimental"] } -clap = { version = "3", features = ["derive", "cargo"] } +clap = { version = "4.5", features = ["derive", "cargo"] } configure_me = { version = "0.4.0" } configure_me_codegen = { version = "0.4.4" } datafusion = "43.0.0" @@ -38,9 +38,11 @@ tonic-build = { version = "0.12", default-features = false, features = [ "transport", "prost" ] } -tracing = "0.1.36" +tracing = "0.1" tracing-appender = "0.2.2" -tracing-subscriber = { version = "0.3.15", features = ["env-filter"] } +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +ctor = { version = "0.2" } +mimalloc = { version = "0.1" } tokio = { version = "1" } uuid = { version = "1.10", features = ["v4", "v7"] } @@ -54,7 +56,6 @@ dashmap = { version = "6.1" } async-trait = { version = "0.1.4" } serde = { version = "1.0" } tokio-stream = { version = "0.1" } -parse_arg = { version = "0.1" } url = { version = "2.5" } # cargo build --profile release-lto diff --git a/ballista-cli/Cargo.toml b/ballista-cli/Cargo.toml index f8fc3694d..9b527e56d 100644 --- a/ballista-cli/Cargo.toml +++ b/ballista-cli/Cargo.toml @@ -25,18 +25,16 @@ keywords = ["ballista", "cli"] license = "Apache-2.0" homepage = "https://github.com/apache/arrow-ballista" repository = "https://github.com/apache/arrow-ballista" -rust-version = "1.72" readme = "README.md" [dependencies] ballista = { path = "../ballista/client", version = "0.12.0", features = ["standalone"] } -# datafusion-cli uses 4.5 clap, thus it does not depend on workspace -clap = { version = "4.5", features = ["derive", "cargo"] } +clap = { workspace = true, features = ["derive", "cargo"] } datafusion = { workspace = true } datafusion-cli = { workspace = true } dirs = "5.0.1" env_logger = { workspace = true } -mimalloc = { version = "0.1", default-features = false } +mimalloc = { workspace = true } rustyline = "14.0.0" tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] } diff --git a/ballista/client/Cargo.toml b/ballista/client/Cargo.toml index a5d930307..e462367a6 100644 --- a/ballista/client/Cargo.toml +++ b/ballista/client/Cargo.toml @@ -25,7 +25,6 @@ repository = "https://github.com/apache/arrow-ballista" readme = "README.md" authors = ["Apache DataFusion "] edition = "2021" -rust-version = "1.72" [dependencies] async-trait = { workspace = true } @@ -33,23 +32,21 @@ ballista-core = { path = "../core", version = "0.12.0" } ballista-executor = { path = "../executor", version = "0.12.0", optional = true } ballista-scheduler = { path = "../scheduler", version = "0.12.0", optional = true } datafusion = { workspace = true } -datafusion-proto = { workspace = true } -futures = { workspace = true } log = { workspace = true } -parking_lot = { workspace = true } -tempfile = { workspace = true } + tokio = { workspace = true } url = { workspace = true } [dev-dependencies] ballista-executor = { path = "../executor", version = "0.12.0" } ballista-scheduler = { path = "../scheduler", version = "0.12.0" } -ctor = { version = "0.2" } +ctor = { workspace = true } +datafusion-proto = { workspace = true } env_logger = { workspace = true } -object_store = { workspace = true, features = ["aws"] } -testcontainers-modules = { version = "0.11", features = ["minio"] } +rstest = { version = "0.23" } +tempfile = { workspace = true } +tonic = { workspace = true } [features] default = ["standalone"] standalone = ["ballista-executor", "ballista-scheduler"] -testcontainers = [] diff --git a/ballista/client/src/extension.rs b/ballista/client/src/extension.rs index ff603ea3e..89e1c7ce9 100644 --- a/ballista/client/src/extension.rs +++ b/ballista/client/src/extension.rs @@ -15,10 +15,11 @@ // specific language governing permissions and limitations // under the License. -pub use ballista_core::utils::SessionConfigExt; +use ballista_core::extension::SessionConfigHelperExt; +pub use ballista_core::extension::{SessionConfigExt, SessionStateExt}; use ballista_core::{ serde::protobuf::{scheduler_grpc_client::SchedulerGrpcClient, CreateSessionParams}, - utils::{create_grpc_client_connection, SessionStateExt}, + utils::create_grpc_client_connection, }; use datafusion::{ error::DataFusionError, @@ -122,7 +123,7 @@ impl SessionContextExt for SessionContext { let config = SessionConfig::new_with_ballista(); let scheduler_url = Extension::parse_url(url)?; log::info!( - "Connecting to Ballista scheduler at {}", + "Connecting to Ballista scheduler at: {}", scheduler_url.clone() ); let remote_session_id = @@ -245,10 +246,11 @@ impl Extension { .map_err(|e| DataFusionError::Configuration(e.to_string()))?; } Some(session_state) => { - ballista_executor::new_standalone_executor_from_state::< - datafusion_proto::protobuf::LogicalPlanNode, - datafusion_proto::protobuf::PhysicalPlanNode, - >(scheduler, concurrent_tasks, session_state) + ballista_executor::new_standalone_executor_from_state( + scheduler, + concurrent_tasks, + session_state, + ) .await .map_err(|e| DataFusionError::Configuration(e.to_string()))?; } diff --git a/ballista/client/tests/common/mod.rs b/ballista/client/tests/common/mod.rs index 30b8f9f90..1d2bca94b 100644 --- a/ballista/client/tests/common/mod.rs +++ b/ballista/client/tests/common/mod.rs @@ -19,64 +19,14 @@ use std::env; use std::error::Error; use std::path::PathBuf; -use ballista::prelude::SessionConfigExt; +use ballista::prelude::{SessionConfigExt, SessionContextExt}; use ballista_core::serde::{ protobuf::scheduler_grpc_client::SchedulerGrpcClient, BallistaCodec, }; use ballista_core::{ConfigProducer, RuntimeProducer}; use ballista_scheduler::SessionBuilder; use datafusion::execution::SessionState; -use datafusion::prelude::SessionConfig; -use object_store::aws::AmazonS3Builder; -use testcontainers_modules::minio::MinIO; -use testcontainers_modules::testcontainers::core::{CmdWaitFor, ExecCommand}; -use testcontainers_modules::testcontainers::ContainerRequest; -use testcontainers_modules::{minio, testcontainers::ImageExt}; - -pub const REGION: &str = "eu-west-1"; -pub const BUCKET: &str = "ballista"; -pub const ACCESS_KEY_ID: &str = "MINIO"; -pub const SECRET_KEY: &str = "MINIOMINIO"; - -#[allow(dead_code)] -pub fn create_s3_store( - port: u16, -) -> std::result::Result { - AmazonS3Builder::new() - .with_endpoint(format!("http://localhost:{port}")) - .with_region(REGION) - .with_bucket_name(BUCKET) - .with_access_key_id(ACCESS_KEY_ID) - .with_secret_access_key(SECRET_KEY) - .with_allow_http(true) - .build() -} - -#[allow(dead_code)] -pub fn create_minio_container() -> ContainerRequest { - MinIO::default() - .with_env_var("MINIO_ACCESS_KEY", ACCESS_KEY_ID) - .with_env_var("MINIO_SECRET_KEY", SECRET_KEY) -} - -#[allow(dead_code)] -pub fn create_bucket_command() -> ExecCommand { - // this is hack to create a bucket without creating s3 client. - // this works with current testcontainer (and image) version 'RELEASE.2022-02-07T08-17-33Z'. - // (testcontainer does not await properly on latest image version) - // - // if testcontainer image version change to something newer we should use "mc mb /data/ballista" - // to crate a bucket. - ExecCommand::new(vec![ - "mkdir".to_string(), - format!("/data/{}", crate::common::BUCKET), - ]) - .with_cmd_ready_condition(CmdWaitFor::seconds(1)) -} - -// /// Remote ballista cluster to be used for local testing. -// static BALLISTA_CLUSTER: tokio::sync::OnceCell<(String, u16)> = -// tokio::sync::OnceCell::const_new(); +use datafusion::prelude::{SessionConfig, SessionContext}; /// Returns the parquet test data directory, which is by default /// stored in a git submodule rooted at @@ -161,17 +111,8 @@ pub async fn setup_test_cluster() -> (String, u16) { let host = "localhost".to_string(); - let scheduler_url = format!("http://{}:{}", host, addr.port()); - - let scheduler = loop { - match SchedulerGrpcClient::connect(scheduler_url.clone()).await { - Err(_) => { - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - log::info!("Attempting to connect to test scheduler..."); - } - Ok(scheduler) => break scheduler, - } - }; + let scheduler = + connect_to_scheduler(format!("http://{}:{}", host, addr.port())).await; ballista_executor::new_standalone_executor( scheduler, @@ -190,7 +131,6 @@ pub async fn setup_test_cluster() -> (String, u16) { #[allow(dead_code)] pub async fn setup_test_cluster_with_state(session_state: SessionState) -> (String, u16) { let config = SessionConfig::new_with_ballista(); - //let default_codec = BallistaCodec::default(); let addr = ballista_scheduler::standalone::new_standalone_scheduler_from_state( &session_state, @@ -200,22 +140,10 @@ pub async fn setup_test_cluster_with_state(session_state: SessionState) -> (Stri let host = "localhost".to_string(); - let scheduler_url = format!("http://{}:{}", host, addr.port()); + let scheduler = + connect_to_scheduler(format!("http://{}:{}", host, addr.port())).await; - let scheduler = loop { - match SchedulerGrpcClient::connect(scheduler_url.clone()).await { - Err(_) => { - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - log::info!("Attempting to connect to test scheduler..."); - } - Ok(scheduler) => break scheduler, - } - }; - - ballista_executor::new_standalone_executor_from_state::< - datafusion_proto::protobuf::LogicalPlanNode, - datafusion_proto::protobuf::PhysicalPlanNode, - >( + ballista_executor::new_standalone_executor_from_state( scheduler, config.ballista_standalone_parallelism(), &session_state, @@ -253,22 +181,13 @@ pub async fn setup_test_cluster_with_builders( let host = "localhost".to_string(); - let scheduler_url = format!("http://{}:{}", host, addr.port()); - - let scheduler = loop { - match SchedulerGrpcClient::connect(scheduler_url.clone()).await { - Err(_) => { - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - log::info!("Attempting to connect to test scheduler..."); - } - Ok(scheduler) => break scheduler, - } - }; + let scheduler = + connect_to_scheduler(format!("http://{}:{}", host, addr.port())).await; ballista_executor::new_standalone_executor_from_builder( scheduler, config.ballista_standalone_parallelism(), - config_producer.clone(), + config_producer, runtime_producer, codec, Default::default(), @@ -281,6 +200,40 @@ pub async fn setup_test_cluster_with_builders( (host, addr.port()) } +async fn connect_to_scheduler( + scheduler_url: String, +) -> SchedulerGrpcClient { + let mut retry = 50; + loop { + match SchedulerGrpcClient::connect(scheduler_url.clone()).await { + Err(_) if retry > 0 => { + retry -= 1; + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + log::debug!("Re-attempting to connect to test scheduler..."); + } + + Err(_) => { + log::error!("scheduler connection timed out"); + panic!("scheduler connection timed out") + } + Ok(scheduler) => break scheduler, + } + } +} + +#[allow(dead_code)] +pub async fn standalone_context() -> SessionContext { + SessionContext::standalone().await.unwrap() +} + +#[allow(dead_code)] +pub async fn remote_context() -> SessionContext { + let (host, port) = setup_test_cluster().await; + SessionContext::remote(&format!("df://{host}:{port}")) + .await + .unwrap() +} + #[ctor::ctor] fn init() { // Enable RUST_LOG logging configuration for test diff --git a/ballista/client/tests/context_standalone.rs b/ballista/client/tests/context_basic.rs similarity index 99% rename from ballista/client/tests/context_standalone.rs rename to ballista/client/tests/context_basic.rs index c17b53e59..5c137eecb 100644 --- a/ballista/client/tests/context_standalone.rs +++ b/ballista/client/tests/context_basic.rs @@ -23,7 +23,7 @@ mod common; // #[cfg(test)] #[cfg(feature = "standalone")] -mod standalone_tests { +mod basic { use ballista::prelude::SessionContextExt; use datafusion::arrow; use datafusion::arrow::util::pretty::pretty_format_batches; diff --git a/ballista/client/tests/context_checks.rs b/ballista/client/tests/context_checks.rs new file mode 100644 index 000000000..d3f8fc930 --- /dev/null +++ b/ballista/client/tests/context_checks.rs @@ -0,0 +1,366 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod common; +#[cfg(test)] +mod supported { + + use crate::common::{remote_context, standalone_context}; + use ballista_core::config::BallistaConfig; + use datafusion::prelude::*; + use datafusion::{assert_batches_eq, prelude::SessionContext}; + use rstest::*; + + #[rstest::fixture] + fn test_data() -> String { + crate::common::example_test_data() + } + + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + async fn should_execute_sql_show( + #[future(awt)] + #[case] + ctx: SessionContext, + test_data: String, + ) -> datafusion::error::Result<()> { + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + + let result = ctx + .sql("select string_col, timestamp_col from test where id > 4") + .await? + .collect() + .await?; + let expected = [ + "+------------+---------------------+", + "| string_col | timestamp_col |", + "+------------+---------------------+", + "| 31 | 2009-03-01T00:01:00 |", + "| 30 | 2009-04-01T00:00:00 |", + "| 31 | 2009-04-01T00:01:00 |", + "+------------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + async fn should_execute_sql_show_configs( + #[future(awt)] + #[case] + ctx: SessionContext, + ) -> datafusion::error::Result<()> { + let result = ctx + .sql("select name from information_schema.df_settings where name like 'datafusion.%' order by name limit 5") + .await? + .collect() + .await?; + // + let expected = [ + "+------------------------------------------------------+", + "| name |", + "+------------------------------------------------------+", + "| datafusion.catalog.create_default_catalog_and_schema |", + "| datafusion.catalog.default_catalog |", + "| datafusion.catalog.default_schema |", + "| datafusion.catalog.format |", + "| datafusion.catalog.has_header |", + "+------------------------------------------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + async fn should_execute_sql_show_configs_ballista( + #[future(awt)] + #[case] + ctx: SessionContext, + ) -> datafusion::error::Result<()> { + let state = ctx.state(); + let ballista_config_extension = + state.config().options().extensions.get::(); + + // ballista configuration should be registered with + // session state + assert!(ballista_config_extension.is_some()); + + let result = ctx + .sql("select name, value from information_schema.df_settings where name like 'ballista.%' order by name limit 2") + .await? + .collect() + .await?; + + let expected = [ + "+---------------------------------------+----------+", + "| name | value |", + "+---------------------------------------+----------+", + "| ballista.grpc_client_max_message_size | 16777216 |", + "| ballista.job.name | |", + "+---------------------------------------+----------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + async fn should_execute_sql_set_configs( + #[future(awt)] + #[case] + ctx: SessionContext, + ) -> datafusion::error::Result<()> { + ctx.sql("SET ballista.job.name = 'Super Cool Ballista App'") + .await? + .show() + .await?; + + let result = ctx + .sql("select name, value from information_schema.df_settings where name like 'ballista.job.name' order by name limit 1") + .await? + .collect() + .await?; + + let expected = [ + "+-------------------+-------------------------+", + "| name | value |", + "+-------------------+-------------------------+", + "| ballista.job.name | Super Cool Ballista App |", + "+-------------------+-------------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + + // select from ballista config + // check for SET = + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + async fn should_execute_show_tables( + #[future(awt)] + #[case] + ctx: SessionContext, + test_data: String, + ) -> datafusion::error::Result<()> { + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + + let result = ctx.sql("show tables").await?.collect().await?; + // + let expected = [ + "+---------------+--------------------+-------------+------------+", + "| table_catalog | table_schema | table_name | table_type |", + "+---------------+--------------------+-------------+------------+", + "| datafusion | public | test | BASE TABLE |", + "| datafusion | information_schema | tables | VIEW |", + "| datafusion | information_schema | views | VIEW |", + "| datafusion | information_schema | columns | VIEW |", + "| datafusion | information_schema | df_settings | VIEW |", + "| datafusion | information_schema | schemata | VIEW |", + "+---------------+--------------------+-------------+------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + async fn should_execute_sql_create_external_table( + #[future(awt)] + #[case] + ctx: SessionContext, + test_data: String, + ) -> datafusion::error::Result<()> { + ctx.sql(&format!("CREATE EXTERNAL TABLE tbl_test STORED AS PARQUET LOCATION '{}/alltypes_plain.parquet'", test_data, )).await?.show().await?; + + let result = ctx + .sql("select id, string_col, timestamp_col from tbl_test where id > 4") + .await? + .collect() + .await?; + let expected = [ + "+----+------------+---------------------+", + "| id | string_col | timestamp_col |", + "+----+------------+---------------------+", + "| 5 | 31 | 2009-03-01T00:01:00 |", + "| 6 | 30 | 2009-04-01T00:00:00 |", + "| 7 | 31 | 2009-04-01T00:01:00 |", + "+----+------------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + async fn should_execute_dataframe( + #[future(awt)] + #[case] + ctx: SessionContext, + test_data: String, + ) -> datafusion::error::Result<()> { + let df = ctx + .read_parquet( + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await? + .select_columns(&["id", "bool_col", "timestamp_col"])? + .filter(col("id").gt(lit(5)))?; + + let result = df.collect().await?; + + let expected = [ + "+----+----------+---------------------+", + "| id | bool_col | timestamp_col |", + "+----+----------+---------------------+", + "| 6 | true | 2009-04-01T00:00:00 |", + "| 7 | false | 2009-04-01T00:01:00 |", + "+----+----------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + #[cfg(not(windows))] // test is failing at windows, can't debug it + async fn should_execute_sql_write( + #[future(awt)] + #[case] + ctx: SessionContext, + test_data: String, + ) -> datafusion::error::Result<()> { + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + let write_dir = tempfile::tempdir().expect("temporary directory to be created"); + let write_dir_path = write_dir + .path() + .to_str() + .expect("path to be converted to str"); + + ctx.sql("select * from test") + .await? + .write_parquet(write_dir_path, Default::default(), Default::default()) + .await?; + ctx.register_parquet("written_table", write_dir_path, Default::default()) + .await?; + + let result = ctx + .sql("select id, string_col, timestamp_col from written_table where id > 4") + .await? + .collect() + .await?; + let expected = [ + "+----+------------+---------------------+", + "| id | string_col | timestamp_col |", + "+----+------------+---------------------+", + "| 5 | 31 | 2009-03-01T00:01:00 |", + "| 6 | 30 | 2009-04-01T00:00:00 |", + "| 7 | 31 | 2009-04-01T00:01:00 |", + "+----+------------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + Ok(()) + } + + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + async fn should_execute_sql_app_name_show( + #[future(awt)] + #[case] + ctx: SessionContext, + test_data: String, + ) -> datafusion::error::Result<()> { + ctx.sql("SET ballista.job.name = 'Super Cool Ballista App'") + .await? + .show() + .await?; + + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + + let result = ctx + .sql("select string_col, timestamp_col from test where id > 4") + .await? + .collect() + .await?; + let expected = [ + "+------------+---------------------+", + "| string_col | timestamp_col |", + "+------------+---------------------+", + "| 31 | 2009-03-01T00:01:00 |", + "| 30 | 2009-04-01T00:00:00 |", + "| 31 | 2009-04-01T00:01:00 |", + "+------------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } +} diff --git a/ballista/client/tests/setup.rs b/ballista/client/tests/context_setup.rs similarity index 100% rename from ballista/client/tests/setup.rs rename to ballista/client/tests/context_setup.rs diff --git a/ballista/client/tests/context_unsupported.rs b/ballista/client/tests/context_unsupported.rs new file mode 100644 index 000000000..805e81325 --- /dev/null +++ b/ballista/client/tests/context_unsupported.rs @@ -0,0 +1,214 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod common; + +/// # Tracking Unsupported Operations +/// +/// It provides indication if/when datafusion +/// gets support for them +#[cfg(test)] +mod unsupported { + use crate::common::{remote_context, standalone_context}; + use datafusion::prelude::*; + use datafusion::{assert_batches_eq, prelude::SessionContext}; + use rstest::*; + + #[rstest::fixture] + fn test_data() -> String { + crate::common::example_test_data() + } + + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + #[should_panic] + async fn should_execute_explain_query_correctly( + #[future(awt)] + #[case] + ctx: SessionContext, + test_data: String, + ) { + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await + .unwrap(); + + let result = ctx + .sql("EXPLAIN select count(*), id from test where id > 4 group by id") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let expected = vec![ + "+---------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+", + "| plan_type | plan |", + "+---------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+", + "| logical_plan | Projection: count(*), test.id |", + "| | Aggregate: groupBy=[[test.id]], aggr=[[count(Int64(1)) AS count(*)]] |", + "| | Filter: test.id > Int32(4) |", + "| | TableScan: test projection=[id], partial_filters=[test.id > Int32(4)] |", + "| physical_plan | ProjectionExec: expr=[count(*)@1 as count(*), id@0 as id] |", + "| | AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[count(*)] |", + "| | CoalesceBatchesExec: target_batch_size=8192 |", + "| | RepartitionExec: partitioning=Hash([id@0], 16), input_partitions=1 |", + "| | AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[count(*)] |", + "| | CoalesceBatchesExec: target_batch_size=8192 |", + "| | FilterExec: id@0 > 4 |", + "| | ParquetExec: file_groups={1 group: [[Users/ballista/git/arrow-ballista/ballista/client/testdata/alltypes_plain.parquet]]}, projection=[id], predicate=id@0 > 4, pruning_predicate=CASE WHEN id_null_count@1 = id_row_count@2 THEN false ELSE id_max@0 > 4 END, required_guarantees=[] |", + "| | |", + "+---------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+", + ]; + + assert_batches_eq!(expected, &result); + } + + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + #[should_panic] + async fn should_support_sql_create_table( + #[future(awt)] + #[case] + ctx: SessionContext, + ) { + ctx.sql("CREATE TABLE tbl_test (id INT, value INT)") + .await + .unwrap() + .show() + .await + .unwrap(); + + // it does create table but it can't be queried + let _result = ctx + .sql("select * from tbl_test where id > 0") + .await + .unwrap() + .collect() + .await + .unwrap(); + } + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + #[should_panic] + async fn should_support_caching_data_frame( + #[future(awt)] + #[case] + ctx: SessionContext, + test_data: String, + ) { + let df = ctx + .read_parquet( + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await + .unwrap() + .select_columns(&["id", "bool_col", "timestamp_col"]) + .unwrap() + .filter(col("id").gt(lit(5))) + .unwrap(); + + let cached_df = df.cache().await.unwrap(); + let result = cached_df.collect().await.unwrap(); + + let expected = [ + "+----+----------+---------------------+", + "| id | bool_col | timestamp_col |", + "+----+----------+---------------------+", + "| 6 | true | 2009-04-01T00:00:00 |", + "| 7 | false | 2009-04-01T00:01:00 |", + "+----+----------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + } + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + #[should_panic] + // "Error: Internal(failed to serialize logical plan: Internal(LogicalPlan serde is not yet implemented for Dml))" + async fn should_support_sql_insert_into( + #[future(awt)] + #[case] + ctx: SessionContext, + test_data: String, + ) { + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await + .unwrap(); + let write_dir = tempfile::tempdir().expect("temporary directory to be created"); + let write_dir_path = write_dir + .path() + .to_str() + .expect("path to be converted to str"); + + ctx.sql("select * from test") + .await + .unwrap() + .write_parquet(write_dir_path, Default::default(), Default::default()) + .await + .unwrap(); + + ctx.register_parquet("written_table", write_dir_path, Default::default()) + .await + .unwrap(); + + let _ = ctx + .sql("INSERT INTO written_table select * from written_table") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let result = ctx + .sql("select id, string_col, timestamp_col from written_table where id > 4 order by id") + .await.unwrap() + .collect() + .await.unwrap(); + + let expected = [ + "+----+------------+---------------------+", + "| id | string_col | timestamp_col |", + "+----+------------+---------------------+", + "| 5 | 31 | 2009-03-01T00:01:00 |", + "| 5 | 31 | 2009-03-01T00:01:00 |", + "| 6 | 30 | 2009-04-01T00:00:00 |", + "| 6 | 30 | 2009-04-01T00:00:00 |", + "| 7 | 31 | 2009-04-01T00:01:00 |", + "| 7 | 31 | 2009-04-01T00:01:00 |", + "+----+------------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + } +} diff --git a/ballista/client/tests/remote.rs b/ballista/client/tests/remote.rs deleted file mode 100644 index c03db8524..000000000 --- a/ballista/client/tests/remote.rs +++ /dev/null @@ -1,185 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -mod common; - -#[cfg(test)] -mod remote { - use ballista::prelude::SessionContextExt; - use datafusion::{assert_batches_eq, prelude::SessionContext}; - - #[tokio::test] - async fn should_execute_sql_show() -> datafusion::error::Result<()> { - let (host, port) = crate::common::setup_test_cluster().await; - let url = format!("df://{host}:{port}"); - - let test_data = crate::common::example_test_data(); - let ctx: SessionContext = SessionContext::remote(&url).await?; - - ctx.register_parquet( - "test", - &format!("{test_data}/alltypes_plain.parquet"), - Default::default(), - ) - .await?; - - let result = ctx - .sql("select string_col, timestamp_col from test where id > 4") - .await? - .collect() - .await?; - let expected = [ - "+------------+---------------------+", - "| string_col | timestamp_col |", - "+------------+---------------------+", - "| 31 | 2009-03-01T00:01:00 |", - "| 30 | 2009-04-01T00:00:00 |", - "| 31 | 2009-04-01T00:01:00 |", - "+------------+---------------------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } - - #[tokio::test] - #[cfg(not(windows))] // test is failing at windows, can't debug it - async fn should_execute_sql_write() -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); - let (host, port) = crate::common::setup_test_cluster().await; - let url = format!("df://{host}:{port}"); - - let ctx: SessionContext = SessionContext::remote(&url).await?; - - ctx.register_parquet( - "test", - &format!("{test_data}/alltypes_plain.parquet"), - Default::default(), - ) - .await?; - let write_dir = tempfile::tempdir().expect("temporary directory to be created"); - let write_dir_path = write_dir - .path() - .to_str() - .expect("path to be converted to str"); - - log::info!("writing to parquet .. {}", write_dir_path); - ctx.sql("select * from test") - .await? - .write_parquet(write_dir_path, Default::default(), Default::default()) - .await?; - - log::info!("registering parquet .. {}", write_dir_path); - ctx.register_parquet("written_table", write_dir_path, Default::default()) - .await?; - log::info!("reading from written parquet .."); - let result = ctx - .sql("select id, string_col, timestamp_col from written_table where id > 4") - .await? - .collect() - .await?; - let expected = [ - "+----+------------+---------------------+", - "| id | string_col | timestamp_col |", - "+----+------------+---------------------+", - "| 5 | 31 | 2009-03-01T00:01:00 |", - "| 6 | 30 | 2009-04-01T00:00:00 |", - "| 7 | 31 | 2009-04-01T00:01:00 |", - "+----+------------+---------------------+", - ]; - log::info!("reading from written parquet .. DONE"); - assert_batches_eq!(expected, &result); - Ok(()) - } - - #[tokio::test] - async fn should_execute_show_tables() -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); - - let (host, port) = crate::common::setup_test_cluster().await; - let url = format!("df://{host}:{port}"); - - let ctx: SessionContext = SessionContext::remote(&url).await?; - - ctx.register_parquet( - "test", - &format!("{test_data}/alltypes_plain.parquet"), - Default::default(), - ) - .await?; - - let result = ctx.sql("show tables").await?.collect().await?; - // - let expected = [ - "+---------------+--------------------+-------------+------------+", - "| table_catalog | table_schema | table_name | table_type |", - "+---------------+--------------------+-------------+------------+", - "| datafusion | public | test | BASE TABLE |", - "| datafusion | information_schema | tables | VIEW |", - "| datafusion | information_schema | views | VIEW |", - "| datafusion | information_schema | columns | VIEW |", - "| datafusion | information_schema | df_settings | VIEW |", - "| datafusion | information_schema | schemata | VIEW |", - "+---------------+--------------------+-------------+------------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } - - #[tokio::test] - async fn should_execute_sql_app_name_show() -> datafusion::error::Result<()> { - let (host, port) = crate::common::setup_test_cluster().await; - let url = format!("df://{host}:{port}"); - - let test_data = crate::common::example_test_data(); - let ctx: SessionContext = SessionContext::remote(&url).await?; - - ctx.sql("SET ballista.job.name = 'Super Cool Ballista App'") - .await? - .show() - .await?; - - ctx.register_parquet( - "test", - &format!("{test_data}/alltypes_plain.parquet"), - Default::default(), - ) - .await?; - - let result = ctx - .sql("select string_col, timestamp_col from test where id > 4") - .await? - .collect() - .await?; - let expected = [ - "+------------+---------------------+", - "| string_col | timestamp_col |", - "+------------+---------------------+", - "| 31 | 2009-03-01T00:01:00 |", - "| 30 | 2009-04-01T00:00:00 |", - "| 31 | 2009-04-01T00:01:00 |", - "+------------+---------------------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } -} diff --git a/ballista/client/tests/standalone.rs b/ballista/client/tests/standalone.rs deleted file mode 100644 index fe9f3df1a..000000000 --- a/ballista/client/tests/standalone.rs +++ /dev/null @@ -1,479 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -mod common; - -#[cfg(test)] -#[cfg(feature = "standalone")] -mod standalone { - use ballista::prelude::SessionContextExt; - use ballista_core::config::BallistaConfig; - use datafusion::prelude::*; - use datafusion::{assert_batches_eq, prelude::SessionContext}; - - #[tokio::test] - async fn should_execute_sql_show() -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); - - let ctx: SessionContext = SessionContext::standalone().await?; - ctx.register_parquet( - "test", - &format!("{test_data}/alltypes_plain.parquet"), - Default::default(), - ) - .await?; - - let result = ctx - .sql("select string_col, timestamp_col from test where id > 4") - .await? - .collect() - .await?; - let expected = [ - "+------------+---------------------+", - "| string_col | timestamp_col |", - "+------------+---------------------+", - "| 31 | 2009-03-01T00:01:00 |", - "| 30 | 2009-04-01T00:00:00 |", - "| 31 | 2009-04-01T00:01:00 |", - "+------------+---------------------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } - - #[tokio::test] - async fn should_execute_sql_show_configs() -> datafusion::error::Result<()> { - let ctx: SessionContext = SessionContext::standalone().await?; - - let result = ctx - .sql("select name from information_schema.df_settings where name like 'datafusion.%' order by name limit 5") - .await? - .collect() - .await?; - // - let expected = [ - "+------------------------------------------------------+", - "| name |", - "+------------------------------------------------------+", - "| datafusion.catalog.create_default_catalog_and_schema |", - "| datafusion.catalog.default_catalog |", - "| datafusion.catalog.default_schema |", - "| datafusion.catalog.format |", - "| datafusion.catalog.has_header |", - "+------------------------------------------------------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } - - #[tokio::test] - async fn should_execute_sql_show_configs_ballista() -> datafusion::error::Result<()> { - let ctx: SessionContext = SessionContext::standalone().await?; - let state = ctx.state(); - let ballista_config_extension = - state.config().options().extensions.get::(); - - // ballista configuration should be registered with - // session state - assert!(ballista_config_extension.is_some()); - - let result = ctx - .sql("select name, value from information_schema.df_settings where name like 'ballista.%' order by name limit 2") - .await? - .collect() - .await?; - - let expected = [ - "+---------------------------------------+----------+", - "| name | value |", - "+---------------------------------------+----------+", - "| ballista.grpc_client_max_message_size | 16777216 |", - "| ballista.job.name | |", - "+---------------------------------------+----------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } - - #[tokio::test] - async fn should_execute_sql_set_configs() -> datafusion::error::Result<()> { - let ctx: SessionContext = SessionContext::standalone().await?; - - ctx.sql("SET ballista.job.name = 'Super Cool Ballista App'") - .await? - .show() - .await?; - - let result = ctx - .sql("select name, value from information_schema.df_settings where name like 'ballista.job.name' order by name limit 1") - .await? - .collect() - .await?; - - let expected = [ - "+-------------------+-------------------------+", - "| name | value |", - "+-------------------+-------------------------+", - "| ballista.job.name | Super Cool Ballista App |", - "+-------------------+-------------------------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } - - // select from ballista config - // check for SET = - - #[tokio::test] - async fn should_execute_show_tables() -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); - - let ctx: SessionContext = SessionContext::standalone().await?; - ctx.register_parquet( - "test", - &format!("{test_data}/alltypes_plain.parquet"), - Default::default(), - ) - .await?; - - let result = ctx.sql("show tables").await?.collect().await?; - // - let expected = [ - "+---------------+--------------------+-------------+------------+", - "| table_catalog | table_schema | table_name | table_type |", - "+---------------+--------------------+-------------+------------+", - "| datafusion | public | test | BASE TABLE |", - "| datafusion | information_schema | tables | VIEW |", - "| datafusion | information_schema | views | VIEW |", - "| datafusion | information_schema | columns | VIEW |", - "| datafusion | information_schema | df_settings | VIEW |", - "| datafusion | information_schema | schemata | VIEW |", - "+---------------+--------------------+-------------+------------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } - - // - // TODO: It calls scheduler to generate the plan, but no - // but there is no ShuffleRead/Write in physical_plan - // - // ShuffleWriterExec: None, metrics=[output_rows=2, input_rows=2, write_time=1.782295ms, repart_time=1ns] - // ExplainExec, metrics=[] - // - #[tokio::test] - #[ignore = "It uses local files, will fail in CI"] - async fn should_execute_sql_explain() -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); - - let ctx: SessionContext = SessionContext::standalone().await?; - ctx.register_parquet( - "test", - &format!("{test_data}/alltypes_plain.parquet"), - Default::default(), - ) - .await?; - - let result = ctx - .sql("EXPLAIN select count(*), id from test where id > 4 group by id") - .await? - .collect() - .await?; - - let expected = vec![ - "+---------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+", - "| plan_type | plan |", - "+---------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+", - "| logical_plan | Projection: count(*), test.id |", - "| | Aggregate: groupBy=[[test.id]], aggr=[[count(Int64(1)) AS count(*)]] |", - "| | Filter: test.id > Int32(4) |", - "| | TableScan: test projection=[id], partial_filters=[test.id > Int32(4)] |", - "| physical_plan | ProjectionExec: expr=[count(*)@1 as count(*), id@0 as id] |", - "| | AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[count(*)] |", - "| | CoalesceBatchesExec: target_batch_size=8192 |", - "| | RepartitionExec: partitioning=Hash([id@0], 16), input_partitions=1 |", - "| | AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[count(*)] |", - "| | CoalesceBatchesExec: target_batch_size=8192 |", - "| | FilterExec: id@0 > 4 |", - "| | ParquetExec: file_groups={1 group: [[Users/ballista/git/arrow-ballista/ballista/client/testdata/alltypes_plain.parquet]]}, projection=[id], predicate=id@0 > 4, pruning_predicate=CASE WHEN id_null_count@1 = id_row_count@2 THEN false ELSE id_max@0 > 4 END, required_guarantees=[] |", - "| | |", - "+---------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } - - #[tokio::test] - async fn should_execute_sql_create_external_table() -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); - - let ctx: SessionContext = SessionContext::standalone().await?; - ctx.sql(&format!("CREATE EXTERNAL TABLE tbl_test STORED AS PARQUET LOCATION '{}/alltypes_plain.parquet'", test_data, )).await?.show().await?; - - let result = ctx - .sql("select id, string_col, timestamp_col from tbl_test where id > 4") - .await? - .collect() - .await?; - let expected = [ - "+----+------------+---------------------+", - "| id | string_col | timestamp_col |", - "+----+------------+---------------------+", - "| 5 | 31 | 2009-03-01T00:01:00 |", - "| 6 | 30 | 2009-04-01T00:00:00 |", - "| 7 | 31 | 2009-04-01T00:01:00 |", - "+----+------------+---------------------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } - - #[tokio::test] - #[ignore = "Error serializing custom table - NotImplemented(LogicalExtensionCodec is not provided))"] - async fn should_execute_sql_create_table() -> datafusion::error::Result<()> { - let ctx: SessionContext = SessionContext::standalone().await?; - ctx.sql("CREATE TABLE tbl_test (id INT, value INT)") - .await? - .show() - .await?; - - // it does create table but it can't be queried - let _result = ctx - .sql("select * from tbl_test where id > 0") - .await? - .collect() - .await?; - - Ok(()) - } - - #[tokio::test] - async fn should_execute_dataframe() -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); - - let ctx: SessionContext = SessionContext::standalone().await?; - - let df = ctx - .read_parquet( - &format!("{test_data}/alltypes_plain.parquet"), - Default::default(), - ) - .await? - .select_columns(&["id", "bool_col", "timestamp_col"])? - .filter(col("id").gt(lit(5)))?; - - let result = df.collect().await?; - - let expected = [ - "+----+----------+---------------------+", - "| id | bool_col | timestamp_col |", - "+----+----------+---------------------+", - "| 6 | true | 2009-04-01T00:00:00 |", - "| 7 | false | 2009-04-01T00:01:00 |", - "+----+----------+---------------------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } - - #[tokio::test] - #[ignore = "Error serializing custom table - NotImplemented(LogicalExtensionCodec is not provided))"] - async fn should_execute_dataframe_cache() -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); - - let ctx: SessionContext = SessionContext::standalone().await?; - - let df = ctx - .read_parquet( - &format!("{test_data}/alltypes_plain.parquet"), - Default::default(), - ) - .await? - .select_columns(&["id", "bool_col", "timestamp_col"])? - .filter(col("id").gt(lit(5)))?; - - let cached_df = df.cache().await?; - let result = cached_df.collect().await?; - - let expected = [ - "+----+----------+---------------------+", - "| id | bool_col | timestamp_col |", - "+----+----------+---------------------+", - "| 6 | true | 2009-04-01T00:00:00 |", - "| 7 | false | 2009-04-01T00:01:00 |", - "+----+----------+---------------------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } - - #[tokio::test] - #[ignore = "Error: Internal(failed to serialize logical plan: Internal(LogicalPlan serde is not yet implemented for Dml))"] - async fn should_execute_sql_insert() -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); - - let ctx: SessionContext = SessionContext::standalone().await?; - - ctx.register_parquet( - "test", - &format!("{test_data}/alltypes_plain.parquet"), - Default::default(), - ) - .await?; - let write_dir = tempfile::tempdir().expect("temporary directory to be created"); - let write_dir_path = write_dir - .path() - .to_str() - .expect("path to be converted to str"); - - ctx.sql("select * from test") - .await? - .write_parquet(write_dir_path, Default::default(), Default::default()) - .await?; - - ctx.register_parquet("written_table", write_dir_path, Default::default()) - .await?; - - let _ = ctx - .sql("INSERT INTO written_table select * from written_table") - .await? - .collect() - .await?; - - let result = ctx - .sql("select id, string_col, timestamp_col from written_table where id > 4 order by id") - .await? - .collect() - .await?; - - let expected = [ - "+----+------------+---------------------+", - "| id | string_col | timestamp_col |", - "+----+------------+---------------------+", - "| 5 | 31 | 2009-03-01T00:01:00 |", - "| 5 | 31 | 2009-03-01T00:01:00 |", - "| 6 | 30 | 2009-04-01T00:00:00 |", - "| 6 | 30 | 2009-04-01T00:00:00 |", - "| 7 | 31 | 2009-04-01T00:01:00 |", - "| 7 | 31 | 2009-04-01T00:01:00 |", - "+----+------------+---------------------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } - - #[tokio::test] - #[cfg(not(windows))] // test is failing at windows, can't debug it - async fn should_execute_sql_write() -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); - - let ctx: SessionContext = SessionContext::standalone().await?; - ctx.register_parquet( - "test", - &format!("{test_data}/alltypes_plain.parquet"), - Default::default(), - ) - .await?; - let write_dir = tempfile::tempdir().expect("temporary directory to be created"); - let write_dir_path = write_dir - .path() - .to_str() - .expect("path to be converted to str"); - - ctx.sql("select * from test") - .await? - .write_parquet(write_dir_path, Default::default(), Default::default()) - .await?; - ctx.register_parquet("written_table", write_dir_path, Default::default()) - .await?; - - let result = ctx - .sql("select id, string_col, timestamp_col from written_table where id > 4") - .await? - .collect() - .await?; - let expected = [ - "+----+------------+---------------------+", - "| id | string_col | timestamp_col |", - "+----+------------+---------------------+", - "| 5 | 31 | 2009-03-01T00:01:00 |", - "| 6 | 30 | 2009-04-01T00:00:00 |", - "| 7 | 31 | 2009-04-01T00:01:00 |", - "+----+------------+---------------------+", - ]; - - assert_batches_eq!(expected, &result); - Ok(()) - } - - #[tokio::test] - async fn should_execute_sql_app_name_show() -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); - let ctx: SessionContext = SessionContext::standalone().await?; - - ctx.sql("SET ballista.job.name = 'Super Cool Ballista App'") - .await? - .show() - .await?; - - ctx.register_parquet( - "test", - &format!("{test_data}/alltypes_plain.parquet"), - Default::default(), - ) - .await?; - - let result = ctx - .sql("select string_col, timestamp_col from test where id > 4") - .await? - .collect() - .await?; - let expected = [ - "+------------+---------------------+", - "| string_col | timestamp_col |", - "+------------+---------------------+", - "| 31 | 2009-03-01T00:01:00 |", - "| 30 | 2009-04-01T00:00:00 |", - "| 31 | 2009-04-01T00:01:00 |", - "+------------+---------------------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } -} diff --git a/ballista/core/Cargo.toml b/ballista/core/Cargo.toml index 80a3d1028..1bf888582 100644 --- a/ballista/core/Cargo.toml +++ b/ballista/core/Cargo.toml @@ -34,25 +34,24 @@ exclude = ["*.proto"] rustc-args = ["--cfg", "docsrs"] [features] +build-binary = ["configure_me", "clap"] docsrs = [] # Used for testing ONLY: causes all values to hash to the same value (test for collisions) force_hash_collisions = ["datafusion/force_hash_collisions"] - [dependencies] arrow-flight = { workspace = true } async-trait = { workspace = true } chrono = { version = "0.4", default-features = false } -clap = { workspace = true } +clap = { workspace = true, optional = true } +configure_me = { workspace = true, optional = true } datafusion = { workspace = true } datafusion-proto = { workspace = true } datafusion-proto-common = { workspace = true } futures = { workspace = true } - itertools = "0.13" log = { workspace = true } md-5 = { version = "^0.10.0" } -parse_arg = { workspace = true } prost = { workspace = true } prost-types = { workspace = true } rand = { workspace = true } @@ -66,5 +65,5 @@ url = { workspace = true } tempfile = { workspace = true } [build-dependencies] -rustc_version = "0.4.0" +rustc_version = "0.4.1" tonic-build = { workspace = true } diff --git a/ballista/core/proto/ballista.proto b/ballista/core/proto/ballista.proto index a40e6f2d2..cb3c148b4 100644 --- a/ballista/core/proto/ballista.proto +++ b/ballista/core/proto/ballista.proto @@ -172,7 +172,7 @@ message TaskInputPartitions { message KeyValuePair { string key = 1; - string value = 2; + optional string value = 2; } message Action { @@ -458,10 +458,6 @@ message MultiTaskDefinition { repeated KeyValuePair props = 9; } -message SessionSettings { - repeated KeyValuePair configs = 1; -} - message JobSessionConfig { string session_id = 1; repeated KeyValuePair configs = 2; @@ -526,9 +522,8 @@ message ExecuteQueryParams { bytes logical_plan = 1; string sql = 2 [deprecated=true]; // I'd suggest to remove this, if SQL needed use `flight-sql` } - oneof optional_session_id { - string session_id = 3; - } + + optional string session_id = 3; repeated KeyValuePair settings = 4; } diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs index 1ddd952be..cb7f7c5d7 100644 --- a/ballista/core/src/config.rs +++ b/ballista/core/src/config.rs @@ -18,8 +18,6 @@ //! Ballista configuration -use clap::ValueEnum; -use core::fmt; use std::collections::HashMap; use std::result; @@ -252,71 +250,57 @@ impl datafusion::config::ConfigExtension for BallistaConfig { // an enum used to configure the scheduler policy // needs to be visible to code generated by configure_me -#[derive(Clone, ValueEnum, Copy, Debug, serde::Deserialize)] +#[derive(Clone, Copy, Debug, serde::Deserialize, Default)] +#[cfg_attr(feature = "build-binary", derive(clap::ValueEnum))] pub enum TaskSchedulingPolicy { + #[default] PullStaged, PushStaged, } +#[cfg(feature = "build-binary")] impl std::str::FromStr for TaskSchedulingPolicy { type Err = String; fn from_str(s: &str) -> std::result::Result { - ValueEnum::from_str(s, true) + clap::ValueEnum::from_str(s, true) } } - -impl parse_arg::ParseArgFromStr for TaskSchedulingPolicy { - fn describe_type(mut writer: W) -> fmt::Result { +#[cfg(feature = "build-binary")] +impl configure_me::parse_arg::ParseArgFromStr for TaskSchedulingPolicy { + fn describe_type(mut writer: W) -> core::fmt::Result { write!(writer, "The scheduler policy for the scheduler") } } // an enum used to configure the log rolling policy // needs to be visible to code generated by configure_me -#[derive(Clone, ValueEnum, Copy, Debug, serde::Deserialize)] +#[derive(Clone, Copy, Debug, serde::Deserialize, Default)] +#[cfg_attr(feature = "build-binary", derive(clap::ValueEnum))] pub enum LogRotationPolicy { Minutely, Hourly, Daily, + #[default] Never, } +#[cfg(feature = "build-binary")] impl std::str::FromStr for LogRotationPolicy { type Err = String; fn from_str(s: &str) -> std::result::Result { - ValueEnum::from_str(s, true) + clap::ValueEnum::from_str(s, true) } } -impl parse_arg::ParseArgFromStr for LogRotationPolicy { - fn describe_type(mut writer: W) -> fmt::Result { +#[cfg(feature = "build-binary")] +impl configure_me::parse_arg::ParseArgFromStr for LogRotationPolicy { + fn describe_type(mut writer: W) -> core::fmt::Result { write!(writer, "The log rotation policy") } } -// an enum used to configure the source data cache policy -// needs to be visible to code generated by configure_me -#[derive(Clone, ValueEnum, Copy, Debug, serde::Deserialize)] -pub enum DataCachePolicy { - LocalDiskFile, -} - -impl std::str::FromStr for DataCachePolicy { - type Err = String; - - fn from_str(s: &str) -> std::result::Result { - ValueEnum::from_str(s, true) - } -} - -impl parse_arg::ParseArgFromStr for DataCachePolicy { - fn describe_type(mut writer: W) -> fmt::Result { - write!(writer, "The data cache policy") - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/ballista/core/src/diagram.rs b/ballista/core/src/diagram.rs new file mode 100644 index 000000000..9ef0da981 --- /dev/null +++ b/ballista/core/src/diagram.rs @@ -0,0 +1,148 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::error::Result; +use crate::execution_plans::{ShuffleWriterExec, UnresolvedShuffleExec}; + +use datafusion::datasource::physical_plan::{CsvExec, ParquetExec}; +use datafusion::physical_plan::aggregates::AggregateExec; +use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; +use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion::physical_plan::filter::FilterExec; +use datafusion::physical_plan::joins::HashJoinExec; +use datafusion::physical_plan::projection::ProjectionExec; +use datafusion::physical_plan::sorts::sort::SortExec; +use datafusion::physical_plan::ExecutionPlan; +use std::fs::File; +use std::io::{BufWriter, Write}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +pub fn produce_diagram(filename: &str, stages: &[Arc]) -> Result<()> { + let write_file = File::create(filename)?; + let mut w = BufWriter::new(&write_file); + writeln!(w, "digraph G {{")?; + + // draw stages and entities + for stage in stages { + writeln!(w, "\tsubgraph cluster{} {{", stage.stage_id())?; + writeln!(w, "\t\tlabel = \"Stage {}\";", stage.stage_id())?; + let mut id = AtomicUsize::new(0); + build_exec_plan_diagram( + &mut w, + stage.children()[0].as_ref(), + stage.stage_id(), + &mut id, + true, + )?; + writeln!(w, "\t}}")?; + } + + // draw relationships + for stage in stages { + let mut id = AtomicUsize::new(0); + build_exec_plan_diagram( + &mut w, + stage.children()[0].as_ref(), + stage.stage_id(), + &mut id, + false, + )?; + } + + write!(w, "}}")?; + Ok(()) +} + +fn build_exec_plan_diagram( + w: &mut BufWriter<&File>, + plan: &dyn ExecutionPlan, + stage_id: usize, + id: &mut AtomicUsize, + draw_entity: bool, +) -> Result { + let operator_str = if plan.as_any().downcast_ref::().is_some() { + "AggregateExec" + } else if plan.as_any().downcast_ref::().is_some() { + "SortExec" + } else if plan.as_any().downcast_ref::().is_some() { + "ProjectionExec" + } else if plan.as_any().downcast_ref::().is_some() { + "HashJoinExec" + } else if plan.as_any().downcast_ref::().is_some() { + "ParquetExec" + } else if plan.as_any().downcast_ref::().is_some() { + "CsvExec" + } else if plan.as_any().downcast_ref::().is_some() { + "FilterExec" + } else if plan.as_any().downcast_ref::().is_some() { + "ShuffleWriterExec" + } else if plan + .as_any() + .downcast_ref::() + .is_some() + { + "UnresolvedShuffleExec" + } else if plan + .as_any() + .downcast_ref::() + .is_some() + { + "CoalesceBatchesExec" + } else if plan + .as_any() + .downcast_ref::() + .is_some() + { + "CoalescePartitionsExec" + } else { + println!("Unknown: {plan:?}"); + "Unknown" + }; + + let node_id = id.load(Ordering::SeqCst); + id.store(node_id + 1, Ordering::SeqCst); + + if draw_entity { + writeln!( + w, + "\t\tstage_{stage_id}_exec_{node_id} [shape=box, label=\"{operator_str}\"];" + )?; + } + for child in plan.children() { + if let Some(shuffle) = child.as_any().downcast_ref::() { + if !draw_entity { + writeln!( + w, + "\tstage_{}_exec_1 -> stage_{}_exec_{};", + shuffle.stage_id, stage_id, node_id + )?; + } + } else { + // relationships within same entity + let child_id = + build_exec_plan_diagram(w, child.as_ref(), stage_id, id, draw_entity)?; + if draw_entity { + writeln!( + w, + "\t\tstage_{stage_id}_exec_{child_id} -> stage_{stage_id}_exec_{node_id};" + )?; + } + } + } + Ok(node_id) +} diff --git a/ballista/core/src/error.rs b/ballista/core/src/error.rs index cbdd90a71..05a706cce 100644 --- a/ballista/core/src/error.rs +++ b/ballista/core/src/error.rs @@ -37,15 +37,11 @@ pub enum BallistaError { NotImplemented(String), General(String), Internal(String), + Configuration(String), ArrowError(ArrowError), DataFusionError(DataFusionError), SqlError(parser::ParserError), IoError(io::Error), - // ReqwestError(reqwest::Error), - // HttpError(http::Error), - // KubeAPIError(kube::error::Error), - // KubeAPIRequestError(k8s_openapi::RequestError), - // KubeAPIResponseError(k8s_openapi::ResponseError), TonicError(tonic::transport::Error), GrpcError(tonic::Status), GrpcConnectionError(String), @@ -112,36 +108,6 @@ impl From for BallistaError { } } -// impl From for BallistaError { -// fn from(e: reqwest::Error) -> Self { -// BallistaError::ReqwestError(e) -// } -// } -// -// impl From for BallistaError { -// fn from(e: http::Error) -> Self { -// BallistaError::HttpError(e) -// } -// } - -// impl From for BallistaError { -// fn from(e: kube::error::Error) -> Self { -// BallistaError::KubeAPIError(e) -// } -// } - -// impl From for BallistaError { -// fn from(e: k8s_openapi::RequestError) -> Self { -// BallistaError::KubeAPIRequestError(e) -// } -// } - -// impl From for BallistaError { -// fn from(e: k8s_openapi::ResponseError) -> Self { -// BallistaError::KubeAPIResponseError(e) -// } -// } - impl From for BallistaError { fn from(e: tonic::transport::Error) -> Self { BallistaError::TonicError(e) @@ -191,15 +157,6 @@ impl Display for BallistaError { } BallistaError::SqlError(ref desc) => write!(f, "SQL error: {desc}"), BallistaError::IoError(ref desc) => write!(f, "IO error: {desc}"), - // BallistaError::ReqwestError(ref desc) => write!(f, "Reqwest error: {}", desc), - // BallistaError::HttpError(ref desc) => write!(f, "HTTP error: {}", desc), - // BallistaError::KubeAPIError(ref desc) => write!(f, "Kube API error: {}", desc), - // BallistaError::KubeAPIRequestError(ref desc) => { - // write!(f, "KubeAPI request error: {}", desc) - // } - // BallistaError::KubeAPIResponseError(ref desc) => { - // write!(f, "KubeAPI response error: {}", desc) - // } BallistaError::TonicError(desc) => write!(f, "Tonic error: {desc}"), BallistaError::GrpcError(desc) => write!(f, "Grpc error: {desc}"), BallistaError::GrpcConnectionError(desc) => { @@ -220,6 +177,9 @@ impl Display for BallistaError { ) } BallistaError::Cancelled => write!(f, "Task cancelled"), + BallistaError::Configuration(desc) => { + write!(f, "Configuration error: {desc}") + } } } } diff --git a/ballista/core/src/execution_plans/distributed_query.rs b/ballista/core/src/execution_plans/distributed_query.rs index dae4bb8ee..785d3b0cb 100644 --- a/ballista/core/src/execution_plans/distributed_query.rs +++ b/ballista/core/src/execution_plans/distributed_query.rs @@ -17,7 +17,6 @@ use crate::client::BallistaClient; use crate::config::BallistaConfig; -use crate::serde::protobuf::execute_query_params::OptionalSessionId; use crate::serde::protobuf::{ execute_query_params::Query, execute_query_result, job_status, scheduler_grpc_client::SchedulerGrpcClient, ExecuteQueryParams, GetJobStatusParams, @@ -218,7 +217,7 @@ impl ExecutionPlan for DistributedQueryExec { .map( |datafusion::config::ConfigEntry { key, value, .. }| KeyValuePair { key: key.to_owned(), - value: value.clone().unwrap_or_else(|| String::from("")), + value: value.clone(), }, ) .collect(); @@ -226,9 +225,7 @@ impl ExecutionPlan for DistributedQueryExec { let query = ExecuteQueryParams { query: Some(Query::LogicalPlan(buf)), settings, - optional_session_id: Some(OptionalSessionId::SessionId( - self.session_id.clone(), - )), + session_id: Some(self.session_id.clone()), }; let stream = futures::stream::once( diff --git a/ballista/core/src/extension.rs b/ballista/core/src/extension.rs new file mode 100644 index 000000000..25bdbad92 --- /dev/null +++ b/ballista/core/src/extension.rs @@ -0,0 +1,428 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::config::{ + BallistaConfig, BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE, BALLISTA_JOB_NAME, + BALLISTA_STANDALONE_PARALLELISM, +}; +use crate::serde::protobuf::KeyValuePair; +use crate::serde::{BallistaLogicalExtensionCodec, BallistaPhysicalExtensionCodec}; +use crate::utils::BallistaQueryPlanner; +use datafusion::execution::context::{QueryPlanner, SessionConfig, SessionState}; +use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::execution::session_state::SessionStateBuilder; +use datafusion_proto::logical_plan::LogicalExtensionCodec; +use datafusion_proto::physical_plan::PhysicalExtensionCodec; +use datafusion_proto::protobuf::LogicalPlanNode; +use std::sync::Arc; + +/// Provides methods which adapt [SessionState] +/// for Ballista usage +pub trait SessionStateExt { + /// Setups new [SessionState] for ballista usage + /// + /// State will be created with appropriate [SessionConfig] configured + fn new_ballista_state( + scheduler_url: String, + session_id: String, + ) -> datafusion::error::Result; + /// Upgrades [SessionState] for ballista usage + /// + /// State will be upgraded to appropriate [SessionConfig] + fn upgrade_for_ballista( + self, + scheduler_url: String, + session_id: String, + ) -> datafusion::error::Result; +} + +/// [SessionConfig] extension with methods needed +/// for Ballista configuration +pub trait SessionConfigExt { + /// Creates session config which has + /// ballista configuration initialized + fn new_with_ballista() -> SessionConfig; + + /// Overrides ballista's [LogicalExtensionCodec] + fn with_ballista_logical_extension_codec( + self, + codec: Arc, + ) -> SessionConfig; + + /// Overrides ballista's [PhysicalExtensionCodec] + fn with_ballista_physical_extension_codec( + self, + codec: Arc, + ) -> SessionConfig; + + /// returns [LogicalExtensionCodec] if set + /// or default ballista codec if not + fn ballista_logical_extension_codec(&self) -> Arc; + + /// returns [PhysicalExtensionCodec] if set + /// or default ballista codec if not + fn ballista_physical_extension_codec(&self) -> Arc; + + /// Overrides ballista's [QueryPlanner] + fn with_ballista_query_planner( + self, + planner: Arc, + ) -> SessionConfig; + + /// Returns ballista's [QueryPlanner] if overridden + fn ballista_query_planner( + &self, + ) -> Option>; + + /// Returns parallelism of standalone cluster + fn ballista_standalone_parallelism(&self) -> usize; + /// Sets parallelism of standalone cluster + /// + /// This option to be used to configure standalone session context + fn with_ballista_standalone_parallelism(self, parallelism: usize) -> Self; + + /// retrieves grpc client max message size + fn ballista_grpc_client_max_message_size(&self) -> usize; + + /// sets grpc client max message size + fn with_ballista_grpc_client_max_message_size(self, max_size: usize) -> Self; + + /// Sets ballista job name + fn with_ballista_job_name(self, job_name: &str) -> Self; +} + +/// [SessionConfigHelperExt] is set of [SessionConfig] extension methods +/// which are used internally (not exposed in client) +pub trait SessionConfigHelperExt { + /// converts [SessionConfig] to proto + fn to_key_value_pairs(&self) -> Vec; + /// updates [SessionConfig] from proto + fn update_from_key_value_pair(self, key_value_pairs: &[KeyValuePair]) -> Self; + /// updates mut [SessionConfig] from proto + fn update_from_key_value_pair_mut(&mut self, key_value_pairs: &[KeyValuePair]); +} + +impl SessionStateExt for SessionState { + fn new_ballista_state( + scheduler_url: String, + session_id: String, + ) -> datafusion::error::Result { + let config = BallistaConfig::default(); + + let planner = + BallistaQueryPlanner::::new(scheduler_url, config.clone()); + + let session_config = SessionConfig::new() + .with_information_schema(true) + .with_option_extension(config.clone()) + // Ballista disables this option + .with_round_robin_repartition(false); + + let runtime_config = RuntimeConfig::default(); + let runtime_env = RuntimeEnv::try_new(runtime_config)?; + let session_state = SessionStateBuilder::new() + .with_default_features() + .with_config(session_config) + .with_runtime_env(Arc::new(runtime_env)) + .with_query_planner(Arc::new(planner)) + .with_session_id(session_id) + .build(); + + Ok(session_state) + } + + fn upgrade_for_ballista( + self, + scheduler_url: String, + session_id: String, + ) -> datafusion::error::Result { + let codec_logical = self.config().ballista_logical_extension_codec(); + let planner_override = self.config().ballista_query_planner(); + + let new_config = self + .config() + .options() + .extensions + .get::() + .cloned() + .unwrap_or_else(BallistaConfig::default); + + let session_config = self + .config() + .clone() + .with_option_extension(new_config.clone()) + // Ballista disables this option + .with_round_robin_repartition(false); + + let builder = SessionStateBuilder::new_from_existing(self) + .with_config(session_config) + .with_session_id(session_id); + + let builder = match planner_override { + Some(planner) => builder.with_query_planner(planner), + None => { + let planner = BallistaQueryPlanner::::with_extension( + scheduler_url, + new_config, + codec_logical, + ); + builder.with_query_planner(Arc::new(planner)) + } + }; + + Ok(builder.build()) + } +} + +impl SessionConfigExt for SessionConfig { + fn new_with_ballista() -> SessionConfig { + SessionConfig::new() + .with_option_extension(BallistaConfig::default()) + .with_target_partitions(16) + .with_round_robin_repartition(false) + } + fn with_ballista_logical_extension_codec( + self, + codec: Arc, + ) -> SessionConfig { + let extension = BallistaConfigExtensionLogicalCodec::new(codec); + self.with_extension(Arc::new(extension)) + } + fn with_ballista_physical_extension_codec( + self, + codec: Arc, + ) -> SessionConfig { + let extension = BallistaConfigExtensionPhysicalCodec::new(codec); + self.with_extension(Arc::new(extension)) + } + + fn ballista_logical_extension_codec(&self) -> Arc { + self.get_extension::() + .map(|c| c.codec()) + .unwrap_or_else(|| Arc::new(BallistaLogicalExtensionCodec::default())) + } + fn ballista_physical_extension_codec(&self) -> Arc { + self.get_extension::() + .map(|c| c.codec()) + .unwrap_or_else(|| Arc::new(BallistaPhysicalExtensionCodec::default())) + } + + fn with_ballista_query_planner( + self, + planner: Arc, + ) -> SessionConfig { + let extension = BallistaQueryPlannerExtension::new(planner); + self.with_extension(Arc::new(extension)) + } + + fn ballista_query_planner( + &self, + ) -> Option> { + self.get_extension::() + .map(|c| c.planner()) + } + + fn ballista_standalone_parallelism(&self) -> usize { + self.options() + .extensions + .get::() + .map(|c| c.default_standalone_parallelism()) + .unwrap_or_else(|| BallistaConfig::default().default_standalone_parallelism()) + } + + fn ballista_grpc_client_max_message_size(&self) -> usize { + self.options() + .extensions + .get::() + .map(|c| c.default_grpc_client_max_message_size()) + .unwrap_or_else(|| { + BallistaConfig::default().default_grpc_client_max_message_size() + }) + } + + fn with_ballista_job_name(self, job_name: &str) -> Self { + if self.options().extensions.get::().is_some() { + self.set_str(BALLISTA_JOB_NAME, job_name) + } else { + self.with_option_extension(BallistaConfig::default()) + .set_str(BALLISTA_JOB_NAME, job_name) + } + } + + fn with_ballista_grpc_client_max_message_size(self, max_size: usize) -> Self { + if self.options().extensions.get::().is_some() { + self.set_usize(BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE, max_size) + } else { + self.with_option_extension(BallistaConfig::default()) + .set_usize(BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE, max_size) + } + } + + fn with_ballista_standalone_parallelism(self, parallelism: usize) -> Self { + if self.options().extensions.get::().is_some() { + self.set_usize(BALLISTA_STANDALONE_PARALLELISM, parallelism) + } else { + self.with_option_extension(BallistaConfig::default()) + .set_usize(BALLISTA_STANDALONE_PARALLELISM, parallelism) + } + } +} + +impl SessionConfigHelperExt for SessionConfig { + fn to_key_value_pairs(&self) -> Vec { + self.options() + .entries() + .iter() + .map(|datafusion::config::ConfigEntry { key, value, .. }| { + log::trace!("sending configuration key: `{}`, value`{:?}`", key, value); + KeyValuePair { + key: key.to_owned(), + value: value.clone(), + } + }) + .collect() + } + + fn update_from_key_value_pair(self, key_value_pairs: &[KeyValuePair]) -> Self { + let mut s = self; + s.update_from_key_value_pair_mut(key_value_pairs); + s + } + + fn update_from_key_value_pair_mut(&mut self, key_value_pairs: &[KeyValuePair]) { + for KeyValuePair { key, value } in key_value_pairs { + match value { + Some(value) => { + log::trace!( + "setting up configuration key: `{}`, value: `{:?}`", + key, + value + ); + if let Err(e) = self.options_mut().set(key, value) { + // there is not much we can do about this error at the moment. + // it used to be warning but it gets very verbose + // as even datafusion properties can't be parsed + log::debug!( + "could not set configuration key: `{}`, value: `{:?}`, reason: {}", + key, + value, + e.to_string() + ) + } + } + None => { + log::trace!( + "can't set up configuration key: `{}`, as value is None", + key, + ) + } + } + } + } +} + +/// Wrapper for [SessionConfig] extension +/// holding [LogicalExtensionCodec] if overridden +struct BallistaConfigExtensionLogicalCodec { + codec: Arc, +} + +impl BallistaConfigExtensionLogicalCodec { + fn new(codec: Arc) -> Self { + Self { codec } + } + fn codec(&self) -> Arc { + self.codec.clone() + } +} + +/// Wrapper for [SessionConfig] extension +/// holding [PhysicalExtensionCodec] if overridden +struct BallistaConfigExtensionPhysicalCodec { + codec: Arc, +} + +impl BallistaConfigExtensionPhysicalCodec { + fn new(codec: Arc) -> Self { + Self { codec } + } + fn codec(&self) -> Arc { + self.codec.clone() + } +} + +/// Wrapper for [SessionConfig] extension +/// holding overridden [QueryPlanner] +struct BallistaQueryPlannerExtension { + planner: Arc, +} + +impl BallistaQueryPlannerExtension { + fn new(planner: Arc) -> Self { + Self { planner } + } + fn planner(&self) -> Arc { + self.planner.clone() + } +} + +#[cfg(test)] +mod test { + use datafusion::{ + execution::{SessionState, SessionStateBuilder}, + prelude::SessionConfig, + }; + + use crate::{ + config::BALLISTA_JOB_NAME, + extension::{SessionConfigExt, SessionConfigHelperExt, SessionStateExt}, + }; + + // Ballista disables round robin repatriations + #[tokio::test] + async fn should_disable_round_robin_repartition() { + let state = SessionState::new_ballista_state( + "scheduler_url".to_string(), + "session_id".to_string(), + ) + .unwrap(); + + assert!(!state.config().round_robin_repartition()); + + let state = SessionStateBuilder::new().build(); + + assert!(state.config().round_robin_repartition()); + let state = state + .upgrade_for_ballista("scheduler_url".to_string(), "session_id".to_string()) + .unwrap(); + + assert!(!state.config().round_robin_repartition()); + } + #[test] + fn should_convert_to_key_value_pairs() { + // key value pairs should contain datafusion and ballista values + + let config = + SessionConfig::new_with_ballista().with_ballista_job_name("job_name"); + let pairs = config.to_key_value_pairs(); + + assert!(pairs.iter().any(|p| p.key == BALLISTA_JOB_NAME)); + assert!(pairs + .iter() + .any(|p| p.key == "datafusion.catalog.information_schema")) + } +} diff --git a/ballista/core/src/lib.rs b/ballista/core/src/lib.rs index 4341f443a..7864d56ec 100644 --- a/ballista/core/src/lib.rs +++ b/ballista/core/src/lib.rs @@ -29,13 +29,14 @@ pub fn print_version() { pub mod client; pub mod config; pub mod consistent_hash; +pub mod diagram; pub mod error; pub mod event_loop; pub mod execution_plans; -pub mod utils; - -#[macro_use] +pub mod extension; +pub mod registry; pub mod serde; +pub mod utils; /// /// [RuntimeProducer] is a factory which creates runtime [RuntimeEnv] diff --git a/ballista/core/src/registry.rs b/ballista/core/src/registry.rs new file mode 100644 index 000000000..2f55e2809 --- /dev/null +++ b/ballista/core/src/registry.rs @@ -0,0 +1,112 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::common::DataFusionError; +use datafusion::execution::{FunctionRegistry, SessionState}; +use datafusion::functions::all_default_functions; +use datafusion::functions_aggregate::all_default_aggregate_functions; +use datafusion::functions_window::all_default_window_functions; +use datafusion::logical_expr::planner::ExprPlanner; +use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF}; +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; + +#[derive(Debug)] +pub struct BallistaFunctionRegistry { + pub scalar_functions: HashMap>, + pub aggregate_functions: HashMap>, + pub window_functions: HashMap>, +} + +impl Default for BallistaFunctionRegistry { + fn default() -> Self { + let scalar_functions = all_default_functions() + .into_iter() + .map(|f| (f.name().to_string(), f)) + .collect(); + + let aggregate_functions = all_default_aggregate_functions() + .into_iter() + .map(|f| (f.name().to_string(), f)) + .collect(); + + let window_functions = all_default_window_functions() + .into_iter() + .map(|f| (f.name().to_string(), f)) + .collect(); + + Self { + scalar_functions, + aggregate_functions, + window_functions, + } + } +} + +impl FunctionRegistry for BallistaFunctionRegistry { + fn expr_planners(&self) -> Vec> { + vec![] + } + + fn udfs(&self) -> HashSet { + self.scalar_functions.keys().cloned().collect() + } + + fn udf(&self, name: &str) -> datafusion::common::Result> { + let result = self.scalar_functions.get(name); + + result.cloned().ok_or_else(|| { + DataFusionError::Internal(format!( + "There is no UDF named \"{name}\" in the TaskContext" + )) + }) + } + + fn udaf(&self, name: &str) -> datafusion::common::Result> { + let result = self.aggregate_functions.get(name); + + result.cloned().ok_or_else(|| { + DataFusionError::Internal(format!( + "There is no UDAF named \"{name}\" in the TaskContext" + )) + }) + } + + fn udwf(&self, name: &str) -> datafusion::common::Result> { + let result = self.window_functions.get(name); + + result.cloned().ok_or_else(|| { + DataFusionError::Internal(format!( + "There is no UDWF named \"{name}\" in the TaskContext" + )) + }) + } +} + +impl From<&SessionState> for BallistaFunctionRegistry { + fn from(state: &SessionState) -> Self { + let scalar_functions = state.scalar_functions().clone(); + let aggregate_functions = state.aggregate_functions().clone(); + let window_functions = state.window_functions().clone(); + + Self { + scalar_functions, + aggregate_functions, + window_functions, + } + } +} diff --git a/ballista/core/src/serde/generated/ballista.rs b/ballista/core/src/serde/generated/ballista.rs index d61ef331e..d4faef825 100644 --- a/ballista/core/src/serde/generated/ballista.rs +++ b/ballista/core/src/serde/generated/ballista.rs @@ -249,8 +249,8 @@ pub struct TaskInputPartitions { pub struct KeyValuePair { #[prost(string, tag = "1")] pub key: ::prost::alloc::string::String, - #[prost(string, tag = "2")] - pub value: ::prost::alloc::string::String, + #[prost(string, optional, tag = "2")] + pub value: ::core::option::Option<::prost::alloc::string::String>, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct Action { @@ -708,11 +708,6 @@ pub struct MultiTaskDefinition { pub props: ::prost::alloc::vec::Vec, } #[derive(Clone, PartialEq, ::prost::Message)] -pub struct SessionSettings { - #[prost(message, repeated, tag = "1")] - pub configs: ::prost::alloc::vec::Vec, -} -#[derive(Clone, PartialEq, ::prost::Message)] pub struct JobSessionConfig { #[prost(string, tag = "1")] pub session_id: ::prost::alloc::string::String, @@ -789,14 +784,12 @@ pub struct UpdateTaskStatusResult { } #[derive(Clone, PartialEq, ::prost::Message)] pub struct ExecuteQueryParams { + #[prost(string, optional, tag = "3")] + pub session_id: ::core::option::Option<::prost::alloc::string::String>, #[prost(message, repeated, tag = "4")] pub settings: ::prost::alloc::vec::Vec, #[prost(oneof = "execute_query_params::Query", tags = "1, 2")] pub query: ::core::option::Option, - #[prost(oneof = "execute_query_params::OptionalSessionId", tags = "3")] - pub optional_session_id: ::core::option::Option< - execute_query_params::OptionalSessionId, - >, } /// Nested message and enum types in `ExecuteQueryParams`. pub mod execute_query_params { @@ -808,11 +801,6 @@ pub mod execute_query_params { #[prost(string, tag = "2")] Sql(::prost::alloc::string::String), } - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum OptionalSessionId { - #[prost(string, tag = "3")] - SessionId(::prost::alloc::string::String), - } } #[derive(Clone, PartialEq, ::prost::Message)] pub struct CreateSessionParams { diff --git a/ballista/core/src/serde/mod.rs b/ballista/core/src/serde/mod.rs index 5400b00ca..d7d6474f7 100644 --- a/ballista/core/src/serde/mod.rs +++ b/ballista/core/src/serde/mod.rs @@ -125,28 +125,28 @@ pub struct BallistaLogicalExtensionCodec { } impl BallistaLogicalExtensionCodec { - // looks for a codec which can operate on this node - // returns a position of codec in the list. - // - // position is important with encoding process - // as there is a need to remember which codec - // in the list was used to encode message, - // so we can use it for decoding as well - - fn try_any( + /// looks for a codec which can operate on this node + /// returns a position of codec in the list and result. + /// + /// position is important with encoding process + /// as position of used codecs is needed + /// so the same codec can be used for decoding + fn try_any( &self, - mut f: impl FnMut(&dyn LogicalExtensionCodec) -> Result, - ) -> Result<(u8, T)> { + mut f: impl FnMut(&dyn LogicalExtensionCodec) -> Result, + ) -> Result<(u32, R)> { let mut last_err = None; for (position, codec) in self.file_format_codecs.iter().enumerate() { match f(codec.as_ref()) { - Ok(node) => return Ok((position as u8, node)), + Ok(result) => return Ok((position as u32, result)), Err(err) => last_err = Some(err), } } Err(last_err.unwrap_or_else(|| { - DataFusionError::NotImplemented("Empty list of composed codecs".to_owned()) + DataFusionError::Internal( + "List of provided extended logical codecs is empty".to_owned(), + ) })) } } @@ -155,10 +155,12 @@ impl Default for BallistaLogicalExtensionCodec { fn default() -> Self { Self { default_codec: Arc::new(DefaultLogicalExtensionCodec {}), + // Position in this list is important as it will be used for decoding. + // If new codec is added it should go to last position. file_format_codecs: vec![ + Arc::new(ParquetLogicalExtensionCodec {}), Arc::new(CsvLogicalExtensionCodec {}), Arc::new(JsonLogicalExtensionCodec {}), - Arc::new(ParquetLogicalExtensionCodec {}), Arc::new(ArrowLogicalExtensionCodec {}), Arc::new(AvroLogicalExtensionCodec {}), ], @@ -210,19 +212,17 @@ impl LogicalExtensionCodec for BallistaLogicalExtensionCodec { buf: &[u8], ctx: &datafusion::prelude::SessionContext, ) -> Result> { - if !buf.is_empty() { - // gets codec id from input buffer - let codec_number = buf[0]; - let codec = self.file_format_codecs.get(codec_number as usize).ok_or( - DataFusionError::NotImplemented("Can't find required codex".to_owned()), - )?; - - codec.try_decode_file_format(&buf[1..], ctx) - } else { - Err(DataFusionError::NotImplemented( - "File format blob should have more than 0 bytes".to_owned(), - )) - } + let proto = FileFormatProto::decode(buf) + .map_err(|e| DataFusionError::Internal(e.to_string()))?; + + let codec = self + .file_format_codecs + .get(proto.encoder_position as usize) + .ok_or(DataFusionError::Internal( + "Can't find required codec in file codec list".to_owned(), + ))?; + + codec.try_decode_file_format(&proto.blob, ctx) } fn try_encode_file_format( @@ -230,18 +230,17 @@ impl LogicalExtensionCodec for BallistaLogicalExtensionCodec { buf: &mut Vec, node: Arc, ) -> Result<()> { - let mut encoded_format = vec![]; - let (codec_number, _) = self.try_any(|codec| { - codec.try_encode_file_format(&mut encoded_format, node.clone()) - })?; - // we need to remember which codec in the list was used to - // encode this node. - buf.push(codec_number); - - // save actual encoded node - buf.append(&mut encoded_format); - - Ok(()) + let mut blob = vec![]; + let (encoder_position, _) = + self.try_any(|codec| codec.try_encode_file_format(&mut blob, node.clone()))?; + + let proto = FileFormatProto { + encoder_position, + blob, + }; + proto + .encode(buf) + .map_err(|e| DataFusionError::Internal(e.to_string())) } } @@ -429,6 +428,25 @@ impl PhysicalExtensionCodec for BallistaPhysicalExtensionCodec { } } +/// FileFormatProto captures data encoded by file format codecs +/// +/// it captures position of codec used to encode FileFormat +/// and actual encoded value. +/// +/// capturing codec position is required, as same codec can decode +/// blobs encoded by different encoders (probability is low but it +/// happened in the past) +/// +#[derive(Clone, PartialEq, prost::Message)] +struct FileFormatProto { + /// encoder id used to encode blob + /// (to be used for decoding) + #[prost(uint32, tag = 1)] + pub encoder_position: u32, + #[prost(bytes, tag = 2)] + pub blob: Vec, +} + #[cfg(test)] mod test { use datafusion::{ diff --git a/ballista/core/src/serde/scheduler/from_proto.rs b/ballista/core/src/serde/scheduler/from_proto.rs index 372c1d910..0257cfa00 100644 --- a/ballista/core/src/serde/scheduler/from_proto.rs +++ b/ballista/core/src/serde/scheduler/from_proto.rs @@ -32,6 +32,7 @@ use std::sync::Arc; use std::time::Duration; use crate::error::BallistaError; +use crate::extension::SessionConfigHelperExt; use crate::serde::scheduler::{ Action, BallistaFunctionRegistry, ExecutorData, ExecutorMetadata, ExecutorSpecification, PartitionId, PartitionLocation, PartitionStats, @@ -39,7 +40,6 @@ use crate::serde::scheduler::{ }; use crate::serde::{protobuf, BallistaCodec}; -use crate::utils::SessionConfigExt; use crate::RuntimeProducer; use protobuf::{operator_metric, NamedCount, NamedGauge, NamedTime}; diff --git a/ballista/core/src/serde/scheduler/mod.rs b/ballista/core/src/serde/scheduler/mod.rs index 2905455eb..a2c92ff8a 100644 --- a/ballista/core/src/serde/scheduler/mod.rs +++ b/ballista/core/src/serde/scheduler/mod.rs @@ -15,27 +15,18 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashSet; -use std::fmt::Debug; -use std::{collections::HashMap, fmt, sync::Arc}; - +use crate::error::BallistaError; +use crate::registry::BallistaFunctionRegistry; use datafusion::arrow::array::{ ArrayBuilder, StructArray, StructBuilder, UInt64Array, UInt64Builder, }; use datafusion::arrow::datatypes::{DataType, Field}; -use datafusion::common::DataFusionError; -use datafusion::execution::{FunctionRegistry, SessionState}; -use datafusion::functions::all_default_functions; -use datafusion::functions_aggregate::all_default_aggregate_functions; -use datafusion::functions_window::all_default_window_functions; -use datafusion::logical_expr::planner::ExprPlanner; -use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF}; use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_plan::Partitioning; use datafusion::prelude::SessionConfig; use serde::Serialize; - -use crate::error::BallistaError; +use std::fmt::Debug; +use std::{collections::HashMap, fmt, sync::Arc}; pub mod from_proto; pub mod to_proto; @@ -295,89 +286,3 @@ pub struct TaskDefinition { pub session_config: SessionConfig, pub function_registry: Arc, } - -#[derive(Debug)] -pub struct BallistaFunctionRegistry { - pub scalar_functions: HashMap>, - pub aggregate_functions: HashMap>, - pub window_functions: HashMap>, -} - -impl Default for BallistaFunctionRegistry { - fn default() -> Self { - let scalar_functions = all_default_functions() - .into_iter() - .map(|f| (f.name().to_string(), f)) - .collect(); - - let aggregate_functions = all_default_aggregate_functions() - .into_iter() - .map(|f| (f.name().to_string(), f)) - .collect(); - - let window_functions = all_default_window_functions() - .into_iter() - .map(|f| (f.name().to_string(), f)) - .collect(); - - Self { - scalar_functions, - aggregate_functions, - window_functions, - } - } -} - -impl FunctionRegistry for BallistaFunctionRegistry { - fn expr_planners(&self) -> Vec> { - vec![] - } - - fn udfs(&self) -> HashSet { - self.scalar_functions.keys().cloned().collect() - } - - fn udf(&self, name: &str) -> datafusion::common::Result> { - let result = self.scalar_functions.get(name); - - result.cloned().ok_or_else(|| { - DataFusionError::Internal(format!( - "There is no UDF named \"{name}\" in the TaskContext" - )) - }) - } - - fn udaf(&self, name: &str) -> datafusion::common::Result> { - let result = self.aggregate_functions.get(name); - - result.cloned().ok_or_else(|| { - DataFusionError::Internal(format!( - "There is no UDAF named \"{name}\" in the TaskContext" - )) - }) - } - - fn udwf(&self, name: &str) -> datafusion::common::Result> { - let result = self.window_functions.get(name); - - result.cloned().ok_or_else(|| { - DataFusionError::Internal(format!( - "There is no UDWF named \"{name}\" in the TaskContext" - )) - }) - } -} - -impl From<&SessionState> for BallistaFunctionRegistry { - fn from(state: &SessionState) -> Self { - let scalar_functions = state.scalar_functions().clone(); - let aggregate_functions = state.aggregate_functions().clone(); - let window_functions = state.window_functions().clone(); - - Self { - scalar_functions, - aggregate_functions, - window_functions, - } - } -} diff --git a/ballista/core/src/utils.rs b/ballista/core/src/utils.rs index 6e39d126c..913e955d3 100644 --- a/ballista/core/src/utils.rs +++ b/ballista/core/src/utils.rs @@ -15,17 +15,13 @@ // specific language governing permissions and limitations // under the License. -use crate::config::{ - BallistaConfig, BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE, BALLISTA_JOB_NAME, - BALLISTA_STANDALONE_PARALLELISM, -}; +use crate::config::BallistaConfig; use crate::error::{BallistaError, Result}; -use crate::execution_plans::{ - DistributedQueryExec, ShuffleWriterExec, UnresolvedShuffleExec, -}; -use crate::serde::protobuf::KeyValuePair; +use crate::execution_plans::DistributedQueryExec; + +use crate::extension::SessionConfigExt; use crate::serde::scheduler::PartitionStats; -use crate::serde::{BallistaLogicalExtensionCodec, BallistaPhysicalExtensionCodec}; +use crate::serde::BallistaLogicalExtensionCodec; use async_trait::async_trait; use datafusion::arrow::datatypes::Schema; @@ -34,33 +30,19 @@ use datafusion::arrow::ipc::writer::StreamWriter; use datafusion::arrow::ipc::CompressionType; use datafusion::arrow::record_batch::RecordBatch; use datafusion::common::tree_node::{TreeNode, TreeNodeVisitor}; -use datafusion::datasource::physical_plan::{CsvExec, ParquetExec}; use datafusion::error::DataFusionError; -use datafusion::execution::context::{ - QueryPlanner, SessionConfig, SessionContext, SessionState, -}; +use datafusion::execution::context::{QueryPlanner, SessionConfig, SessionState}; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::logical_expr::{DdlStatement, LogicalPlan, TableScan}; -use datafusion::physical_plan::aggregates::AggregateExec; -use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; -use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::empty::EmptyExec; -use datafusion::physical_plan::filter::FilterExec; -use datafusion::physical_plan::joins::HashJoinExec; use datafusion::physical_plan::metrics::MetricsSet; -use datafusion::physical_plan::projection::ProjectionExec; -use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{metrics, ExecutionPlan, RecordBatchStream}; use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; use datafusion_proto::logical_plan::{AsLogicalPlan, LogicalExtensionCodec}; -use datafusion_proto::physical_plan::PhysicalExtensionCodec; -use datafusion_proto::protobuf::LogicalPlanNode; use futures::StreamExt; use log::error; -use std::io::{BufWriter, Write}; use std::marker::PhantomData; -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::{fs::File, pin::Pin}; @@ -68,14 +50,14 @@ use tonic::codegen::StdError; use tonic::transport::{Channel, Error, Server}; /// Default session builder using the provided configuration -pub fn default_session_builder(config: SessionConfig) -> SessionState { - SessionStateBuilder::new() +pub fn default_session_builder( + config: SessionConfig, +) -> datafusion::common::Result { + Ok(SessionStateBuilder::new() .with_default_features() .with_config(config) - .with_runtime_env(Arc::new( - RuntimeEnv::try_new(RuntimeConfig::default()).unwrap(), - )) - .build() + .with_runtime_env(Arc::new(RuntimeEnv::try_new(RuntimeConfig::default())?)) + .build()) } pub fn default_config_producer() -> SessionConfig { @@ -135,499 +117,6 @@ pub async fn collect_stream( Ok(batches) } -pub fn produce_diagram(filename: &str, stages: &[Arc]) -> Result<()> { - let write_file = File::create(filename)?; - let mut w = BufWriter::new(&write_file); - writeln!(w, "digraph G {{")?; - - // draw stages and entities - for stage in stages { - writeln!(w, "\tsubgraph cluster{} {{", stage.stage_id())?; - writeln!(w, "\t\tlabel = \"Stage {}\";", stage.stage_id())?; - let mut id = AtomicUsize::new(0); - build_exec_plan_diagram( - &mut w, - stage.children()[0].as_ref(), - stage.stage_id(), - &mut id, - true, - )?; - writeln!(w, "\t}}")?; - } - - // draw relationships - for stage in stages { - let mut id = AtomicUsize::new(0); - build_exec_plan_diagram( - &mut w, - stage.children()[0].as_ref(), - stage.stage_id(), - &mut id, - false, - )?; - } - - write!(w, "}}")?; - Ok(()) -} - -fn build_exec_plan_diagram( - w: &mut BufWriter<&File>, - plan: &dyn ExecutionPlan, - stage_id: usize, - id: &mut AtomicUsize, - draw_entity: bool, -) -> Result { - let operator_str = if plan.as_any().downcast_ref::().is_some() { - "AggregateExec" - } else if plan.as_any().downcast_ref::().is_some() { - "SortExec" - } else if plan.as_any().downcast_ref::().is_some() { - "ProjectionExec" - } else if plan.as_any().downcast_ref::().is_some() { - "HashJoinExec" - } else if plan.as_any().downcast_ref::().is_some() { - "ParquetExec" - } else if plan.as_any().downcast_ref::().is_some() { - "CsvExec" - } else if plan.as_any().downcast_ref::().is_some() { - "FilterExec" - } else if plan.as_any().downcast_ref::().is_some() { - "ShuffleWriterExec" - } else if plan - .as_any() - .downcast_ref::() - .is_some() - { - "UnresolvedShuffleExec" - } else if plan - .as_any() - .downcast_ref::() - .is_some() - { - "CoalesceBatchesExec" - } else if plan - .as_any() - .downcast_ref::() - .is_some() - { - "CoalescePartitionsExec" - } else { - println!("Unknown: {plan:?}"); - "Unknown" - }; - - let node_id = id.load(Ordering::SeqCst); - id.store(node_id + 1, Ordering::SeqCst); - - if draw_entity { - writeln!( - w, - "\t\tstage_{stage_id}_exec_{node_id} [shape=box, label=\"{operator_str}\"];" - )?; - } - for child in plan.children() { - if let Some(shuffle) = child.as_any().downcast_ref::() { - if !draw_entity { - writeln!( - w, - "\tstage_{}_exec_1 -> stage_{}_exec_{};", - shuffle.stage_id, stage_id, node_id - )?; - } - } else { - // relationships within same entity - let child_id = - build_exec_plan_diagram(w, child.as_ref(), stage_id, id, draw_entity)?; - if draw_entity { - writeln!( - w, - "\t\tstage_{stage_id}_exec_{child_id} -> stage_{stage_id}_exec_{node_id};" - )?; - } - } - } - Ok(node_id) -} - -/// Create a client DataFusion context that uses the BallistaQueryPlanner to send logical plans -/// to a Ballista scheduler -pub fn create_df_ctx_with_ballista_query_planner( - scheduler_url: String, - session_id: String, - config: &BallistaConfig, -) -> SessionContext { - // TODO: put ballista configuration as part of sessions state - // planner can get it from there. - // This would make it changeable during run time - // using SQL SET statement - let planner: Arc> = - Arc::new(BallistaQueryPlanner::new(scheduler_url, config.clone())); - - let session_config = SessionConfig::new_with_ballista() - .with_information_schema(true) - .with_option_extension(config.clone()); - - let session_state = SessionStateBuilder::new() - .with_default_features() - .with_config(session_config) - .with_runtime_env(Arc::new( - RuntimeEnv::try_new(RuntimeConfig::default()).unwrap(), - )) - .with_query_planner(planner) - .with_session_id(session_id) - .build(); - // the SessionContext created here is the client side context, but the session_id is from server side. - SessionContext::new_with_state(session_state) -} - -pub trait SessionStateExt { - fn new_ballista_state( - scheduler_url: String, - session_id: String, - ) -> datafusion::error::Result; - fn upgrade_for_ballista( - self, - scheduler_url: String, - session_id: String, - ) -> datafusion::error::Result; - #[deprecated] - fn ballista_config(&self) -> BallistaConfig; -} - -impl SessionStateExt for SessionState { - fn ballista_config(&self) -> BallistaConfig { - self.config() - .options() - .extensions - .get::() - .cloned() - .unwrap_or_else(BallistaConfig::default) - } - - fn new_ballista_state( - scheduler_url: String, - session_id: String, - ) -> datafusion::error::Result { - let config = BallistaConfig::default(); - - let planner = - BallistaQueryPlanner::::new(scheduler_url, config.clone()); - - let session_config = SessionConfig::new() - .with_information_schema(true) - .with_option_extension(config.clone()) - // Ballista disables this option - .with_round_robin_repartition(false); - - let runtime_config = RuntimeConfig::default(); - let runtime_env = RuntimeEnv::try_new(runtime_config)?; - let session_state = SessionStateBuilder::new() - .with_default_features() - .with_config(session_config) - .with_runtime_env(Arc::new(runtime_env)) - .with_query_planner(Arc::new(planner)) - .with_session_id(session_id) - .build(); - - Ok(session_state) - } - - fn upgrade_for_ballista( - self, - scheduler_url: String, - session_id: String, - ) -> datafusion::error::Result { - let codec_logical = self.config().ballista_logical_extension_codec(); - let planner_override = self.config().ballista_query_planner(); - - let new_config = self - .config() - .options() - .extensions - .get::() - .cloned() - .unwrap_or_else(BallistaConfig::default); - - let session_config = self - .config() - .clone() - .with_option_extension(new_config.clone()) - // Ballista disables this option - .with_round_robin_repartition(false); - - let builder = SessionStateBuilder::new_from_existing(self) - .with_config(session_config) - .with_session_id(session_id); - - let builder = match planner_override { - Some(planner) => builder.with_query_planner(planner), - None => { - let planner = BallistaQueryPlanner::::with_extension( - scheduler_url, - new_config, - codec_logical, - ); - builder.with_query_planner(Arc::new(planner)) - } - }; - - Ok(builder.build()) - } -} - -pub trait SessionConfigExt { - /// Creates session config which has - /// ballista configuration initialized - fn new_with_ballista() -> SessionConfig; - - /// Overrides ballista's [LogicalExtensionCodec] - fn with_ballista_logical_extension_codec( - self, - codec: Arc, - ) -> SessionConfig; - - /// Overrides ballista's [PhysicalExtensionCodec] - fn with_ballista_physical_extension_codec( - self, - codec: Arc, - ) -> SessionConfig; - - /// returns [LogicalExtensionCodec] if set - /// or default ballista codec if not - fn ballista_logical_extension_codec(&self) -> Arc; - - /// returns [PhysicalExtensionCodec] if set - /// or default ballista codec if not - fn ballista_physical_extension_codec(&self) -> Arc; - - /// Overrides ballista's [QueryPlanner] - fn with_ballista_query_planner( - self, - planner: Arc, - ) -> SessionConfig; - - /// Returns ballista's [QueryPlanner] if overridden - fn ballista_query_planner( - &self, - ) -> Option>; - - fn ballista_standalone_parallelism(&self) -> usize; - - fn ballista_grpc_client_max_message_size(&self) -> usize; - - fn to_key_value_pairs(&self) -> Vec; - - fn update_from_key_value_pair(self, key_value_pairs: &[KeyValuePair]) -> Self; - - fn with_ballista_job_name(self, job_name: &str) -> Self; - - fn with_ballista_grpc_client_max_message_size(self, max_size: usize) -> Self; - - fn with_ballista_standalone_parallelism(self, parallelism: usize) -> Self; - - fn update_from_key_value_pair_mut(&mut self, key_value_pairs: &[KeyValuePair]); -} - -impl SessionConfigExt for SessionConfig { - fn new_with_ballista() -> SessionConfig { - SessionConfig::new() - .with_option_extension(BallistaConfig::default()) - .with_target_partitions(16) - .with_round_robin_repartition(false) - } - fn with_ballista_logical_extension_codec( - self, - codec: Arc, - ) -> SessionConfig { - let extension = BallistaConfigExtensionLogicalCodec::new(codec); - self.with_extension(Arc::new(extension)) - } - fn with_ballista_physical_extension_codec( - self, - codec: Arc, - ) -> SessionConfig { - let extension = BallistaConfigExtensionPhysicalCodec::new(codec); - self.with_extension(Arc::new(extension)) - } - - fn ballista_logical_extension_codec(&self) -> Arc { - self.get_extension::() - .map(|c| c.codec()) - .unwrap_or_else(|| Arc::new(BallistaLogicalExtensionCodec::default())) - } - fn ballista_physical_extension_codec(&self) -> Arc { - self.get_extension::() - .map(|c| c.codec()) - .unwrap_or_else(|| Arc::new(BallistaPhysicalExtensionCodec::default())) - } - - fn with_ballista_query_planner( - self, - planner: Arc, - ) -> SessionConfig { - let extension = BallistaQueryPlannerExtension::new(planner); - self.with_extension(Arc::new(extension)) - } - - fn ballista_query_planner( - &self, - ) -> Option> { - self.get_extension::() - .map(|c| c.planner()) - } - - fn ballista_standalone_parallelism(&self) -> usize { - self.options() - .extensions - .get::() - .map(|c| c.default_standalone_parallelism()) - .unwrap_or_else(|| BallistaConfig::default().default_standalone_parallelism()) - } - - fn ballista_grpc_client_max_message_size(&self) -> usize { - self.options() - .extensions - .get::() - .map(|c| c.default_grpc_client_max_message_size()) - .unwrap_or_else(|| { - BallistaConfig::default().default_grpc_client_max_message_size() - }) - } - - fn to_key_value_pairs(&self) -> Vec { - self.options() - .entries() - .iter() - .filter(|v| v.value.is_some()) - .map( - // TODO MM make `value` optional value - |datafusion::config::ConfigEntry { key, value, .. }| { - log::trace!( - "sending configuration key: `{}`, value`{:?}`", - key, - value - ); - KeyValuePair { - key: key.to_owned(), - value: value.clone().unwrap(), - } - }, - ) - .collect() - } - - fn update_from_key_value_pair(self, key_value_pairs: &[KeyValuePair]) -> Self { - let mut s = self; - for KeyValuePair { key, value } in key_value_pairs { - log::trace!( - "setting up configuration key: `{}`, value: `{}`", - key, - value - ); - if let Err(e) = s.options_mut().set(key, value) { - log::warn!( - "could not set configuration key: `{}`, value: `{}`, reason: {}", - key, - value, - e.to_string() - ) - } - } - s - } - - fn update_from_key_value_pair_mut(&mut self, key_value_pairs: &[KeyValuePair]) { - for KeyValuePair { key, value } in key_value_pairs { - log::trace!( - "setting up configuration key : `{}`, value: `{}`", - key, - value - ); - if let Err(e) = self.options_mut().set(key, value) { - log::warn!( - "could not set configuration key: `{}`, value: `{}`, reason: {}", - key, - value, - e.to_string() - ) - } - } - } - - fn with_ballista_job_name(self, job_name: &str) -> Self { - if self.options().extensions.get::().is_some() { - self.set_str(BALLISTA_JOB_NAME, job_name) - } else { - self.with_option_extension(BallistaConfig::default()) - .set_str(BALLISTA_JOB_NAME, job_name) - } - } - - fn with_ballista_grpc_client_max_message_size(self, max_size: usize) -> Self { - if self.options().extensions.get::().is_some() { - self.set_usize(BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE, max_size) - } else { - self.with_option_extension(BallistaConfig::default()) - .set_usize(BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE, max_size) - } - } - - fn with_ballista_standalone_parallelism(self, parallelism: usize) -> Self { - if self.options().extensions.get::().is_some() { - self.set_usize(BALLISTA_STANDALONE_PARALLELISM, parallelism) - } else { - self.with_option_extension(BallistaConfig::default()) - .set_usize(BALLISTA_STANDALONE_PARALLELISM, parallelism) - } - } -} - -/// Wrapper for [SessionConfig] extension -/// holding [LogicalExtensionCodec] if overridden -struct BallistaConfigExtensionLogicalCodec { - codec: Arc, -} - -impl BallistaConfigExtensionLogicalCodec { - fn new(codec: Arc) -> Self { - Self { codec } - } - fn codec(&self) -> Arc { - self.codec.clone() - } -} - -/// Wrapper for [SessionConfig] extension -/// holding [PhysicalExtensionCodec] if overridden -struct BallistaConfigExtensionPhysicalCodec { - codec: Arc, -} - -impl BallistaConfigExtensionPhysicalCodec { - fn new(codec: Arc) -> Self { - Self { codec } - } - fn codec(&self) -> Arc { - self.codec.clone() - } -} - -/// Wrapper for [SessionConfig] extension -/// holding overridden [QueryPlanner] -struct BallistaQueryPlannerExtension { - planner: Arc, -} - -impl BallistaQueryPlannerExtension { - fn new(planner: Arc) -> Self { - Self { planner } - } - fn planner(&self) -> Arc { - self.planner.clone() - } -} - pub struct BallistaQueryPlanner { scheduler_url: String, config: BallistaConfig, @@ -830,17 +319,12 @@ mod test { error::Result, execution::{ runtime_env::{RuntimeConfig, RuntimeEnv}, - SessionState, SessionStateBuilder, + SessionStateBuilder, }, prelude::{SessionConfig, SessionContext}, }; - use crate::{ - config::BALLISTA_JOB_NAME, - utils::{LocalRun, SessionStateExt}, - }; - - use super::SessionConfigExt; + use crate::utils::LocalRun; fn context() -> SessionContext { let runtime_environment = RuntimeEnv::try_new(RuntimeConfig::new()).unwrap(); @@ -917,38 +401,4 @@ mod test { Ok(()) } - - // Ballista disables round robin repatriations - #[tokio::test] - async fn should_disable_round_robin_repartition() { - let state = SessionState::new_ballista_state( - "scheduler_url".to_string(), - "session_id".to_string(), - ) - .unwrap(); - - assert!(!state.config().round_robin_repartition()); - - let state = SessionStateBuilder::new().build(); - - assert!(state.config().round_robin_repartition()); - let state = state - .upgrade_for_ballista("scheduler_url".to_string(), "session_id".to_string()) - .unwrap(); - - assert!(!state.config().round_robin_repartition()); - } - #[test] - fn should_convert_to_key_value_pairs() { - // key value pairs should contain datafusion and ballista values - - let config = - SessionConfig::new_with_ballista().with_ballista_job_name("job_name"); - let pairs = config.to_key_value_pairs(); - - assert!(pairs.iter().any(|p| p.key == BALLISTA_JOB_NAME)); - assert!(pairs - .iter() - .any(|p| p.key == "datafusion.catalog.information_schema")) - } } diff --git a/ballista/executor/Cargo.toml b/ballista/executor/Cargo.toml index e1822e9c1..6a2dfa619 100644 --- a/ballista/executor/Cargo.toml +++ b/ballista/executor/Cargo.toml @@ -32,23 +32,24 @@ executor = "executor_config_spec.toml" [[bin]] name = "ballista-executor" path = "src/bin/main.rs" +required-features = ["build-binary"] [features] -default = ["mimalloc"] +build-binary = ["configure_me", "tracing-subscriber", "tracing-appender", "tracing", "ballista-core/build-binary"] +default = ["build-binary", "mimalloc"] [dependencies] -anyhow = "1" arrow = { workspace = true } arrow-flight = { workspace = true } async-trait = { workspace = true } ballista-core = { path = "../core", version = "0.12.0" } -configure_me = { workspace = true } +configure_me = { workspace = true, optional = true } dashmap = { workspace = true } datafusion = { workspace = true } datafusion-proto = { workspace = true } futures = { workspace = true } log = { workspace = true } -mimalloc = { version = "0.1", default-features = false, optional = true } +mimalloc = { workspace = true, optional = true } parking_lot = { workspace = true } tempfile = { workspace = true } tokio = { workspace = true, features = [ @@ -60,9 +61,9 @@ tokio = { workspace = true, features = [ ] } tokio-stream = { workspace = true, features = ["net"] } tonic = { workspace = true } -tracing = { workspace = true } -tracing-appender = { workspace = true } -tracing-subscriber = { workspace = true } +tracing = { workspace = true, optional = true } +tracing-appender = { workspace = true, optional = true } +tracing-subscriber = { workspace = true, optional = true } uuid = { workspace = true } [dev-dependencies] diff --git a/ballista/executor/build.rs b/ballista/executor/build.rs index 7d2b9b87b..21ce2d8fe 100644 --- a/ballista/executor/build.rs +++ b/ballista/executor/build.rs @@ -15,10 +15,12 @@ // specific language governing permissions and limitations // under the License. -extern crate configure_me_codegen; - fn main() -> Result<(), String> { + #[cfg(feature = "build-binary")] println!("cargo:rerun-if-changed=executor_config_spec.toml"); + #[cfg(feature = "build-binary")] configure_me_codegen::build_script_auto() - .map_err(|e| format!("configure_me code generation failed: {e}")) + .map_err(|e| format!("configure_me code generation failed: {e}"))?; + + Ok(()) } diff --git a/ballista/executor/executor_config_spec.toml b/ballista/executor/executor_config_spec.toml index 209069de1..4379a0cd4 100644 --- a/ballista/executor/executor_config_spec.toml +++ b/ballista/executor/executor_config_spec.toml @@ -143,25 +143,3 @@ name = "executor_heartbeat_interval_seconds" type = "u64" doc = "The heartbeat interval in seconds to the scheduler for push-based task scheduling" default = "60" - -[[param]] -name = "data_cache_policy" -type = "ballista_core::config::DataCachePolicy" -doc = "Data cache policy, possible values: local-disk-file" - -[[param]] -name = "cache_dir" -type = "String" -doc = "Directory for cached source data" - -[[param]] -name = "cache_capacity" -type = "u64" -doc = "The maximum capacity can be used for cache. Default: 1GB" -default = "1073741824" - -[[param]] -name = "cache_io_concurrency" -type = "u32" -doc = "The number of worker threads for the runtime of caching. Default: 2" -default = "2" \ No newline at end of file diff --git a/ballista/executor/src/bin/main.rs b/ballista/executor/src/bin/main.rs index 9f5ed12f1..18abb9960 100644 --- a/ballista/executor/src/bin/main.rs +++ b/ballista/executor/src/bin/main.rs @@ -17,32 +17,22 @@ //! Ballista Rust executor binary. -use anyhow::Result; -use std::sync::Arc; - +use ballista_core::config::LogRotationPolicy; use ballista_core::print_version; +use ballista_executor::config::prelude::*; use ballista_executor::executor_process::{ start_executor_process, ExecutorProcessConfig, }; -use config::prelude::*; - -#[allow(unused_imports)] -#[macro_use] -extern crate configure_me; - -#[allow(clippy::all, warnings)] -mod config { - // Ideally we would use the include_config macro from configure_me, but then we cannot use - // #[allow(clippy::all)] to silence clippy warnings from the generated code - include!(concat!(env!("OUT_DIR"), "/executor_configure_me_config.rs")); -} +use std::env; +use std::sync::Arc; +use tracing_subscriber::EnvFilter; #[cfg(feature = "mimalloc")] #[global_allocator] static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; #[tokio::main] -async fn main() -> Result<()> { +async fn main() -> ballista_core::error::Result<()> { // parse command-line arguments let (opt, _remaining_args) = Config::including_optional_config_files(&["/etc/ballista/executor.toml"]) @@ -53,46 +43,40 @@ async fn main() -> Result<()> { std::process::exit(0); } - let log_file_name_prefix = format!( - "executor_{}_{}", - opt.external_host - .clone() - .unwrap_or_else(|| "localhost".to_string()), - opt.bind_port - ); + let config: ExecutorProcessConfig = opt.try_into()?; + + let rust_log = env::var(EnvFilter::DEFAULT_ENV); + let log_filter = + EnvFilter::new(rust_log.unwrap_or(config.special_mod_log_level.clone())); + + let tracing = tracing_subscriber::fmt() + .with_ansi(false) + .with_thread_names(config.print_thread_info) + .with_thread_ids(config.print_thread_info) + .with_env_filter(log_filter); - let config = ExecutorProcessConfig { - special_mod_log_level: opt.log_level_setting, - external_host: opt.external_host, - bind_host: opt.bind_host, - port: opt.bind_port, - grpc_port: opt.bind_grpc_port, - scheduler_host: opt.scheduler_host, - scheduler_port: opt.scheduler_port, - scheduler_connect_timeout_seconds: opt.scheduler_connect_timeout_seconds, - concurrent_tasks: opt.concurrent_tasks, - task_scheduling_policy: opt.task_scheduling_policy, - work_dir: opt.work_dir, - log_dir: opt.log_dir, - log_file_name_prefix, - log_rotation_policy: opt.log_rotation_policy, - print_thread_info: opt.print_thread_info, - job_data_ttl_seconds: opt.job_data_ttl_seconds, - job_data_clean_up_interval_seconds: opt.job_data_clean_up_interval_seconds, - grpc_max_decoding_message_size: opt.grpc_server_max_decoding_message_size, - grpc_max_encoding_message_size: opt.grpc_server_max_encoding_message_size, - executor_heartbeat_interval_seconds: opt.executor_heartbeat_interval_seconds, - data_cache_policy: opt.data_cache_policy, - cache_dir: opt.cache_dir, - cache_capacity: opt.cache_capacity, - cache_io_concurrency: opt.cache_io_concurrency, - execution_engine: None, - function_registry: None, - config_producer: None, - runtime_producer: None, - logical_codec: None, - physical_codec: None, - }; + // File layer + if let Some(log_dir) = &config.log_dir { + let log_file = match config.log_rotation_policy { + LogRotationPolicy::Minutely => tracing_appender::rolling::minutely( + log_dir, + config.log_file_name_prefix(), + ), + LogRotationPolicy::Hourly => { + tracing_appender::rolling::hourly(log_dir, config.log_file_name_prefix()) + } + LogRotationPolicy::Daily => { + tracing_appender::rolling::daily(log_dir, config.log_file_name_prefix()) + } + LogRotationPolicy::Never => { + tracing_appender::rolling::never(log_dir, config.log_file_name_prefix()) + } + }; + + tracing.with_writer(log_file).init(); + } else { + tracing.init(); + } start_executor_process(Arc::new(config)).await } diff --git a/ballista/executor/src/config.rs b/ballista/executor/src/config.rs new file mode 100644 index 000000000..91b547327 --- /dev/null +++ b/ballista/executor/src/config.rs @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use ballista_core::error::BallistaError; + +use crate::executor_process::ExecutorProcessConfig; + +// Ideally we would use the include_config macro from configure_me, but then we cannot use +// #[allow(clippy::all)] to silence clippy warnings from the generated code + +include!(concat!(env!("OUT_DIR"), "/executor_configure_me_config.rs")); + +impl TryFrom for ExecutorProcessConfig { + type Error = BallistaError; + + fn try_from(opt: Config) -> Result { + Ok(ExecutorProcessConfig { + special_mod_log_level: opt.log_level_setting, + external_host: opt.external_host, + bind_host: opt.bind_host, + port: opt.bind_port, + grpc_port: opt.bind_grpc_port, + scheduler_host: opt.scheduler_host, + scheduler_port: opt.scheduler_port, + scheduler_connect_timeout_seconds: opt.scheduler_connect_timeout_seconds, + concurrent_tasks: opt.concurrent_tasks, + task_scheduling_policy: opt.task_scheduling_policy, + work_dir: opt.work_dir, + log_dir: opt.log_dir, + log_rotation_policy: opt.log_rotation_policy, + print_thread_info: opt.print_thread_info, + job_data_ttl_seconds: opt.job_data_ttl_seconds, + job_data_clean_up_interval_seconds: opt.job_data_clean_up_interval_seconds, + grpc_max_decoding_message_size: opt.grpc_server_max_decoding_message_size, + grpc_max_encoding_message_size: opt.grpc_server_max_encoding_message_size, + executor_heartbeat_interval_seconds: opt.executor_heartbeat_interval_seconds, + override_execution_engine: None, + override_function_registry: None, + override_config_producer: None, + override_runtime_producer: None, + override_logical_codec: None, + override_physical_codec: None, + }) + } +} diff --git a/ballista/executor/src/execution_engine.rs b/ballista/executor/src/execution_engine.rs index 5121f016b..42084267c 100644 --- a/ballista/executor/src/execution_engine.rs +++ b/ballista/executor/src/execution_engine.rs @@ -27,7 +27,6 @@ use std::fmt::Debug; use std::sync::Arc; /// Execution engine extension point - pub trait ExecutionEngine: Sync + Send { fn create_query_stage_exec( &self, diff --git a/ballista/executor/src/execution_loop.rs b/ballista/executor/src/execution_loop.rs index 758b34781..2094425d7 100644 --- a/ballista/executor/src/execution_loop.rs +++ b/ballista/executor/src/execution_loop.rs @@ -19,13 +19,13 @@ use crate::cpu_bound_executor::DedicatedExecutor; use crate::executor::Executor; use crate::{as_task_status, TaskExecutionTimes}; use ballista_core::error::BallistaError; +use ballista_core::extension::SessionConfigHelperExt; use ballista_core::serde::protobuf::{ scheduler_grpc_client::SchedulerGrpcClient, PollWorkParams, PollWorkResult, TaskDefinition, TaskStatus, }; use ballista_core::serde::scheduler::{ExecutorSpecification, PartitionId}; use ballista_core::serde::BallistaCodec; -use ballista_core::utils::SessionConfigExt; use datafusion::execution::context::TaskContext; use datafusion::physical_plan::ExecutionPlan; use datafusion_proto::logical_plan::AsLogicalPlan; @@ -77,16 +77,14 @@ pub async fn poll_loop let task_status: Vec = sample_tasks_status(&mut task_status_receiver).await; - let poll_work_result: anyhow::Result< - tonic::Response, - tonic::Status, - > = scheduler - .poll_work(PollWorkParams { - metadata: Some(executor.metadata.clone()), - num_free_slots: available_task_slots.available_permits() as u32, - task_status, - }) - .await; + let poll_work_result: Result, tonic::Status> = + scheduler + .poll_work(PollWorkParams { + metadata: Some(executor.metadata.clone()), + num_free_slots: available_task_slots.available_permits() as u32, + task_status, + }) + .await; match poll_work_result { Ok(result) => { @@ -163,9 +161,12 @@ async fn run_received_task { + Result::Ok(status) => { task_status.push(status); } Err(TryRecvError::Empty) => { diff --git a/ballista/executor/src/executor.rs b/ballista/executor/src/executor.rs index d9246bfe9..1b029e171 100644 --- a/ballista/executor/src/executor.rs +++ b/ballista/executor/src/executor.rs @@ -23,9 +23,9 @@ use crate::execution_engine::QueryStageExecutor; use crate::metrics::ExecutorMetricsCollector; use crate::metrics::LoggingMetricsCollector; use ballista_core::error::BallistaError; +use ballista_core::registry::BallistaFunctionRegistry; use ballista_core::serde::protobuf; use ballista_core::serde::protobuf::ExecutorRegistration; -use ballista_core::serde::scheduler::BallistaFunctionRegistry; use ballista_core::serde::scheduler::PartitionId; use ballista_core::ConfigProducer; use ballista_core::RuntimeProducer; @@ -111,7 +111,6 @@ impl Executor { /// Create a new executor instance with given [RuntimeEnv], /// [ScalarUDF], [AggregateUDF] and [WindowUDF] - #[allow(clippy::too_many_arguments)] pub fn new( metadata: ExecutorRegistration, @@ -216,13 +215,13 @@ impl Executor { mod test { use crate::execution_engine::DefaultQueryStageExec; use crate::executor::Executor; - use arrow::datatypes::{Schema, SchemaRef}; - use arrow::record_batch::RecordBatch; use ballista_core::execution_plans::ShuffleWriterExec; use ballista_core::serde::protobuf::ExecutorRegistration; use ballista_core::serde::scheduler::PartitionId; use ballista_core::utils::default_config_producer; use ballista_core::RuntimeProducer; + use datafusion::arrow::datatypes::{Schema, SchemaRef}; + use datafusion::arrow::record_batch::RecordBatch; use datafusion::error::{DataFusionError, Result}; use datafusion::execution::context::TaskContext; diff --git a/ballista/executor/src/executor_process.rs b/ballista/executor/src/executor_process.rs index 07881ef58..fac02b48d 100644 --- a/ballista/executor/src/executor_process.rs +++ b/ballista/executor/src/executor_process.rs @@ -21,11 +21,9 @@ use std::net::SocketAddr; use std::sync::atomic::Ordering; use std::sync::Arc; use std::time::{Duration, Instant, UNIX_EPOCH}; -use std::{env, io}; -use anyhow::{Context, Result}; use arrow_flight::flight_service_server::FlightServiceServer; -use ballista_core::serde::scheduler::BallistaFunctionRegistry; +use ballista_core::registry::BallistaFunctionRegistry; use datafusion_proto::logical_plan::LogicalExtensionCodec; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use futures::stream::FuturesUnordered; @@ -37,12 +35,11 @@ use tokio::signal; use tokio::sync::mpsc; use tokio::task::JoinHandle; use tokio::{fs, time}; -use tracing_subscriber::EnvFilter; use uuid::Uuid; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; -use ballista_core::config::{DataCachePolicy, LogRotationPolicy, TaskSchedulingPolicy}; +use ballista_core::config::{LogRotationPolicy, TaskSchedulingPolicy}; use ballista_core::error::BallistaError; use ballista_core::serde::protobuf::executor_resource::Resource; use ballista_core::serde::protobuf::executor_status::Status; @@ -83,14 +80,9 @@ pub struct ExecutorProcessConfig { pub work_dir: Option, pub special_mod_log_level: String, pub print_thread_info: bool, - pub log_file_name_prefix: String, pub log_rotation_policy: LogRotationPolicy, pub job_data_ttl_seconds: u64, pub job_data_clean_up_interval_seconds: u64, - pub data_cache_policy: Option, - pub cache_dir: Option, - pub cache_capacity: u64, - pub cache_io_concurrency: u32, /// The maximum size of a decoded message pub grpc_max_decoding_message_size: u32, /// The maximum size of an encoded message @@ -98,61 +90,70 @@ pub struct ExecutorProcessConfig { pub executor_heartbeat_interval_seconds: u64, /// Optional execution engine to use to execute physical plans, will default to /// DataFusion if none is provided. - pub execution_engine: Option>, + pub override_execution_engine: Option>, /// Overrides default function registry - pub function_registry: Option>, + pub override_function_registry: Option>, /// [RuntimeProducer] override option - pub runtime_producer: Option, + pub override_runtime_producer: Option, /// [ConfigProducer] override option - pub config_producer: Option, + pub override_config_producer: Option, /// [PhysicalExtensionCodec] override option - pub logical_codec: Option>, + pub override_logical_codec: Option>, /// [PhysicalExtensionCodec] override option - pub physical_codec: Option>, + pub override_physical_codec: Option>, } -pub async fn start_executor_process(opt: Arc) -> Result<()> { - let rust_log = env::var(EnvFilter::DEFAULT_ENV); - let log_filter = - EnvFilter::new(rust_log.unwrap_or(opt.special_mod_log_level.clone())); - // File layer - if let Some(log_dir) = opt.log_dir.clone() { - let log_file = match opt.log_rotation_policy { - LogRotationPolicy::Minutely => { - tracing_appender::rolling::minutely(log_dir, &opt.log_file_name_prefix) - } - LogRotationPolicy::Hourly => { - tracing_appender::rolling::hourly(log_dir, &opt.log_file_name_prefix) - } - LogRotationPolicy::Daily => { - tracing_appender::rolling::daily(log_dir, &opt.log_file_name_prefix) - } - LogRotationPolicy::Never => { - tracing_appender::rolling::never(log_dir, &opt.log_file_name_prefix) - } - }; - tracing_subscriber::fmt() - .with_ansi(false) - .with_thread_names(opt.print_thread_info) - .with_thread_ids(opt.print_thread_info) - .with_writer(log_file) - .with_env_filter(log_filter) - .init(); - } else { - // Console layer - tracing_subscriber::fmt() - .with_ansi(false) - .with_thread_names(opt.print_thread_info) - .with_thread_ids(opt.print_thread_info) - .with_writer(io::stdout) - .with_env_filter(log_filter) - .init(); +impl ExecutorProcessConfig { + pub fn log_file_name_prefix(&self) -> String { + format!( + "executor_{}_{}", + self.external_host + .clone() + .unwrap_or_else(|| "localhost".to_string()), + self.port + ) } +} +impl Default for ExecutorProcessConfig { + fn default() -> Self { + Self { + bind_host: "127.0.0.1".into(), + external_host: None, + port: 50051, + grpc_port: 50052, + scheduler_host: "localhost".into(), + scheduler_port: 50050, + scheduler_connect_timeout_seconds: 0, + concurrent_tasks: std::thread::available_parallelism().unwrap().get(), + task_scheduling_policy: Default::default(), + log_dir: None, + work_dir: None, + special_mod_log_level: "INFO".into(), + print_thread_info: true, + log_rotation_policy: Default::default(), + job_data_ttl_seconds: 604800, + job_data_clean_up_interval_seconds: 0, + grpc_max_decoding_message_size: 16777216, + grpc_max_encoding_message_size: 16777216, + executor_heartbeat_interval_seconds: 60, + override_execution_engine: None, + override_function_registry: None, + override_runtime_producer: None, + override_config_producer: None, + override_logical_codec: None, + override_physical_codec: None, + } + } +} + +pub async fn start_executor_process( + opt: Arc, +) -> ballista_core::error::Result<()> { let addr = format!("{}:{}", opt.bind_host, opt.port); - let addr = addr - .parse() - .with_context(|| format!("Could not parse address: {addr}"))?; + let addr = addr.parse().map_err(|e: std::net::AddrParseError| { + BallistaError::Configuration(e.to_string()) + })?; let scheduler_host = opt.scheduler_host.clone(); let scheduler_port = opt.scheduler_port; @@ -194,23 +195,26 @@ pub async fn start_executor_process(opt: Arc) -> Result<( // put them to session config let metrics_collector = Arc::new(LoggingMetricsCollector::default()); let config_producer = opt - .config_producer + .override_config_producer .clone() .unwrap_or_else(|| Arc::new(default_config_producer)); let wd = work_dir.clone(); - let runtime_producer: RuntimeProducer = Arc::new(move |_| { - let config = RuntimeConfig::new().with_temp_file_path(wd.clone()); - Ok(Arc::new(RuntimeEnv::try_new(config)?)) - }); + let runtime_producer: RuntimeProducer = + opt.override_runtime_producer.clone().unwrap_or_else(|| { + Arc::new(move |_| { + let config = RuntimeConfig::new().with_temp_file_path(wd.clone()); + Ok(Arc::new(RuntimeEnv::try_new(config)?)) + }) + }); let logical = opt - .logical_codec + .override_logical_codec .clone() .unwrap_or_else(|| Arc::new(BallistaLogicalExtensionCodec::default())); let physical = opt - .physical_codec + .override_physical_codec .clone() .unwrap_or_else(|| Arc::new(BallistaPhysicalExtensionCodec::default())); @@ -224,17 +228,21 @@ pub async fn start_executor_process(opt: Arc) -> Result<( &work_dir, runtime_producer, config_producer, - opt.function_registry.clone().unwrap_or_default(), + opt.override_function_registry.clone().unwrap_or_default(), metrics_collector, concurrent_tasks, - opt.execution_engine.clone(), + opt.override_execution_engine.clone(), )); let connect_timeout = opt.scheduler_connect_timeout_seconds as u64; let connection = if connect_timeout == 0 { create_grpc_client_connection(scheduler_url) .await - .context("Could not connect to scheduler") + .map_err(|_| { + BallistaError::GrpcConnectionError( + "Could not connect to scheduler".to_string(), + ) + }) } else { // this feature was added to support docker-compose so that we can have the executor // wait for the scheduler to start, or at least run for 10 seconds before failing so @@ -246,8 +254,11 @@ pub async fn start_executor_process(opt: Arc) -> Result<( { match create_grpc_client_connection(scheduler_url.clone()) .await - .context("Could not connect to scheduler") - { + .map_err(|_| { + BallistaError::GrpcConnectionError( + "Could not connect to scheduler".to_string(), + ) + }) { Ok(conn) => { info!("Connected to scheduler at {}", scheduler_url); x = Some(conn); @@ -265,8 +276,7 @@ pub async fn start_executor_process(opt: Arc) -> Result<( Some(conn) => Ok(conn), _ => Err(BallistaError::General(format!( "Timed out attempting to connect to scheduler at {scheduler_url}" - )) - .into()), + ))), } }?; @@ -486,7 +496,10 @@ async fn check_services( /// This function will be scheduled periodically for cleanup the job shuffle data left on the executor. /// Only directories will be checked cleaned. -async fn clean_shuffle_data_loop(work_dir: &str, seconds: u64) -> Result<()> { +async fn clean_shuffle_data_loop( + work_dir: &str, + seconds: u64, +) -> ballista_core::error::Result<()> { let mut dir = fs::read_dir(work_dir).await?; let mut to_deleted = Vec::new(); while let Some(child) = dir.next_entry().await? { @@ -524,7 +537,7 @@ async fn clean_shuffle_data_loop(work_dir: &str, seconds: u64) -> Result<()> { } /// This function will clean up all shuffle data on this executor -async fn clean_all_shuffle_data(work_dir: &str) -> Result<()> { +async fn clean_all_shuffle_data(work_dir: &str) -> ballista_core::error::Result<()> { let mut dir = fs::read_dir(work_dir).await?; let mut to_deleted = Vec::new(); while let Some(child) = dir.next_entry().await? { @@ -549,7 +562,10 @@ async fn clean_all_shuffle_data(work_dir: &str) -> Result<()> { /// Determines if a directory contains files newer than the cutoff time. /// If return true, it means the directory contains files newer than the cutoff time. It satisfy the ttl and should not be deleted. -pub async fn satisfy_dir_ttl(dir: DirEntry, ttl_seconds: u64) -> Result { +pub async fn satisfy_dir_ttl( + dir: DirEntry, + ttl_seconds: u64, +) -> ballista_core::error::Result { let cutoff = get_time_before(ttl_seconds); let mut to_check = vec![dir]; diff --git a/ballista/executor/src/flight_service.rs b/ballista/executor/src/flight_service.rs index a96a752c2..939b5a8f5 100644 --- a/ballista/executor/src/flight_service.rs +++ b/ballista/executor/src/flight_service.rs @@ -17,24 +17,24 @@ //! Implementation of the Apache Arrow Flight protocol that wraps an executor. -use arrow::ipc::reader::StreamReader; +use datafusion::arrow::ipc::reader::StreamReader; use std::convert::TryFrom; use std::fs::File; use std::pin::Pin; -use arrow::ipc::CompressionType; use arrow_flight::encode::FlightDataEncoderBuilder; use arrow_flight::error::FlightError; use ballista_core::error::BallistaError; use ballista_core::serde::decode_protobuf; use ballista_core::serde::scheduler::Action as BallistaAction; +use datafusion::arrow::ipc::CompressionType; -use arrow::ipc::writer::IpcWriteOptions; use arrow_flight::{ flight_service_server::FlightService, Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket, }; +use datafusion::arrow::ipc::writer::IpcWriteOptions; use datafusion::arrow::{error::ArrowError, record_batch::RecordBatch}; use futures::{Stream, StreamExt, TryStreamExt}; use log::{debug, info}; @@ -45,7 +45,6 @@ use tokio::{sync::mpsc::Sender, task}; use tokio_stream::wrappers::ReceiverStream; use tonic::metadata::MetadataValue; use tonic::{Request, Response, Status, Streaming}; -use tracing::warn; /// Service implementing the Apache Arrow Flight Protocol #[derive(Clone)] @@ -103,7 +102,10 @@ impl FlightService for BallistaFlightService { let schema = reader.schema(); task::spawn_blocking(move || { if let Err(e) = read_partition(reader, tx) { - warn!(error = %e, "error streaming shuffle partition"); + log::warn!( + "error streaming shuffle partition: {}", + e.to_string() + ); } }); diff --git a/ballista/executor/src/lib.rs b/ballista/executor/src/lib.rs index bc9d23e87..23e68f85c 100644 --- a/ballista/executor/src/lib.rs +++ b/ballista/executor/src/lib.rs @@ -18,6 +18,8 @@ #![doc = include_str!("../README.md")] pub mod collect; +#[cfg(feature = "build-binary")] +pub mod config; pub mod execution_engine; pub mod execution_loop; pub mod executor; diff --git a/ballista/executor/src/standalone.rs b/ballista/executor/src/standalone.rs index dc23c5308..57082fc2c 100644 --- a/ballista/executor/src/standalone.rs +++ b/ballista/executor/src/standalone.rs @@ -19,8 +19,9 @@ use crate::metrics::LoggingMetricsCollector; use crate::{execution_loop, executor::Executor, flight_service::BallistaFlightService}; use arrow_flight::flight_service_server::FlightServiceServer; use ballista_core::config::BallistaConfig; -use ballista_core::serde::scheduler::BallistaFunctionRegistry; -use ballista_core::utils::{default_config_producer, SessionConfigExt}; +use ballista_core::extension::SessionConfigExt; +use ballista_core::registry::BallistaFunctionRegistry; +use ballista_core::utils::default_config_producer; use ballista_core::{ error::Result, serde::protobuf::{scheduler_grpc_client::SchedulerGrpcClient, ExecutorRegistration}, @@ -47,10 +48,7 @@ use uuid::Uuid; /// /// This provides flexible way of configuring underlying /// components. -pub async fn new_standalone_executor_from_state< - T: 'static + AsLogicalPlan, - U: 'static + AsExecutionPlan, ->( +pub async fn new_standalone_executor_from_state( scheduler: SchedulerGrpcClient, concurrent_tasks: usize, session_state: &SessionState, diff --git a/ballista/scheduler/Cargo.toml b/ballista/scheduler/Cargo.toml index 642e63d48..fc3ca09a8 100644 --- a/ballista/scheduler/Cargo.toml +++ b/ballista/scheduler/Cargo.toml @@ -32,34 +32,33 @@ scheduler = "scheduler_config_spec.toml" [[bin]] name = "ballista-scheduler" path = "src/bin/main.rs" +required-features = ["build-binary"] [features] -default = [] -flight-sql = [] +build-binary = ["configure_me", "clap", "tracing-subscriber", "tracing-appender", "tracing", "ballista-core/build-binary"] +default = ["build-binary"] +flight-sql = ["base64"] keda-scaler = [] prometheus-metrics = ["prometheus", "once_cell"] rest-api = [] [dependencies] -anyhow = "1" arrow-flight = { workspace = true } async-trait = { workspace = true } axum = "0.7.7" ballista-core = { path = "../core", version = "0.12.0" } -base64 = { version = "0.22" } -clap = { workspace = true } -configure_me = { workspace = true } +base64 = { version = "0.22", optional = true } +clap = { workspace = true, optional = true } +configure_me = { workspace = true, optional = true } dashmap = { workspace = true } datafusion = { workspace = true } datafusion-proto = { workspace = true } futures = { workspace = true } -graphviz-rust = "0.9.0" http = "1.1" log = { workspace = true } object_store = { workspace = true } once_cell = { version = "1.16.0", optional = true } parking_lot = { workspace = true } -parse_arg = { workspace = true } prometheus = { version = "0.13", features = ["process"], optional = true } prost = { workspace = true } prost-types = { workspace = true } @@ -68,13 +67,12 @@ serde = { workspace = true, features = ["derive"] } tokio = { workspace = true, features = ["full"] } tokio-stream = { workspace = true, features = ["net"] } tonic = { workspace = true } -tracing = { workspace = true } -tracing-appender = { workspace = true } -tracing-subscriber = { workspace = true } +tracing = { workspace = true, optional = true } +tracing-appender = { workspace = true, optional = true } +tracing-subscriber = { workspace = true, optional = true } uuid = { workspace = true } [dev-dependencies] -ballista-core = { path = "../core", version = "0.12.0" } [build-dependencies] configure_me_codegen = { workspace = true } diff --git a/ballista/scheduler/build.rs b/ballista/scheduler/build.rs index 5a3e00cc1..9f2f123f2 100644 --- a/ballista/scheduler/build.rs +++ b/ballista/scheduler/build.rs @@ -15,10 +15,10 @@ // specific language governing permissions and limitations // under the License. -extern crate configure_me_codegen; - fn main() -> Result<(), String> { + #[cfg(feature = "build-binary")] println!("cargo:rerun-if-changed=scheduler_config_spec.toml"); + #[cfg(feature = "build-binary")] configure_me_codegen::build_script_auto() .map_err(|e| format!("configure_me code generation failed: {e}"))?; diff --git a/ballista/scheduler/scheduler_config_spec.toml b/ballista/scheduler/scheduler_config_spec.toml index 804987d9a..20bceb5f2 100644 --- a/ballista/scheduler/scheduler_config_spec.toml +++ b/ballista/scheduler/scheduler_config_spec.toml @@ -82,9 +82,9 @@ doc = "Delayed interval for cleaning up finished job state. Default: 3600" [[param]] name = "task_distribution" -type = "ballista_scheduler::config::TaskDistribution" +type = "crate::config::TaskDistribution" doc = "The policy of distributing tasks to available executor slots, possible values: bias, round-robin, consistent-hash. Default: bias" -default = "ballista_scheduler::config::TaskDistribution::Bias" +default = "crate::config::TaskDistribution::Bias" [[param]] name = "consistent_hash_num_replicas" diff --git a/ballista/scheduler/src/bin/main.rs b/ballista/scheduler/src/bin/main.rs index 7d8b4b1b0..ea31810a9 100644 --- a/ballista/scheduler/src/bin/main.rs +++ b/ballista/scheduler/src/bin/main.rs @@ -17,36 +17,17 @@ //! Ballista Rust scheduler binary. -use std::sync::Arc; -use std::{env, io}; - -use anyhow::Result; - -use crate::config::{Config, ResultExt}; use ballista_core::config::LogRotationPolicy; +use ballista_core::error::BallistaError; use ballista_core::print_version; use ballista_scheduler::cluster::BallistaCluster; -use ballista_scheduler::config::{ - ClusterStorageConfig, SchedulerConfig, TaskDistribution, TaskDistributionPolicy, -}; +use ballista_scheduler::config::{Config, ResultExt}; use ballista_scheduler::scheduler_process::start_server; +use std::sync::Arc; +use std::{env, io}; use tracing_subscriber::EnvFilter; -#[allow(unused_imports)] -#[macro_use] -extern crate configure_me; - -#[allow(clippy::all, warnings)] -mod config { - // Ideally we would use the include_config macro from configure_me, but then we cannot use - // #[allow(clippy::all)] to silence clippy warnings from the generated code - include!(concat!( - env!("OUT_DIR"), - "/scheduler_configure_me_config.rs" - )); -} - -fn main() -> Result<()> { +fn main() -> ballista_core::error::Result<()> { let runtime = tokio::runtime::Builder::new_multi_thread() .enable_io() .enable_time() @@ -56,7 +37,7 @@ fn main() -> Result<()> { runtime.block_on(inner()) } -async fn inner() -> Result<()> { +async fn inner() -> ballista_core::error::Result<()> { // parse options let (opt, _remaining_args) = Config::including_optional_config_files(&["/etc/ballista/scheduler.toml"]) @@ -67,19 +48,23 @@ async fn inner() -> Result<()> { std::process::exit(0); } - let special_mod_log_level = opt.log_level_setting; - let log_dir = opt.log_dir; - let print_thread_info = opt.print_thread_info; + let rust_log = env::var(EnvFilter::DEFAULT_ENV); + let log_filter = EnvFilter::new(rust_log.unwrap_or(opt.log_level_setting.clone())); - let log_file_name_prefix = format!( - "scheduler_{}_{}_{}", - opt.namespace, opt.external_host, opt.bind_port - ); + let tracing = tracing_subscriber::fmt() + .with_ansi(false) + .with_thread_names(opt.print_thread_info) + .with_thread_ids(opt.print_thread_info) + .with_writer(io::stdout) + .with_env_filter(log_filter); - let rust_log = env::var(EnvFilter::DEFAULT_ENV); - let log_filter = EnvFilter::new(rust_log.unwrap_or(special_mod_log_level)); // File layer - if let Some(log_dir) = log_dir { + if let Some(log_dir) = &opt.log_dir { + let log_file_name_prefix = format!( + "scheduler_{}_{}_{}", + opt.namespace, opt.external_host, opt.bind_port + ); + let log_file = match opt.log_rotation_policy { LogRotationPolicy::Minutely => { tracing_appender::rolling::minutely(log_dir, &log_file_name_prefix) @@ -94,68 +79,19 @@ async fn inner() -> Result<()> { tracing_appender::rolling::never(log_dir, &log_file_name_prefix) } }; - tracing_subscriber::fmt() - .with_ansi(false) - .with_thread_names(print_thread_info) - .with_thread_ids(print_thread_info) - .with_writer(log_file) - .with_env_filter(log_filter) - .init(); + + tracing.with_writer(log_file).init(); } else { - // Console layer - tracing_subscriber::fmt() - .with_ansi(false) - .with_thread_names(print_thread_info) - .with_thread_ids(print_thread_info) - .with_writer(io::stdout) - .with_env_filter(log_filter) - .init(); + tracing.init(); } - let addr = format!("{}:{}", opt.bind_host, opt.bind_port); - let addr = addr.parse()?; - - let cluster_storage_config = ClusterStorageConfig::Memory; - - let task_distribution = match opt.task_distribution { - TaskDistribution::Bias => TaskDistributionPolicy::Bias, - TaskDistribution::RoundRobin => TaskDistributionPolicy::RoundRobin, - TaskDistribution::ConsistentHash => { - let num_replicas = opt.consistent_hash_num_replicas as usize; - let tolerance = opt.consistent_hash_tolerance as usize; - TaskDistributionPolicy::ConsistentHash { - num_replicas, - tolerance, - } - } - }; - - let config = SchedulerConfig { - namespace: opt.namespace, - external_host: opt.external_host, - bind_port: opt.bind_port, - scheduling_policy: opt.scheduler_policy, - event_loop_buffer_size: opt.event_loop_buffer_size, - task_distribution, - finished_job_data_clean_up_interval_seconds: opt - .finished_job_data_clean_up_interval_seconds, - finished_job_state_clean_up_interval_seconds: opt - .finished_job_state_clean_up_interval_seconds, - advertise_flight_sql_endpoint: opt.advertise_flight_sql_endpoint, - cluster_storage: cluster_storage_config, - job_resubmit_interval_ms: (opt.job_resubmit_interval_ms > 0) - .then_some(opt.job_resubmit_interval_ms), - executor_termination_grace_period: opt.executor_termination_grace_period, - scheduler_event_expected_processing_duration: opt - .scheduler_event_expected_processing_duration, - grpc_server_max_decoding_message_size: opt.grpc_server_max_decoding_message_size, - grpc_server_max_encoding_message_size: opt.grpc_server_max_encoding_message_size, - executor_timeout_seconds: opt.executor_timeout_seconds, - expire_dead_executor_interval_seconds: opt.expire_dead_executor_interval_seconds, - }; + let addr = addr.parse().map_err(|e: std::net::AddrParseError| { + BallistaError::Configuration(e.to_string()) + })?; + let config = opt.try_into()?; let cluster = BallistaCluster::new_from_config(&config).await?; - start_server(cluster, addr, Arc::new(config)).await?; + Ok(()) } diff --git a/ballista/scheduler/src/cluster/memory.rs b/ballista/scheduler/src/cluster/memory.rs index 6e32510a0..07f646b8c 100644 --- a/ballista/scheduler/src/cluster/memory.rs +++ b/ballista/scheduler/src/cluster/memory.rs @@ -37,7 +37,7 @@ use crate::scheduler_server::{timestamp_millis, timestamp_secs, SessionBuilder}; use crate::state::session_manager::create_datafusion_context; use crate::state::task_manager::JobInfoCache; use ballista_core::serde::protobuf::job_status::Status; -use log::{error, info, warn}; +use log::{debug, error, info, warn}; use std::collections::{HashMap, HashSet}; use std::ops::DerefMut; @@ -45,7 +45,6 @@ use ballista_core::consistent_hash::node::Node; use datafusion::physical_plan::ExecutionPlan; use std::sync::Arc; use tokio::sync::{Mutex, MutexGuard}; -use tracing::debug; #[derive(Default)] pub struct InMemoryClusterState { @@ -290,7 +289,7 @@ pub struct InMemoryJobState { session_builder: SessionBuilder, /// Sender of job events job_event_sender: ClusterEventSender, - + /// Config producer config_producer: ConfigProducer, } @@ -408,7 +407,7 @@ impl JobState for InMemoryJobState { &self, config: &SessionConfig, ) -> Result> { - let session = create_datafusion_context(config, self.session_builder.clone()); + let session = create_datafusion_context(config, self.session_builder.clone())?; self.sessions.insert(session.session_id(), session.clone()); Ok(session) @@ -419,7 +418,7 @@ impl JobState for InMemoryJobState { session_id: &str, config: &SessionConfig, ) -> Result> { - let session = create_datafusion_context(config, self.session_builder.clone()); + let session = create_datafusion_context(config, self.session_builder.clone())?; self.sessions .insert(session_id.to_string(), session.clone()); diff --git a/ballista/scheduler/src/cluster/mod.rs b/ballista/scheduler/src/cluster/mod.rs index 2869c8876..c54b0ceae 100644 --- a/ballista/scheduler/src/cluster/mod.rs +++ b/ballista/scheduler/src/cluster/mod.rs @@ -16,7 +16,6 @@ // under the License. use std::collections::{HashMap, HashSet}; -use std::fmt; use std::pin::Pin; use std::sync::Arc; @@ -69,9 +68,9 @@ impl std::str::FromStr for ClusterStorage { ValueEnum::from_str(s, true) } } - -impl parse_arg::ParseArgFromStr for ClusterStorage { - fn describe_type(mut writer: W) -> fmt::Result { +#[cfg(feature = "build-binary")] +impl configure_me::parse_arg::ParseArgFromStr for ClusterStorage { + fn describe_type(mut writer: W) -> std::fmt::Result { write!(writer, "The cluster storage backend for the scheduler") } } @@ -111,11 +110,21 @@ impl BallistaCluster { pub async fn new_from_config(config: &SchedulerConfig) -> Result { let scheduler = config.scheduler_name(); + let session_builder = config + .override_session_builder + .clone() + .unwrap_or_else(|| Arc::new(default_session_builder)); + + let config_producer = config + .override_config_producer + .clone() + .unwrap_or_else(|| Arc::new(default_config_producer)); + match &config.cluster_storage { ClusterStorageConfig::Memory => Ok(BallistaCluster::new_memory( scheduler, - Arc::new(default_session_builder), - Arc::new(default_config_producer), + session_builder, + config_producer, )), } } diff --git a/ballista/scheduler/src/config.rs b/ballista/scheduler/src/config.rs index ce542e519..b221ecb65 100644 --- a/ballista/scheduler/src/config.rs +++ b/ballista/scheduler/src/config.rs @@ -18,18 +18,28 @@ //! Ballista scheduler specific configuration -use ballista_core::config::TaskSchedulingPolicy; -use clap::ValueEnum; -use std::fmt; +use crate::SessionBuilder; +use ballista_core::{config::TaskSchedulingPolicy, ConfigProducer}; +use datafusion_proto::logical_plan::LogicalExtensionCodec; +use datafusion_proto::physical_plan::PhysicalExtensionCodec; +use std::sync::Arc; + +#[cfg(feature = "build-binary")] +include!(concat!( + env!("OUT_DIR"), + "/scheduler_configure_me_config.rs" +)); /// Configurations for the ballista scheduler of scheduling jobs and tasks -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct SchedulerConfig { /// Namespace of this scheduler. Schedulers using the same cluster storage and namespace /// will share global cluster state. pub namespace: String, /// The external hostname of the scheduler pub external_host: String, + /// The bind host for the scheduler's gRPC service + pub bind_host: String, /// The bind port for the scheduler's gRPC service pub bind_port: u16, /// The task scheduling policy for the scheduler @@ -62,21 +72,31 @@ pub struct SchedulerConfig { pub executor_timeout_seconds: u64, /// The interval to check expired or dead executors pub expire_dead_executor_interval_seconds: u64, + + /// [ConfigProducer] override option + pub override_config_producer: Option, + /// [SessionBuilder] override option + pub override_session_builder: Option, + /// [PhysicalExtensionCodec] override option + pub override_logical_codec: Option>, + /// [PhysicalExtensionCodec] override option + pub override_physical_codec: Option>, } impl Default for SchedulerConfig { fn default() -> Self { Self { namespace: String::default(), - external_host: "localhost".to_string(), + external_host: "localhost".into(), bind_port: 50050, - scheduling_policy: TaskSchedulingPolicy::PullStaged, + bind_host: "127.0.0.1".into(), + scheduling_policy: Default::default(), event_loop_buffer_size: 10000, - task_distribution: TaskDistributionPolicy::Bias, + task_distribution: Default::default(), finished_job_data_clean_up_interval_seconds: 300, finished_job_state_clean_up_interval_seconds: 3600, advertise_flight_sql_endpoint: None, - cluster_storage: ClusterStorageConfig::Memory, + cluster_storage: Default::default(), job_resubmit_interval_ms: None, executor_termination_grace_period: 0, scheduler_event_expected_processing_duration: 0, @@ -84,6 +104,10 @@ impl Default for SchedulerConfig { grpc_server_max_encoding_message_size: 16777216, executor_timeout_seconds: 180, expire_dead_executor_interval_seconds: 15, + override_config_producer: None, + override_session_builder: None, + override_logical_codec: None, + override_physical_codec: None, } } } @@ -177,15 +201,17 @@ impl SchedulerConfig { } } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Default)] pub enum ClusterStorageConfig { + #[default] Memory, } /// Policy of distributing tasks to available executor slots /// /// It needs to be visible to code generated by configure_me -#[derive(Clone, ValueEnum, Copy, Debug, serde::Deserialize)] +#[derive(Clone, Copy, Debug, serde::Deserialize)] +#[cfg_attr(feature = "build-binary", derive(clap::ValueEnum))] pub enum TaskDistribution { /// Eagerly assign tasks to executor slots. This will assign as many task slots per executor /// as are currently available @@ -200,24 +226,27 @@ pub enum TaskDistribution { ConsistentHash, } +#[cfg(feature = "build-binary")] impl std::str::FromStr for TaskDistribution { type Err = String; fn from_str(s: &str) -> std::result::Result { - ValueEnum::from_str(s, true) + clap::ValueEnum::from_str(s, true) } } -impl parse_arg::ParseArgFromStr for TaskDistribution { - fn describe_type(mut writer: W) -> fmt::Result { +#[cfg(feature = "build-binary")] +impl configure_me::parse_arg::ParseArgFromStr for TaskDistribution { + fn describe_type(mut writer: W) -> std::fmt::Result { write!(writer, "The executor slots policy for the scheduler") } } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Default)] pub enum TaskDistributionPolicy { /// Eagerly assign tasks to executor slots. This will assign as many task slots per executor /// as are currently available + #[default] Bias, /// Distribute tasks evenly across executors. This will try and iterate through available executors /// and assign one task to each executor until all tasks are assigned. @@ -231,3 +260,56 @@ pub enum TaskDistributionPolicy { tolerance: usize, }, } +#[cfg(feature = "build-binary")] +impl TryFrom for SchedulerConfig { + type Error = ballista_core::error::BallistaError; + + fn try_from(opt: Config) -> Result { + let task_distribution = match opt.task_distribution { + TaskDistribution::Bias => TaskDistributionPolicy::Bias, + TaskDistribution::RoundRobin => TaskDistributionPolicy::RoundRobin, + TaskDistribution::ConsistentHash => { + let num_replicas = opt.consistent_hash_num_replicas as usize; + let tolerance = opt.consistent_hash_tolerance as usize; + TaskDistributionPolicy::ConsistentHash { + num_replicas, + tolerance, + } + } + }; + + let config = SchedulerConfig { + namespace: opt.namespace, + external_host: opt.external_host, + bind_port: opt.bind_port, + bind_host: opt.bind_host, + scheduling_policy: opt.scheduler_policy, + event_loop_buffer_size: opt.event_loop_buffer_size, + task_distribution, + finished_job_data_clean_up_interval_seconds: opt + .finished_job_data_clean_up_interval_seconds, + finished_job_state_clean_up_interval_seconds: opt + .finished_job_state_clean_up_interval_seconds, + advertise_flight_sql_endpoint: opt.advertise_flight_sql_endpoint, + cluster_storage: Default::default(), + job_resubmit_interval_ms: (opt.job_resubmit_interval_ms > 0) + .then_some(opt.job_resubmit_interval_ms), + executor_termination_grace_period: opt.executor_termination_grace_period, + scheduler_event_expected_processing_duration: opt + .scheduler_event_expected_processing_duration, + grpc_server_max_decoding_message_size: opt + .grpc_server_max_decoding_message_size, + grpc_server_max_encoding_message_size: opt + .grpc_server_max_encoding_message_size, + executor_timeout_seconds: opt.executor_timeout_seconds, + expire_dead_executor_interval_seconds: opt + .expire_dead_executor_interval_seconds, + override_config_producer: None, + override_logical_codec: None, + override_physical_codec: None, + override_session_builder: None, + }; + + Ok(config) + } +} diff --git a/ballista/scheduler/src/display.rs b/ballista/scheduler/src/display.rs index 9026e0f08..fa26331ef 100644 --- a/ballista/scheduler/src/display.rs +++ b/ballista/scheduler/src/display.rs @@ -87,7 +87,7 @@ impl<'a> DisplayableBallistaExecutionPlan<'a> { plan: &'a dyn ExecutionPlan, metrics: &'a Vec, } - impl<'a> fmt::Display for Wrapper<'a> { + impl fmt::Display for Wrapper<'_> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let t = DisplayFormatType::Default; let mut visitor = IndentVisitor { @@ -121,7 +121,7 @@ struct IndentVisitor<'a, 'b> { metric_index: usize, } -impl<'a, 'b> ExecutionPlanVisitor for IndentVisitor<'a, 'b> { +impl ExecutionPlanVisitor for IndentVisitor<'_, '_> { type Error = fmt::Error; fn pre_visit( &mut self, @@ -150,7 +150,7 @@ impl<'a, 'b> ExecutionPlanVisitor for IndentVisitor<'a, 'b> { } } -impl<'a> ToStringifiedPlan for DisplayableBallistaExecutionPlan<'a> { +impl ToStringifiedPlan for DisplayableBallistaExecutionPlan<'_> { fn to_stringified( &self, plan_type: datafusion::logical_expr::PlanType, diff --git a/ballista/scheduler/src/scheduler_process.rs b/ballista/scheduler/src/scheduler_process.rs index 4b9706079..bf6d484f0 100644 --- a/ballista/scheduler/src/scheduler_process.rs +++ b/ballista/scheduler/src/scheduler_process.rs @@ -15,11 +15,13 @@ // specific language governing permissions and limitations // under the License. -use anyhow::{Error, Result}; #[cfg(feature = "flight-sql")] use arrow_flight::flight_service_server::FlightServiceServer; +use ballista_core::error::BallistaError; use ballista_core::serde::protobuf::scheduler_grpc_server::SchedulerGrpcServer; -use ballista_core::serde::BallistaCodec; +use ballista_core::serde::{ + BallistaCodec, BallistaLogicalExtensionCodec, BallistaPhysicalExtensionCodec, +}; use ballista_core::utils::create_grpc_server; use ballista_core::BALLISTA_VERSION; use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode}; @@ -41,7 +43,7 @@ pub async fn start_server( cluster: BallistaCluster, addr: SocketAddr, config: Arc, -) -> Result<()> { +) -> ballista_core::error::Result<()> { info!( "Ballista v{} Scheduler listening on {:?}", BALLISTA_VERSION, addr @@ -54,11 +56,23 @@ pub async fn start_server( let metrics_collector = default_metrics_collector()?; + let codec_logical = config + .override_logical_codec + .clone() + .unwrap_or_else(|| Arc::new(BallistaLogicalExtensionCodec::default())); + + let codec_physical = config + .override_physical_codec + .clone() + .unwrap_or_else(|| Arc::new(BallistaPhysicalExtensionCodec::default())); + + let codec = BallistaCodec::new(codec_logical, codec_physical); + let mut scheduler_server: SchedulerServer = SchedulerServer::new( config.scheduler_name(), cluster, - BallistaCodec::default(), + codec, config, metrics_collector, ); @@ -95,9 +109,9 @@ pub async fn start_server( let listener = tokio::net::TcpListener::bind(&addr) .await - .map_err(Error::from)?; + .map_err(BallistaError::from)?; axum::serve(listener, final_route) .await - .map_err(Error::from) + .map_err(BallistaError::from) } diff --git a/ballista/scheduler/src/scheduler_server/grpc.rs b/ballista/scheduler/src/scheduler_server/grpc.rs index b03a99307..02c21a884 100644 --- a/ballista/scheduler/src/scheduler_server/grpc.rs +++ b/ballista/scheduler/src/scheduler_server/grpc.rs @@ -17,7 +17,8 @@ use axum::extract::ConnectInfo; use ballista_core::config::BALLISTA_JOB_NAME; -use ballista_core::serde::protobuf::execute_query_params::{OptionalSessionId, Query}; +use ballista_core::extension::SessionConfigHelperExt; +use ballista_core::serde::protobuf::execute_query_params::Query; use ballista_core::serde::protobuf::scheduler_grpc_server::SchedulerGrpc; use ballista_core::serde::protobuf::{ execute_query_failure_result, execute_query_result, AvailableTaskSlots, @@ -31,7 +32,6 @@ use ballista_core::serde::protobuf::{ UpdateTaskStatusParams, UpdateTaskStatusResult, }; use ballista_core::serde::scheduler::ExecutorMetadata; -use ballista_core::utils::SessionConfigExt; use datafusion_proto::logical_plan::AsLogicalPlan; use datafusion_proto::physical_plan::AsExecutionPlan; use log::{debug, error, info, trace, warn}; @@ -337,25 +337,28 @@ impl SchedulerGrpc let query_params = request.into_inner(); if let ExecuteQueryParams { query: Some(query), - optional_session_id, + session_id, settings, } = query_params { let job_name = settings .iter() .find(|s| s.key == BALLISTA_JOB_NAME) - .map(|s| s.value.clone()) - .unwrap_or_else(|| "None".to_string()); + .and_then(|s| s.value.clone()) + .unwrap_or_default(); - let (session_id, session_ctx) = match optional_session_id { - Some(OptionalSessionId::SessionId(session_id)) => { + let (session_id, session_ctx) = match session_id { + Some(session_id) => { match self.state.session_manager.get_session(&session_id).await { Ok(ctx) => { - // [SessionConfig] will be updated from received properties + // Update [SessionConfig] using received properties // TODO MM can we do something better here? // move this to update session and use .update_session(&session_params.session_id, &session_config) - // instead of get_session + // instead of get_session. + // + // also we should consider sending properties if/when changed rather than + // all properties every time let state = ctx.state_ref(); let mut state = state.write(); diff --git a/ballista/scheduler/src/scheduler_server/mod.rs b/ballista/scheduler/src/scheduler_server/mod.rs index 5fa222595..653e2d410 100644 --- a/ballista/scheduler/src/scheduler_server/mod.rs +++ b/ballista/scheduler/src/scheduler_server/mod.rs @@ -56,7 +56,8 @@ mod external_scaler; mod grpc; pub(crate) mod query_stage_scheduler; -pub type SessionBuilder = Arc SessionState + Send + Sync>; +pub type SessionBuilder = + Arc datafusion::common::Result + Send + Sync>; #[derive(Clone)] pub struct SchedulerServer { @@ -346,7 +347,7 @@ pub fn timestamp_millis() -> u64 { mod test { use std::sync::Arc; - use ballista_core::utils::SessionConfigExt; + use ballista_core::extension::SessionConfigExt; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::functions_aggregate::sum::sum; use datafusion::logical_expr::{col, LogicalPlan}; diff --git a/ballista/scheduler/src/scheduler_server/query_stage_scheduler.rs b/ballista/scheduler/src/scheduler_server/query_stage_scheduler.rs index c3f3e7eb8..b9b49c7fe 100644 --- a/ballista/scheduler/src/scheduler_server/query_stage_scheduler.rs +++ b/ballista/scheduler/src/scheduler_server/query_stage_scheduler.rs @@ -359,14 +359,9 @@ mod tests { use datafusion::test_util::scan_empty_with_partitions; use std::sync::Arc; use std::time::Duration; - use tracing_subscriber::EnvFilter; #[tokio::test] async fn test_pending_job_metric() -> Result<()> { - tracing_subscriber::fmt() - .with_env_filter(EnvFilter::from_default_env()) - .init(); - let plan = test_plan(10); let metrics_collector = Arc::new(TestMetricsCollector::default()); diff --git a/ballista/scheduler/src/standalone.rs b/ballista/scheduler/src/standalone.rs index 1e7d93844..e9c483456 100644 --- a/ballista/scheduler/src/standalone.rs +++ b/ballista/scheduler/src/standalone.rs @@ -19,10 +19,10 @@ use crate::cluster::BallistaCluster; use crate::config::SchedulerConfig; use crate::metrics::default_metrics_collector; use crate::scheduler_server::SchedulerServer; +use ballista_core::extension::SessionConfigExt; use ballista_core::serde::BallistaCodec; use ballista_core::utils::{ create_grpc_server, default_config_producer, default_session_builder, - SessionConfigExt, }; use ballista_core::ConfigProducer; use ballista_core::{ @@ -57,9 +57,11 @@ pub async fn new_standalone_scheduler_from_state( let session_config = session_state.config().clone(); let session_state = session_state.clone(); let session_builder = Arc::new(move |c: SessionConfig| { - SessionStateBuilder::new_from_existing(session_state.clone()) - .with_config(c) - .build() + Ok( + SessionStateBuilder::new_from_existing(session_state.clone()) + .with_config(c) + .build(), + ) }); let config_producer = Arc::new(move || session_config.clone()); diff --git a/ballista/scheduler/src/state/execution_graph_dot.rs b/ballista/scheduler/src/state/execution_graph_dot.rs index f2c9bf1d8..68a2ebdfc 100644 --- a/ballista/scheduler/src/state/execution_graph_dot.rs +++ b/ballista/scheduler/src/state/execution_graph_dot.rs @@ -418,7 +418,7 @@ mod tests { use crate::state::execution_graph::ExecutionGraph; use crate::state::execution_graph_dot::ExecutionGraphDot; use ballista_core::error::{BallistaError, Result}; - use ballista_core::utils::SessionConfigExt; + use ballista_core::extension::SessionConfigExt; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::datasource::MemTable; use datafusion::prelude::{SessionConfig, SessionContext}; diff --git a/ballista/scheduler/src/state/session_manager.rs b/ballista/scheduler/src/state/session_manager.rs index 8a769edbd..598131670 100644 --- a/ballista/scheduler/src/state/session_manager.rs +++ b/ballista/scheduler/src/state/session_manager.rs @@ -67,7 +67,7 @@ impl SessionManager { pub fn create_datafusion_context( session_config: &SessionConfig, session_builder: SessionBuilder, -) -> Arc { +) -> datafusion::common::Result> { let session_state = if session_config.round_robin_repartition() { let session_config = session_config .clone() @@ -75,10 +75,10 @@ pub fn create_datafusion_context( .with_round_robin_repartition(false); log::warn!("session manager will override `datafusion.optimizer.enable_round_robin_repartition` to `false` "); - session_builder(session_config) + session_builder(session_config)? } else { - session_builder(session_config.clone()) + session_builder(session_config.clone())? }; - Arc::new(SessionContext::new_with_state(session_state)) + Ok(Arc::new(SessionContext::new_with_state(session_state))) } diff --git a/ballista/scheduler/src/state/task_manager.rs b/ballista/scheduler/src/state/task_manager.rs index 11b99ae57..cc8442f2f 100644 --- a/ballista/scheduler/src/state/task_manager.rs +++ b/ballista/scheduler/src/state/task_manager.rs @@ -24,7 +24,7 @@ use crate::state::executor_manager::ExecutorManager; use ballista_core::error::BallistaError; use ballista_core::error::Result; -use ballista_core::utils::SessionConfigExt; +use ballista_core::extension::SessionConfigHelperExt; use datafusion::prelude::SessionConfig; use crate::cluster::JobState; @@ -38,7 +38,7 @@ use dashmap::DashMap; use datafusion::physical_plan::ExecutionPlan; use datafusion_proto::logical_plan::AsLogicalPlan; use datafusion_proto::physical_plan::AsExecutionPlan; -use log::{debug, error, info, warn}; +use log::{debug, error, info, trace, warn}; use rand::distributions::Alphanumeric; use rand::{thread_rng, Rng}; use std::collections::{HashMap, HashSet}; @@ -48,8 +48,6 @@ use std::time::Duration; use std::time::{SystemTime, UNIX_EPOCH}; use tokio::sync::RwLock; -use tracing::trace; - type ActiveJobCache = Arc>; // TODO move to configuration file diff --git a/ballista/scheduler/src/test_utils.rs b/ballista/scheduler/src/test_utils.rs index 7f59f89dd..8e4565a45 100644 --- a/ballista/scheduler/src/test_utils.rs +++ b/ballista/scheduler/src/test_utils.rs @@ -16,6 +16,7 @@ // under the License. use ballista_core::error::{BallistaError, Result}; +use ballista_core::extension::SessionConfigExt; use datafusion::catalog::Session; use std::any::Any; use std::collections::HashMap; @@ -56,9 +57,7 @@ use crate::cluster::BallistaCluster; use crate::scheduler_server::event::QueryStageSchedulerEvent; use crate::state::execution_graph::{ExecutionGraph, ExecutionStage, TaskDescription}; -use ballista_core::utils::{ - default_config_producer, default_session_builder, SessionConfigExt, -}; +use ballista_core::utils::{default_config_producer, default_session_builder}; use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode}; use parking_lot::Mutex; use tokio::sync::mpsc::{channel, Receiver, Sender}; diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index 84820d48a..941ec8498 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -25,7 +25,6 @@ homepage = "https://github.com/apache/arrow-ballista" repository = "https://github.com/apache/arrow-ballista" license = "Apache-2.0" publish = false -rust-version = "1.72" [features] ci = [] @@ -38,7 +37,7 @@ datafusion = { workspace = true } datafusion-proto = { workspace = true } env_logger = { workspace = true } futures = { workspace = true } -mimalloc = { version = "0.1", optional = true, default-features = false } +mimalloc = { workspace = true, optional = true } rand = { workspace = true } serde = { workspace = true } serde_json = "1.0.78" diff --git a/docs/source/index.rst b/docs/source/index.rst index 959d5844b..9289eab75 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -55,6 +55,7 @@ Table of content user-guide/tuning-guide user-guide/metrics user-guide/faq + user-guide/extending-components .. _toc.contributors: diff --git a/docs/source/user-guide/extending-components.md b/docs/source/user-guide/extending-components.md new file mode 100644 index 000000000..60de1b7b1 --- /dev/null +++ b/docs/source/user-guide/extending-components.md @@ -0,0 +1,250 @@ + + +# Extending Ballista Scheduler And Executors + +Ballista scheduler and executor provide a set of configuration options +which can be used to extend their basic functionality. They allow registering +new configuration extensions, object stores, logical and physical codecs ... + +- `function registry` - provides possibility to override set of build in functions. +- `config producer` - function which creates new `SessionConfig`, which can hold extended configuration options +- `runtime producer` - function which creates new `RuntimeEnv` based on provided `SessionConfig`. +- `session builder` - function which creates new `SessionState` for each user session +- `logical codec` - overrides `LogicalCodec` +- `physical codec` - overrides `PhysicalCodec` + +Ballista executor can be configured using `ExecutorProcessConfig` which supports overriding `function registry`,`runtime producer`, `config producer`, `logical codec`, `physical codec`. + +Ballista scheduler can be tunned using `SchedulerConfig` which supports overriding `config producer`, `session builder`, `logical codec`, `physical codec` + +## Example: Custom Object Store Integration + +Extending basic building blocks will be demonstrated by integrating S3 object store. For this, new `ObjectStoreRegistry` and `S3Options` will be provided. `ObjectStoreRegistry` creates new `ObjectStore` instances configured using `S3Options`. + +For this specific task `config producer`, `runtime producer` and `session builder` have to be provided, and client, scheduler and executor need to be configured. + +```rust +/// Custom [SessionConfig] constructor method +/// +/// This method registers config extension [S3Options] +/// which is used to configure [ObjectStore] with ACCESS and +/// SECRET key +pub fn custom_session_config_with_s3_options() -> SessionConfig { + SessionConfig::new_with_ballista() + .with_information_schema(true) + .with_option_extension(S3Options::default()) +} +``` + +```rust +/// Custom [RuntimeEnv] constructor method +/// +/// It will register [CustomObjectStoreRegistry] which will +/// use configuration extension [S3Options] to configure +/// and created [ObjectStore]s +pub fn custom_runtime_env_with_s3_support( + session_config: &SessionConfig, +) -> Result> { + let s3options = session_config + .options() + .extensions + .get::() + .ok_or(DataFusionError::Configuration( + "S3 Options not set".to_string(), + ))?; + + let config = RuntimeConfig::new().with_object_store_registry(Arc::new( + CustomObjectStoreRegistry::new(s3options.clone()), + )); + + Ok(Arc::new(RuntimeEnv::try_new(config)?)) +} +``` + +```rust +/// Custom [SessionState] constructor method +/// +/// It will configure [SessionState] with provided [SessionConfig], +/// and [RuntimeEnv]. +pub fn custom_session_state_with_s3_support( + session_config: SessionConfig, +) -> SessionState { + let runtime_env = custom_runtime_env_with_s3_support(&session_config).unwrap(); + + SessionStateBuilder::new() + .with_runtime_env(runtime_env) + .with_config(session_config) + .build() +} +``` + +`S3Options` & `CustomObjectStoreRegistry` implementation can be found in examples sub-project. + +### Configuring Scheduler + +```rust +#[tokio::main] +async fn main() -> Result<()> { + // parse CLI options (default options which Ballista scheduler exposes) + let (opt, _remaining_args) = + Config::including_optional_config_files(&["/etc/ballista/scheduler.toml"]) + .unwrap_or_exit(); + + let addr = format!("{}:{}", opt.bind_host, opt.bind_port); + let addr = addr.parse()?; + + // converting CLI options to SchedulerConfig + let mut config: SchedulerConfig = opt.try_into()?; + + // overriding default runtime producer with custom producer + // which knows how to create S3 connections + config.override_config_producer = + Some(Arc::new(custom_session_config_with_s3_options)); + + // overriding default session builder, which has custom session configuration + // runtime environment and session state. + config.override_session_builder = Some(Arc::new(|session_config: SessionConfig| { + custom_session_state_with_s3_support(session_config) + })); + let cluster = BallistaCluster::new_from_config(&config).await?; + start_server(cluster, addr, Arc::new(config)).await?; + Ok(()) +} +``` + +### Configuring Executor + +```rust +#[tokio::main] +async fn main() -> Result<()> { + // parse CLI options (default options which Ballista executor exposes) + let (opt, _remaining_args) = + Config::including_optional_config_files(&["/etc/ballista/executor.toml"]) + .unwrap_or_exit(); + + // Converting CLI options to executor configuration + let mut config: ExecutorProcessConfig = opt.try_into().unwrap(); + + // overriding default config producer with custom producer + // which has required S3 configuration options + config.override_config_producer = + Some(Arc::new(custom_session_config_with_s3_options)); + + // overriding default runtime producer with custom producer + // which knows how to create S3 connections + config.override_runtime_producer = + Some(Arc::new(|session_config: &SessionConfig| { + custom_runtime_env_with_s3_support(session_config) + })); + + start_executor_process(Arc::new(config)).await + Ok(()) +} + +``` + +### Configuring Client + +```rust +let test_data = ballista_examples::test_util::examples_test_data(); + +// new sessions state with required custom session configuration and runtime environment +let state = + custom_session_state_with_s3_support(custom_session_config_with_s3_options()); + +let ctx: SessionContext = + SessionContext::remote_with_state("df://localhost:50050", state).await?; + +// once we have it all setup we can configure object store +// +// as session config has relevant S3 options registered and exposed, +// S3 configuration options can be changed using SQL `SET` statement. + +ctx.sql("SET s3.allow_http = true").await?.show().await?; + +ctx.sql(&format!("SET s3.access_key_id = '{}'", S3_ACCESS_KEY_ID)) + .await? + .show() + .await?; + +ctx.sql(&format!("SET s3.secret_access_key = '{}'", S3_SECRET_KEY)) + .await? + .show() + .await?; + +ctx.sql("SET s3.endpoint = 'http://localhost:9000'") + .await? + .show() + .await?; +ctx.sql("SET s3.allow_http = true").await?.show().await?; + +ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), +) +.await?; + +let write_dir_path = &format!("s3://{}/write_test.parquet", S3_BUCKET); + +ctx.sql("select * from test") + .await? + .write_parquet(write_dir_path, Default::default(), Default::default()) + .await?; + +ctx.register_parquet("written_table", write_dir_path, Default::default()) + .await?; + +let result = ctx + .sql("select id, string_col, timestamp_col from written_table where id > 4") + .await? + .collect() + .await?; + +let expected = [ + "+----+------------+---------------------+", + "| id | string_col | timestamp_col |", + "+----+------------+---------------------+", + "| 5 | 31 | 2009-03-01T00:01:00 |", + "| 6 | 30 | 2009-04-01T00:00:00 |", + "| 7 | 31 | 2009-04-01T00:01:00 |", + "+----+------------+---------------------+", +]; + +assert_batches_eq!(expected, &result); +``` + +## Example: Client Side Logical/Physical Codec + +Default physical and logical codecs can be replaced if needed. For scheduler and executor procedure is similar to previous example. At the client side procedure is slightly different, `ballista::prelude::SessionConfigExt` provides methods to be used to override physical and logical codecs on client side. + +```rust +let session_config = SessionConfig::new_with_ballista() + .with_information_schema(true) + .with_ballista_physical_extension_codec(Arc::new(BetterPhysicalCodec::default())) + .with_ballista_logical_extension_codec(Arc::new(BetterLogicalCodec::default())); + +let state = SessionStateBuilder::new() + .with_default_features() + .with_config(session_config) + .build(); + +let ctx: SessionContext = SessionContext::standalone_with_state(state).await?; +``` diff --git a/docs/source/user-guide/tuning-guide.md b/docs/source/user-guide/tuning-guide.md index b1b61240b..22955b44c 100644 --- a/docs/source/user-guide/tuning-guide.md +++ b/docs/source/user-guide/tuning-guide.md @@ -32,7 +32,7 @@ For example, if there is a table "customer" that consists of 200 Parquet files, 200 partitions and the table scan and certain subsequent operations will also have 200 partitions. Conversely, if the table only has a single Parquet file then there will be a single partition and the work will not be able to scale even if the cluster has resource available. Ballista supports repartitioning within a query to improve parallelism. -The configuration setting `ballista.shuffle.partitions`can be set to the desired number of partitions. This is +The configuration setting `datafusion.execution.target_partitions`can be set to the desired number of partitions. This is currently a global setting for the entire context. The default value for this setting is 16. Note that Ballista will never decrease the number of partitions based on this setting and will only repartition if @@ -41,11 +41,17 @@ the source operation has fewer partitions than this setting. Example: Setting the desired number of shuffle partitions when creating a context. ```rust -let config = BallistaConfig::builder() - .set("ballista.shuffle.partitions", "200") - .build()?; +use ballista::extension::{SessionConfigExt, SessionContextExt}; -let ctx = BallistaContext::remote("localhost", 50050, &config).await?; +let session_config = SessionConfig::new_with_ballista() + .with_target_partitions(200); + +let state = SessionStateBuilder::new() + .with_default_features() + .with_config(session_config) + .build(); + +let ctx: SessionContext = SessionContext::remote_with_state(&url,state).await?; ``` ## Configuring Executor Concurrency Levels @@ -75,6 +81,8 @@ processes. The default is `pull-based`. The scheduler provides a REST API for monitoring jobs. See the [scheduler documentation](scheduler.md) for more information. +> This is optional scheduler feature which should be enabled with rest-api feature + To download a query plan in dot format from the scheduler, submit a request to the following API endpoint ``` diff --git a/examples/Cargo.toml b/examples/Cargo.toml index c87c039cf..743ff8264 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -26,7 +26,6 @@ license = "Apache-2.0" keywords = ["arrow", "distributed", "query", "sql"] edition = "2021" publish = false -rust-version = "1.72" [[example]] name = "standalone_sql" @@ -35,7 +34,14 @@ required-features = ["ballista/standalone"] [dependencies] ballista = { path = "../ballista/client", version = "0.12.0" } +ballista-core = { path = "../ballista/core", version = "0.12.0" } +ballista-executor = { path = "../ballista/executor", version = "0.12.0", default-features = false } +ballista-scheduler = { path = "../ballista/scheduler", version = "0.12.0", default-features = false } datafusion = { workspace = true } +env_logger = { workspace = true } +log = { workspace = true } +object_store = { workspace = true, features = ["aws"] } +parking_lot = { workspace = true } tokio = { workspace = true, features = [ "macros", "rt", @@ -43,4 +49,14 @@ tokio = { workspace = true, features = [ "sync", "parking_lot" ] } +url = { workspace = true } +[dev-dependencies] +ctor = { workspace = true } +env_logger = { workspace = true } +testcontainers-modules = { version = "0.11", features = ["minio"] } +tonic = { workspace = true } + +[features] +default = [] +testcontainers = [] diff --git a/examples/examples/custom-client.rs b/examples/examples/custom-client.rs new file mode 100644 index 000000000..9e7ec8595 --- /dev/null +++ b/examples/examples/custom-client.rs @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use ballista::extension::SessionContextExt; +use ballista_examples::object_store::{ + custom_session_config_with_s3_options, custom_session_state_with_s3_support, +}; +use datafusion::error::Result; +use datafusion::{assert_batches_eq, prelude::SessionContext}; + +/// bucket name to be used for this example +const S3_BUCKET: &str = "ballista"; +/// S3 access key +const S3_ACCESS_KEY_ID: &str = "MINIO"; +/// S3 secret key +const S3_SECRET_KEY: &str = "MINIOSECRET"; +/// +/// # Extending Ballista +/// +/// This example demonstrates how to extend ballista scheduler and executor registering new object store registry. +/// It uses local [minio](https://min.io) to act as S3 object store. +/// +/// Ballista will be extended providing custom session configuration, runtime environment and session state. +/// +/// Minio can be started: +/// +/// ```bash +/// docker run --rm -p 9000:9000 -p 9001:9001 -e "MINIO_ACCESS_KEY=MINIO" -e "MINIO_SECRET_KEY=MINIOSECRET" quay.io/minio/minio server /data --console-address ":9001" +/// ``` +/// After minio, we need to start `custom-scheduler` +/// +/// ```bash +/// cargo run --example custom-scheduler +/// ``` +/// +/// and `custom-executor` +/// +/// ```bash +/// cargo run --example custom-executor +/// ``` +/// +/// ```bash +/// cargo run --example custom-client +/// ``` +#[tokio::main] +async fn main() -> Result<()> { + let test_data = ballista_examples::test_util::examples_test_data(); + + // new sessions state with required custom session configuration and runtime environment + let state = + custom_session_state_with_s3_support(custom_session_config_with_s3_options())?; + + let ctx: SessionContext = + SessionContext::remote_with_state("df://localhost:50050", state).await?; + + // session config has relevant S3 options registered and exposed. + // S3 configuration options can be changed using `SET` statement + ctx.sql("SET s3.allow_http = true").await?.show().await?; + + ctx.sql(&format!("SET s3.access_key_id = '{}'", S3_ACCESS_KEY_ID)) + .await? + .show() + .await?; + ctx.sql(&format!("SET s3.secret_access_key = '{}'", S3_SECRET_KEY)) + .await? + .show() + .await?; + ctx.sql("SET s3.endpoint = 'http://localhost:9000'") + .await? + .show() + .await?; + ctx.sql("SET s3.allow_http = true").await?.show().await?; + + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + + let write_dir_path = &format!("s3://{}/write_test.parquet", S3_BUCKET); + + ctx.sql("select * from test") + .await? + .write_parquet(write_dir_path, Default::default(), Default::default()) + .await?; + + ctx.register_parquet("written_table", write_dir_path, Default::default()) + .await?; + + let result = ctx + .sql("select id, string_col, timestamp_col from written_table where id > 4") + .await? + .collect() + .await?; + + let expected = [ + "+----+------------+---------------------+", + "| id | string_col | timestamp_col |", + "+----+------------+---------------------+", + "| 5 | 31 | 2009-03-01T00:01:00 |", + "| 6 | 30 | 2009-04-01T00:00:00 |", + "| 7 | 31 | 2009-04-01T00:01:00 |", + "+----+------------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + Ok(()) +} diff --git a/examples/examples/custom-executor.rs b/examples/examples/custom-executor.rs new file mode 100644 index 000000000..534182121 --- /dev/null +++ b/examples/examples/custom-executor.rs @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use ballista_examples::object_store::{ + custom_runtime_env_with_s3_support, custom_session_config_with_s3_options, +}; + +use ballista_executor::executor_process::{ + start_executor_process, ExecutorProcessConfig, +}; +use datafusion::prelude::SessionConfig; +use std::sync::Arc; +/// +/// # Custom Ballista Executor +/// +/// This example demonstrates how to crate custom ballista executors. +/// +#[tokio::main] +async fn main() -> ballista_core::error::Result<()> { + let _ = env_logger::builder() + .filter_level(log::LevelFilter::Info) + .is_test(true) + .try_init(); + + let config: ExecutorProcessConfig = ExecutorProcessConfig { + // overriding default config producer with custom producer + // which has required S3 configuration options + override_config_producer: Some(Arc::new(custom_session_config_with_s3_options)), + // overriding default runtime producer with custom producer + // which knows how to create S3 connections + override_runtime_producer: Some(Arc::new(|session_config: &SessionConfig| { + custom_runtime_env_with_s3_support(session_config) + })), + ..Default::default() + }; + + start_executor_process(Arc::new(config)).await +} diff --git a/examples/examples/custom-scheduler.rs b/examples/examples/custom-scheduler.rs new file mode 100644 index 000000000..9783ae28e --- /dev/null +++ b/examples/examples/custom-scheduler.rs @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use ballista_core::error::BallistaError; +use ballista_examples::object_store::{ + custom_session_config_with_s3_options, custom_session_state_with_s3_support, +}; +use ballista_scheduler::cluster::BallistaCluster; +use ballista_scheduler::config::SchedulerConfig; +use ballista_scheduler::scheduler_process::start_server; +use datafusion::prelude::SessionConfig; +use std::net::AddrParseError; +use std::sync::Arc; + +/// +/// # Custom Ballista Scheduler +/// +/// This example demonstrates how to crate custom ballista schedulers. +/// +#[tokio::main] +async fn main() -> ballista_core::error::Result<()> { + let _ = env_logger::builder() + .filter_level(log::LevelFilter::Info) + .is_test(true) + .try_init(); + + let config: SchedulerConfig = SchedulerConfig { + // overriding default runtime producer with custom producer + // which knows how to create S3 connections + override_config_producer: Some(Arc::new(custom_session_config_with_s3_options)), + // overriding default session builder, which has custom session configuration + // runtime environment and session state. + override_session_builder: Some(Arc::new(|session_config: SessionConfig| { + custom_session_state_with_s3_support(session_config) + })), + ..Default::default() + }; + + let addr = format!("{}:{}", config.bind_host, config.bind_port); + let addr = addr + .parse() + .map_err(|e: AddrParseError| BallistaError::Configuration(e.to_string()))?; + + let cluster = BallistaCluster::new_from_config(&config).await?; + start_server(cluster, addr, Arc::new(config)).await?; + + Ok(()) +} diff --git a/examples/src/lib.rs b/examples/src/lib.rs index 6dc48f6b9..f8d7cc59b 100644 --- a/examples/src/lib.rs +++ b/examples/src/lib.rs @@ -15,4 +15,7 @@ // specific language governing permissions and limitations // under the License. +/// Provides required structures and methods to +/// integrate with S3 object store +pub mod object_store; pub mod test_util; diff --git a/examples/src/object_store.rs b/examples/src/object_store.rs new file mode 100644 index 000000000..5b5e38a6a --- /dev/null +++ b/examples/src/object_store.rs @@ -0,0 +1,323 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # Extending Ballista +//! +//! This example demonstrates extending standard ballista behavior, +//! integrating external [ObjectStoreRegistry]. +//! +//! [ObjectStore] is provided by [ObjectStoreRegistry], and configured +//! using [ExtensionOptions], which can be configured using SQL `SET` command. + +use ballista::prelude::SessionConfigExt; +use datafusion::common::{config_err, exec_err}; +use datafusion::config::{ + ConfigEntry, ConfigExtension, ConfigField, ExtensionOptions, Visit, +}; +use datafusion::error::Result; +use datafusion::execution::object_store::ObjectStoreRegistry; +use datafusion::execution::SessionState; +use datafusion::prelude::SessionConfig; +use datafusion::{ + error::DataFusionError, + execution::{ + runtime_env::{RuntimeConfig, RuntimeEnv}, + SessionStateBuilder, + }, +}; +use object_store::aws::AmazonS3Builder; +use object_store::local::LocalFileSystem; +use object_store::ObjectStore; +use parking_lot::RwLock; +use std::any::Any; +use std::fmt::Display; +use std::sync::Arc; +use url::Url; + +/// Custom [SessionConfig] constructor method +/// +/// This method registers config extension [S3Options] +/// which is used to configure [ObjectStore] with ACCESS and +/// SECRET key +pub fn custom_session_config_with_s3_options() -> SessionConfig { + SessionConfig::new_with_ballista() + .with_information_schema(true) + .with_option_extension(S3Options::default()) +} + +/// Custom [RuntimeEnv] constructor method +/// +/// It will register [CustomObjectStoreRegistry] which will +/// use configuration extension [S3Options] to configure +/// and created [ObjectStore]s +pub fn custom_runtime_env_with_s3_support( + session_config: &SessionConfig, +) -> Result> { + let s3options = session_config + .options() + .extensions + .get::() + .ok_or(DataFusionError::Configuration( + "S3 Options not set".to_string(), + ))?; + + let config = RuntimeConfig::new().with_object_store_registry(Arc::new( + CustomObjectStoreRegistry::new(s3options.clone()), + )); + + Ok(Arc::new(RuntimeEnv::try_new(config)?)) +} + +/// Custom [SessionState] constructor method +/// +/// It will configure [SessionState] with provided [SessionConfig], +/// and [RuntimeEnv]. +pub fn custom_session_state_with_s3_support( + session_config: SessionConfig, +) -> datafusion::common::Result { + let runtime_env = custom_runtime_env_with_s3_support(&session_config)?; + + Ok(SessionStateBuilder::new() + .with_runtime_env(runtime_env) + .with_config(session_config) + .build()) +} + +/// Custom [ObjectStoreRegistry] which will create +/// and configure [ObjectStore] using provided [S3Options] +#[derive(Debug)] +pub struct CustomObjectStoreRegistry { + local: Arc, + s3options: S3Options, +} + +impl CustomObjectStoreRegistry { + pub fn new(s3options: S3Options) -> Self { + Self { + s3options, + local: Arc::new(LocalFileSystem::new()), + } + } +} + +impl ObjectStoreRegistry for CustomObjectStoreRegistry { + fn register_store( + &self, + _url: &Url, + _store: Arc, + ) -> Option> { + unimplemented!("register_store not supported") + } + + fn get_store(&self, url: &Url) -> Result> { + let scheme = url.scheme(); + log::info!("get_store: {:?}", &self.s3options.config.read()); + match scheme { + "" | "file" => Ok(self.local.clone()), + "s3" => { + let s3store = + Self::s3_object_store_builder(url, &self.s3options.config.read())? + .build()?; + + Ok(Arc::new(s3store)) + } + + _ => exec_err!("get_store - store not supported, url {}", url), + } + } +} + +impl CustomObjectStoreRegistry { + fn s3_object_store_builder( + url: &Url, + aws_options: &S3RegistryConfiguration, + ) -> Result { + let S3RegistryConfiguration { + access_key_id, + secret_access_key, + session_token, + region, + endpoint, + allow_http, + } = aws_options; + + let bucket_name = Self::get_bucket_name(url)?; + let mut builder = AmazonS3Builder::from_env().with_bucket_name(bucket_name); + + if let (Some(access_key_id), Some(secret_access_key)) = + (access_key_id, secret_access_key) + { + builder = builder + .with_access_key_id(access_key_id) + .with_secret_access_key(secret_access_key); + + if let Some(session_token) = session_token { + builder = builder.with_token(session_token); + } + } else { + return config_err!( + "'s3.access_key_id' & 's3.secret_access_key' must be configured" + ); + } + + if let Some(region) = region { + builder = builder.with_region(region); + } + + if let Some(endpoint) = endpoint { + if let Ok(endpoint_url) = Url::try_from(endpoint.as_str()) { + if !matches!(allow_http, Some(true)) && endpoint_url.scheme() == "http" { + return config_err!("Invalid endpoint: {endpoint}. HTTP is not allowed for S3 endpoints. To allow HTTP, set 's3.allow_http' to true"); + } + } + + builder = builder.with_endpoint(endpoint); + } + + if let Some(allow_http) = allow_http { + builder = builder.with_allow_http(*allow_http); + } + + Ok(builder) + } + + fn get_bucket_name(url: &Url) -> Result<&str> { + url.host_str().ok_or_else(|| { + DataFusionError::Execution(format!( + "Not able to parse bucket name from url: {}", + url.as_str() + )) + }) + } +} + +/// Custom [SessionConfig] extension which allows +/// users to configure [ObjectStore] access using SQL +/// interface +#[derive(Debug, Clone, Default)] +pub struct S3Options { + config: Arc>, +} + +impl ExtensionOptions for S3Options { + fn as_any(&self) -> &dyn Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + fn cloned(&self) -> Box { + Box::new(self.clone()) + } + + fn set(&mut self, key: &str, value: &str) -> Result<()> { + log::debug!("set config, key:{}, value:{}", key, value); + match key { + "access_key_id" => { + let mut c = self.config.write(); + c.access_key_id.set(key, value)?; + } + "secret_access_key" => { + let mut c = self.config.write(); + c.secret_access_key.set(key, value)?; + } + "session_token" => { + let mut c = self.config.write(); + c.session_token.set(key, value)?; + } + "region" => { + let mut c = self.config.write(); + c.region.set(key, value)?; + } + "endpoint" => { + let mut c = self.config.write(); + c.endpoint.set(key, value)?; + } + "allow_http" => { + let mut c = self.config.write(); + c.allow_http.set(key, value)?; + } + _ => { + log::warn!("Config value {} cant be set to {}", key, value); + return config_err!("Config value \"{}\" not found in S3Options", key); + } + } + Ok(()) + } + + fn entries(&self) -> Vec { + struct Visitor(Vec); + + impl Visit for Visitor { + fn some( + &mut self, + key: &str, + value: V, + description: &'static str, + ) { + self.0.push(ConfigEntry { + key: format!("{}.{}", S3Options::PREFIX, key), + value: Some(value.to_string()), + description, + }) + } + + fn none(&mut self, key: &str, description: &'static str) { + self.0.push(ConfigEntry { + key: format!("{}.{}", S3Options::PREFIX, key), + value: None, + description, + }) + } + } + let c = self.config.read(); + + let mut v = Visitor(vec![]); + c.access_key_id + .visit(&mut v, "access_key_id", "S3 Access Key"); + c.secret_access_key + .visit(&mut v, "secret_access_key", "S3 Secret Key"); + c.session_token + .visit(&mut v, "session_token", "S3 Session token"); + c.region.visit(&mut v, "region", "S3 region"); + c.endpoint.visit(&mut v, "endpoint", "S3 Endpoint"); + c.allow_http.visit(&mut v, "allow_http", "S3 Allow Http"); + + v.0 + } +} + +impl ConfigExtension for S3Options { + const PREFIX: &'static str = "s3"; +} +#[derive(Default, Debug, Clone)] +struct S3RegistryConfiguration { + /// Access Key ID + pub access_key_id: Option, + /// Secret Access Key + pub secret_access_key: Option, + /// Session token + pub session_token: Option, + /// AWS Region + pub region: Option, + /// OSS or COS Endpoint + pub endpoint: Option, + /// Allow HTTP (otherwise will always use https) + pub allow_http: Option, +} diff --git a/examples/tests/common/mod.rs b/examples/tests/common/mod.rs new file mode 100644 index 000000000..1e8091ed9 --- /dev/null +++ b/examples/tests/common/mod.rs @@ -0,0 +1,177 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use ballista::prelude::SessionConfigExt; +use ballista_core::serde::{ + protobuf::scheduler_grpc_client::SchedulerGrpcClient, BallistaCodec, +}; +use ballista_core::{ConfigProducer, RuntimeProducer}; +use ballista_scheduler::SessionBuilder; +use datafusion::execution::SessionState; +use datafusion::prelude::SessionConfig; +use object_store::aws::AmazonS3Builder; +use testcontainers_modules::minio::MinIO; +use testcontainers_modules::testcontainers::core::{CmdWaitFor, ExecCommand}; +use testcontainers_modules::testcontainers::ContainerRequest; +use testcontainers_modules::{minio, testcontainers::ImageExt}; + +pub const REGION: &str = "eu-west-1"; +pub const BUCKET: &str = "ballista"; +pub const ACCESS_KEY_ID: &str = "MINIO"; +pub const SECRET_KEY: &str = "MINIOMINIO"; + +#[allow(dead_code)] +pub fn create_s3_store( + host: &str, + port: u16, +) -> std::result::Result { + log::info!("create S3 client: host: {}, port: {}", host, port); + AmazonS3Builder::new() + .with_endpoint(format!("http://{host}:{port}")) + .with_region(REGION) + .with_bucket_name(BUCKET) + .with_access_key_id(ACCESS_KEY_ID) + .with_secret_access_key(SECRET_KEY) + .with_allow_http(true) + .build() +} + +#[allow(dead_code)] +pub fn create_minio_container() -> ContainerRequest { + MinIO::default() + .with_env_var("MINIO_ACCESS_KEY", ACCESS_KEY_ID) + .with_env_var("MINIO_SECRET_KEY", SECRET_KEY) +} + +#[allow(dead_code)] +pub fn create_bucket_command() -> ExecCommand { + // this is hack to create a bucket without creating s3 client. + // this works with current testcontainer (and image) version 'RELEASE.2022-02-07T08-17-33Z'. + // (testcontainer does not await properly on latest image version) + // + // if testcontainer image version change to something newer we should use "mc mb /data/ballista" + // to crate a bucket. + ExecCommand::new(vec![ + "mkdir".to_string(), + format!("/data/{}", crate::common::BUCKET), + ]) + .with_cmd_ready_condition(CmdWaitFor::seconds(1)) +} + +/// starts a ballista cluster for integration tests +#[allow(dead_code)] +pub async fn setup_test_cluster_with_state(session_state: SessionState) -> (String, u16) { + let config = SessionConfig::new_with_ballista(); + + let addr = ballista_scheduler::standalone::new_standalone_scheduler_from_state( + &session_state, + ) + .await + .expect("scheduler to be created"); + + let host = "localhost".to_string(); + + let scheduler = + connect_to_scheduler(format!("http://{}:{}", host, addr.port())).await; + + ballista_executor::new_standalone_executor_from_state( + scheduler, + config.ballista_standalone_parallelism(), + &session_state, + ) + .await + .expect("executor to be created"); + + log::info!("test scheduler created at: {}:{}", host, addr.port()); + + (host, addr.port()) +} + +#[allow(dead_code)] +pub async fn setup_test_cluster_with_builders( + config_producer: ConfigProducer, + runtime_producer: RuntimeProducer, + session_builder: SessionBuilder, +) -> (String, u16) { + let config = config_producer(); + + let logical = config.ballista_logical_extension_codec(); + let physical = config.ballista_physical_extension_codec(); + let codec = BallistaCodec::new(logical, physical); + + let addr = ballista_scheduler::standalone::new_standalone_scheduler_with_builder( + session_builder, + config_producer.clone(), + codec.clone(), + ) + .await + .expect("scheduler to be created"); + + let host = "localhost".to_string(); + + let scheduler = + connect_to_scheduler(format!("http://{}:{}", host, addr.port())).await; + + ballista_executor::new_standalone_executor_from_builder( + scheduler, + config.ballista_standalone_parallelism(), + config_producer, + runtime_producer, + codec, + Default::default(), + ) + .await + .expect("executor to be created"); + + log::info!("test scheduler created at: {}:{}", host, addr.port()); + + (host, addr.port()) +} + +async fn connect_to_scheduler( + scheduler_url: String, +) -> SchedulerGrpcClient { + let mut retry = 50; + loop { + match SchedulerGrpcClient::connect(scheduler_url.clone()).await { + Err(_) if retry > 0 => { + retry -= 1; + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + log::debug!("Re-attempting to connect to test scheduler..."); + } + + Err(_) => { + log::error!("scheduler connection timed out"); + panic!("scheduler connection timed out") + } + Ok(scheduler) => break scheduler, + } + } +} + +#[ctor::ctor] +fn init() { + // Enable RUST_LOG logging configuration for test + let _ = env_logger::builder() + .filter_level(log::LevelFilter::Info) + .parse_filters( + "ballista=debug,ballista_scheduler-rs=debug,ballista_executor=debug", + ) + //.parse_filters("ballista=debug,ballista_scheduler-rs=debug,ballista_executor=debug,datafusion=debug") + .is_test(true) + .try_init(); +} diff --git a/ballista/client/tests/object_store.rs b/examples/tests/object_store.rs similarity index 63% rename from ballista/client/tests/object_store.rs rename to examples/tests/object_store.rs index 83df931fc..3c3443e82 100644 --- a/ballista/client/tests/object_store.rs +++ b/examples/tests/object_store.rs @@ -27,11 +27,11 @@ mod common; #[cfg(test)] -#[cfg(feature = "standalone")] #[cfg(feature = "testcontainers")] mod standalone { use ballista::extension::SessionContextExt; + use ballista_examples::test_util::examples_test_data; use datafusion::{assert_batches_eq, prelude::SessionContext}; use datafusion::{ error::DataFusionError, @@ -52,12 +52,13 @@ mod standalone { .await .unwrap(); + let host = node.get_host().await.unwrap(); let port = node.get_host_port_ipv4(9000).await.unwrap(); - let object_store = crate::common::create_s3_store(port) + let object_store = crate::common::create_s3_store(&host.to_string(), port) .map_err(|e| DataFusionError::External(e.into()))?; - let test_data = crate::common::example_test_data(); + let test_data = examples_test_data(); let config = RuntimeConfig::new(); let runtime_env = RuntimeEnv::try_new(config)?; @@ -116,6 +117,7 @@ mod standalone { mod remote { use ballista::extension::SessionContextExt; + use ballista_examples::test_util::examples_test_data; use datafusion::{assert_batches_eq, prelude::SessionContext}; use datafusion::{ error::DataFusionError, @@ -129,7 +131,7 @@ mod remote { #[tokio::test] async fn should_execute_sql_write() -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); + let test_data = examples_test_data(); let container = crate::common::create_minio_container(); let node = container.start().await.unwrap(); @@ -138,9 +140,10 @@ mod remote { .await .unwrap(); + let host = node.get_host().await.unwrap(); let port = node.get_host_port_ipv4(9000).await.unwrap(); - let object_store = crate::common::create_s3_store(port) + let object_store = crate::common::create_s3_store(&host.to_string(), port) .map_err(|e| DataFusionError::External(e.into()))?; let config = RuntimeConfig::new(); @@ -210,15 +213,13 @@ mod remote { #[cfg(feature = "testcontainers")] mod custom_s3_config { + use crate::common::{ACCESS_KEY_ID, SECRET_KEY}; use ballista::extension::SessionContextExt; use ballista::prelude::SessionConfigExt; use ballista_core::RuntimeProducer; - use datafusion::common::{config_err, exec_err}; - use datafusion::config::{ - ConfigEntry, ConfigExtension, ConfigField, ExtensionOptions, Visit, - }; - use datafusion::error::Result; - use datafusion::execution::object_store::ObjectStoreRegistry; + use ballista_examples::object_store::{CustomObjectStoreRegistry, S3Options}; + use ballista_examples::test_util::examples_test_data; + use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::SessionState; use datafusion::prelude::SessionConfig; use datafusion::{assert_batches_eq, prelude::SessionContext}; @@ -229,22 +230,13 @@ mod custom_s3_config { SessionStateBuilder, }, }; - use object_store::aws::AmazonS3Builder; - use object_store::local::LocalFileSystem; - use object_store::ObjectStore; - use parking_lot::RwLock; - use std::any::Any; - use std::fmt::Display; use std::sync::Arc; use testcontainers_modules::testcontainers::runners::AsyncRunner; - use url::Url; - - use crate::common::{ACCESS_KEY_ID, SECRET_KEY}; #[tokio::test] async fn should_configure_s3_execute_sql_write_remote( ) -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); + let test_data = examples_test_data(); // // Minio cluster setup @@ -256,8 +248,15 @@ mod custom_s3_config { .await .unwrap(); + let endpoint_host = node.get_host().await.unwrap(); let endpoint_port = node.get_host_port_ipv4(9000).await.unwrap(); + log::info!( + "MINIO testcontainers host: {}, port: {}", + endpoint_host, + endpoint_port + ); + // // Session Context and Ballista cluster setup // @@ -300,7 +299,7 @@ mod custom_s3_config { // object store registry. let session_builder = Arc::new(produce_state); - let state = session_builder(config_producer()); + let state = session_builder(config_producer())?; // setting up ballista cluster with new runtime, configuration, and session state producers let (host, port) = crate::common::setup_test_cluster_with_builders( @@ -325,8 +324,8 @@ mod custom_s3_config { .show() .await?; ctx.sql(&format!( - "SET s3.endpoint = 'http://localhost:{}'", - endpoint_port + "SET s3.endpoint = 'http://{}:{}'", + endpoint_host, endpoint_port )) .await? .show() @@ -380,10 +379,9 @@ mod custom_s3_config { // SessionConfig propagation across ballista cluster. #[tokio::test] - #[cfg(feature = "standalone")] async fn should_configure_s3_execute_sql_write_standalone( ) -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); + let test_data = examples_test_data(); // // Minio cluster setup @@ -395,6 +393,7 @@ mod custom_s3_config { .await .unwrap(); + let endpoint_host = node.get_host().await.unwrap(); let endpoint_port = node.get_host_port_ipv4(9000).await.unwrap(); // @@ -418,12 +417,7 @@ mod custom_s3_config { // object store registry. let session_builder = Arc::new(produce_state); - let state = session_builder(config_producer()); - - // // setting up ballista cluster with new runtime, configuration, and session state producers - // let (host, port) = - // crate::common::setup_test_cluster_with_state(state.clone()).await; - // let url = format!("df://{host}:{port}"); + let state = session_builder(config_producer())?; // // establishing cluster connection, let ctx: SessionContext = SessionContext::standalone_with_state(state).await?; @@ -439,8 +433,8 @@ mod custom_s3_config { .show() .await?; ctx.sql(&format!( - "SET s3.endpoint = 'http://localhost:{}'", - endpoint_port + "SET s3.endpoint = 'http://{}:{}'", + endpoint_host, endpoint_port )) .await? .show() @@ -487,251 +481,26 @@ mod custom_s3_config { Ok(()) } - fn produce_state(session_config: SessionConfig) -> SessionState { + fn produce_state( + session_config: SessionConfig, + ) -> datafusion::common::Result { let s3options = session_config .options() .extensions .get::() .ok_or(DataFusionError::Configuration( "S3 Options not set".to_string(), - )) - .unwrap(); + ))?; let config = RuntimeConfig::new().with_object_store_registry(Arc::new( CustomObjectStoreRegistry::new(s3options.clone()), )); - let runtime_env = RuntimeEnv::try_new(config).unwrap(); - SessionStateBuilder::new() + let runtime_env = RuntimeEnv::try_new(config)?; + + Ok(SessionStateBuilder::new() .with_runtime_env(runtime_env.into()) .with_config(session_config) - .build() - } - - #[derive(Debug)] - pub struct CustomObjectStoreRegistry { - local: Arc, - s3options: S3Options, - } - - impl CustomObjectStoreRegistry { - fn new(s3options: S3Options) -> Self { - Self { - s3options, - local: Arc::new(LocalFileSystem::new()), - } - } - } - - impl ObjectStoreRegistry for CustomObjectStoreRegistry { - fn register_store( - &self, - _url: &Url, - _store: Arc, - ) -> Option> { - unreachable!("register_store not supported ") - } - - fn get_store(&self, url: &Url) -> Result> { - let scheme = url.scheme(); - log::info!("get_store: {:?}", &self.s3options.config.read()); - match scheme { - "" | "file" => Ok(self.local.clone()), - "s3" => { - let s3store = Self::s3_object_store_builder( - url, - &self.s3options.config.read(), - )? - .build()?; - - Ok(Arc::new(s3store)) - } - - _ => exec_err!("get_store - store not supported, url {}", url), - } - } - } - - impl CustomObjectStoreRegistry { - pub fn s3_object_store_builder( - url: &Url, - aws_options: &S3RegistryConfiguration, - ) -> Result { - let S3RegistryConfiguration { - access_key_id, - secret_access_key, - session_token, - region, - endpoint, - allow_http, - } = aws_options; - - let bucket_name = Self::get_bucket_name(url)?; - let mut builder = AmazonS3Builder::from_env().with_bucket_name(bucket_name); - - if let (Some(access_key_id), Some(secret_access_key)) = - (access_key_id, secret_access_key) - { - builder = builder - .with_access_key_id(access_key_id) - .with_secret_access_key(secret_access_key); - - if let Some(session_token) = session_token { - builder = builder.with_token(session_token); - } - } else { - return config_err!( - "'s3.access_key_id' & 's3.secret_access_key' must be configured" - ); - } - - if let Some(region) = region { - builder = builder.with_region(region); - } - - if let Some(endpoint) = endpoint { - if let Ok(endpoint_url) = Url::try_from(endpoint.as_str()) { - if !matches!(allow_http, Some(true)) - && endpoint_url.scheme() == "http" - { - return config_err!("Invalid endpoint: {endpoint}. HTTP is not allowed for S3 endpoints. To allow HTTP, set 's3.allow_http' to true"); - } - } - - builder = builder.with_endpoint(endpoint); - } - - if let Some(allow_http) = allow_http { - builder = builder.with_allow_http(*allow_http); - } - - Ok(builder) - } - - fn get_bucket_name(url: &Url) -> Result<&str> { - url.host_str().ok_or_else(|| { - DataFusionError::Execution(format!( - "Not able to parse bucket name from url: {}", - url.as_str() - )) - }) - } - } - - #[derive(Debug, Clone, Default)] - pub struct S3Options { - config: Arc>, - } - - impl ExtensionOptions for S3Options { - fn as_any(&self) -> &dyn Any { - self - } - - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } - - fn cloned(&self) -> Box { - Box::new(self.clone()) - } - - fn set(&mut self, key: &str, value: &str) -> Result<()> { - log::debug!("set config, key:{}, value:{}", key, value); - match key { - "access_key_id" => { - let mut c = self.config.write(); - c.access_key_id.set(key, value)?; - } - "secret_access_key" => { - let mut c = self.config.write(); - c.secret_access_key.set(key, value)?; - } - "session_token" => { - let mut c = self.config.write(); - c.session_token.set(key, value)?; - } - "region" => { - let mut c = self.config.write(); - c.region.set(key, value)?; - } - "endpoint" => { - let mut c = self.config.write(); - c.endpoint.set(key, value)?; - } - "allow_http" => { - let mut c = self.config.write(); - c.allow_http.set(key, value)?; - } - _ => { - log::warn!("Config value {} cant be set to {}", key, value); - return config_err!( - "Config value \"{}\" not found in S3Options", - key - ); - } - } - Ok(()) - } - - fn entries(&self) -> Vec { - struct Visitor(Vec); - - impl Visit for Visitor { - fn some( - &mut self, - key: &str, - value: V, - description: &'static str, - ) { - self.0.push(ConfigEntry { - key: format!("{}.{}", S3Options::PREFIX, key), - value: Some(value.to_string()), - description, - }) - } - - fn none(&mut self, key: &str, description: &'static str) { - self.0.push(ConfigEntry { - key: format!("{}.{}", S3Options::PREFIX, key), - value: None, - description, - }) - } - } - let c = self.config.read(); - - let mut v = Visitor(vec![]); - c.access_key_id - .visit(&mut v, "access_key_id", "S3 Access Key"); - c.secret_access_key - .visit(&mut v, "secret_access_key", "S3 Secret Key"); - c.session_token - .visit(&mut v, "session_token", "S3 Session token"); - c.region.visit(&mut v, "region", "S3 region"); - c.endpoint.visit(&mut v, "endpoint", "S3 Endpoint"); - c.allow_http.visit(&mut v, "allow_http", "S3 Allow Http"); - - v.0 - } - } - - impl ConfigExtension for S3Options { - const PREFIX: &'static str = "s3"; - } - #[derive(Default, Debug, Clone)] - pub struct S3RegistryConfiguration { - /// Access Key ID - pub access_key_id: Option, - /// Secret Access Key - pub secret_access_key: Option, - /// Session token - pub session_token: Option, - /// AWS Region - pub region: Option, - /// OSS or COS Endpoint - pub endpoint: Option, - /// Allow HTTP (otherwise will always use https) - pub allow_http: Option, + .build()) } } diff --git a/python/Cargo.toml b/python/Cargo.toml index b03f1e997..f70838226 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -25,21 +25,22 @@ description = "Apache Arrow Ballista Python Client" readme = "README.md" license = "Apache-2.0" edition = "2021" -rust-version = "1.72" include = ["/src", "/ballista", "/LICENSE.txt", "pyproject.toml", "Cargo.toml", "Cargo.lock"] publish = false [dependencies] async-trait = "0.1.77" -ballista = { path = "../ballista/client", version = "0.12.0", features = ["standalone"] } +ballista = { path = "../ballista/client", version = "0.12.0" } ballista-core = { path = "../ballista/core", version = "0.12.0" } +ballista-executor = { path = "../ballista/executor", version = "0.12.0", default-features = false } +ballista-scheduler = { path = "../ballista/scheduler", version = "0.12.0", default-features = false } datafusion = { version = "42", features = ["pyarrow", "avro"] } datafusion-proto = { version = "42" } datafusion-python = { version = "42" } pyo3 = { version = "0.22", features = ["extension-module", "abi3", "abi3-py38"] } pyo3-log = "0.11.0" -tokio = { version = "1.35", features = ["macros", "rt", "rt-multi-thread", "sync"] } +tokio = { version = "1.42", features = ["macros", "rt", "rt-multi-thread", "sync"] } [lib] crate-type = ["cdylib"] diff --git a/python/README.md b/python/README.md index 01b0a7f90..d8ba03f3d 100644 --- a/python/README.md +++ b/python/README.md @@ -26,6 +26,12 @@ part of the default Cargo workspace so that it doesn't cause overhead for mainta ## Creating a SessionContext +> [!IMPORTANT] +> Current approach is to support datafusion python API, there are know limitations of current approach, +> with some cases producing errors. +> We trying to come up with the best approach to support datafusion python interface. +> More details could be found at [#1142](https://github.com/apache/datafusion-ballista/issues/1142) + Creates a new context and connects to a Ballista scheduler process. ```python @@ -33,22 +39,50 @@ from ballista import BallistaBuilder >>> ctx = BallistaBuilder().standalone() ``` -## Example SQL Usage +### Example SQL Usage ```python ->>> ctx.sql("create external table t stored as parquet location '/mnt/bigdata/tpch/sf10-parquet/lineitem.parquet'") +>>> ctx.sql("create external table t stored as parquet location './testdata/test.parquet'") >>> df = ctx.sql("select * from t limit 5") >>> pyarrow_batches = df.collect() ``` -## Example DataFrame Usage +### Example DataFrame Usage ```python ->>> df = ctx.read_parquet('/mnt/bigdata/tpch/sf10-parquet/lineitem.parquet').limit(5) +>>> df = ctx.read_parquet('./testdata/test.parquet').limit(5) >>> pyarrow_batches = df.collect() ``` -## Creating Virtual Environment +## Scheduler and Executor + +Scheduler and executors can be configured and started from python code. + +To start scheduler: + +```python +from ballista import BallistaScheduler + +scheduler = BallistaScheduler() + +scheduler.start() +scheduler.wait_for_termination() +``` + +For executor: + +```python +from ballista import BallistaExecutor + +executor = BallistaExecutor() + +executor.start() +executor.wait_for_termination() +``` + +## Development Process + +### Creating Virtual Environment ```shell python3 -m venv venv @@ -56,7 +90,7 @@ source venv/bin/activate pip3 install -r requirements.txt ``` -## Building +### Building ```shell maturin develop @@ -64,7 +98,7 @@ maturin develop Note that you can also run `maturin develop --release` to get a release build locally. -## Testing +### Testing ```shell python3 -m pytest diff --git a/python/ballista/__init__.py b/python/ballista/__init__.py index a143f17e9..4e80422b7 100644 --- a/python/ballista/__init__.py +++ b/python/ballista/__init__.py @@ -26,11 +26,11 @@ import pyarrow as pa from .ballista_internal import ( - BallistaBuilder, + BallistaBuilder, BallistaScheduler, BallistaExecutor ) __version__ = importlib_metadata.version(__name__) __all__ = [ - "BallistaBuilder", + "BallistaBuilder", "BallistaScheduler", "BallistaExecutor" ] \ No newline at end of file diff --git a/python/examples/client_remote.py b/python/examples/client_remote.py new file mode 100644 index 000000000..fd85858ac --- /dev/null +++ b/python/examples/client_remote.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# %% +from ballista import BallistaBuilder +from datafusion.context import SessionContext + +ctx: SessionContext = BallistaBuilder().remote("df://127.0.0.1:50050") + +# Select 1 to verify its working +ctx.sql("SELECT 1").show() + +# %% diff --git a/python/examples/example.py b/python/examples/client_standalone.py similarity index 79% rename from python/examples/example.py rename to python/examples/client_standalone.py index 61a9abbd2..dfe3c372f 100644 --- a/python/examples/example.py +++ b/python/examples/client_standalone.py @@ -15,18 +15,23 @@ # specific language governing permissions and limitations # under the License. +# %% + from ballista import BallistaBuilder from datafusion.context import SessionContext -# Ballista will initiate with an empty config -# set config variables with `config` ctx: SessionContext = BallistaBuilder()\ + .config("datafusion.catalog.information_schema","true")\ .config("ballista.job.name", "example ballista")\ - .config("ballista.shuffle.partitions", "16")\ .standalone() -#ctx_remote: SessionContext = ballista.remote("remote_ip", 50050) -# Select 1 to verify its working ctx.sql("SELECT 1").show() -#ctx_remote.sql("SELECT 2").show() \ No newline at end of file + +# %% +ctx.sql("SHOW TABLES").show() +# %% +ctx.sql("select name, value from information_schema.df_settings where name like 'ballista.job.name'").show() + + +# %% diff --git a/python/examples/executor.py b/python/examples/executor.py new file mode 100644 index 000000000..bb032f634 --- /dev/null +++ b/python/examples/executor.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# %% +from ballista import BallistaExecutor +# %% +executor = BallistaExecutor() +# %% +executor.start() +# %% +executor +# %% +executor.wait_for_termination() +# %% +# %% +executor.close() +# %% diff --git a/python/examples/readme_remote.py b/python/examples/readme_remote.py new file mode 100644 index 000000000..7e1c82d83 --- /dev/null +++ b/python/examples/readme_remote.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# %% + +from ballista import BallistaBuilder +from datafusion.context import SessionContext + +ctx: SessionContext = BallistaBuilder()\ + .config("ballista.job.name", "Readme Example Remote")\ + .config("datafusion.execution.target_partitions", "4")\ + .remote("df://127.0.0.1:50050") + +ctx.sql("create external table t stored as parquet location '../testdata/test.parquet'") + +# %% +df = ctx.sql("select * from t limit 5") +pyarrow_batches = df.collect() +pyarrow_batches[0].to_pandas() +# %% +df = ctx.read_parquet('../testdata/test.parquet').limit(5) +pyarrow_batches = df.collect() +pyarrow_batches[0].to_pandas() +# %% \ No newline at end of file diff --git a/python/examples/readme_standalone.py b/python/examples/readme_standalone.py new file mode 100644 index 000000000..15404e02d --- /dev/null +++ b/python/examples/readme_standalone.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# %% + +from ballista import BallistaBuilder +from datafusion.context import SessionContext + +ctx: SessionContext = BallistaBuilder()\ + .config("ballista.job.name", "Readme Example")\ + .config("datafusion.execution.target_partitions", "4")\ + .standalone() + +ctx.sql("create external table t stored as parquet location '../testdata/test.parquet'") + +# %% +df = ctx.sql("select * from t limit 5") +pyarrow_batches = df.collect() +pyarrow_batches[0].to_pandas() +# %% +df = ctx.read_parquet('../testdata/test.parquet').limit(5) +pyarrow_batches = df.collect() +pyarrow_batches[0].to_pandas() +# %% \ No newline at end of file diff --git a/python/examples/scheduler.py b/python/examples/scheduler.py new file mode 100644 index 000000000..1c40ce1ee --- /dev/null +++ b/python/examples/scheduler.py @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# %% +from ballista import BallistaScheduler +# %% +scheduler = BallistaScheduler() +# %% +scheduler +# %% +scheduler.start() +# %% +scheduler.wait_for_termination() +# %% +scheduler.close() \ No newline at end of file diff --git a/python/pyproject.toml b/python/pyproject.toml index 2d06b225d..d9b6d2bd9 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -16,7 +16,7 @@ # under the License. [build-system] -requires = ["maturin>=0.15,<0.16"] +requires = ["maturin>=1.5.1,<1.6.0"] build-backend = "maturin" [project] @@ -24,7 +24,7 @@ name = "ballista" description = "Python client for Apache Arrow Ballista Distributed SQL Query Engine" readme = "README.md" license = {file = "LICENSE.txt"} -requires-python = ">=3.6" +requires-python = ">=3.7" keywords = ["ballista", "sql", "rust", "distributed"] classifier = [ "Development Status :: 2 - Pre-Alpha", @@ -43,7 +43,7 @@ classifier = [ "Programming Language :: Rust", ] dependencies = [ - "pyarrow>=11.0.0", + "pyarrow>=11.0.0", "cloudpickle" ] [project.urls] @@ -61,4 +61,4 @@ include = [ ] exclude = [".github/**", "ci/**", ".asf.yaml"] # Require Cargo.lock is up to date -locked = true \ No newline at end of file +locked = true diff --git a/python/requirements.txt b/python/requirements.txt index a03a8f8d2..bfc0e03cf 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -1,3 +1,6 @@ -datafusion==35.0.0 +datafusion==42.0.0 pyarrow -pytest \ No newline at end of file +pytest +maturin==1.5.1 +cloudpickle +pandas \ No newline at end of file diff --git a/python/src/cluster.rs b/python/src/cluster.rs new file mode 100644 index 000000000..848fc4888 --- /dev/null +++ b/python/src/cluster.rs @@ -0,0 +1,266 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::future::IntoFuture; +use std::sync::Arc; + +use crate::codec::{PyLogicalCodec, PyPhysicalCodec}; +use crate::utils::to_pyerr; +use crate::utils::{spawn_feature, wait_for_future}; +use ballista_executor::executor_process::{ + start_executor_process, ExecutorProcessConfig, +}; +use ballista_scheduler::cluster::BallistaCluster; +use ballista_scheduler::config::SchedulerConfig; +use ballista_scheduler::scheduler_process::start_server; +use pyo3::exceptions::PyException; +use pyo3::{pyclass, pymethods, PyResult, Python}; +use tokio::task::JoinHandle; + +#[pyclass(name = "BallistaScheduler", module = "ballista", subclass)] +pub struct PyScheduler { + config: SchedulerConfig, + handle: Option>, +} + +#[pymethods] +impl PyScheduler { + #[pyo3(signature = (bind_host=None, bind_port=None))] + #[new] + pub fn new(py: Python, bind_host: Option, bind_port: Option) -> Self { + let mut config = SchedulerConfig::default(); + + if let Some(bind_port) = bind_port { + config.bind_port = bind_port; + } + + if let Some(host) = bind_host { + config.bind_host = host; + } + + config.override_logical_codec = + Some(Arc::new(PyLogicalCodec::try_new(py).unwrap())); + config.override_physical_codec = + Some(Arc::new(PyPhysicalCodec::try_new(py).unwrap())); + + Self { + config, + handle: None, + } + } + + pub fn start(&mut self, py: Python) -> PyResult<()> { + if self.handle.is_some() { + return Err(PyException::new_err("Scheduler already started")); + } + let cluster = wait_for_future(py, BallistaCluster::new_from_config(&self.config)) + .map_err(to_pyerr)?; + + let config = self.config.clone(); + let address = format!("{}:{}", config.bind_host, config.bind_port); + let address = address.parse()?; + let handle = spawn_feature(py, async move { + start_server(cluster, address, Arc::new(config)) + .await + .unwrap(); + }); + self.handle = Some(handle); + + Ok(()) + } + + pub fn wait_for_termination(&mut self, py: Python) -> PyResult<()> { + if self.handle.is_none() { + return Err(PyException::new_err("Scheduler not started")); + } + let mut handle = None; + std::mem::swap(&mut self.handle, &mut handle); + + match handle { + Some(handle) => wait_for_future(py, handle.into_future()) + .map_err(|e| PyException::new_err(e.to_string())), + None => Ok(()), + } + } + + pub fn close(&mut self) -> PyResult<()> { + let mut handle = None; + std::mem::swap(&mut self.handle, &mut handle); + + if let Some(handle) = handle { + handle.abort() + } + + Ok(()) + } + + #[classattr] + pub fn version() -> &'static str { + ballista_core::BALLISTA_VERSION + } + + pub fn __str__(&self) -> String { + match self.handle { + Some(_) => format!( + "listening address={}:{}", + self.config.bind_host, self.config.bind_port, + ), + None => format!( + "configured address={}:{}", + self.config.bind_host, self.config.bind_port, + ), + } + } + + pub fn __repr__(&self) -> String { + format!( + "BallistaScheduler(listening address={}:{}, listening= {})", + self.config.bind_host, + self.config.bind_port, + self.handle.is_some() + ) + } +} + +#[pyclass(name = "BallistaExecutor", module = "ballista", subclass)] +pub struct PyExecutor { + config: Arc, + handle: Option>, +} + +#[pymethods] +impl PyExecutor { + #[pyo3(signature = (bind_port=None, bind_host =None, scheduler_host = None, scheduler_port = None, concurrent_tasks = None))] + #[new] + pub fn new( + py: Python, + bind_port: Option, + bind_host: Option, + scheduler_host: Option, + scheduler_port: Option, + concurrent_tasks: Option, + ) -> PyResult { + let mut config = ExecutorProcessConfig::default(); + if let Some(port) = bind_port { + config.port = port; + } + + if let Some(host) = bind_host { + config.bind_host = host; + } + + if let Some(port) = scheduler_port { + config.scheduler_port = port; + } + + if let Some(host) = scheduler_host { + config.scheduler_host = host; + } + + if let Some(concurrent_tasks) = concurrent_tasks { + config.concurrent_tasks = concurrent_tasks as usize + } + + config.override_logical_codec = Some(Arc::new(PyLogicalCodec::try_new(py)?)); + config.override_physical_codec = Some(Arc::new(PyPhysicalCodec::try_new(py)?)); + + let config = Arc::new(config); + Ok(Self { + config, + handle: None, + }) + } + + pub fn start(&mut self, py: Python) -> PyResult<()> { + if self.handle.is_some() { + return Err(PyException::new_err("Executor already started")); + } + + let config = self.config.clone(); + + let handle = + spawn_feature( + py, + async move { start_executor_process(config).await.unwrap() }, + ); + self.handle = Some(handle); + + Ok(()) + } + + pub fn wait_for_termination(&mut self, py: Python) -> PyResult<()> { + if self.handle.is_none() { + return Err(PyException::new_err("Executor not started")); + } + let mut handle = None; + std::mem::swap(&mut self.handle, &mut handle); + + match handle { + Some(handle) => wait_for_future(py, handle.into_future()) + .map_err(|e| PyException::new_err(e.to_string())) + .map(|_| ()), + None => Ok(()), + } + } + + pub fn close(&mut self) -> PyResult<()> { + let mut handle = None; + std::mem::swap(&mut self.handle, &mut handle); + + if let Some(handle) = handle { + handle.abort() + } + + Ok(()) + } + + #[classattr] + pub fn version() -> &'static str { + ballista_core::BALLISTA_VERSION + } + + pub fn __str__(&self) -> String { + match self.handle { + Some(_) => format!( + "listening address={}:{}, scheduler={}:{}", + self.config.bind_host, + self.config.port, + self.config.scheduler_host, + self.config.scheduler_port + ), + None => format!( + "configured address={}:{}, scheduler={}:{}", + self.config.bind_host, + self.config.port, + self.config.scheduler_host, + self.config.scheduler_port, + ), + } + } + + pub fn __repr__(&self) -> String { + format!( + "BallistaExecutor(address={}:{}, scheduler={}:{}, concurrent_tasks={} listening={})", + self.config.bind_host, + self.config.port, + self.config.scheduler_host, + self.config.scheduler_port, + self.config.concurrent_tasks, + self.handle.is_some() + ) + } +} diff --git a/python/src/codec.rs b/python/src/codec.rs new file mode 100644 index 000000000..c6b0b7e50 --- /dev/null +++ b/python/src/codec.rs @@ -0,0 +1,253 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use ballista_core::serde::{ + BallistaLogicalExtensionCodec, BallistaPhysicalExtensionCodec, +}; +use datafusion::logical_expr::ScalarUDF; +use datafusion_proto::logical_plan::LogicalExtensionCodec; +use datafusion_proto::physical_plan::PhysicalExtensionCodec; +use pyo3::types::{PyAnyMethods, PyBytes, PyBytesMethods}; +use pyo3::{PyObject, PyResult, Python}; +use std::fmt::Debug; +use std::sync::Arc; + +static MODULE: &str = "cloudpickle"; +static FUN_LOADS: &str = "loads"; +static FUN_DUMPS: &str = "dumps"; + +/// Serde protocol for UD(a)F +#[derive(Debug)] +struct CloudPickle { + loads: PyObject, + dumps: PyObject, +} + +impl CloudPickle { + pub fn try_new(py: Python<'_>) -> PyResult { + let module = py.import_bound(MODULE)?; + let loads = module.getattr(FUN_LOADS)?.unbind(); + let dumps = module.getattr(FUN_DUMPS)?.unbind(); + + Ok(Self { loads, dumps }) + } + + pub fn pickle(&self, py: Python<'_>, py_any: &PyObject) -> PyResult> { + let b: PyObject = self.dumps.call1(py, (py_any,))?.extract(py)?; + let blob = b.downcast_bound::(py)?.clone(); + + Ok(blob.as_bytes().to_owned()) + } + + pub fn unpickle(&self, py: Python<'_>, blob: &[u8]) -> PyResult { + let t: PyObject = self.loads.call1(py, (blob,))?.extract(py)?; + + Ok(t) + } +} + +pub struct PyLogicalCodec { + inner: BallistaLogicalExtensionCodec, + cloudpickle: CloudPickle, +} + +impl PyLogicalCodec { + pub fn try_new(py: Python<'_>) -> PyResult { + Ok(Self { + inner: BallistaLogicalExtensionCodec::default(), + cloudpickle: CloudPickle::try_new(py)?, + }) + } +} + +impl Debug for PyLogicalCodec { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PyLogicalCodec").finish() + } +} + +impl LogicalExtensionCodec for PyLogicalCodec { + fn try_decode( + &self, + buf: &[u8], + inputs: &[datafusion::logical_expr::LogicalPlan], + ctx: &datafusion::prelude::SessionContext, + ) -> datafusion::error::Result { + self.inner.try_decode(buf, inputs, ctx) + } + + fn try_encode( + &self, + node: &datafusion::logical_expr::Extension, + buf: &mut Vec, + ) -> datafusion::error::Result<()> { + self.inner.try_encode(node, buf) + } + + fn try_decode_table_provider( + &self, + buf: &[u8], + table_ref: &datafusion::sql::TableReference, + schema: datafusion::arrow::datatypes::SchemaRef, + ctx: &datafusion::prelude::SessionContext, + ) -> datafusion::error::Result> + { + self.inner + .try_decode_table_provider(buf, table_ref, schema, ctx) + } + + fn try_encode_table_provider( + &self, + table_ref: &datafusion::sql::TableReference, + node: std::sync::Arc, + buf: &mut Vec, + ) -> datafusion::error::Result<()> { + self.inner.try_encode_table_provider(table_ref, node, buf) + } + + fn try_decode_file_format( + &self, + buf: &[u8], + ctx: &datafusion::prelude::SessionContext, + ) -> datafusion::error::Result< + std::sync::Arc, + > { + self.inner.try_decode_file_format(buf, ctx) + } + + fn try_encode_file_format( + &self, + buf: &mut Vec, + node: std::sync::Arc, + ) -> datafusion::error::Result<()> { + self.inner.try_encode_file_format(buf, node) + } + + fn try_decode_udf( + &self, + name: &str, + buf: &[u8], + ) -> datafusion::error::Result> + { + // use cloud pickle to decode udf + self.inner.try_decode_udf(name, buf) + } + + fn try_encode_udf( + &self, + node: &datafusion::logical_expr::ScalarUDF, + buf: &mut Vec, + ) -> datafusion::error::Result<()> { + // use cloud pickle to decode udf + self.inner.try_encode_udf(node, buf) + } + + fn try_decode_udaf( + &self, + name: &str, + buf: &[u8], + ) -> datafusion::error::Result> + { + self.inner.try_decode_udaf(name, buf) + } + + fn try_encode_udaf( + &self, + node: &datafusion::logical_expr::AggregateUDF, + buf: &mut Vec, + ) -> datafusion::error::Result<()> { + self.inner.try_encode_udaf(node, buf) + } + + fn try_decode_udwf( + &self, + name: &str, + buf: &[u8], + ) -> datafusion::error::Result> + { + self.inner.try_decode_udwf(name, buf) + } + + fn try_encode_udwf( + &self, + node: &datafusion::logical_expr::WindowUDF, + buf: &mut Vec, + ) -> datafusion::error::Result<()> { + self.inner.try_encode_udwf(node, buf) + } +} + +pub struct PyPhysicalCodec { + inner: BallistaPhysicalExtensionCodec, + cloudpickle: CloudPickle, +} + +impl Debug for PyPhysicalCodec { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PyPhysicalCodec").finish() + } +} + +impl PyPhysicalCodec { + pub fn try_new(py: Python<'_>) -> PyResult { + Ok(Self { + inner: BallistaPhysicalExtensionCodec::default(), + cloudpickle: CloudPickle::try_new(py)?, + }) + } +} + +impl PhysicalExtensionCodec for PyPhysicalCodec { + fn try_decode( + &self, + buf: &[u8], + inputs: &[std::sync::Arc], + registry: &dyn datafusion::execution::FunctionRegistry, + ) -> datafusion::error::Result< + std::sync::Arc, + > { + self.inner.try_decode(buf, inputs, registry) + } + + fn try_encode( + &self, + node: std::sync::Arc, + buf: &mut Vec, + ) -> datafusion::error::Result<()> { + self.inner.try_encode(node, buf) + } + + fn try_decode_udf( + &self, + name: &str, + _buf: &[u8], + ) -> datafusion::common::Result> { + // use cloudpickle here + datafusion::common::not_impl_err!( + "PhysicalExtensionCodec is not provided for scalar function {name}" + ) + } + + fn try_encode_udf( + &self, + _node: &ScalarUDF, + _buf: &mut Vec, + ) -> datafusion::common::Result<()> { + // use cloudpickle here + Ok(()) + } +} diff --git a/python/src/lib.rs b/python/src/lib.rs index 41b4b6d31..13a6c38b9 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -15,32 +15,36 @@ // specific language governing permissions and limitations // under the License. +use crate::utils::wait_for_future; use ballista::prelude::*; +use cluster::{PyExecutor, PyScheduler}; use datafusion::execution::SessionStateBuilder; use datafusion::prelude::*; use datafusion_python::context::PySessionContext; -use datafusion_python::utils::wait_for_future; - -use std::collections::HashMap; - use pyo3::prelude::*; + +mod cluster; +#[allow(dead_code)] +mod codec; mod utils; -use utils::to_pyerr; + +pub(crate) struct TokioRuntime(tokio::runtime::Runtime); #[pymodule] fn ballista_internal(_py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { pyo3_log::init(); - // BallistaBuilder struct + m.add_class::()?; - // DataFusion struct m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + Ok(()) } -// Ballista Builder will take a HasMap/Dict Cionfg #[pyclass(name = "BallistaBuilder", module = "ballista", subclass)] pub struct PyBallistaBuilder { - conf: HashMap, + session_config: SessionConfig, } #[pymethods] @@ -48,56 +52,47 @@ impl PyBallistaBuilder { #[new] pub fn new() -> Self { Self { - conf: HashMap::new(), + session_config: SessionConfig::new_with_ballista(), } } pub fn config( mut slf: PyRefMut<'_, Self>, - k: &str, - v: &str, + key: &str, + value: &str, py: Python, ) -> PyResult { - slf.conf.insert(k.into(), v.into()); + let _ = slf.session_config.options_mut().set(key, value); Ok(slf.into_py(py)) } /// Construct the standalone instance from the SessionContext pub fn standalone(&self, py: Python) -> PyResult { - // Build the config - let config: SessionConfig = SessionConfig::from_string_hash_map(&self.conf)?; - // Build the state let state = SessionStateBuilder::new() - .with_config(config) + .with_config(self.session_config.clone()) .with_default_features() .build(); - // Build the context - let standalone_session = SessionContext::standalone_with_state(state); - // SessionContext is an async function - let ctx = wait_for_future(py, standalone_session)?; + let ctx = wait_for_future(py, SessionContext::standalone_with_state(state))?; - // Convert the SessionContext into a Python SessionContext Ok(ctx.into()) } /// Construct the remote instance from the SessionContext pub fn remote(&self, url: &str, py: Python) -> PyResult { - // Build the config - let config: SessionConfig = SessionConfig::from_string_hash_map(&self.conf)?; - // Build the state let state = SessionStateBuilder::new() - .with_config(config) + .with_config(self.session_config.clone()) .with_default_features() .build(); - // Build the context - let remote_session = SessionContext::remote_with_state(url, state); - // SessionContext is an async function - let ctx = wait_for_future(py, remote_session)?; + let ctx = wait_for_future(py, SessionContext::remote_with_state(url, state))?; - // Convert the SessionContext into a Python SessionContext Ok(ctx.into()) } + + #[classattr] + pub fn version() -> &'static str { + ballista_core::BALLISTA_VERSION + } } diff --git a/python/src/utils.rs b/python/src/utils.rs index 10278537e..f069475ea 100644 --- a/python/src/utils.rs +++ b/python/src/utils.rs @@ -15,10 +15,48 @@ // specific language governing permissions and limitations // under the License. +use std::future::Future; +use std::sync::OnceLock; +use tokio::task::JoinHandle; + use ballista_core::error::BallistaError; use pyo3::exceptions::PyException; -use pyo3::PyErr; +use pyo3::{PyErr, Python}; +use tokio::runtime::Runtime; + +use crate::TokioRuntime; pub(crate) fn to_pyerr(err: BallistaError) -> PyErr { PyException::new_err(err.to_string()) } + +#[inline] +pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime { + // NOTE: Other pyo3 python libraries have had issues with using tokio + // behind a forking app-server like `gunicorn` + // If we run into that problem, in the future we can look to `delta-rs` + // which adds a check in that disallows calls from a forked process + // https://github.com/delta-io/delta-rs/blob/87010461cfe01563d91a4b9cd6fa468e2ad5f283/python/src/utils.rs#L10-L31 + static RUNTIME: OnceLock = OnceLock::new(); + RUNTIME.get_or_init(|| TokioRuntime(tokio::runtime::Runtime::new().unwrap())) +} + +/// Utility to collect rust futures with GIL released +pub(crate) fn wait_for_future(py: Python, f: F) -> F::Output +where + F: Future + Send, + F::Output: Send, +{ + let runtime: &Runtime = &get_tokio_runtime().0; + py.allow_threads(|| runtime.block_on(f)) +} + +pub(crate) fn spawn_feature(py: Python, f: F) -> JoinHandle +where + F: Future + Send + 'static, + F::Output: Send, +{ + let runtime: &Runtime = &get_tokio_runtime().0; + // do we need py.allow_threads ? + py.allow_threads(|| runtime.spawn(f)) +}