From 9073845754bf61b066ff1a1a6889f49bcbc41625 Mon Sep 17 00:00:00 2001 From: Kould <2435992353@qq.com> Date: Thu, 21 Mar 2024 02:37:31 +0800 Subject: [PATCH] fix: In multi-level nested subquery, the field relationships of different tables are wrong --- src/binder/create_table.rs | 7 ++++- src/binder/expr.rs | 17 ++++++++++-- src/binder/mod.rs | 53 +++++++++++++++++++++++++++++++++----- src/db.rs | 7 ++++- src/errors.rs | 8 ++++-- src/optimizer/core/memo.rs | 7 ++++- 6 files changed, 85 insertions(+), 14 deletions(-) diff --git a/src/binder/create_table.rs b/src/binder/create_table.rs index a06ec12f..9e263494 100644 --- a/src/binder/create_table.rs +++ b/src/binder/create_table.rs @@ -142,6 +142,7 @@ mod tests { use crate::storage::kip::KipStorage; use crate::storage::Storage; use crate::types::LogicalType; + use std::sync::atomic::AtomicUsize; use tempfile::TempDir; #[tokio::test] @@ -152,7 +153,11 @@ mod tests { let functions = Default::default(); let sql = "create table t1 (id int primary key, name varchar(10) null)"; - let mut binder = Binder::new(BinderContext::new(&transaction, &functions)); + let mut binder = Binder::new(BinderContext::new( + &transaction, + &functions, + Arc::new(AtomicUsize::new(0)), + )); let stmt = crate::parser::parse_sql(sql).unwrap(); let plan1 = binder.bind(&stmt[0]).unwrap(); diff --git a/src/binder/expr.rs b/src/binder/expr.rs index a286f7e4..845c112c 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -10,7 +10,7 @@ use sqlparser::ast::{ use std::slice; use std::sync::Arc; -use super::{lower_ident, Binder, QueryBindStep, SubQueryType}; +use super::{lower_ident, Binder, BinderContext, QueryBindStep, SubQueryType}; use crate::expression::function::{FunctionSummary, ScalarFunction}; use crate::expression::{AliasType, ScalarExpression}; use crate::planner::LogicalPlan; @@ -226,7 +226,19 @@ impl<'a, T: Transaction> Binder<'a, T> { &mut self, subquery: &Query, ) -> Result<(LogicalPlan, Arc), DatabaseError> { - let mut sub_query = self.bind_query(subquery)?; + let BinderContext { + transaction, + functions, + temp_table_id, + .. + } = &self.context; + let mut binder = Binder::new(BinderContext::new( + *transaction, + functions, + temp_table_id.clone(), + )); + + let mut sub_query = binder.bind_query(subquery)?; let sub_query_schema = sub_query.output_schema(); if sub_query_schema.len() != 1 { @@ -236,6 +248,7 @@ impl<'a, T: Transaction> Binder<'a, T> { )); } let column = sub_query_schema[0].clone(); + self.context.merge_context(binder.context)?; Ok((sub_query, column)) } diff --git a/src/binder/mod.rs b/src/binder/mod.rs index 226bce7e..73d8e80f 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -18,6 +18,7 @@ mod update; use sqlparser::ast::{Ident, ObjectName, ObjectType, SetExpr, Statement}; use std::collections::{BTreeMap, HashMap}; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use crate::catalog::{TableCatalog, TableName}; @@ -55,7 +56,7 @@ pub enum SubQueryType { #[derive(Clone)] pub struct BinderContext<'a, T: Transaction> { - functions: &'a Functions, + pub(crate) functions: &'a Functions, pub(crate) transaction: &'a T, // Tips: When there are multiple tables and Wildcard, use BTreeMap to ensure that the order of the output tables is certain. pub(crate) bind_table: BTreeMap<(TableName, Option), &'a TableCatalog>, @@ -69,12 +70,16 @@ pub struct BinderContext<'a, T: Transaction> { bind_step: QueryBindStep, sub_queries: HashMap>, - temp_table_id: usize, + temp_table_id: Arc, pub(crate) allow_default: bool, } impl<'a, T: Transaction> BinderContext<'a, T> { - pub fn new(transaction: &'a T, functions: &'a Functions) -> Self { + pub fn new( + transaction: &'a T, + functions: &'a Functions, + temp_table_id: Arc, + ) -> Self { BinderContext { functions, transaction, @@ -85,14 +90,16 @@ impl<'a, T: Transaction> BinderContext<'a, T> { agg_calls: Default::default(), bind_step: QueryBindStep::From, sub_queries: Default::default(), - temp_table_id: 0, + temp_table_id, allow_default: false, } } pub fn temp_table(&mut self) -> TableName { - self.temp_table_id += 1; - Arc::new(format!("_temp_table_{}_", self.temp_table_id)) + Arc::new(format!( + "_temp_table_{}_", + self.temp_table_id.fetch_add(1, Ordering::SeqCst) + )) } pub fn step(&mut self, bind_step: QueryBindStep) { @@ -163,6 +170,33 @@ impl<'a, T: Transaction> BinderContext<'a, T> { pub fn has_agg_call(&self, expr: &ScalarExpression) -> bool { self.group_by_exprs.contains(expr) } + + pub fn merge_context(&mut self, context: BinderContext<'a, T>) -> Result<(), DatabaseError> { + let BinderContext { + expr_aliases, + table_aliases, + bind_table, + .. + } = context; + + for (alias, expr) in expr_aliases { + if self.expr_aliases.contains_key(&alias) { + return Err(DatabaseError::DuplicateAliasExpr(alias)); + } + self.expr_aliases.insert(alias, expr); + } + for (alias, table_name) in table_aliases { + if self.table_aliases.contains_key(&alias) { + return Err(DatabaseError::DuplicateAliasTable(alias)); + } + self.table_aliases.insert(alias, table_name); + } + for (key, table) in bind_table { + self.bind_table.insert(key, table); + } + + Ok(()) + } } pub struct Binder<'a, T: Transaction> { @@ -305,6 +339,7 @@ pub mod test { use crate::storage::{Storage, Transaction}; use crate::types::LogicalType::Integer; use std::path::PathBuf; + use std::sync::atomic::AtomicUsize; use std::sync::Arc; use tempfile::TempDir; @@ -358,7 +393,11 @@ pub mod test { let storage = build_test_catalog(temp_dir.path()).await?; let transaction = storage.transaction().await?; let functions = Default::default(); - let mut binder = Binder::new(BinderContext::new(&transaction, &functions)); + let mut binder = Binder::new(BinderContext::new( + &transaction, + &functions, + Arc::new(AtomicUsize::new(0)), + )); let stmt = crate::parser::parse_sql(sql)?; Ok(binder.bind(&stmt[0])?) diff --git a/src/db.rs b/src/db.rs index 6f5e7c9f..0a415832 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,5 +1,6 @@ use ahash::HashMap; use std::path::PathBuf; +use std::sync::atomic::AtomicUsize; use std::sync::Arc; use crate::binder::{Binder, BinderContext}; @@ -98,7 +99,11 @@ impl Database { if stmts.is_empty() { return Err(DatabaseError::EmptyStatement); } - let mut binder = Binder::new(BinderContext::new(transaction, functions)); + let mut binder = Binder::new(BinderContext::new( + transaction, + functions, + Arc::new(AtomicUsize::new(0)), + )); /// Build a logical plan. /// /// SELECT a,b FROM t1 ORDER BY a LIMIT 1; diff --git a/src/errors.rs b/src/errors.rs index 7a5a4e0e..4c81f06b 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -29,12 +29,16 @@ pub enum DatabaseError { #[source] csv::Error, ), - #[error("duplicate primary key")] - DuplicatePrimaryKey, + #[error("alias expr: {0} already exists")] + DuplicateAliasExpr(String), + #[error("alias table: {0} already exists")] + DuplicateAliasTable(String), #[error("column: {0} already exists")] DuplicateColumn(String), #[error("index: {0} already exists")] DuplicateIndex(String), + #[error("duplicate primary key")] + DuplicatePrimaryKey, #[error("the column has been declared unique and the value already exists")] DuplicateUniqueValue, #[error("empty plan")] diff --git a/src/optimizer/core/memo.rs b/src/optimizer/core/memo.rs index 951e918c..be20c64f 100644 --- a/src/optimizer/core/memo.rs +++ b/src/optimizer/core/memo.rs @@ -98,6 +98,7 @@ mod tests { use crate::types::LogicalType; use petgraph::stable_graph::NodeIndex; use std::ops::Bound; + use std::sync::atomic::AtomicUsize; use std::sync::Arc; use tempfile::TempDir; @@ -121,7 +122,11 @@ mod tests { let transaction = database.storage.transaction().await?; let functions = Default::default(); - let mut binder = Binder::new(BinderContext::new(&transaction, &functions)); + let mut binder = Binder::new(BinderContext::new( + &transaction, + &functions, + Arc::new(AtomicUsize::new(0)), + )); let stmt = crate::parser::parse_sql( // FIXME: Only by bracketing (c1 > 40 or c1 = 2) can the filter be pushed down below the join "select c1, c3 from t1 inner join t2 on c1 = c3 where (c1 > 40 or c1 = 2) and c3 > 22",