Skip to content

Commit

Permalink
feat: support udtf
Browse files Browse the repository at this point in the history
  • Loading branch information
KKould committed Aug 27, 2024
1 parent 811d182 commit 713bc56
Show file tree
Hide file tree
Showing 41 changed files with 978 additions and 151 deletions.
8 changes: 5 additions & 3 deletions src/binder/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use sqlparser::ast::{Expr, OrderByExpr};
use std::collections::HashSet;

use crate::errors::DatabaseError;
use crate::expression::function::ScalarFunction;
use crate::expression::function::scala::ScalarFunction;
use crate::planner::LogicalPlan;
use crate::storage::Transaction;
use crate::{
Expand Down Expand Up @@ -153,7 +153,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
ScalarExpression::Constant(_) | ScalarExpression::ColumnRef { .. } => (),
ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(),
ScalarExpression::Tuple(args)
| ScalarExpression::Function(ScalarFunction { args, .. })
| ScalarExpression::ScalaFunction(ScalarFunction { args, .. })
| ScalarExpression::Coalesce { exprs: args, .. } => {
for expr in args {
self.visit_column_agg_expr(expr)?;
Expand Down Expand Up @@ -199,6 +199,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
self.visit_column_agg_expr(expr)?;
}
}
ScalarExpression::TableFunction(_) => unreachable!(),
}

