Skip to content

Commit

Permalink
fix: In multi-level nested subquery, the field relationships of diffe…
Browse files Browse the repository at this point in the history
…rent tables are wrong (#173)

* fix: In multi-level nested subquery, the field relationships of different tables are wrong

* fix: Correlated subqueries

* style: code simplification

* style: remove unused error

* test: test case for issue_169

* code fmt
  • Loading branch information
KKould authored Mar 21, 2024
1 parent 917804e commit e0de637
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 28 deletions.
6 changes: 5 additions & 1 deletion src/binder/create_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ mod tests {
use crate::storage::kip::KipStorage;
use crate::storage::Storage;
use crate::types::LogicalType;
use std::sync::atomic::AtomicUsize;
use tempfile::TempDir;

#[tokio::test]
Expand All @@ -152,7 +153,10 @@ mod tests {
let functions = Default::default();

let sql = "create table t1 (id int primary key, name varchar(10) null)";
let mut binder = Binder::new(BinderContext::new(&transaction, &functions));
let mut binder = Binder::new(
BinderContext::new(&transaction, &functions, Arc::new(AtomicUsize::new(0))),
None,
);
let stmt = crate::parser::parse_sql(sql).unwrap();
let plan1 = binder.bind(&stmt[0]).unwrap();

Expand Down
35 changes: 26 additions & 9 deletions src/binder/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use sqlparser::ast::{
use std::slice;
use std::sync::Arc;

use super::{lower_ident, Binder, QueryBindStep, SubQueryType};
use super::{lower_ident, Binder, BinderContext, QueryBindStep, SubQueryType};
use crate::expression::function::{FunctionSummary, ScalarFunction};
use crate::expression::{AliasType, ScalarExpression};
use crate::planner::LogicalPlan;
Expand Down Expand Up @@ -226,7 +226,17 @@ impl<'a, T: Transaction> Binder<'a, T> {
&mut self,
subquery: &Query,
) -> Result<(LogicalPlan, Arc<ColumnCatalog>), DatabaseError> {
let mut sub_query = self.bind_query(subquery)?;
let BinderContext {
transaction,
functions,
temp_table_id,
..
} = &self.context;
let mut binder = Binder::new(
BinderContext::new(*transaction, functions, temp_table_id.clone()),
Some(self),
);
let mut sub_query = binder.bind_query(subquery)?;
let sub_query_schema = sub_query.output_schema();

if sub_query_schema.len() != 1 {
Expand Down Expand Up @@ -294,15 +304,22 @@ impl<'a, T: Transaction> Binder<'a, T> {
.ok_or_else(|| DatabaseError::NotFound("column", column_name))?;
Ok(ScalarExpression::ColumnRef(column_catalog.clone()))
} else {
let op = |got_column: &mut Option<&'a ColumnRef>, context: &BinderContext<'a, T>| {
for table_catalog in context.bind_table.values() {
if got_column.is_some() {
break;
}
if let Some(column_catalog) = table_catalog.get_column_by_name(&column_name) {
*got_column = Some(column_catalog);
}
}
};
// handle col syntax
let mut got_column = None;
for table_catalog in self.context.bind_table.values().rev() {
if let Some(column_catalog) = table_catalog.get_column_by_name(&column_name) {
got_column = Some(column_catalog);
}
if got_column.is_some() {
break;
}

op(&mut got_column, &self.context);
if let Some(parent) = self.parent {
op(&mut got_column, &parent.context);
}
let column_catalog =
got_column.ok_or_else(|| DatabaseError::NotFound("column", column_name))?;
Expand Down
30 changes: 21 additions & 9 deletions src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ mod update;

use sqlparser::ast::{Ident, ObjectName, ObjectType, SetExpr, Statement};
use std::collections::{BTreeMap, HashMap};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

use crate::catalog::{TableCatalog, TableName};
Expand Down Expand Up @@ -55,7 +56,7 @@ pub enum SubQueryType {

#[derive(Clone)]
pub struct BinderContext<'a, T: Transaction> {
functions: &'a Functions,
pub(crate) functions: &'a Functions,
pub(crate) transaction: &'a T,
// Tips: When there are multiple tables and Wildcard, use BTreeMap to ensure that the order of the output tables is certain.
pub(crate) bind_table: BTreeMap<(TableName, Option<JoinType>), &'a TableCatalog>,
Expand All @@ -69,12 +70,16 @@ pub struct BinderContext<'a, T: Transaction> {
bind_step: QueryBindStep,
sub_queries: HashMap<QueryBindStep, Vec<SubQueryType>>,

temp_table_id: usize,
temp_table_id: Arc<AtomicUsize>,
pub(crate) allow_default: bool,
}

impl<'a, T: Transaction> BinderContext<'a, T> {
pub fn new(transaction: &'a T, functions: &'a Functions) -> Self {
pub fn new(
transaction: &'a T,
functions: &'a Functions,
temp_table_id: Arc<AtomicUsize>,
) -> Self {
BinderContext {
functions,
transaction,
Expand All @@ -85,14 +90,16 @@ impl<'a, T: Transaction> BinderContext<'a, T> {
agg_calls: Default::default(),
bind_step: QueryBindStep::From,
sub_queries: Default::default(),
temp_table_id: 0,
temp_table_id,
allow_default: false,
}
}

pub fn temp_table(&mut self) -> TableName {
self.temp_table_id += 1;
Arc::new(format!("_temp_table_{}_", self.temp_table_id))
Arc::new(format!(
"_temp_table_{}_",
self.temp_table_id.fetch_add(1, Ordering::SeqCst)
))
}

pub fn step(&mut self, bind_step: QueryBindStep) {
Expand Down Expand Up @@ -167,11 +174,12 @@ impl<'a, T: Transaction> BinderContext<'a, T> {

pub struct Binder<'a, T: Transaction> {
context: BinderContext<'a, T>,
pub(crate) parent: Option<&'a Binder<'a, T>>,
}

impl<'a, T: Transaction> Binder<'a, T> {
pub fn new(context: BinderContext<'a, T>) -> Self {
Binder { context }
pub fn new(context: BinderContext<'a, T>, parent: Option<&'a Binder<'a, T>>) -> Self {
Binder { context, parent }
}

pub fn bind(&mut self, stmt: &Statement) -> Result<LogicalPlan, DatabaseError> {
Expand Down Expand Up @@ -305,6 +313,7 @@ pub mod test {
use crate::storage::{Storage, Transaction};
use crate::types::LogicalType::Integer;
use std::path::PathBuf;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
use tempfile::TempDir;

Expand Down Expand Up @@ -358,7 +367,10 @@ pub mod test {
let storage = build_test_catalog(temp_dir.path()).await?;
let transaction = storage.transaction().await?;
let functions = Default::default();
let mut binder = Binder::new(BinderContext::new(&transaction, &functions));
let mut binder = Binder::new(
BinderContext::new(&transaction, &functions, Arc::new(AtomicUsize::new(0))),
None,
);
let stmt = crate::parser::parse_sql(sql)?;

Ok(binder.bind(&stmt[0])?)
Expand Down
6 changes: 5 additions & 1 deletion src/db.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use ahash::HashMap;
use std::path::PathBuf;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;

use crate::binder::{Binder, BinderContext};
Expand Down Expand Up @@ -98,7 +99,10 @@ impl<S: Storage> Database<S> {
if stmts.is_empty() {
return Err(DatabaseError::EmptyStatement);
}
let mut binder = Binder::new(BinderContext::new(transaction, functions));
let mut binder = Binder::new(
BinderContext::new(transaction, functions, Arc::new(AtomicUsize::new(0))),
None,
);
/// Build a logical plan.
///
/// SELECT a,b FROM t1 ORDER BY a LIMIT 1;
Expand Down
4 changes: 2 additions & 2 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ pub enum DatabaseError {
#[source]
csv::Error,
),
#[error("duplicate primary key")]
DuplicatePrimaryKey,
#[error("column: {0} already exists")]
DuplicateColumn(String),
#[error("index: {0} already exists")]
DuplicateIndex(String),
#[error("duplicate primary key")]
DuplicatePrimaryKey,
#[error("the column has been declared unique and the value already exists")]
DuplicateUniqueValue,
#[error("empty plan")]
Expand Down
6 changes: 5 additions & 1 deletion src/optimizer/core/memo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ mod tests {
use crate::types::LogicalType;
use petgraph::stable_graph::NodeIndex;
use std::ops::Bound;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
use tempfile::TempDir;

Expand All @@ -121,7 +122,10 @@ mod tests {

let transaction = database.storage.transaction().await?;
let functions = Default::default();
let mut binder = Binder::new(BinderContext::new(&transaction, &functions));
let mut binder = Binder::new(
BinderContext::new(&transaction, &functions, Arc::new(AtomicUsize::new(0))),
None,
);
let stmt = crate::parser::parse_sql(
// FIXME: Only by bracketing (c1 > 40 or c1 = 2) can the filter be pushed down below the join
"select c1, c3 from t1 inner join t2 on c1 = c3 where (c1 > 40 or c1 = 2) and c3 > 22",
Expand Down
35 changes: 30 additions & 5 deletions tests/slt/subquery.slt
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ select * from t1 where a <= (select 4) and a > (select 1)
----
1 3 4

# query III
# select * from t1 where a <= (select 4) and (-a + 1) < (select 1) - 1
# ----
# 1 3 4
query III
select * from t1 where a <= (select 4) and (-a + 1) < (select 1) - 1
----
1 3 4

statement ok
insert into t1 values (2, 3, 3), (3, 1, 4);
Expand Down Expand Up @@ -70,4 +70,29 @@ select * from t1 where a not in (select 1) and b = 3
2 3 3

statement ok
drop table t1;
drop table t1;

# https://github.com/KipData/FnckSQL/issues/169
statement ok
create table t2(id int primary key, a int not null, b int not null);

statement ok
create table t3(id int primary key, a int not null, c int not null);

statement ok
insert into t2 values (0, 1, 2), (3, 4, 5);

statement ok
insert into t3 values (0, 2, 2), (3, 8, 5);

query III
SELECT id, a, b FROM t2 WHERE a*2 in (SELECT a FROM t3 where a/2 in (select a from t2));
----
0 1 2
3 4 5

statement ok
drop table t2;

statement ok
drop table t3;

0 comments on commit e0de637

Please sign in to comment.