Skip to content

Commit

Permalink
Minor: Change from &mut SessionContext to &SessionContext in subs…
Browse files Browse the repository at this point in the history
…trait (apache#7965)

* Lower &mut SessionContext in substrait

* rm mut ctx in tests
  • Loading branch information
my-vegetable-has-exploded authored Oct 29, 2023
1 parent 9ee055a commit d24228a
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 22 deletions.
4 changes: 2 additions & 2 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ fn split_eq_and_noneq_join_predicate_with_nulls_equality(

/// Convert Substrait Plan to DataFusion DataFrame
pub async fn from_substrait_plan(
ctx: &mut SessionContext,
ctx: &SessionContext,
plan: &Plan,
) -> Result<LogicalPlan> {
// Register function extension
Expand Down Expand Up @@ -219,7 +219,7 @@ pub async fn from_substrait_plan(
/// Convert Substrait Rel to DataFusion DataFrame
#[async_recursion]
pub async fn from_substrait_rel(
ctx: &mut SessionContext,
ctx: &SessionContext,
rel: &Rel,
extensions: &HashMap<u32, &String>,
) -> Result<LogicalPlan> {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/substrait/src/physical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use substrait::proto::{
/// Convert Substrait Rel to DataFusion ExecutionPlan
#[async_recursion]
pub async fn from_substrait_rel(
_ctx: &mut SessionContext,
_ctx: &SessionContext,
rel: &Rel,
_extensions: &HashMap<u32, &String>,
) -> Result<Arc<dyn ExecutionPlan>> {
Expand Down
30 changes: 15 additions & 15 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ async fn new_test_grammar() -> Result<()> {

#[tokio::test]
async fn extension_logical_plan() -> Result<()> {
let mut ctx = create_context().await?;
let ctx = create_context().await?;
let validation_bytes = "MockUserDefinedLogicalPlan".as_bytes().to_vec();
let ext_plan = LogicalPlan::Extension(Extension {
node: Arc::new(MockUserDefinedLogicalPlan {
Expand All @@ -617,7 +617,7 @@ async fn extension_logical_plan() -> Result<()> {
});

let proto = to_substrait_plan(&ext_plan, &ctx)?;
let plan2 = from_substrait_plan(&mut ctx, &proto).await?;
let plan2 = from_substrait_plan(&ctx, &proto).await?;

let plan1str = format!("{ext_plan:?}");
let plan2str = format!("{plan2:?}");
Expand Down Expand Up @@ -712,23 +712,23 @@ async fn verify_post_join_filter_value(proto: Box<Plan>) -> Result<()> {
}

async fn assert_expected_plan(sql: &str, expected_plan_str: &str) -> Result<()> {
let mut ctx = create_context().await?;
let ctx = create_context().await?;
let df = ctx.sql(sql).await?;
let plan = df.into_optimized_plan()?;
let proto = to_substrait_plan(&plan, &ctx)?;
let plan2 = from_substrait_plan(&mut ctx, &proto).await?;
let plan2 = from_substrait_plan(&ctx, &proto).await?;
let plan2 = ctx.state().optimize(&plan2)?;
let plan2str = format!("{plan2:?}");
assert_eq!(expected_plan_str, &plan2str);
Ok(())
}

async fn roundtrip_fill_na(sql: &str) -> Result<()> {
let mut ctx = create_context().await?;
let ctx = create_context().await?;
let df = ctx.sql(sql).await?;
let plan1 = df.into_optimized_plan()?;
let proto = to_substrait_plan(&plan1, &ctx)?;
let plan2 = from_substrait_plan(&mut ctx, &proto).await?;
let plan2 = from_substrait_plan(&ctx, &proto).await?;
let plan2 = ctx.state().optimize(&plan2)?;

// Format plan string and replace all None's with 0
Expand All @@ -743,15 +743,15 @@ async fn test_alias(sql_with_alias: &str, sql_no_alias: &str) -> Result<()> {
// Since we ignore the SubqueryAlias in the producer, the result should be
// the same as producing a Substrait plan from the same query without aliases
// sql_with_alias -> substrait -> logical plan = sql_no_alias -> substrait -> logical plan
let mut ctx = create_context().await?;
let ctx = create_context().await?;

let df_a = ctx.sql(sql_with_alias).await?;
let proto_a = to_substrait_plan(&df_a.into_optimized_plan()?, &ctx)?;
let plan_with_alias = from_substrait_plan(&mut ctx, &proto_a).await?;
let plan_with_alias = from_substrait_plan(&ctx, &proto_a).await?;

let df = ctx.sql(sql_no_alias).await?;
let proto = to_substrait_plan(&df.into_optimized_plan()?, &ctx)?;
let plan = from_substrait_plan(&mut ctx, &proto).await?;
let plan = from_substrait_plan(&ctx, &proto).await?;

println!("{plan_with_alias:#?}");
println!("{plan:#?}");
Expand All @@ -763,11 +763,11 @@ async fn test_alias(sql_with_alias: &str, sql_no_alias: &str) -> Result<()> {
}

async fn roundtrip(sql: &str) -> Result<()> {
let mut ctx = create_context().await?;
let ctx = create_context().await?;
let df = ctx.sql(sql).await?;
let plan = df.into_optimized_plan()?;
let proto = to_substrait_plan(&plan, &ctx)?;
let plan2 = from_substrait_plan(&mut ctx, &proto).await?;
let plan2 = from_substrait_plan(&ctx, &proto).await?;
let plan2 = ctx.state().optimize(&plan2)?;

println!("{plan:#?}");
Expand All @@ -780,11 +780,11 @@ async fn roundtrip(sql: &str) -> Result<()> {
}

async fn roundtrip_verify_post_join_filter(sql: &str) -> Result<()> {
let mut ctx = create_context().await?;
let ctx = create_context().await?;
let df = ctx.sql(sql).await?;
let plan = df.into_optimized_plan()?;
let proto = to_substrait_plan(&plan, &ctx)?;
let plan2 = from_substrait_plan(&mut ctx, &proto).await?;
let plan2 = from_substrait_plan(&ctx, &proto).await?;
let plan2 = ctx.state().optimize(&plan2)?;

println!("{plan:#?}");
Expand All @@ -799,11 +799,11 @@ async fn roundtrip_verify_post_join_filter(sql: &str) -> Result<()> {
}

async fn roundtrip_all_types(sql: &str) -> Result<()> {
let mut ctx = create_all_type_context().await?;
let ctx = create_all_type_context().await?;
let df = ctx.sql(sql).await?;
let plan = df.into_optimized_plan()?;
let proto = to_substrait_plan(&plan, &ctx)?;
let plan2 = from_substrait_plan(&mut ctx, &proto).await?;
let plan2 = from_substrait_plan(&ctx, &proto).await?;
let plan2 = ctx.state().optimize(&plan2)?;

println!("{plan:#?}");
Expand Down
4 changes: 2 additions & 2 deletions datafusion/substrait/tests/cases/roundtrip_physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ async fn parquet_exec() -> Result<()> {
let substrait_rel =
producer::to_substrait_rel(parquet_exec.as_ref(), &mut extension_info)?;

let mut ctx = SessionContext::new();
let ctx = SessionContext::new();

let parquet_exec_roundtrip =
consumer::from_substrait_rel(&mut ctx, substrait_rel.as_ref(), &HashMap::new())
consumer::from_substrait_rel(&ctx, substrait_rel.as_ref(), &HashMap::new())
.await?;

let expected = format!("{}", displayable(parquet_exec.as_ref()).indent(true));
Expand Down
4 changes: 2 additions & 2 deletions datafusion/substrait/tests/cases/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ mod tests {

#[tokio::test]
async fn serialize_simple_select() -> Result<()> {
let mut ctx = create_context().await?;
let ctx = create_context().await?;
let path = "tests/simple_select.bin";
let sql = "SELECT a, b FROM data";
// Test reference
Expand All @@ -42,7 +42,7 @@ mod tests {
// Read substrait plan from file
let proto = serializer::deserialize(path).await?;
// Check plan equality
let plan = from_substrait_plan(&mut ctx, &proto).await?;
let plan = from_substrait_plan(&ctx, &proto).await?;
let plan_str_ref = format!("{plan_ref:?}");
let plan_str = format!("{plan:?}");
assert_eq!(plan_str_ref, plan_str);
Expand Down

0 comments on commit d24228a

Please sign in to comment.