diff --git a/Cargo.toml b/Cargo.toml index 0467d8ab6..a803da0d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,7 @@ tonic-build = { version = "0.12", default-features = false, features = [ tracing = "0.1.36" tracing-appender = "0.2.2" tracing-subscriber = { version = "0.3.15", features = ["env-filter"] } +ctor = { version = "0.2" } tokio = { version = "1" } uuid = { version = "1.10", features = ["v4", "v7"] } diff --git a/ballista/client/Cargo.toml b/ballista/client/Cargo.toml index a5d930307..9614412f8 100644 --- a/ballista/client/Cargo.toml +++ b/ballista/client/Cargo.toml @@ -44,12 +44,11 @@ url = { workspace = true } [dev-dependencies] ballista-executor = { path = "../executor", version = "0.12.0" } ballista-scheduler = { path = "../scheduler", version = "0.12.0" } -ctor = { version = "0.2" } +ctor = { workspace = true } env_logger = { workspace = true } -object_store = { workspace = true, features = ["aws"] } -testcontainers-modules = { version = "0.11", features = ["minio"] } +rstest = { version = "0.23" } +tonic = { workspace = true } [features] default = ["standalone"] standalone = ["ballista-executor", "ballista-scheduler"] -testcontainers = [] diff --git a/ballista/client/src/extension.rs b/ballista/client/src/extension.rs index ff603ea3e..272f0ca96 100644 --- a/ballista/client/src/extension.rs +++ b/ballista/client/src/extension.rs @@ -122,7 +122,7 @@ impl SessionContextExt for SessionContext { let config = SessionConfig::new_with_ballista(); let scheduler_url = Extension::parse_url(url)?; log::info!( - "Connecting to Ballista scheduler at {}", + "Connecting to Ballista scheduler at: {}", scheduler_url.clone() ); let remote_session_id = @@ -245,10 +245,11 @@ impl Extension { .map_err(|e| DataFusionError::Configuration(e.to_string()))?; } Some(session_state) => { - ballista_executor::new_standalone_executor_from_state::< - datafusion_proto::protobuf::LogicalPlanNode, - datafusion_proto::protobuf::PhysicalPlanNode, - >(scheduler, concurrent_tasks, session_state) + ballista_executor::new_standalone_executor_from_state( + scheduler, + concurrent_tasks, + session_state, + ) .await .map_err(|e| DataFusionError::Configuration(e.to_string()))?; } diff --git a/ballista/client/tests/common/mod.rs b/ballista/client/tests/common/mod.rs index 30b8f9f90..1d2bca94b 100644 --- a/ballista/client/tests/common/mod.rs +++ b/ballista/client/tests/common/mod.rs @@ -19,64 +19,14 @@ use std::env; use std::error::Error; use std::path::PathBuf; -use ballista::prelude::SessionConfigExt; +use ballista::prelude::{SessionConfigExt, SessionContextExt}; use ballista_core::serde::{ protobuf::scheduler_grpc_client::SchedulerGrpcClient, BallistaCodec, }; use ballista_core::{ConfigProducer, RuntimeProducer}; use ballista_scheduler::SessionBuilder; use datafusion::execution::SessionState; -use datafusion::prelude::SessionConfig; -use object_store::aws::AmazonS3Builder; -use testcontainers_modules::minio::MinIO; -use testcontainers_modules::testcontainers::core::{CmdWaitFor, ExecCommand}; -use testcontainers_modules::testcontainers::ContainerRequest; -use testcontainers_modules::{minio, testcontainers::ImageExt}; - -pub const REGION: &str = "eu-west-1"; -pub const BUCKET: &str = "ballista"; -pub const ACCESS_KEY_ID: &str = "MINIO"; -pub const SECRET_KEY: &str = "MINIOMINIO"; - -#[allow(dead_code)] -pub fn create_s3_store( - port: u16, -) -> std::result::Result { - AmazonS3Builder::new() - .with_endpoint(format!("http://localhost:{port}")) - .with_region(REGION) - .with_bucket_name(BUCKET) - .with_access_key_id(ACCESS_KEY_ID) - .with_secret_access_key(SECRET_KEY) - .with_allow_http(true) - .build() -} - -#[allow(dead_code)] -pub fn create_minio_container() -> ContainerRequest { - MinIO::default() - .with_env_var("MINIO_ACCESS_KEY", ACCESS_KEY_ID) - .with_env_var("MINIO_SECRET_KEY", SECRET_KEY) -} - -#[allow(dead_code)] -pub fn create_bucket_command() -> ExecCommand { - // this is hack to create a bucket without creating s3 client. - // this works with current testcontainer (and image) version 'RELEASE.2022-02-07T08-17-33Z'. - // (testcontainer does not await properly on latest image version) - // - // if testcontainer image version change to something newer we should use "mc mb /data/ballista" - // to crate a bucket. - ExecCommand::new(vec![ - "mkdir".to_string(), - format!("/data/{}", crate::common::BUCKET), - ]) - .with_cmd_ready_condition(CmdWaitFor::seconds(1)) -} - -// /// Remote ballista cluster to be used for local testing. -// static BALLISTA_CLUSTER: tokio::sync::OnceCell<(String, u16)> = -// tokio::sync::OnceCell::const_new(); +use datafusion::prelude::{SessionConfig, SessionContext}; /// Returns the parquet test data directory, which is by default /// stored in a git submodule rooted at @@ -161,17 +111,8 @@ pub async fn setup_test_cluster() -> (String, u16) { let host = "localhost".to_string(); - let scheduler_url = format!("http://{}:{}", host, addr.port()); - - let scheduler = loop { - match SchedulerGrpcClient::connect(scheduler_url.clone()).await { - Err(_) => { - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - log::info!("Attempting to connect to test scheduler..."); - } - Ok(scheduler) => break scheduler, - } - }; + let scheduler = + connect_to_scheduler(format!("http://{}:{}", host, addr.port())).await; ballista_executor::new_standalone_executor( scheduler, @@ -190,7 +131,6 @@ pub async fn setup_test_cluster() -> (String, u16) { #[allow(dead_code)] pub async fn setup_test_cluster_with_state(session_state: SessionState) -> (String, u16) { let config = SessionConfig::new_with_ballista(); - //let default_codec = BallistaCodec::default(); let addr = ballista_scheduler::standalone::new_standalone_scheduler_from_state( &session_state, @@ -200,22 +140,10 @@ pub async fn setup_test_cluster_with_state(session_state: SessionState) -> (Stri let host = "localhost".to_string(); - let scheduler_url = format!("http://{}:{}", host, addr.port()); + let scheduler = + connect_to_scheduler(format!("http://{}:{}", host, addr.port())).await; - let scheduler = loop { - match SchedulerGrpcClient::connect(scheduler_url.clone()).await { - Err(_) => { - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - log::info!("Attempting to connect to test scheduler..."); - } - Ok(scheduler) => break scheduler, - } - }; - - ballista_executor::new_standalone_executor_from_state::< - datafusion_proto::protobuf::LogicalPlanNode, - datafusion_proto::protobuf::PhysicalPlanNode, - >( + ballista_executor::new_standalone_executor_from_state( scheduler, config.ballista_standalone_parallelism(), &session_state, @@ -253,22 +181,13 @@ pub async fn setup_test_cluster_with_builders( let host = "localhost".to_string(); - let scheduler_url = format!("http://{}:{}", host, addr.port()); - - let scheduler = loop { - match SchedulerGrpcClient::connect(scheduler_url.clone()).await { - Err(_) => { - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - log::info!("Attempting to connect to test scheduler..."); - } - Ok(scheduler) => break scheduler, - } - }; + let scheduler = + connect_to_scheduler(format!("http://{}:{}", host, addr.port())).await; ballista_executor::new_standalone_executor_from_builder( scheduler, config.ballista_standalone_parallelism(), - config_producer.clone(), + config_producer, runtime_producer, codec, Default::default(), @@ -281,6 +200,40 @@ pub async fn setup_test_cluster_with_builders( (host, addr.port()) } +async fn connect_to_scheduler( + scheduler_url: String, +) -> SchedulerGrpcClient { + let mut retry = 50; + loop { + match SchedulerGrpcClient::connect(scheduler_url.clone()).await { + Err(_) if retry > 0 => { + retry -= 1; + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + log::debug!("Re-attempting to connect to test scheduler..."); + } + + Err(_) => { + log::error!("scheduler connection timed out"); + panic!("scheduler connection timed out") + } + Ok(scheduler) => break scheduler, + } + } +} + +#[allow(dead_code)] +pub async fn standalone_context() -> SessionContext { + SessionContext::standalone().await.unwrap() +} + +#[allow(dead_code)] +pub async fn remote_context() -> SessionContext { + let (host, port) = setup_test_cluster().await; + SessionContext::remote(&format!("df://{host}:{port}")) + .await + .unwrap() +} + #[ctor::ctor] fn init() { // Enable RUST_LOG logging configuration for test diff --git a/ballista/client/tests/context_standalone.rs b/ballista/client/tests/context_basic.rs similarity index 99% rename from ballista/client/tests/context_standalone.rs rename to ballista/client/tests/context_basic.rs index c17b53e59..5c137eecb 100644 --- a/ballista/client/tests/context_standalone.rs +++ b/ballista/client/tests/context_basic.rs @@ -23,7 +23,7 @@ mod common; // #[cfg(test)] #[cfg(feature = "standalone")] -mod standalone_tests { +mod basic { use ballista::prelude::SessionContextExt; use datafusion::arrow; use datafusion::arrow::util::pretty::pretty_format_batches; diff --git a/ballista/client/tests/context_checks.rs b/ballista/client/tests/context_checks.rs new file mode 100644 index 000000000..d3f8fc930 --- /dev/null +++ b/ballista/client/tests/context_checks.rs @@ -0,0 +1,366 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod common; +#[cfg(test)] +mod supported { + + use crate::common::{remote_context, standalone_context}; + use ballista_core::config::BallistaConfig; + use datafusion::prelude::*; + use datafusion::{assert_batches_eq, prelude::SessionContext}; + use rstest::*; + + #[rstest::fixture] + fn test_data() -> String { + crate::common::example_test_data() + } + + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + async fn should_execute_sql_show( + #[future(awt)] + #[case] + ctx: SessionContext, + test_data: String, + ) -> datafusion::error::Result<()> { + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + + let result = ctx + .sql("select string_col, timestamp_col from test where id > 4") + .await? + .collect() + .await?; + let expected = [ + "+------------+---------------------+", + "| string_col | timestamp_col |", + "+------------+---------------------+", + "| 31 | 2009-03-01T00:01:00 |", + "| 30 | 2009-04-01T00:00:00 |", + "| 31 | 2009-04-01T00:01:00 |", + "+------------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + async fn should_execute_sql_show_configs( + #[future(awt)] + #[case] + ctx: SessionContext, + ) -> datafusion::error::Result<()> { + let result = ctx + .sql("select name from information_schema.df_settings where name like 'datafusion.%' order by name limit 5") + .await? + .collect() + .await?; + // + let expected = [ + "+------------------------------------------------------+", + "| name |", + "+------------------------------------------------------+", + "| datafusion.catalog.create_default_catalog_and_schema |", + "| datafusion.catalog.default_catalog |", + "| datafusion.catalog.default_schema |", + "| datafusion.catalog.format |", + "| datafusion.catalog.has_header |", + "+------------------------------------------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + async fn should_execute_sql_show_configs_ballista( + #[future(awt)] + #[case] + ctx: SessionContext, + ) -> datafusion::error::Result<()> { + let state = ctx.state(); + let ballista_config_extension = + state.config().options().extensions.get::(); + + // ballista configuration should be registered with + // session state + assert!(ballista_config_extension.is_some()); + + let result = ctx + .sql("select name, value from information_schema.df_settings where name like 'ballista.%' order by name limit 2") + .await? + .collect() + .await?; + + let expected = [ + "+---------------------------------------+----------+", + "| name | value |", + "+---------------------------------------+----------+", + "| ballista.grpc_client_max_message_size | 16777216 |", + "| ballista.job.name | |", + "+---------------------------------------+----------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + async fn should_execute_sql_set_configs( + #[future(awt)] + #[case] + ctx: SessionContext, + ) -> datafusion::error::Result<()> { + ctx.sql("SET ballista.job.name = 'Super Cool Ballista App'") + .await? + .show() + .await?; + + let result = ctx + .sql("select name, value from information_schema.df_settings where name like 'ballista.job.name' order by name limit 1") + .await? + .collect() + .await?; + + let expected = [ + "+-------------------+-------------------------+", + "| name | value |", + "+-------------------+-------------------------+", + "| ballista.job.name | Super Cool Ballista App |", + "+-------------------+-------------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + + // select from ballista config + // check for SET = + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + async fn should_execute_show_tables( + #[future(awt)] + #[case] + ctx: SessionContext, + test_data: String, + ) -> datafusion::error::Result<()> { + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + + let result = ctx.sql("show tables").await?.collect().await?; + // + let expected = [ + "+---------------+--------------------+-------------+------------+", + "| table_catalog | table_schema | table_name | table_type |", + "+---------------+--------------------+-------------+------------+", + "| datafusion | public | test | BASE TABLE |", + "| datafusion | information_schema | tables | VIEW |", + "| datafusion | information_schema | views | VIEW |", + "| datafusion | information_schema | columns | VIEW |", + "| datafusion | information_schema | df_settings | VIEW |", + "| datafusion | information_schema | schemata | VIEW |", + "+---------------+--------------------+-------------+------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + async fn should_execute_sql_create_external_table( + #[future(awt)] + #[case] + ctx: SessionContext, + test_data: String, + ) -> datafusion::error::Result<()> { + ctx.sql(&format!("CREATE EXTERNAL TABLE tbl_test STORED AS PARQUET LOCATION '{}/alltypes_plain.parquet'", test_data, )).await?.show().await?; + + let result = ctx + .sql("select id, string_col, timestamp_col from tbl_test where id > 4") + .await? + .collect() + .await?; + let expected = [ + "+----+------------+---------------------+", + "| id | string_col | timestamp_col |", + "+----+------------+---------------------+", + "| 5 | 31 | 2009-03-01T00:01:00 |", + "| 6 | 30 | 2009-04-01T00:00:00 |", + "| 7 | 31 | 2009-04-01T00:01:00 |", + "+----+------------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + async fn should_execute_dataframe( + #[future(awt)] + #[case] + ctx: SessionContext, + test_data: String, + ) -> datafusion::error::Result<()> { + let df = ctx + .read_parquet( + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await? + .select_columns(&["id", "bool_col", "timestamp_col"])? + .filter(col("id").gt(lit(5)))?; + + let result = df.collect().await?; + + let expected = [ + "+----+----------+---------------------+", + "| id | bool_col | timestamp_col |", + "+----+----------+---------------------+", + "| 6 | true | 2009-04-01T00:00:00 |", + "| 7 | false | 2009-04-01T00:01:00 |", + "+----+----------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + #[cfg(not(windows))] // test is failing at windows, can't debug it + async fn should_execute_sql_write( + #[future(awt)] + #[case] + ctx: SessionContext, + test_data: String, + ) -> datafusion::error::Result<()> { + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + let write_dir = tempfile::tempdir().expect("temporary directory to be created"); + let write_dir_path = write_dir + .path() + .to_str() + .expect("path to be converted to str"); + + ctx.sql("select * from test") + .await? + .write_parquet(write_dir_path, Default::default(), Default::default()) + .await?; + ctx.register_parquet("written_table", write_dir_path, Default::default()) + .await?; + + let result = ctx + .sql("select id, string_col, timestamp_col from written_table where id > 4") + .await? + .collect() + .await?; + let expected = [ + "+----+------------+---------------------+", + "| id | string_col | timestamp_col |", + "+----+------------+---------------------+", + "| 5 | 31 | 2009-03-01T00:01:00 |", + "| 6 | 30 | 2009-04-01T00:00:00 |", + "| 7 | 31 | 2009-04-01T00:01:00 |", + "+----+------------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + Ok(()) + } + + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + async fn should_execute_sql_app_name_show( + #[future(awt)] + #[case] + ctx: SessionContext, + test_data: String, + ) -> datafusion::error::Result<()> { + ctx.sql("SET ballista.job.name = 'Super Cool Ballista App'") + .await? + .show() + .await?; + + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + + let result = ctx + .sql("select string_col, timestamp_col from test where id > 4") + .await? + .collect() + .await?; + let expected = [ + "+------------+---------------------+", + "| string_col | timestamp_col |", + "+------------+---------------------+", + "| 31 | 2009-03-01T00:01:00 |", + "| 30 | 2009-04-01T00:00:00 |", + "| 31 | 2009-04-01T00:01:00 |", + "+------------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } +} diff --git a/ballista/client/tests/setup.rs b/ballista/client/tests/context_setup.rs similarity index 100% rename from ballista/client/tests/setup.rs rename to ballista/client/tests/context_setup.rs diff --git a/ballista/client/tests/context_unsupported.rs b/ballista/client/tests/context_unsupported.rs new file mode 100644 index 000000000..fb6a16c7c --- /dev/null +++ b/ballista/client/tests/context_unsupported.rs @@ -0,0 +1,215 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod common; + +/// # Tracking Unsupported Operations +/// +/// It provides indication if/when datafusion +/// gets support for them + +#[cfg(test)] +mod unsupported { + use crate::common::{remote_context, standalone_context}; + use datafusion::prelude::*; + use datafusion::{assert_batches_eq, prelude::SessionContext}; + use rstest::*; + + #[rstest::fixture] + fn test_data() -> String { + crate::common::example_test_data() + } + + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + #[should_panic] + async fn should_execute_explain_query_correctly( + #[future(awt)] + #[case] + ctx: SessionContext, + test_data: String, + ) { + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await + .unwrap(); + + let result = ctx + .sql("EXPLAIN select count(*), id from test where id > 4 group by id") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let expected = vec![ + "+---------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+", + "| plan_type | plan |", + "+---------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+", + "| logical_plan | Projection: count(*), test.id |", + "| | Aggregate: groupBy=[[test.id]], aggr=[[count(Int64(1)) AS count(*)]] |", + "| | Filter: test.id > Int32(4) |", + "| | TableScan: test projection=[id], partial_filters=[test.id > Int32(4)] |", + "| physical_plan | ProjectionExec: expr=[count(*)@1 as count(*), id@0 as id] |", + "| | AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[count(*)] |", + "| | CoalesceBatchesExec: target_batch_size=8192 |", + "| | RepartitionExec: partitioning=Hash([id@0], 16), input_partitions=1 |", + "| | AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[count(*)] |", + "| | CoalesceBatchesExec: target_batch_size=8192 |", + "| | FilterExec: id@0 > 4 |", + "| | ParquetExec: file_groups={1 group: [[Users/ballista/git/arrow-ballista/ballista/client/testdata/alltypes_plain.parquet]]}, projection=[id], predicate=id@0 > 4, pruning_predicate=CASE WHEN id_null_count@1 = id_row_count@2 THEN false ELSE id_max@0 > 4 END, required_guarantees=[] |", + "| | |", + "+---------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+", + ]; + + assert_batches_eq!(expected, &result); + } + + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + #[should_panic] + async fn should_support_sql_create_table( + #[future(awt)] + #[case] + ctx: SessionContext, + ) { + ctx.sql("CREATE TABLE tbl_test (id INT, value INT)") + .await + .unwrap() + .show() + .await + .unwrap(); + + // it does create table but it can't be queried + let _result = ctx + .sql("select * from tbl_test where id > 0") + .await + .unwrap() + .collect() + .await + .unwrap(); + } + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + #[should_panic] + async fn should_support_caching_data_frame( + #[future(awt)] + #[case] + ctx: SessionContext, + test_data: String, + ) { + let df = ctx + .read_parquet( + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await + .unwrap() + .select_columns(&["id", "bool_col", "timestamp_col"]) + .unwrap() + .filter(col("id").gt(lit(5))) + .unwrap(); + + let cached_df = df.cache().await.unwrap(); + let result = cached_df.collect().await.unwrap(); + + let expected = [ + "+----+----------+---------------------+", + "| id | bool_col | timestamp_col |", + "+----+----------+---------------------+", + "| 6 | true | 2009-04-01T00:00:00 |", + "| 7 | false | 2009-04-01T00:01:00 |", + "+----+----------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + } + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + #[should_panic] + // "Error: Internal(failed to serialize logical plan: Internal(LogicalPlan serde is not yet implemented for Dml))" + async fn should_support_sql_insert_into( + #[future(awt)] + #[case] + ctx: SessionContext, + test_data: String, + ) { + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await + .unwrap(); + let write_dir = tempfile::tempdir().expect("temporary directory to be created"); + let write_dir_path = write_dir + .path() + .to_str() + .expect("path to be converted to str"); + + ctx.sql("select * from test") + .await + .unwrap() + .write_parquet(write_dir_path, Default::default(), Default::default()) + .await + .unwrap(); + + ctx.register_parquet("written_table", write_dir_path, Default::default()) + .await + .unwrap(); + + let _ = ctx + .sql("INSERT INTO written_table select * from written_table") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let result = ctx + .sql("select id, string_col, timestamp_col from written_table where id > 4 order by id") + .await.unwrap() + .collect() + .await.unwrap(); + + let expected = [ + "+----+------------+---------------------+", + "| id | string_col | timestamp_col |", + "+----+------------+---------------------+", + "| 5 | 31 | 2009-03-01T00:01:00 |", + "| 5 | 31 | 2009-03-01T00:01:00 |", + "| 6 | 30 | 2009-04-01T00:00:00 |", + "| 6 | 30 | 2009-04-01T00:00:00 |", + "| 7 | 31 | 2009-04-01T00:01:00 |", + "| 7 | 31 | 2009-04-01T00:01:00 |", + "+----+------------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + } +} diff --git a/ballista/client/tests/remote.rs b/ballista/client/tests/remote.rs deleted file mode 100644 index c03db8524..000000000 --- a/ballista/client/tests/remote.rs +++ /dev/null @@ -1,185 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -mod common; - -#[cfg(test)] -mod remote { - use ballista::prelude::SessionContextExt; - use datafusion::{assert_batches_eq, prelude::SessionContext}; - - #[tokio::test] - async fn should_execute_sql_show() -> datafusion::error::Result<()> { - let (host, port) = crate::common::setup_test_cluster().await; - let url = format!("df://{host}:{port}"); - - let test_data = crate::common::example_test_data(); - let ctx: SessionContext = SessionContext::remote(&url).await?; - - ctx.register_parquet( - "test", - &format!("{test_data}/alltypes_plain.parquet"), - Default::default(), - ) - .await?; - - let result = ctx - .sql("select string_col, timestamp_col from test where id > 4") - .await? - .collect() - .await?; - let expected = [ - "+------------+---------------------+", - "| string_col | timestamp_col |", - "+------------+---------------------+", - "| 31 | 2009-03-01T00:01:00 |", - "| 30 | 2009-04-01T00:00:00 |", - "| 31 | 2009-04-01T00:01:00 |", - "+------------+---------------------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } - - #[tokio::test] - #[cfg(not(windows))] // test is failing at windows, can't debug it - async fn should_execute_sql_write() -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); - let (host, port) = crate::common::setup_test_cluster().await; - let url = format!("df://{host}:{port}"); - - let ctx: SessionContext = SessionContext::remote(&url).await?; - - ctx.register_parquet( - "test", - &format!("{test_data}/alltypes_plain.parquet"), - Default::default(), - ) - .await?; - let write_dir = tempfile::tempdir().expect("temporary directory to be created"); - let write_dir_path = write_dir - .path() - .to_str() - .expect("path to be converted to str"); - - log::info!("writing to parquet .. {}", write_dir_path); - ctx.sql("select * from test") - .await? - .write_parquet(write_dir_path, Default::default(), Default::default()) - .await?; - - log::info!("registering parquet .. {}", write_dir_path); - ctx.register_parquet("written_table", write_dir_path, Default::default()) - .await?; - log::info!("reading from written parquet .."); - let result = ctx - .sql("select id, string_col, timestamp_col from written_table where id > 4") - .await? - .collect() - .await?; - let expected = [ - "+----+------------+---------------------+", - "| id | string_col | timestamp_col |", - "+----+------------+---------------------+", - "| 5 | 31 | 2009-03-01T00:01:00 |", - "| 6 | 30 | 2009-04-01T00:00:00 |", - "| 7 | 31 | 2009-04-01T00:01:00 |", - "+----+------------+---------------------+", - ]; - log::info!("reading from written parquet .. DONE"); - assert_batches_eq!(expected, &result); - Ok(()) - } - - #[tokio::test] - async fn should_execute_show_tables() -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); - - let (host, port) = crate::common::setup_test_cluster().await; - let url = format!("df://{host}:{port}"); - - let ctx: SessionContext = SessionContext::remote(&url).await?; - - ctx.register_parquet( - "test", - &format!("{test_data}/alltypes_plain.parquet"), - Default::default(), - ) - .await?; - - let result = ctx.sql("show tables").await?.collect().await?; - // - let expected = [ - "+---------------+--------------------+-------------+------------+", - "| table_catalog | table_schema | table_name | table_type |", - "+---------------+--------------------+-------------+------------+", - "| datafusion | public | test | BASE TABLE |", - "| datafusion | information_schema | tables | VIEW |", - "| datafusion | information_schema | views | VIEW |", - "| datafusion | information_schema | columns | VIEW |", - "| datafusion | information_schema | df_settings | VIEW |", - "| datafusion | information_schema | schemata | VIEW |", - "+---------------+--------------------+-------------+------------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } - - #[tokio::test] - async fn should_execute_sql_app_name_show() -> datafusion::error::Result<()> { - let (host, port) = crate::common::setup_test_cluster().await; - let url = format!("df://{host}:{port}"); - - let test_data = crate::common::example_test_data(); - let ctx: SessionContext = SessionContext::remote(&url).await?; - - ctx.sql("SET ballista.job.name = 'Super Cool Ballista App'") - .await? - .show() - .await?; - - ctx.register_parquet( - "test", - &format!("{test_data}/alltypes_plain.parquet"), - Default::default(), - ) - .await?; - - let result = ctx - .sql("select string_col, timestamp_col from test where id > 4") - .await? - .collect() - .await?; - let expected = [ - "+------------+---------------------+", - "| string_col | timestamp_col |", - "+------------+---------------------+", - "| 31 | 2009-03-01T00:01:00 |", - "| 30 | 2009-04-01T00:00:00 |", - "| 31 | 2009-04-01T00:01:00 |", - "+------------+---------------------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } -} diff --git a/ballista/client/tests/standalone.rs b/ballista/client/tests/standalone.rs deleted file mode 100644 index fe9f3df1a..000000000 --- a/ballista/client/tests/standalone.rs +++ /dev/null @@ -1,479 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -mod common; - -#[cfg(test)] -#[cfg(feature = "standalone")] -mod standalone { - use ballista::prelude::SessionContextExt; - use ballista_core::config::BallistaConfig; - use datafusion::prelude::*; - use datafusion::{assert_batches_eq, prelude::SessionContext}; - - #[tokio::test] - async fn should_execute_sql_show() -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); - - let ctx: SessionContext = SessionContext::standalone().await?; - ctx.register_parquet( - "test", - &format!("{test_data}/alltypes_plain.parquet"), - Default::default(), - ) - .await?; - - let result = ctx - .sql("select string_col, timestamp_col from test where id > 4") - .await? - .collect() - .await?; - let expected = [ - "+------------+---------------------+", - "| string_col | timestamp_col |", - "+------------+---------------------+", - "| 31 | 2009-03-01T00:01:00 |", - "| 30 | 2009-04-01T00:00:00 |", - "| 31 | 2009-04-01T00:01:00 |", - "+------------+---------------------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } - - #[tokio::test] - async fn should_execute_sql_show_configs() -> datafusion::error::Result<()> { - let ctx: SessionContext = SessionContext::standalone().await?; - - let result = ctx - .sql("select name from information_schema.df_settings where name like 'datafusion.%' order by name limit 5") - .await? - .collect() - .await?; - // - let expected = [ - "+------------------------------------------------------+", - "| name |", - "+------------------------------------------------------+", - "| datafusion.catalog.create_default_catalog_and_schema |", - "| datafusion.catalog.default_catalog |", - "| datafusion.catalog.default_schema |", - "| datafusion.catalog.format |", - "| datafusion.catalog.has_header |", - "+------------------------------------------------------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } - - #[tokio::test] - async fn should_execute_sql_show_configs_ballista() -> datafusion::error::Result<()> { - let ctx: SessionContext = SessionContext::standalone().await?; - let state = ctx.state(); - let ballista_config_extension = - state.config().options().extensions.get::(); - - // ballista configuration should be registered with - // session state - assert!(ballista_config_extension.is_some()); - - let result = ctx - .sql("select name, value from information_schema.df_settings where name like 'ballista.%' order by name limit 2") - .await? - .collect() - .await?; - - let expected = [ - "+---------------------------------------+----------+", - "| name | value |", - "+---------------------------------------+----------+", - "| ballista.grpc_client_max_message_size | 16777216 |", - "| ballista.job.name | |", - "+---------------------------------------+----------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } - - #[tokio::test] - async fn should_execute_sql_set_configs() -> datafusion::error::Result<()> { - let ctx: SessionContext = SessionContext::standalone().await?; - - ctx.sql("SET ballista.job.name = 'Super Cool Ballista App'") - .await? - .show() - .await?; - - let result = ctx - .sql("select name, value from information_schema.df_settings where name like 'ballista.job.name' order by name limit 1") - .await? - .collect() - .await?; - - let expected = [ - "+-------------------+-------------------------+", - "| name | value |", - "+-------------------+-------------------------+", - "| ballista.job.name | Super Cool Ballista App |", - "+-------------------+-------------------------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } - - // select from ballista config - // check for SET = - - #[tokio::test] - async fn should_execute_show_tables() -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); - - let ctx: SessionContext = SessionContext::standalone().await?; - ctx.register_parquet( - "test", - &format!("{test_data}/alltypes_plain.parquet"), - Default::default(), - ) - .await?; - - let result = ctx.sql("show tables").await?.collect().await?; - // - let expected = [ - "+---------------+--------------------+-------------+------------+", - "| table_catalog | table_schema | table_name | table_type |", - "+---------------+--------------------+-------------+------------+", - "| datafusion | public | test | BASE TABLE |", - "| datafusion | information_schema | tables | VIEW |", - "| datafusion | information_schema | views | VIEW |", - "| datafusion | information_schema | columns | VIEW |", - "| datafusion | information_schema | df_settings | VIEW |", - "| datafusion | information_schema | schemata | VIEW |", - "+---------------+--------------------+-------------+------------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } - - // - // TODO: It calls scheduler to generate the plan, but no - // but there is no ShuffleRead/Write in physical_plan - // - // ShuffleWriterExec: None, metrics=[output_rows=2, input_rows=2, write_time=1.782295ms, repart_time=1ns] - // ExplainExec, metrics=[] - // - #[tokio::test] - #[ignore = "It uses local files, will fail in CI"] - async fn should_execute_sql_explain() -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); - - let ctx: SessionContext = SessionContext::standalone().await?; - ctx.register_parquet( - "test", - &format!("{test_data}/alltypes_plain.parquet"), - Default::default(), - ) - .await?; - - let result = ctx - .sql("EXPLAIN select count(*), id from test where id > 4 group by id") - .await? - .collect() - .await?; - - let expected = vec![ - "+---------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+", - "| plan_type | plan |", - "+---------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+", - "| logical_plan | Projection: count(*), test.id |", - "| | Aggregate: groupBy=[[test.id]], aggr=[[count(Int64(1)) AS count(*)]] |", - "| | Filter: test.id > Int32(4) |", - "| | TableScan: test projection=[id], partial_filters=[test.id > Int32(4)] |", - "| physical_plan | ProjectionExec: expr=[count(*)@1 as count(*), id@0 as id] |", - "| | AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[count(*)] |", - "| | CoalesceBatchesExec: target_batch_size=8192 |", - "| | RepartitionExec: partitioning=Hash([id@0], 16), input_partitions=1 |", - "| | AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[count(*)] |", - "| | CoalesceBatchesExec: target_batch_size=8192 |", - "| | FilterExec: id@0 > 4 |", - "| | ParquetExec: file_groups={1 group: [[Users/ballista/git/arrow-ballista/ballista/client/testdata/alltypes_plain.parquet]]}, projection=[id], predicate=id@0 > 4, pruning_predicate=CASE WHEN id_null_count@1 = id_row_count@2 THEN false ELSE id_max@0 > 4 END, required_guarantees=[] |", - "| | |", - "+---------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } - - #[tokio::test] - async fn should_execute_sql_create_external_table() -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); - - let ctx: SessionContext = SessionContext::standalone().await?; - ctx.sql(&format!("CREATE EXTERNAL TABLE tbl_test STORED AS PARQUET LOCATION '{}/alltypes_plain.parquet'", test_data, )).await?.show().await?; - - let result = ctx - .sql("select id, string_col, timestamp_col from tbl_test where id > 4") - .await? - .collect() - .await?; - let expected = [ - "+----+------------+---------------------+", - "| id | string_col | timestamp_col |", - "+----+------------+---------------------+", - "| 5 | 31 | 2009-03-01T00:01:00 |", - "| 6 | 30 | 2009-04-01T00:00:00 |", - "| 7 | 31 | 2009-04-01T00:01:00 |", - "+----+------------+---------------------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } - - #[tokio::test] - #[ignore = "Error serializing custom table - NotImplemented(LogicalExtensionCodec is not provided))"] - async fn should_execute_sql_create_table() -> datafusion::error::Result<()> { - let ctx: SessionContext = SessionContext::standalone().await?; - ctx.sql("CREATE TABLE tbl_test (id INT, value INT)") - .await? - .show() - .await?; - - // it does create table but it can't be queried - let _result = ctx - .sql("select * from tbl_test where id > 0") - .await? - .collect() - .await?; - - Ok(()) - } - - #[tokio::test] - async fn should_execute_dataframe() -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); - - let ctx: SessionContext = SessionContext::standalone().await?; - - let df = ctx - .read_parquet( - &format!("{test_data}/alltypes_plain.parquet"), - Default::default(), - ) - .await? - .select_columns(&["id", "bool_col", "timestamp_col"])? - .filter(col("id").gt(lit(5)))?; - - let result = df.collect().await?; - - let expected = [ - "+----+----------+---------------------+", - "| id | bool_col | timestamp_col |", - "+----+----------+---------------------+", - "| 6 | true | 2009-04-01T00:00:00 |", - "| 7 | false | 2009-04-01T00:01:00 |", - "+----+----------+---------------------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } - - #[tokio::test] - #[ignore = "Error serializing custom table - NotImplemented(LogicalExtensionCodec is not provided))"] - async fn should_execute_dataframe_cache() -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); - - let ctx: SessionContext = SessionContext::standalone().await?; - - let df = ctx - .read_parquet( - &format!("{test_data}/alltypes_plain.parquet"), - Default::default(), - ) - .await? - .select_columns(&["id", "bool_col", "timestamp_col"])? - .filter(col("id").gt(lit(5)))?; - - let cached_df = df.cache().await?; - let result = cached_df.collect().await?; - - let expected = [ - "+----+----------+---------------------+", - "| id | bool_col | timestamp_col |", - "+----+----------+---------------------+", - "| 6 | true | 2009-04-01T00:00:00 |", - "| 7 | false | 2009-04-01T00:01:00 |", - "+----+----------+---------------------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } - - #[tokio::test] - #[ignore = "Error: Internal(failed to serialize logical plan: Internal(LogicalPlan serde is not yet implemented for Dml))"] - async fn should_execute_sql_insert() -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); - - let ctx: SessionContext = SessionContext::standalone().await?; - - ctx.register_parquet( - "test", - &format!("{test_data}/alltypes_plain.parquet"), - Default::default(), - ) - .await?; - let write_dir = tempfile::tempdir().expect("temporary directory to be created"); - let write_dir_path = write_dir - .path() - .to_str() - .expect("path to be converted to str"); - - ctx.sql("select * from test") - .await? - .write_parquet(write_dir_path, Default::default(), Default::default()) - .await?; - - ctx.register_parquet("written_table", write_dir_path, Default::default()) - .await?; - - let _ = ctx - .sql("INSERT INTO written_table select * from written_table") - .await? - .collect() - .await?; - - let result = ctx - .sql("select id, string_col, timestamp_col from written_table where id > 4 order by id") - .await? - .collect() - .await?; - - let expected = [ - "+----+------------+---------------------+", - "| id | string_col | timestamp_col |", - "+----+------------+---------------------+", - "| 5 | 31 | 2009-03-01T00:01:00 |", - "| 5 | 31 | 2009-03-01T00:01:00 |", - "| 6 | 30 | 2009-04-01T00:00:00 |", - "| 6 | 30 | 2009-04-01T00:00:00 |", - "| 7 | 31 | 2009-04-01T00:01:00 |", - "| 7 | 31 | 2009-04-01T00:01:00 |", - "+----+------------+---------------------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } - - #[tokio::test] - #[cfg(not(windows))] // test is failing at windows, can't debug it - async fn should_execute_sql_write() -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); - - let ctx: SessionContext = SessionContext::standalone().await?; - ctx.register_parquet( - "test", - &format!("{test_data}/alltypes_plain.parquet"), - Default::default(), - ) - .await?; - let write_dir = tempfile::tempdir().expect("temporary directory to be created"); - let write_dir_path = write_dir - .path() - .to_str() - .expect("path to be converted to str"); - - ctx.sql("select * from test") - .await? - .write_parquet(write_dir_path, Default::default(), Default::default()) - .await?; - ctx.register_parquet("written_table", write_dir_path, Default::default()) - .await?; - - let result = ctx - .sql("select id, string_col, timestamp_col from written_table where id > 4") - .await? - .collect() - .await?; - let expected = [ - "+----+------------+---------------------+", - "| id | string_col | timestamp_col |", - "+----+------------+---------------------+", - "| 5 | 31 | 2009-03-01T00:01:00 |", - "| 6 | 30 | 2009-04-01T00:00:00 |", - "| 7 | 31 | 2009-04-01T00:01:00 |", - "+----+------------+---------------------+", - ]; - - assert_batches_eq!(expected, &result); - Ok(()) - } - - #[tokio::test] - async fn should_execute_sql_app_name_show() -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); - let ctx: SessionContext = SessionContext::standalone().await?; - - ctx.sql("SET ballista.job.name = 'Super Cool Ballista App'") - .await? - .show() - .await?; - - ctx.register_parquet( - "test", - &format!("{test_data}/alltypes_plain.parquet"), - Default::default(), - ) - .await?; - - let result = ctx - .sql("select string_col, timestamp_col from test where id > 4") - .await? - .collect() - .await?; - let expected = [ - "+------------+---------------------+", - "| string_col | timestamp_col |", - "+------------+---------------------+", - "| 31 | 2009-03-01T00:01:00 |", - "| 30 | 2009-04-01T00:00:00 |", - "| 31 | 2009-04-01T00:01:00 |", - "+------------+---------------------+", - ]; - - assert_batches_eq!(expected, &result); - - Ok(()) - } -} diff --git a/ballista/executor/src/execution_loop.rs b/ballista/executor/src/execution_loop.rs index 758b34781..402ad2736 100644 --- a/ballista/executor/src/execution_loop.rs +++ b/ballista/executor/src/execution_loop.rs @@ -163,9 +163,12 @@ async fn run_received_task( +pub async fn new_standalone_executor_from_state( scheduler: SchedulerGrpcClient, concurrent_tasks: usize, session_state: &SessionState, diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 97b9f441b..65d9cd946 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -52,3 +52,13 @@ tokio = { workspace = true, features = [ "parking_lot" ] } url = { workspace = true } + +[dev-dependencies] +ctor = { workspace = true } +env_logger = { workspace = true } +testcontainers-modules = { version = "0.11", features = ["minio"] } +tonic = { workspace = true } + +[features] +default = [] +testcontainers = [] diff --git a/examples/tests/common/mod.rs b/examples/tests/common/mod.rs new file mode 100644 index 000000000..1e8091ed9 --- /dev/null +++ b/examples/tests/common/mod.rs @@ -0,0 +1,177 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use ballista::prelude::SessionConfigExt; +use ballista_core::serde::{ + protobuf::scheduler_grpc_client::SchedulerGrpcClient, BallistaCodec, +}; +use ballista_core::{ConfigProducer, RuntimeProducer}; +use ballista_scheduler::SessionBuilder; +use datafusion::execution::SessionState; +use datafusion::prelude::SessionConfig; +use object_store::aws::AmazonS3Builder; +use testcontainers_modules::minio::MinIO; +use testcontainers_modules::testcontainers::core::{CmdWaitFor, ExecCommand}; +use testcontainers_modules::testcontainers::ContainerRequest; +use testcontainers_modules::{minio, testcontainers::ImageExt}; + +pub const REGION: &str = "eu-west-1"; +pub const BUCKET: &str = "ballista"; +pub const ACCESS_KEY_ID: &str = "MINIO"; +pub const SECRET_KEY: &str = "MINIOMINIO"; + +#[allow(dead_code)] +pub fn create_s3_store( + host: &str, + port: u16, +) -> std::result::Result { + log::info!("create S3 client: host: {}, port: {}", host, port); + AmazonS3Builder::new() + .with_endpoint(format!("http://{host}:{port}")) + .with_region(REGION) + .with_bucket_name(BUCKET) + .with_access_key_id(ACCESS_KEY_ID) + .with_secret_access_key(SECRET_KEY) + .with_allow_http(true) + .build() +} + +#[allow(dead_code)] +pub fn create_minio_container() -> ContainerRequest { + MinIO::default() + .with_env_var("MINIO_ACCESS_KEY", ACCESS_KEY_ID) + .with_env_var("MINIO_SECRET_KEY", SECRET_KEY) +} + +#[allow(dead_code)] +pub fn create_bucket_command() -> ExecCommand { + // this is hack to create a bucket without creating s3 client. + // this works with current testcontainer (and image) version 'RELEASE.2022-02-07T08-17-33Z'. + // (testcontainer does not await properly on latest image version) + // + // if testcontainer image version change to something newer we should use "mc mb /data/ballista" + // to crate a bucket. + ExecCommand::new(vec![ + "mkdir".to_string(), + format!("/data/{}", crate::common::BUCKET), + ]) + .with_cmd_ready_condition(CmdWaitFor::seconds(1)) +} + +/// starts a ballista cluster for integration tests +#[allow(dead_code)] +pub async fn setup_test_cluster_with_state(session_state: SessionState) -> (String, u16) { + let config = SessionConfig::new_with_ballista(); + + let addr = ballista_scheduler::standalone::new_standalone_scheduler_from_state( + &session_state, + ) + .await + .expect("scheduler to be created"); + + let host = "localhost".to_string(); + + let scheduler = + connect_to_scheduler(format!("http://{}:{}", host, addr.port())).await; + + ballista_executor::new_standalone_executor_from_state( + scheduler, + config.ballista_standalone_parallelism(), + &session_state, + ) + .await + .expect("executor to be created"); + + log::info!("test scheduler created at: {}:{}", host, addr.port()); + + (host, addr.port()) +} + +#[allow(dead_code)] +pub async fn setup_test_cluster_with_builders( + config_producer: ConfigProducer, + runtime_producer: RuntimeProducer, + session_builder: SessionBuilder, +) -> (String, u16) { + let config = config_producer(); + + let logical = config.ballista_logical_extension_codec(); + let physical = config.ballista_physical_extension_codec(); + let codec = BallistaCodec::new(logical, physical); + + let addr = ballista_scheduler::standalone::new_standalone_scheduler_with_builder( + session_builder, + config_producer.clone(), + codec.clone(), + ) + .await + .expect("scheduler to be created"); + + let host = "localhost".to_string(); + + let scheduler = + connect_to_scheduler(format!("http://{}:{}", host, addr.port())).await; + + ballista_executor::new_standalone_executor_from_builder( + scheduler, + config.ballista_standalone_parallelism(), + config_producer, + runtime_producer, + codec, + Default::default(), + ) + .await + .expect("executor to be created"); + + log::info!("test scheduler created at: {}:{}", host, addr.port()); + + (host, addr.port()) +} + +async fn connect_to_scheduler( + scheduler_url: String, +) -> SchedulerGrpcClient { + let mut retry = 50; + loop { + match SchedulerGrpcClient::connect(scheduler_url.clone()).await { + Err(_) if retry > 0 => { + retry -= 1; + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + log::debug!("Re-attempting to connect to test scheduler..."); + } + + Err(_) => { + log::error!("scheduler connection timed out"); + panic!("scheduler connection timed out") + } + Ok(scheduler) => break scheduler, + } + } +} + +#[ctor::ctor] +fn init() { + // Enable RUST_LOG logging configuration for test + let _ = env_logger::builder() + .filter_level(log::LevelFilter::Info) + .parse_filters( + "ballista=debug,ballista_scheduler-rs=debug,ballista_executor=debug", + ) + //.parse_filters("ballista=debug,ballista_scheduler-rs=debug,ballista_executor=debug,datafusion=debug") + .is_test(true) + .try_init(); +} diff --git a/ballista/client/tests/object_store.rs b/examples/tests/object_store.rs similarity index 65% rename from ballista/client/tests/object_store.rs rename to examples/tests/object_store.rs index b36fd951b..6b0cacf09 100644 --- a/ballista/client/tests/object_store.rs +++ b/examples/tests/object_store.rs @@ -27,11 +27,11 @@ mod common; #[cfg(test)] -#[cfg(feature = "standalone")] #[cfg(feature = "testcontainers")] mod standalone { use ballista::extension::SessionContextExt; + use ballista_examples::test_util::examples_test_data; use datafusion::{assert_batches_eq, prelude::SessionContext}; use datafusion::{ error::DataFusionError, @@ -52,12 +52,13 @@ mod standalone { .await .unwrap(); + let host = node.get_host().await.unwrap(); let port = node.get_host_port_ipv4(9000).await.unwrap(); - let object_store = crate::common::create_s3_store(port) + let object_store = crate::common::create_s3_store(&host.to_string(), port) .map_err(|e| DataFusionError::External(e.into()))?; - let test_data = crate::common::example_test_data(); + let test_data = examples_test_data(); let config = RuntimeConfig::new(); let runtime_env = RuntimeEnv::new(config)?; @@ -116,6 +117,7 @@ mod standalone { mod remote { use ballista::extension::SessionContextExt; + use ballista_examples::test_util::examples_test_data; use datafusion::{assert_batches_eq, prelude::SessionContext}; use datafusion::{ error::DataFusionError, @@ -129,7 +131,7 @@ mod remote { #[tokio::test] async fn should_execute_sql_write() -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); + let test_data = examples_test_data(); let container = crate::common::create_minio_container(); let node = container.start().await.unwrap(); @@ -138,9 +140,10 @@ mod remote { .await .unwrap(); + let host = node.get_host().await.unwrap(); let port = node.get_host_port_ipv4(9000).await.unwrap(); - let object_store = crate::common::create_s3_store(port) + let object_store = crate::common::create_s3_store(&host.to_string(), port) .map_err(|e| DataFusionError::External(e.into()))?; let config = RuntimeConfig::new(); @@ -210,15 +213,12 @@ mod remote { #[cfg(feature = "testcontainers")] mod custom_s3_config { + use crate::common::{ACCESS_KEY_ID, SECRET_KEY}; use ballista::extension::SessionContextExt; use ballista::prelude::SessionConfigExt; use ballista_core::RuntimeProducer; - use datafusion::common::{config_err, exec_err}; - use datafusion::config::{ - ConfigEntry, ConfigExtension, ConfigField, ExtensionOptions, Visit, - }; - use datafusion::error::Result; - use datafusion::execution::object_store::ObjectStoreRegistry; + use ballista_examples::object_store::{CustomObjectStoreRegistry, S3Options}; + use ballista_examples::test_util::examples_test_data; use datafusion::execution::SessionState; use datafusion::prelude::SessionConfig; use datafusion::{assert_batches_eq, prelude::SessionContext}; @@ -229,22 +229,13 @@ mod custom_s3_config { SessionStateBuilder, }, }; - use object_store::aws::AmazonS3Builder; - use object_store::local::LocalFileSystem; - use object_store::ObjectStore; - use parking_lot::RwLock; - use std::any::Any; - use std::fmt::Display; use std::sync::Arc; use testcontainers_modules::testcontainers::runners::AsyncRunner; - use url::Url; - - use crate::common::{ACCESS_KEY_ID, SECRET_KEY}; #[tokio::test] async fn should_configure_s3_execute_sql_write_remote( ) -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); + let test_data = examples_test_data(); // // Minio cluster setup @@ -256,8 +247,15 @@ mod custom_s3_config { .await .unwrap(); + let endpoint_host = node.get_host().await.unwrap(); let endpoint_port = node.get_host_port_ipv4(9000).await.unwrap(); + log::info!( + "MINIO testcontainers host: {}, port: {}", + endpoint_host, + endpoint_port + ); + // // Session Context and Ballista cluster setup // @@ -325,8 +323,8 @@ mod custom_s3_config { .show() .await?; ctx.sql(&format!( - "SET s3.endpoint = 'http://localhost:{}'", - endpoint_port + "SET s3.endpoint = 'http://{}:{}'", + endpoint_host, endpoint_port )) .await? .show() @@ -380,10 +378,9 @@ mod custom_s3_config { // SessionConfig propagation across ballista cluster. #[tokio::test] - #[cfg(feature = "standalone")] async fn should_configure_s3_execute_sql_write_standalone( ) -> datafusion::error::Result<()> { - let test_data = crate::common::example_test_data(); + let test_data = examples_test_data(); // // Minio cluster setup @@ -420,11 +417,6 @@ mod custom_s3_config { let session_builder = Arc::new(produce_state); let state = session_builder(config_producer()); - // // setting up ballista cluster with new runtime, configuration, and session state producers - // let (host, port) = - // crate::common::setup_test_cluster_with_state(state.clone()).await; - // let url = format!("df://{host}:{port}"); - // // establishing cluster connection, let ctx: SessionContext = SessionContext::standalone_with_state(state).await?; @@ -507,231 +499,4 @@ mod custom_s3_config { .with_config(session_config) .build() } - - #[derive(Debug)] - pub struct CustomObjectStoreRegistry { - local: Arc, - s3options: S3Options, - } - - impl CustomObjectStoreRegistry { - fn new(s3options: S3Options) -> Self { - Self { - s3options, - local: Arc::new(LocalFileSystem::new()), - } - } - } - - impl ObjectStoreRegistry for CustomObjectStoreRegistry { - fn register_store( - &self, - _url: &Url, - _store: Arc, - ) -> Option> { - unreachable!("register_store not supported ") - } - - fn get_store(&self, url: &Url) -> Result> { - let scheme = url.scheme(); - log::info!("get_store: {:?}", &self.s3options.config.read()); - match scheme { - "" | "file" => Ok(self.local.clone()), - "s3" => { - let s3store = Self::s3_object_store_builder( - url, - &self.s3options.config.read(), - )? - .build()?; - - Ok(Arc::new(s3store)) - } - - _ => exec_err!("get_store - store not supported, url {}", url), - } - } - } - - impl CustomObjectStoreRegistry { - pub fn s3_object_store_builder( - url: &Url, - aws_options: &S3RegistryConfiguration, - ) -> Result { - let S3RegistryConfiguration { - access_key_id, - secret_access_key, - session_token, - region, - endpoint, - allow_http, - } = aws_options; - - let bucket_name = Self::get_bucket_name(url)?; - let mut builder = AmazonS3Builder::from_env().with_bucket_name(bucket_name); - - if let (Some(access_key_id), Some(secret_access_key)) = - (access_key_id, secret_access_key) - { - builder = builder - .with_access_key_id(access_key_id) - .with_secret_access_key(secret_access_key); - - if let Some(session_token) = session_token { - builder = builder.with_token(session_token); - } - } else { - return config_err!( - "'s3.access_key_id' & 's3.secret_access_key' must be configured" - ); - } - - if let Some(region) = region { - builder = builder.with_region(region); - } - - if let Some(endpoint) = endpoint { - if let Ok(endpoint_url) = Url::try_from(endpoint.as_str()) { - if !matches!(allow_http, Some(true)) - && endpoint_url.scheme() == "http" - { - return config_err!("Invalid endpoint: {endpoint}. HTTP is not allowed for S3 endpoints. To allow HTTP, set 's3.allow_http' to true"); - } - } - - builder = builder.with_endpoint(endpoint); - } - - if let Some(allow_http) = allow_http { - builder = builder.with_allow_http(*allow_http); - } - - Ok(builder) - } - - fn get_bucket_name(url: &Url) -> Result<&str> { - url.host_str().ok_or_else(|| { - DataFusionError::Execution(format!( - "Not able to parse bucket name from url: {}", - url.as_str() - )) - }) - } - } - - #[derive(Debug, Clone, Default)] - pub struct S3Options { - config: Arc>, - } - - impl ExtensionOptions for S3Options { - fn as_any(&self) -> &dyn Any { - self - } - - fn as_any_mut(&mut self) -> &mut dyn Any { - self - } - - fn cloned(&self) -> Box { - Box::new(self.clone()) - } - - fn set(&mut self, key: &str, value: &str) -> Result<()> { - log::debug!("set config, key:{}, value:{}", key, value); - match key { - "access_key_id" => { - let mut c = self.config.write(); - c.access_key_id.set(key, value)?; - } - "secret_access_key" => { - let mut c = self.config.write(); - c.secret_access_key.set(key, value)?; - } - "session_token" => { - let mut c = self.config.write(); - c.session_token.set(key, value)?; - } - "region" => { - let mut c = self.config.write(); - c.region.set(key, value)?; - } - "endpoint" => { - let mut c = self.config.write(); - c.endpoint.set(key, value)?; - } - "allow_http" => { - let mut c = self.config.write(); - c.allow_http.set(key, value)?; - } - _ => { - log::warn!("Config value {} cant be set to {}", key, value); - return config_err!( - "Config value \"{}\" not found in S3Options", - key - ); - } - } - Ok(()) - } - - fn entries(&self) -> Vec { - struct Visitor(Vec); - - impl Visit for Visitor { - fn some( - &mut self, - key: &str, - value: V, - description: &'static str, - ) { - self.0.push(ConfigEntry { - key: format!("{}.{}", S3Options::PREFIX, key), - value: Some(value.to_string()), - description, - }) - } - - fn none(&mut self, key: &str, description: &'static str) { - self.0.push(ConfigEntry { - key: format!("{}.{}", S3Options::PREFIX, key), - value: None, - description, - }) - } - } - let c = self.config.read(); - - let mut v = Visitor(vec![]); - c.access_key_id - .visit(&mut v, "access_key_id", "S3 Access Key"); - c.secret_access_key - .visit(&mut v, "secret_access_key", "S3 Secret Key"); - c.session_token - .visit(&mut v, "session_token", "S3 Session token"); - c.region.visit(&mut v, "region", "S3 region"); - c.endpoint.visit(&mut v, "endpoint", "S3 Endpoint"); - c.allow_http.visit(&mut v, "allow_http", "S3 Allow Http"); - - v.0 - } - } - - impl ConfigExtension for S3Options { - const PREFIX: &'static str = "s3"; - } - #[derive(Default, Debug, Clone)] - pub struct S3RegistryConfiguration { - /// Access Key ID - pub access_key_id: Option, - /// Secret Access Key - pub secret_access_key: Option, - /// Session token - pub session_token: Option, - /// AWS Region - pub region: Option, - /// OSS or COS Endpoint - pub endpoint: Option, - /// Allow HTTP (otherwise will always use https) - pub allow_http: Option, - } }