Skip to content

Commit

Permalink
feat: support UDF (#140)
Browse files Browse the repository at this point in the history
* feat: support `UDF`

* fix: DataBase build on sqllogictest
  • Loading branch information
KKould authored Feb 15, 2024
1 parent 0270f6b commit 3d8d2eb
Show file tree
Hide file tree
Showing 25 changed files with 904 additions and 612 deletions.
16 changes: 15 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ then use `psql` to enter sql
![pg](./static/images/pg.gif)
Using FnckSQL in code
```rust
let fnck_sql = Database::with_kipdb("./data").await?;
let fnck_sql = DataBaseBuilder::path("./data")
.build()
.await?;
let tuples = fnck_sql.run("select * from t1").await?;
```
Storage Support:
Expand Down Expand Up @@ -80,6 +82,18 @@ implement_from_tuple!(
)
);
```
- User-Defined Function: `features = ["marcos"]`
```rust
function!(TestFunction::test(LogicalType::Integer, LogicalType::Integer) -> LogicalType::Integer => |v1: ValueRef, v2: ValueRef| {
let value = DataValue::binary_op(&v1, &v2, &BinaryOperator::Plus)?;
DataValue::unary_op(&value, &UnaryOperator::Minus)
});

let fnck_sql = DataBaseBuilder::path("./data")
.register_function(TestFunction::new())
.build()
.await?;
```
- Optimizer
- RBO
- CBO based on RBO(Physical Selection)
Expand Down
8 changes: 5 additions & 3 deletions benchmarks/query_benchmark.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use criterion::{criterion_group, criterion_main, Criterion};
use fnck_sql::db::Database;
use fnck_sql::db::{DataBaseBuilder, Database};
use fnck_sql::errors::DatabaseError;
use fnck_sql::execution::volcano;
use fnck_sql::storage::kip::KipStorage;
Expand All @@ -18,7 +18,8 @@ const QUERY_BENCH_SQLITE_PATH: &'static str = "./sqlite_bench";
const TABLE_ROW_NUM: u64 = 2_00_000;

async fn init_fncksql_query_bench() -> Result<(), DatabaseError> {
let database = Database::with_kipdb(QUERY_BENCH_FNCK_SQL_PATH)
let database = DataBaseBuilder::path(QUERY_BENCH_FNCK_SQL_PATH)
.build()
.await
.unwrap();
database
Expand Down Expand Up @@ -96,7 +97,8 @@ fn query_on_execute(c: &mut Criterion) {
init_fncksql_query_bench().await.unwrap();
}

Database::<KipStorage>::with_kipdb(QUERY_BENCH_FNCK_SQL_PATH)
DataBaseBuilder::path(QUERY_BENCH_FNCK_SQL_PATH)
.build()
.await
.unwrap()
});
Expand Down
4 changes: 2 additions & 2 deletions examples/hello_world.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use fnck_sql::db::Database;
use fnck_sql::db::DataBaseBuilder;
use fnck_sql::errors::DatabaseError;
use fnck_sql::implement_from_tuple;
use fnck_sql::types::tuple::Tuple;
Expand Down Expand Up @@ -30,7 +30,7 @@ implement_from_tuple!(
#[cfg(feature = "marcos")]
#[tokio::main]
async fn main() -> Result<(), DatabaseError> {
let database = Database::with_kipdb("./hello_world").await?;
let database = DataBaseBuilder::path("./hello_world").build().await?;

let _ = database
.run("create table if not exists my_struct (c1 int primary key, c2 int)")
Expand Down
4 changes: 2 additions & 2 deletions examples/transaction.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use fnck_sql::db::Database;
use fnck_sql::db::DataBaseBuilder;
use fnck_sql::errors::DatabaseError;

#[tokio::main]
async fn main() -> Result<(), DatabaseError> {
let database = Database::with_kipdb("./transaction").await?;
let database = DataBaseBuilder::path("./transaction").build().await?;
let mut tx_1 = database.new_transaction().await?;

let _ = tx_1
Expand Down
4 changes: 2 additions & 2 deletions src/bin/server.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use async_trait::async_trait;
use clap::Parser;
use fnck_sql::db::{DBTransaction, Database};
use fnck_sql::db::{DBTransaction, DataBaseBuilder, Database};
use fnck_sql::errors::DatabaseError;
use fnck_sql::storage::kip::KipStorage;
use fnck_sql::types::tuple::Tuple;
Expand Down Expand Up @@ -79,7 +79,7 @@ impl MakeHandler for FnckSQLBackend {

impl FnckSQLBackend {
pub async fn new(path: impl Into<PathBuf> + Send) -> Result<FnckSQLBackend, DatabaseError> {
let database = Database::with_kipdb(path).await?;
let database = DataBaseBuilder::path(path).build().await?;

Ok(FnckSQLBackend {
inner: Arc::new(database),
Expand Down
11 changes: 7 additions & 4 deletions src/binder/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use sqlparser::ast::{Expr, OrderByExpr};
use std::collections::HashSet;

use crate::errors::DatabaseError;
use crate::expression::function::ScalarFunction;
use crate::planner::LogicalPlan;
use crate::storage::Transaction;
use crate::{
Expand Down Expand Up @@ -137,7 +138,8 @@ impl<'a, T: Transaction> Binder<'a, T> {
}
ScalarExpression::Constant(_) | ScalarExpression::ColumnRef { .. } => (),
ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(),
ScalarExpression::Tuple(args) => {
ScalarExpression::Tuple(args)
| ScalarExpression::Function(ScalarFunction { args, .. }) => {
for expr in args {
self.visit_column_agg_expr(expr)?;
}
Expand Down Expand Up @@ -248,7 +250,7 @@ impl<'a, T: Transaction> Binder<'a, T> {

Err(DatabaseError::AggMiss(
format!(
"column {:?} must appear in the GROUP BY clause or be used in an aggregate function",
"expression '{}' must appear in the GROUP BY clause or be used in an aggregate function",
expr
)
))
Expand All @@ -263,7 +265,7 @@ impl<'a, T: Transaction> Binder<'a, T> {

Err(DatabaseError::AggMiss(
format!(
"column {:?} must appear in the GROUP BY clause or be used in an aggregate function",
"expression '{}' must appear in the GROUP BY clause or be used in an aggregate function",
expr
)
))
Expand Down Expand Up @@ -315,7 +317,8 @@ impl<'a, T: Transaction> Binder<'a, T> {
}
ScalarExpression::Constant(_) => Ok(()),
ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(),
ScalarExpression::Tuple(args) => {
ScalarExpression::Tuple(args)
| ScalarExpression::Function(ScalarFunction { args, .. }) => {
for expr in args {
self.validate_having_orderby(expr)?;
}
Expand Down
3 changes: 2 additions & 1 deletion src/binder/create_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,10 @@ mod tests {
let temp_dir = TempDir::new().expect("unable to create temporary working directory");
let storage = KipStorage::new(temp_dir.path()).await?;
let transaction = storage.transaction().await?;
let functions = Default::default();

let sql = "create table t1 (id int primary key, name varchar(10) null)";
let mut binder = Binder::new(BinderContext::new(&transaction));
let mut binder = Binder::new(BinderContext::new(&transaction, &functions));
let stmt = crate::parser::parse_sql(sql).unwrap();
let plan1 = binder.bind(&stmt[0]).unwrap();

Expand Down
51 changes: 34 additions & 17 deletions src/binder/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use std::slice;
use std::sync::Arc;

use super::{lower_ident, Binder};
use crate::expression::function::{FunctionSummary, ScalarFunction};
use crate::expression::{AliasType, ScalarExpression};
use crate::storage::Transaction;
use crate::types::value::DataValue;
Expand Down Expand Up @@ -267,56 +268,72 @@ impl<'a, T: Transaction> Binder<'a, T> {
_ => todo!(),
}
}
let function_name = func.name.to_string().to_lowercase();

Ok(match func.name.to_string().to_lowercase().as_str() {
"count" => ScalarExpression::AggCall {
distinct: func.distinct,
kind: AggKind::Count,
args,
ty: LogicalType::Integer,
},
match function_name.as_str() {
"count" => {
return Ok(ScalarExpression::AggCall {
distinct: func.distinct,
kind: AggKind::Count,
args,
ty: LogicalType::Integer,
})
}
"sum" => {
let ty = args[0].return_type();

ScalarExpression::AggCall {
return Ok(ScalarExpression::AggCall {
distinct: func.distinct,
kind: AggKind::Sum,
args,
ty,
}
});
}
"min" => {
let ty = args[0].return_type();

ScalarExpression::AggCall {
return Ok(ScalarExpression::AggCall {
distinct: func.distinct,
kind: AggKind::Min,
args,
ty,
}
});
}
"max" => {
let ty = args[0].return_type();

ScalarExpression::AggCall {
return Ok(ScalarExpression::AggCall {
distinct: func.distinct,
kind: AggKind::Max,
args,
ty,
}
});
}
"avg" => {
let ty = args[0].return_type();

ScalarExpression::AggCall {
return Ok(ScalarExpression::AggCall {
distinct: func.distinct,
kind: AggKind::Avg,
args,
ty,
}
});
}
_ => todo!(),
})
_ => (),
}
let arg_types = args.iter().map(ScalarExpression::return_type).collect_vec();
let summary = FunctionSummary {
name: function_name,
arg_types,
};
if let Some(function) = self.context.functions.get(&summary) {
return Ok(ScalarExpression::Function(ScalarFunction {
args,
inner: function.clone(),
}));
}

Err(DatabaseError::NotFound("function", summary.name))
}

fn bind_is_null(
Expand Down
4 changes: 2 additions & 2 deletions src/binder/insert.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::binder::{lower_case_name, Binder};
use crate::errors::DatabaseError;
use crate::expression::value_compute::unary_op;
use crate::expression::ScalarExpression;
use crate::planner::operator::insert::InsertOperator;
use crate::planner::operator::values::ValuesOperator;
Expand Down Expand Up @@ -72,7 +71,8 @@ impl<'a, T: Transaction> Binder<'a, T> {
ScalarExpression::Unary { expr, op, .. } => {
if let ScalarExpression::Constant(value) = expr.as_ref() {
row.push(Arc::new(
unary_op(value, op)?.cast(schema_ref[i].datatype())?,
DataValue::unary_op(value, op)?
.cast(schema_ref[i].datatype())?,
))
} else {
unreachable!()
Expand Down
10 changes: 7 additions & 3 deletions src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use std::collections::HashMap;
use std::sync::Arc;

use crate::catalog::{TableCatalog, TableName};
use crate::db::Functions;
use crate::errors::DatabaseError;
use crate::expression::ScalarExpression;
use crate::planner::operator::join::JoinType;
Expand Down Expand Up @@ -47,7 +48,8 @@ pub enum QueryBindStep {

#[derive(Clone)]
pub struct BinderContext<'a, T: Transaction> {
transaction: &'a T,
functions: &'a Functions,
pub(crate) transaction: &'a T,
pub(crate) bind_table: HashMap<TableName, (&'a TableCatalog, Option<JoinType>)>,
// alias
expr_aliases: HashMap<String, ScalarExpression>,
Expand All @@ -63,8 +65,9 @@ pub struct BinderContext<'a, T: Transaction> {
}

impl<'a, T: Transaction> BinderContext<'a, T> {
pub fn new(transaction: &'a T) -> Self {
pub fn new(transaction: &'a T, functions: &'a Functions) -> Self {
BinderContext {
functions,
transaction,
bind_table: Default::default(),
expr_aliases: Default::default(),
Expand Down Expand Up @@ -336,7 +339,8 @@ pub mod test {
let temp_dir = TempDir::new().expect("unable to create temporary working directory");
let storage = build_test_catalog(temp_dir.path()).await?;
let transaction = storage.transaction().await?;
let mut binder = Binder::new(BinderContext::new(&transaction));
let functions = Default::default();
let mut binder = Binder::new(BinderContext::new(&transaction, &functions));
let stmt = crate::parser::parse_sql(sql)?;

Ok(binder.bind(&stmt[0])?)
Expand Down
Loading

0 comments on commit 3d8d2eb

Please sign in to comment.