From e0de6376c4a51d1eedb4d79cc256caa396d85242 Mon Sep 17 00:00:00 2001 From: Kould <2435992353@qq.com> Date: Thu, 21 Mar 2024 14:50:02 +0800 Subject: [PATCH] fix: In multi-level nested subquery, the field relationships of different tables are wrong (#173) * fix: In multi-level nested subquery, the field relationships of different tables are wrong * fix: Correlated subqueries * style: code simplification * style: remove unused error * test: test case for issue_169 * code fmt --- src/binder/create_table.rs | 6 +++++- src/binder/expr.rs | 35 ++++++++++++++++++++++++++--------- src/binder/mod.rs | 30 +++++++++++++++++++++--------- src/db.rs | 6 +++++- src/errors.rs | 4 ++-- src/optimizer/core/memo.rs | 6 +++++- tests/slt/subquery.slt | 35 ++++++++++++++++++++++++++++++----- 7 files changed, 94 insertions(+), 28 deletions(-) diff --git a/src/binder/create_table.rs b/src/binder/create_table.rs index a06ec12f..5fabf9ec 100644 --- a/src/binder/create_table.rs +++ b/src/binder/create_table.rs @@ -142,6 +142,7 @@ mod tests { use crate::storage::kip::KipStorage; use crate::storage::Storage; use crate::types::LogicalType; + use std::sync::atomic::AtomicUsize; use tempfile::TempDir; #[tokio::test] @@ -152,7 +153,10 @@ mod tests { let functions = Default::default(); let sql = "create table t1 (id int primary key, name varchar(10) null)"; - let mut binder = Binder::new(BinderContext::new(&transaction, &functions)); + let mut binder = Binder::new( + BinderContext::new(&transaction, &functions, Arc::new(AtomicUsize::new(0))), + None, + ); let stmt = crate::parser::parse_sql(sql).unwrap(); let plan1 = binder.bind(&stmt[0]).unwrap(); diff --git a/src/binder/expr.rs b/src/binder/expr.rs index a286f7e4..15c48c54 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -10,7 +10,7 @@ use sqlparser::ast::{ use std::slice; use std::sync::Arc; -use super::{lower_ident, Binder, QueryBindStep, SubQueryType}; +use super::{lower_ident, Binder, BinderContext, QueryBindStep, SubQueryType}; use crate::expression::function::{FunctionSummary, ScalarFunction}; use crate::expression::{AliasType, ScalarExpression}; use crate::planner::LogicalPlan; @@ -226,7 +226,17 @@ impl<'a, T: Transaction> Binder<'a, T> { &mut self, subquery: &Query, ) -> Result<(LogicalPlan, Arc), DatabaseError> { - let mut sub_query = self.bind_query(subquery)?; + let BinderContext { + transaction, + functions, + temp_table_id, + .. + } = &self.context; + let mut binder = Binder::new( + BinderContext::new(*transaction, functions, temp_table_id.clone()), + Some(self), + ); + let mut sub_query = binder.bind_query(subquery)?; let sub_query_schema = sub_query.output_schema(); if sub_query_schema.len() != 1 { @@ -294,15 +304,22 @@ impl<'a, T: Transaction> Binder<'a, T> { .ok_or_else(|| DatabaseError::NotFound("column", column_name))?; Ok(ScalarExpression::ColumnRef(column_catalog.clone())) } else { + let op = |got_column: &mut Option<&'a ColumnRef>, context: &BinderContext<'a, T>| { + for table_catalog in context.bind_table.values() { + if got_column.is_some() { + break; + } + if let Some(column_catalog) = table_catalog.get_column_by_name(&column_name) { + *got_column = Some(column_catalog); + } + } + }; // handle col syntax let mut got_column = None; - for table_catalog in self.context.bind_table.values().rev() { - if let Some(column_catalog) = table_catalog.get_column_by_name(&column_name) { - got_column = Some(column_catalog); - } - if got_column.is_some() { - break; - } + + op(&mut got_column, &self.context); + if let Some(parent) = self.parent { + op(&mut got_column, &parent.context); } let column_catalog = got_column.ok_or_else(|| DatabaseError::NotFound("column", column_name))?; diff --git a/src/binder/mod.rs b/src/binder/mod.rs index 226bce7e..94f3059c 100644 --- a/src/binder/mod.rs +++ b/src/binder/mod.rs @@ -18,6 +18,7 @@ mod update; use sqlparser::ast::{Ident, ObjectName, ObjectType, SetExpr, Statement}; use std::collections::{BTreeMap, HashMap}; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use crate::catalog::{TableCatalog, TableName}; @@ -55,7 +56,7 @@ pub enum SubQueryType { #[derive(Clone)] pub struct BinderContext<'a, T: Transaction> { - functions: &'a Functions, + pub(crate) functions: &'a Functions, pub(crate) transaction: &'a T, // Tips: When there are multiple tables and Wildcard, use BTreeMap to ensure that the order of the output tables is certain. pub(crate) bind_table: BTreeMap<(TableName, Option), &'a TableCatalog>, @@ -69,12 +70,16 @@ pub struct BinderContext<'a, T: Transaction> { bind_step: QueryBindStep, sub_queries: HashMap>, - temp_table_id: usize, + temp_table_id: Arc, pub(crate) allow_default: bool, } impl<'a, T: Transaction> BinderContext<'a, T> { - pub fn new(transaction: &'a T, functions: &'a Functions) -> Self { + pub fn new( + transaction: &'a T, + functions: &'a Functions, + temp_table_id: Arc, + ) -> Self { BinderContext { functions, transaction, @@ -85,14 +90,16 @@ impl<'a, T: Transaction> BinderContext<'a, T> { agg_calls: Default::default(), bind_step: QueryBindStep::From, sub_queries: Default::default(), - temp_table_id: 0, + temp_table_id, allow_default: false, } } pub fn temp_table(&mut self) -> TableName { - self.temp_table_id += 1; - Arc::new(format!("_temp_table_{}_", self.temp_table_id)) + Arc::new(format!( + "_temp_table_{}_", + self.temp_table_id.fetch_add(1, Ordering::SeqCst) + )) } pub fn step(&mut self, bind_step: QueryBindStep) { @@ -167,11 +174,12 @@ impl<'a, T: Transaction> BinderContext<'a, T> { pub struct Binder<'a, T: Transaction> { context: BinderContext<'a, T>, + pub(crate) parent: Option<&'a Binder<'a, T>>, } impl<'a, T: Transaction> Binder<'a, T> { - pub fn new(context: BinderContext<'a, T>) -> Self { - Binder { context } + pub fn new(context: BinderContext<'a, T>, parent: Option<&'a Binder<'a, T>>) -> Self { + Binder { context, parent } } pub fn bind(&mut self, stmt: &Statement) -> Result { @@ -305,6 +313,7 @@ pub mod test { use crate::storage::{Storage, Transaction}; use crate::types::LogicalType::Integer; use std::path::PathBuf; + use std::sync::atomic::AtomicUsize; use std::sync::Arc; use tempfile::TempDir; @@ -358,7 +367,10 @@ pub mod test { let storage = build_test_catalog(temp_dir.path()).await?; let transaction = storage.transaction().await?; let functions = Default::default(); - let mut binder = Binder::new(BinderContext::new(&transaction, &functions)); + let mut binder = Binder::new( + BinderContext::new(&transaction, &functions, Arc::new(AtomicUsize::new(0))), + None, + ); let stmt = crate::parser::parse_sql(sql)?; Ok(binder.bind(&stmt[0])?) diff --git a/src/db.rs b/src/db.rs index 6f5e7c9f..ea67f3fe 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,5 +1,6 @@ use ahash::HashMap; use std::path::PathBuf; +use std::sync::atomic::AtomicUsize; use std::sync::Arc; use crate::binder::{Binder, BinderContext}; @@ -98,7 +99,10 @@ impl Database { if stmts.is_empty() { return Err(DatabaseError::EmptyStatement); } - let mut binder = Binder::new(BinderContext::new(transaction, functions)); + let mut binder = Binder::new( + BinderContext::new(transaction, functions, Arc::new(AtomicUsize::new(0))), + None, + ); /// Build a logical plan. /// /// SELECT a,b FROM t1 ORDER BY a LIMIT 1; diff --git a/src/errors.rs b/src/errors.rs index 7a5a4e0e..51cb4964 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -29,12 +29,12 @@ pub enum DatabaseError { #[source] csv::Error, ), - #[error("duplicate primary key")] - DuplicatePrimaryKey, #[error("column: {0} already exists")] DuplicateColumn(String), #[error("index: {0} already exists")] DuplicateIndex(String), + #[error("duplicate primary key")] + DuplicatePrimaryKey, #[error("the column has been declared unique and the value already exists")] DuplicateUniqueValue, #[error("empty plan")] diff --git a/src/optimizer/core/memo.rs b/src/optimizer/core/memo.rs index 951e918c..fea4276b 100644 --- a/src/optimizer/core/memo.rs +++ b/src/optimizer/core/memo.rs @@ -98,6 +98,7 @@ mod tests { use crate::types::LogicalType; use petgraph::stable_graph::NodeIndex; use std::ops::Bound; + use std::sync::atomic::AtomicUsize; use std::sync::Arc; use tempfile::TempDir; @@ -121,7 +122,10 @@ mod tests { let transaction = database.storage.transaction().await?; let functions = Default::default(); - let mut binder = Binder::new(BinderContext::new(&transaction, &functions)); + let mut binder = Binder::new( + BinderContext::new(&transaction, &functions, Arc::new(AtomicUsize::new(0))), + None, + ); let stmt = crate::parser::parse_sql( // FIXME: Only by bracketing (c1 > 40 or c1 = 2) can the filter be pushed down below the join "select c1, c3 from t1 inner join t2 on c1 = c3 where (c1 > 40 or c1 = 2) and c3 > 22", diff --git a/tests/slt/subquery.slt b/tests/slt/subquery.slt index a90fe587..5e7ae7d6 100644 --- a/tests/slt/subquery.slt +++ b/tests/slt/subquery.slt @@ -39,10 +39,10 @@ select * from t1 where a <= (select 4) and a > (select 1) ---- 1 3 4 -# query III -# select * from t1 where a <= (select 4) and (-a + 1) < (select 1) - 1 -# ---- -# 1 3 4 +query III +select * from t1 where a <= (select 4) and (-a + 1) < (select 1) - 1 +---- +1 3 4 statement ok insert into t1 values (2, 3, 3), (3, 1, 4); @@ -70,4 +70,29 @@ select * from t1 where a not in (select 1) and b = 3 2 3 3 statement ok -drop table t1; \ No newline at end of file +drop table t1; + +# https://github.com/KipData/FnckSQL/issues/169 +statement ok +create table t2(id int primary key, a int not null, b int not null); + +statement ok +create table t3(id int primary key, a int not null, c int not null); + +statement ok +insert into t2 values (0, 1, 2), (3, 4, 5); + +statement ok +insert into t3 values (0, 2, 2), (3, 8, 5); + +query III +SELECT id, a, b FROM t2 WHERE a*2 in (SELECT a FROM t3 where a/2 in (select a from t2)); +---- +0 1 2 +3 4 5 + +statement ok +drop table t2; + +statement ok +drop table t3; \ No newline at end of file