Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: In multi-level nested subquery, the field relationships of different tables are wrong #173

Merged
merged 6 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/binder/create_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ mod tests {
use crate::storage::kip::KipStorage;
use crate::storage::Storage;
use crate::types::LogicalType;
use std::sync::atomic::AtomicUsize;
use tempfile::TempDir;

#[tokio::test]
Expand All @@ -152,7 +153,10 @@ mod tests {
let functions = Default::default();

let sql = "create table t1 (id int primary key, name varchar(10) null)";
let mut binder = Binder::new(BinderContext::new(&transaction, &functions));
let mut binder = Binder::new(
BinderContext::new(&transaction, &functions, Arc::new(AtomicUsize::new(0))),
None,
);
let stmt = crate::parser::parse_sql(sql).unwrap();
let plan1 = binder.bind(&stmt[0]).unwrap();

Expand Down
35 changes: 26 additions & 9 deletions src/binder/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use sqlparser::ast::{
use std::slice;
use std::sync::Arc;

use super::{lower_ident, Binder, QueryBindStep, SubQueryType};
use super::{lower_ident, Binder, BinderContext, QueryBindStep, SubQueryType};
use crate::expression::function::{FunctionSummary, ScalarFunction};
use crate::expression::{AliasType, ScalarExpression};
use crate::planner::LogicalPlan;
Expand Down Expand Up @@ -226,7 +226,17 @@ impl<'a, T: Transaction> Binder<'a, T> {
&mut self,
subquery: &Query,
) -> Result<(LogicalPlan, Arc<ColumnCatalog>), DatabaseError> {
let mut sub_query = self.bind_query(subquery)?;
let BinderContext {
transaction,
functions,
temp_table_id,
..
} = &self.context;
let mut binder = Binder::new(
BinderContext::new(*transaction, functions, temp_table_id.clone()),
Some(self),
);
let mut sub_query = binder.bind_query(subquery)?;
let sub_query_schema = sub_query.output_schema();

if sub_query_schema.len() != 1 {
Expand Down Expand Up @@ -294,15 +304,22 @@ impl<'a, T: Transaction> Binder<'a, T> {
.ok_or_else(|| DatabaseError::NotFound("column", column_name))?;
Ok(ScalarExpression::ColumnRef(column_catalog.clone()))
} else {
let op = |got_column: &mut Option<&'a ColumnRef>, context: &BinderContext<'a, T>| {
for table_catalog in context.bind_table.values() {
if got_column.is_some() {
break;
}
if let Some(column_catalog) = table_catalog.get_column_by_name(&column_name) {
*got_column = Some(column_catalog);
}
}
};
// handle col syntax
let mut got_column = None;
for table_catalog in self.context.bind_table.values().rev() {
if let Some(column_catalog) = table_catalog.get_column_by_name(&column_name) {
got_column = Some(column_catalog);
}
if got_column.is_some() {
break;
}

op(&mut got_column, &self.context);
if let Some(parent) = self.parent {
op(&mut got_column, &parent.context);
}
let column_catalog =
got_column.ok_or_else(|| DatabaseError::NotFound("column", column_name))?;
Expand Down
30 changes: 21 additions & 9 deletions src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ mod update;

use sqlparser::ast::{Ident, ObjectName, ObjectType, SetExpr, Statement};
use std::collections::{BTreeMap, HashMap};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

use crate::catalog::{TableCatalog, TableName};
Expand Down Expand Up @@ -55,7 +56,7 @@ pub enum SubQueryType {

#[derive(Clone)]
pub struct BinderContext<'a, T: Transaction> {
functions: &'a Functions,
pub(crate) functions: &'a Functions,
pub(crate) transaction: &'a T,
// Tips: When there are multiple tables and Wildcard, use BTreeMap to ensure that the order of the output tables is certain.
pub(crate) bind_table: BTreeMap<(TableName, Option<JoinType>), &'a TableCatalog>,
Expand All @@ -69,12 +70,16 @@ pub struct BinderContext<'a, T: Transaction> {
bind_step: QueryBindStep,
sub_queries: HashMap<QueryBindStep, Vec<SubQueryType>>,

temp_table_id: usize,
temp_table_id: Arc<AtomicUsize>,
pub(crate) allow_default: bool,
}

impl<'a, T: Transaction> BinderContext<'a, T> {
pub fn new(transaction: &'a T, functions: &'a Functions) -> Self {
pub fn new(
transaction: &'a T,
functions: &'a Functions,
temp_table_id: Arc<AtomicUsize>,
) -> Self {
BinderContext {
functions,
transaction,
Expand All @@ -85,14 +90,16 @@ impl<'a, T: Transaction> BinderContext<'a, T> {
agg_calls: Default::default(),
bind_step: QueryBindStep::From,
sub_queries: Default::default(),
temp_table_id: 0,
temp_table_id,
allow_default: false,
}
}

pub fn temp_table(&mut self) -> TableName {
self.temp_table_id += 1;
Arc::new(format!("_temp_table_{}_", self.temp_table_id))
Arc::new(format!(
"_temp_table_{}_",
self.temp_table_id.fetch_add(1, Ordering::SeqCst)
))
}

pub fn step(&mut self, bind_step: QueryBindStep) {
Expand Down Expand Up @@ -167,11 +174,12 @@ impl<'a, T: Transaction> BinderContext<'a, T> {

pub struct Binder<'a, T: Transaction> {
context: BinderContext<'a, T>,
pub(crate) parent: Option<&'a Binder<'a, T>>,
}

impl<'a, T: Transaction> Binder<'a, T> {
pub fn new(context: BinderContext<'a, T>) -> Self {
Binder { context }
pub fn new(context: BinderContext<'a, T>, parent: Option<&'a Binder<'a, T>>) -> Self {
Binder { context, parent }
}

pub fn bind(&mut self, stmt: &Statement) -> Result<LogicalPlan, DatabaseError> {
Expand Down Expand Up @@ -305,6 +313,7 @@ pub mod test {
use crate::storage::{Storage, Transaction};
use crate::types::LogicalType::Integer;
use std::path::PathBuf;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
use tempfile::TempDir;

Expand Down Expand Up @@ -358,7 +367,10 @@ pub mod test {
let storage = build_test_catalog(temp_dir.path()).await?;
let transaction = storage.transaction().await?;
let functions = Default::default();
let mut binder = Binder::new(BinderContext::new(&transaction, &functions));
let mut binder = Binder::new(
BinderContext::new(&transaction, &functions, Arc::new(AtomicUsize::new(0))),
None,
);
let stmt = crate::parser::parse_sql(sql)?;

Ok(binder.bind(&stmt[0])?)
Expand Down
6 changes: 5 additions & 1 deletion src/db.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use ahash::HashMap;
use std::path::PathBuf;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;

use crate::binder::{Binder, BinderContext};
Expand Down Expand Up @@ -98,7 +99,10 @@ impl<S: Storage> Database<S> {
if stmts.is_empty() {
return Err(DatabaseError::EmptyStatement);
}
let mut binder = Binder::new(BinderContext::new(transaction, functions));
let mut binder = Binder::new(
BinderContext::new(transaction, functions, Arc::new(AtomicUsize::new(0))),
None,
);
/// Build a logical plan.
///
/// SELECT a,b FROM t1 ORDER BY a LIMIT 1;
Expand Down
4 changes: 2 additions & 2 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ pub enum DatabaseError {
#[source]
csv::Error,
),
#[error("duplicate primary key")]
DuplicatePrimaryKey,
#[error("column: {0} already exists")]
DuplicateColumn(String),
#[error("index: {0} already exists")]
DuplicateIndex(String),
#[error("duplicate primary key")]
DuplicatePrimaryKey,
#[error("the column has been declared unique and the value already exists")]
DuplicateUniqueValue,
#[error("empty plan")]
Expand Down
6 changes: 5 additions & 1 deletion src/optimizer/core/memo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ mod tests {
use crate::types::LogicalType;
use petgraph::stable_graph::NodeIndex;
use std::ops::Bound;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
use tempfile::TempDir;

Expand All @@ -121,7 +122,10 @@ mod tests {

let transaction = database.storage.transaction().await?;
let functions = Default::default();
let mut binder = Binder::new(BinderContext::new(&transaction, &functions));
let mut binder = Binder::new(
BinderContext::new(&transaction, &functions, Arc::new(AtomicUsize::new(0))),
None,
);
let stmt = crate::parser::parse_sql(
// FIXME: Only by bracketing (c1 > 40 or c1 = 2) can the filter be pushed down below the join
"select c1, c3 from t1 inner join t2 on c1 = c3 where (c1 > 40 or c1 = 2) and c3 > 22",
Expand Down
Loading