Skip to content

Commit

Permalink
Feat/udtf (#215)
Browse files Browse the repository at this point in the history
* feat: support udtf

* chore: clean macro

* docs: add UDTF
  • Loading branch information
KKould authored Aug 27, 2024
1 parent 811d182 commit 88c68b5
Show file tree
Hide file tree
Showing 51 changed files with 1,378 additions and 440 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,5 @@ pprof = { version = "0.13", features = ["flamegraph", "criterion"] }
[workspace]
members = [
"tests/sqllogictest",
"tests/macros-test"
]
27 changes: 23 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ kould23333/fncksql:latest
~~~

### Features
- ORM Mapping: `features = ["marcos"]`
- ORM Mapping: `features = ["macros"]`
```rust
#[derive(Default, Debug, PartialEq)]
struct MyStruct {
Expand All @@ -114,9 +114,9 @@ implement_from_tuple!(
)
);
```
- User-Defined Function: `features = ["marcos"]`
- User-Defined Function: `features = ["macros"]`
```rust
function!(TestFunction::test(LogicalType::Integer, LogicalType::Integer) -> LogicalType::Integer => |v1: ValueRef, v2: ValueRef| {
scala_function!(TestFunction::test(LogicalType::Integer, LogicalType::Integer) -> LogicalType::Integer => |v1: ValueRef, v2: ValueRef| {
let plus_binary_evaluator = EvaluatorFactory::binary_create(LogicalType::Integer, BinaryOperator::Plus)?;
let value = plus_binary_evaluator.binary_eval(&v1, &v2);

Expand All @@ -125,9 +125,28 @@ function!(TestFunction::test(LogicalType::Integer, LogicalType::Integer) -> Logi
});

let fnck_sql = DataBaseBuilder::path("./data")
.register_function(TestFunction::new())
.register_scala_function(TestFunction::new())
.build()?;
```
- User-Defined Table Function: `features = ["macros"]`
```rust
table_function!(MyTableFunction::test_numbers(LogicalType::Integer) -> [c1: LogicalType::Integer, c2: LogicalType::Integer] => (|v1: ValueRef| {
let num = v1.i32().unwrap();

Ok(Box::new((0..num)
.into_iter()
.map(|i| Ok(Tuple {
id: None,
values: vec![
Arc::new(DataValue::Int32(Some(i))),
Arc::new(DataValue::Int32(Some(i))),
]
}))) as Box<dyn Iterator<Item = Result<Tuple, DatabaseError>>>)
}));
let fnck_sql = DataBaseBuilder::path("./data")
.register_table_function(TestFunction::new())
.build()?;
```
- Optimizer
- RBO
- CBO based on RBO(Physical Selection)
Expand Down
2 changes: 0 additions & 2 deletions examples/hello_world.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use fnck_sql::db::DataBaseBuilder;
use fnck_sql::errors::DatabaseError;
use fnck_sql::implement_from_tuple;
use fnck_sql::types::tuple::{SchemaRef, Tuple};
use fnck_sql::types::value::DataValue;
use fnck_sql::types::LogicalType;
use itertools::Itertools;

#[derive(Default, Debug, PartialEq)]
Expand Down
8 changes: 5 additions & 3 deletions src/binder/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use sqlparser::ast::{Expr, OrderByExpr};
use std::collections::HashSet;

use crate::errors::DatabaseError;
use crate::expression::function::ScalarFunction;
use crate::expression::function::scala::ScalarFunction;
use crate::planner::LogicalPlan;
use crate::storage::Transaction;
use crate::{
Expand Down Expand Up @@ -153,7 +153,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
ScalarExpression::Constant(_) | ScalarExpression::ColumnRef { .. } => (),
ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(),
ScalarExpression::Tuple(args)
| ScalarExpression::Function(ScalarFunction { args, .. })
| ScalarExpression::ScalaFunction(ScalarFunction { args, .. })
| ScalarExpression::Coalesce { exprs: args, .. } => {
for expr in args {
self.visit_column_agg_expr(expr)?;
Expand Down Expand Up @@ -199,6 +199,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
self.visit_column_agg_expr(expr)?;
}
}
ScalarExpression::TableFunction(_) => unreachable!(),
}

Ok(())
Expand Down Expand Up @@ -389,7 +390,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
ScalarExpression::Constant(_) => Ok(()),
ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(),
ScalarExpression::Tuple(args)
| ScalarExpression::Function(ScalarFunction { args, .. })
| ScalarExpression::ScalaFunction(ScalarFunction { args, .. })
| ScalarExpression::Coalesce { exprs: args, .. } => {
for expr in args {
self.validate_having_orderby(expr)?;
Expand Down Expand Up @@ -442,6 +443,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {

Ok(())
}
ScalarExpression::TableFunction(_) => unreachable!(),
}
}
}
6 changes: 3 additions & 3 deletions src/binder/alter_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::binder::lower_case_name;
use crate::errors::DatabaseError;
use crate::planner::operator::alter_table::add_column::AddColumnOperator;
use crate::planner::operator::alter_table::drop_column::DropColumnOperator;
use crate::planner::operator::scan::ScanOperator;
use crate::planner::operator::table_scan::TableScanOperator;
use crate::planner::operator::Operator;
use crate::planner::LogicalPlan;
use crate::storage::Transaction;
Expand All @@ -29,7 +29,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
if_not_exists,
column_def,
} => {
let plan = ScanOperator::build(table_name.clone(), table);
let plan = TableScanOperator::build(table_name.clone(), table);
let column = self.bind_column(column_def)?;

if !is_valid_identifier(column.name()) {
Expand All @@ -51,7 +51,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
if_exists,
..
} => {
let plan = ScanOperator::build(table_name.clone(), table);
let plan = TableScanOperator::build(table_name.clone(), table);
let column_name = column_name.value.clone();

LogicalPlan::new(
Expand Down
4 changes: 2 additions & 2 deletions src/binder/analyze.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::binder::{lower_case_name, Binder};
use crate::errors::DatabaseError;
use crate::planner::operator::analyze::AnalyzeOperator;
use crate::planner::operator::scan::ScanOperator;
use crate::planner::operator::table_scan::TableScanOperator;
use crate::planner::operator::Operator;
use crate::planner::LogicalPlan;
use crate::storage::Transaction;
Expand All @@ -17,7 +17,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
.table_and_bind(table_name.clone(), None, None)?;
let index_metas = table_catalog.indexes.clone();

let scan_op = ScanOperator::build(table_name.clone(), table_catalog);
let scan_op = TableScanOperator::build(table_name.clone(), table_catalog);
Ok(LogicalPlan::new(
Operator::Analyze(AnalyzeOperator {
table_name,
Expand Down
4 changes: 2 additions & 2 deletions src/binder/create_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::binder::{lower_case_name, Binder};
use crate::errors::DatabaseError;
use crate::expression::ScalarExpression;
use crate::planner::operator::create_index::CreateIndexOperator;
use crate::planner::operator::scan::ScanOperator;
use crate::planner::operator::table_scan::TableScanOperator;
use crate::planner::operator::Operator;
use crate::planner::LogicalPlan;
use crate::storage::Transaction;
Expand Down Expand Up @@ -32,7 +32,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
let table = self
.context
.table_and_bind(table_name.clone(), None, None)?;
let plan = ScanOperator::build(table_name.clone(), table);
let plan = TableScanOperator::build(table_name.clone(), table);
let mut columns = Vec::with_capacity(exprs.len());

for expr in exprs {
Expand Down
6 changes: 4 additions & 2 deletions src/binder/create_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,16 @@ mod tests {
let storage = RocksStorage::new(temp_dir.path())?;
let transaction = storage.transaction()?;
let table_cache = Arc::new(ShardingLruCache::new(128, 16, RandomState::new())?);
let functions = Default::default();
let scala_functions = Default::default();
let table_functions = Default::default();

let sql = "create table t1 (id int primary key, name varchar(10) null)";
let mut binder = Binder::new(
BinderContext::new(
&table_cache,
&transaction,
&functions,
&scala_functions,
&table_functions,
Arc::new(AtomicUsize::new(0)),
),
None,
Expand Down
4 changes: 2 additions & 2 deletions src/binder/delete.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::binder::{lower_case_name, Binder};
use crate::errors::DatabaseError;
use crate::planner::operator::delete::DeleteOperator;
use crate::planner::operator::scan::ScanOperator;
use crate::planner::operator::table_scan::TableScanOperator;
use crate::planner::operator::Operator;
use crate::planner::LogicalPlan;
use crate::storage::Transaction;
Expand Down Expand Up @@ -31,7 +31,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
.find(|column| column.desc.is_primary)
.cloned()
.unwrap();
let mut plan = ScanOperator::build(table_name.clone(), table_catalog);
let mut plan = TableScanOperator::build(table_name.clone(), table_catalog);

if let Some(alias_idents) = alias_idents {
plan =
Expand Down
34 changes: 26 additions & 8 deletions src/binder/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ use std::slice;
use std::sync::Arc;

use super::{lower_ident, Binder, BinderContext, QueryBindStep, SubQueryType};
use crate::expression::function::{FunctionSummary, ScalarFunction};
use crate::expression::function::scala::ScalarFunction;
use crate::expression::function::table::TableFunction;
use crate::expression::function::FunctionSummary;
use crate::expression::{AliasType, ScalarExpression};
use crate::planner::LogicalPlan;
use crate::storage::Transaction;
Expand Down Expand Up @@ -219,9 +221,7 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
ty,
})
}
expr => {
todo!("{}", expr)
}
expr => Err(DatabaseError::UnsupportedStmt(expr.to_string())),
}
}

Expand Down Expand Up @@ -250,12 +250,19 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
let BinderContext {
table_cache,
transaction,
functions,
scala_functions,
table_functions,
temp_table_id,
..
} = &self.context;
let mut binder = Binder::new(
BinderContext::new(table_cache, *transaction, functions, temp_table_id.clone()),
BinderContext::new(
table_cache,
*transaction,
scala_functions,
table_functions,
temp_table_id.clone(),
),
Some(self),
);
let mut sub_query = binder.bind_query(subquery)?;
Expand Down Expand Up @@ -429,6 +436,11 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
}

fn bind_function(&mut self, func: &Function) -> Result<ScalarExpression, DatabaseError> {
if !matches!(self.context.step_now(), QueryBindStep::From) {
return Err(DatabaseError::UnsupportedStmt(
"`TableFunction` cannot bind in non-From step".to_string(),
));
}
let mut args = Vec::with_capacity(func.args.len());

for arg in func.args.iter() {
Expand Down Expand Up @@ -586,8 +598,14 @@ impl<'a, 'b, T: Transaction> Binder<'a, 'b, T> {
name: function_name,
arg_types,
};
if let Some(function) = self.context.functions.get(&summary) {
return Ok(ScalarExpression::Function(ScalarFunction {
if let Some(function) = self.context.scala_functions.get(&summary) {
return Ok(ScalarExpression::ScalaFunction(ScalarFunction {
args,
inner: function.clone(),
}));
}
if let Some(function) = self.context.table_functions.get(&summary) {
return Ok(ScalarExpression::TableFunction(TableFunction {
args,
inner: function.clone(),
}));
Expand Down
17 changes: 11 additions & 6 deletions src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

use crate::catalog::{TableCatalog, TableName};
use crate::db::Functions;
use crate::db::{ScalaFunctions, TableFunctions};
use crate::errors::DatabaseError;
use crate::expression::ScalarExpression;
use crate::planner::operator::join::JoinType;
Expand Down Expand Up @@ -82,7 +82,8 @@ pub enum SubQueryType {

#[derive(Clone)]
pub struct BinderContext<'a, T: Transaction> {
pub(crate) functions: &'a Functions,
pub(crate) scala_functions: &'a ScalaFunctions,
pub(crate) table_functions: &'a TableFunctions,
pub(crate) table_cache: &'a TableCache,
pub(crate) transaction: &'a T,
// Tips: When there are multiple tables and Wildcard, use BTreeMap to ensure that the order of the output tables is certain.
Expand All @@ -108,11 +109,13 @@ impl<'a, T: Transaction> BinderContext<'a, T> {
pub fn new(
table_cache: &'a TableCache,
transaction: &'a T,
functions: &'a Functions,
scala_functions: &'a ScalaFunctions,
table_functions: &'a TableFunctions,
temp_table_id: Arc<AtomicUsize>,
) -> Self {
BinderContext {
functions,
scala_functions,
table_functions,
table_cache,
transaction,
bind_table: Default::default(),
Expand Down Expand Up @@ -445,12 +448,14 @@ pub mod test {
let table_cache = Arc::new(ShardingLruCache::new(128, 16, RandomState::new())?);
let storage = build_test_catalog(&table_cache, temp_dir.path())?;
let transaction = storage.transaction()?;
let functions = Default::default();
let scala_functions = Default::default();
let table_functions = Default::default();
let mut binder = Binder::new(
BinderContext::new(
&table_cache,
&transaction,
&functions,
&scala_functions,
&table_functions,
Arc::new(AtomicUsize::new(0)),
),
None,
Expand Down
Loading

0 comments on commit 88c68b5

Please sign in to comment.