Ok(())
Expand Down Expand Up @@ -389,7 +390,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
ScalarExpression::Constant(_) => Ok(()),
ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(),
ScalarExpression::Tuple(args)
| ScalarExpression::Function(ScalarFunction { args, .. })
| ScalarExpression::ScalaFunction(ScalarFunction { args, .. })
| ScalarExpression::Coalesce { exprs: args, .. } => {
for expr in args {
self.validate_having_orderby(expr)?;
Expand Down Expand Up @@ -442,6 +443,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {

Ok(())
}
ScalarExpression::TableFunction(_) => unreachable!(),
}
}
}
6 changes: 3 additions & 3 deletions src/binder/alter_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::binder::lower_case_name;
use crate::errors::DatabaseError;
use crate::planner::operator::alter_table::add_column::AddColumnOperator;
use crate::planner::operator::alter_table::drop_column::DropColumnOperator;
use crate::planner::operator::scan::ScanOperator;
use crate::planner::operator::table_scan::TableScanOperator;
use crate::planner::operator::Operator;
use crate::planner::LogicalPlan;
use crate::storage::Transaction;
Expand All @@ -29,7 +29,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
if_not_exists,
column_def,
} => {
let plan = ScanOperator::build(table_name.clone(), table);
let plan = TableScanOperator::build(table_name.clone(), table);
let column = self.bind_column(column_def)?;

if !is_valid_identifier(column.name()) {
Expand All @@ -51,7 +51,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
if_exists,
..
} => {
let plan = ScanOperator::build(table_name.clone(), table);
let plan = TableScanOperator::build(table_name.clone(), table);
let column_name = column_name.value.clone();

LogicalPlan::new(
Expand Down
4 changes: 2 additions & 2 deletions src/binder/analyze.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::binder::{lower_case_name, Binder};
use crate::errors::DatabaseError;
use crate::planner::operator::analyze::AnalyzeOperator;
use crate::planner::operator::scan::ScanOperator;
use crate::planner::operator::table_scan::TableScanOperator;
use crate::planner::operator::Operator;
use crate::planner::LogicalPlan;
use crate::storage::Transaction;
Expand All @@ -17,7 +17,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
.table_and_bind(table_name.clone(), None, None)?;
let index_metas = table_catalog.indexes.clone();

let scan_op = ScanOperator::build(table_name.clone(), table_catalog);
let scan_op = TableScanOperator::build(table_name.clone(), table_catalog);
Ok(LogicalPlan::new(
Operator::Analyze(AnalyzeOperator {
table_name,
Expand Down
4 changes: 2 additions & 2 deletions src/binder/create_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::binder::{lower_case_name, Binder};
use crate::errors::DatabaseError;
use crate::expression::ScalarExpression;
use crate::planner::operator::create_index::CreateIndexOperator;
use crate::planner::operator::scan::ScanOperator;
use crate::planner::operator::table_scan::TableScanOperator;
use crate::planner::operator::Operator;
use crate::planner::LogicalPlan;
use crate::storage::Transaction;
Expand Down Expand Up @@ -32,7 +32,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
let table = self
.context
.table_and_bind(table_name.clone(), None, None)?;
let plan = ScanOperator::build(table_name.clone(), table);
let plan = TableScanOperator::build(table_name.clone(), table);
let mut columns = Vec::with_capacity(exprs.len());

for expr in exprs {
Expand Down
6 changes: 4 additions & 2 deletions src/binder/create_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,16 @@ mod tests {
let storage = RocksStorage::new(temp_dir.path())?;
let transaction = storage.transaction()?;
let table_cache = Arc::new(ShardingLruCache::new(128, 16, RandomState::new())?);
let functions = Default::default();
let scala_functions = Default::default();
let table_functions = Default::default();

let sql = "create table t1 (id int primary key, name varchar(10) null)";
let mut binder = Binder::new(
BinderContext::new(
&table_cache,
&transaction,
&functions,
&scala_functions,
&table_functions,
Arc::new(AtomicUsize::new(0)),
),
None,
Expand Down
4 changes: 2 additions & 2 deletions src/binder/delete.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::binder::{lower_case_name, Binder};
use crate::errors::DatabaseError;
use crate::planner::operator::delete::DeleteOperator;
use crate::planner::operator::scan::ScanOperator;
use crate::planner::operator::table_scan::TableScanOperator;
use crate::planner::operator::Operator;
use crate::planner::LogicalPlan;
use crate::storage::Transaction;
Expand Down Expand Up @@ -31,7 +31,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
.find(|column| column.desc.is_primary)
.cloned()
.unwrap();
let mut plan = ScanOperator::build(table_name.clone(), table_catalog);
let mut plan = TableScanOperator::build(table_name.clone(), table_catalog);

if let Some(alias_idents) = alias_idents {
plan =
Expand Down
34 changes: 26 additions & 8 deletions src/binder/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ use std::slice;
use std::sync::Arc;

use super::{lower_ident, Binder, BinderContext, QueryBindStep, SubQueryType};
use crate::expression::function::{FunctionSummary, ScalarFunction};
use crate::expression::function::scala::ScalarFunction;
use crate::expression::function::table::TableFunction;
use crate::expression::function::FunctionSummary;
use crate::expression::{AliasType, ScalarExpression};
use crate::planner::LogicalPlan;
use crate::storage::Transaction;
Expand Down Expand Up @@ -219,9 +221,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
ty,
})
}
expr => {
todo!("{}", expr)
}
expr => Err(DatabaseError::UnsupportedStmt(expr.to_string())),
}
}

Expand Down Expand Up @@ -250,12 +250,19 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
let BinderContext {
table_cache,
transaction,
functions,
scala_functions,
table_functions,
temp_table_id,
..
} = &self.context;
let mut binder = Binder::new(
BinderContext::new(table_cache, *transaction, functions, temp_table_id.clone()),
BinderContext::new(
table_cache,
*transaction,
scala_functions,
table_functions,
temp_table_id.clone(),
),
Some(self),
);
let mut sub_query = binder.bind_query(subquery)?;
Expand Down Expand Up @@ -429,6 +436,11 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
}

fn bind_function(&mut self, func: &Function) -> Result<ScalarExpression, DatabaseError> {
if !matches!(self.context.step_now(), QueryBindStep::From) {
return Err(DatabaseError::UnsupportedStmt(
"`TableFunction` cannot bind in non-From step".to_string(),
));
}
let mut args = Vec::with_capacity(func.args.len());

for arg in func.args.iter() {
Expand Down Expand Up @@ -586,8 +598,14 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
name: function_name,
arg_types,
};
if let Some(function) = self.context.functions.get(&summary) {
return Ok(ScalarExpression::Function(ScalarFunction {
if let Some(function) = self.context.scala_functions.get(&summary) {
return Ok(ScalarExpression::ScalaFunction(ScalarFunction {
args,
inner: function.clone(),
}));
}
if let Some(function) = self.context.table_functions.get(&summary) {
return Ok(ScalarExpression::TableFunction(TableFunction {
args,
inner: function.clone(),
}));
Expand Down
17 changes: 11 additions & 6 deletions src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

use crate::catalog::{TableCatalog, TableName};
use crate::db::Functions;
use crate::db::{ScalaFunctions, TableFunctions};
use crate::errors::DatabaseError;
use crate::expression::ScalarExpression;
use crate::planner::operator::join::JoinType;
Expand Down Expand Up @@ -82,7 +82,8 @@ pub enum SubQueryType {

#[derive(Clone)]
pub struct BinderContext<'a, T: Transaction> {
pub(crate) functions: &'a Functions,
pub(crate) scala_functions: &'a ScalaFunctions,
pub(crate) table_functions: &'a TableFunctions,
pub(crate) table_cache: &'a TableCache,
pub(crate) transaction: &'a T,
// Tips: When there are multiple tables and Wildcard, use BTreeMap to ensure that the order of the output tables is certain.
Expand All @@ -108,11 +109,13 @@ impl<'a, T: Transaction> BinderContext<'a, T> {
pub fn new(
table_cache: &'a TableCache,
transaction: &'a T,
functions: &'a Functions,
scala_functions: &'a ScalaFunctions,
table_functions: &'a TableFunctions,
temp_table_id: Arc<AtomicUsize>,
) -> Self {
BinderContext {
functions,
scala_functions,
table_functions,
table_cache,
transaction,
bind_table: Default::default(),
Expand Down Expand Up @@ -445,12 +448,14 @@ pub mod test {
let table_cache = Arc::new(ShardingLruCache::new(128, 16, RandomState::new())?);
let storage = build_test_catalog(&table_cache, temp_dir.path())?;
let transaction = storage.transaction()?;
let functions = Default::default();
let scala_functions = Default::default();
let table_functions = Default::default();
let mut binder = Binder::new(
BinderContext::new(
&table_cache,
&transaction,
&functions,
&scala_functions,
&table_functions,
Arc::new(AtomicUsize::new(0)),
),
None,
Expand Down
50 changes: 45 additions & 5 deletions src/binder/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
filter::FilterOperator, join::JoinOperator as LJoinOperator, limit::LimitOperator,
project::ProjectOperator, Operator,
},
operator::{join::JoinType, scan::ScanOperator},
operator::{join::JoinType, table_scan::TableScanOperator},
},
types::value::DataValue,
};
Expand All @@ -20,6 +20,7 @@ use crate::catalog::{ColumnCatalog, ColumnSummary, TableName};
use crate::errors::DatabaseError;
use crate::execution::dql::join::joins_nullable;
use crate::expression::{AliasType, BinaryOperator};
use crate::planner::operator::function_scan::FunctionScanOperator;
use crate::planner::operator::insert::InsertOperator;
use crate::planner::operator::join::JoinCondition;
use crate::planner::operator::sort::{SortField, SortOperator};
Expand Down Expand Up @@ -129,7 +130,9 @@ impl<'a: 'b, 'b, T: Transaction> Binder<'a, 'b, T> {
plan = self.bind_sort(plan, orderby);
}

plan = self.bind_project(plan, select_list)?;
if !select_list.is_empty() {
plan = self.bind_project(plan, select_list)?;
}

if let Some(SelectInto {
name,
Expand Down Expand Up @@ -284,6 +287,36 @@ impl<'a: 'b, 'b, T: Transaction> Binder<'a, 'b, T> {
}
plan
}
TableFactor::TableFunction { expr, alias } => {
if let ScalarExpression::TableFunction(function) = self.bind_expr(expr)? {
let mut table_alias = None;
let table_name = Arc::new(function.summary().name.clone());
let table = function.table();
let mut plan = FunctionScanOperator::build(function);

if let Some(TableAlias {
name,
columns: alias_column,
}) = alias
{
table_alias = Some(Arc::new(name.value.to_lowercase()));

plan = self.bind_alias(
plan,
alias_column,
table_alias.clone().unwrap(),
table_name.clone(),
)?;
}

self.context
.bind_table
.insert((table_name, table_alias, joint_type), table);
plan
} else {
unreachable!()
}
}
_ => unimplemented!(),
};

Expand Down Expand Up @@ -356,7 +389,7 @@ impl<'a: 'b, 'b, T: Transaction> Binder<'a, 'b, T> {
let table_catalog =
self.context
.table_and_bind(table_name.clone(), table_alias.clone(), join_type)?;
let mut scan_op = ScanOperator::build(table_name.clone(), table_catalog);
let mut scan_op = TableScanOperator::build(table_name.clone(), table_catalog);

if let Some(idents) = alias_idents {
scan_op = self.bind_alias(scan_op, idents, table_alias.unwrap(), table_name.clone())?;
Expand Down Expand Up @@ -496,12 +529,19 @@ impl<'a: 'b, 'b, T: Transaction> Binder<'a, 'b, T> {
let BinderContext {
table_cache,
transaction,
functions,
scala_functions,
table_functions,
temp_table_id,
..
} = &self.context;
let mut binder = Binder::new(
BinderContext::new(table_cache, *transaction, functions, temp_table_id.clone()),
BinderContext::new(
table_cache,
*transaction,
scala_functions,
table_functions,
temp_table_id.clone(),
),
Some(self),
);
let mut right = binder.bind_single_table_ref(relation, Some(join_type))?;
Expand Down
Loading

0 comments on commit 713bc56

Please sign in to comment.