Skip to content

Commit

Permalink
Refactor the executor to support transaction detachment (#82)
Browse files Browse the repository at this point in the history
* refactor: refactor the executor to support transaction detachment

* style: `inputs` restore

* style: code fmt

* rollback: `Text` LogicalType
  • Loading branch information
KKould authored Oct 10, 2023
1 parent 82d1ce4 commit 8786f0f
Show file tree
Hide file tree
Showing 46 changed files with 779 additions and 1,147 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ lazy_static = "1.4.0"
comfy-table = "7.0.1"
bytes = "1.5.0"
kip_db = "0.1.2-alpha.17"
async-recursion = "1.0.5"
rust_decimal = "1"
csv = "1"

Expand Down
21 changes: 10 additions & 11 deletions src/binder/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ use std::collections::HashSet;

use crate::binder::{BindError, InputRefType};
use crate::planner::LogicalPlan;
use crate::storage::Storage;
use crate::storage::Transaction;
use crate::{
expression::ScalarExpression,
planner::operator::{aggregate::AggregateOperator, sort::SortField},
};

use super::Binder;

impl<S: Storage> Binder<S> {
impl<'a, T: Transaction> Binder<'a, T> {
pub fn bind_aggregate(
&mut self,
children: LogicalPlan,
Expand All @@ -33,29 +33,28 @@ impl<S: Storage> Binder<S> {
Ok(())
}

pub async fn extract_group_by_aggregate(
pub fn extract_group_by_aggregate(
&mut self,
select_list: &mut [ScalarExpression],
groupby: &[Expr],
) -> Result<(), BindError> {
self.validate_groupby_illegal_column(select_list, groupby)
.await?;
self.validate_groupby_illegal_column(select_list, groupby)?;

for gb in groupby {
let mut expr = self.bind_expr(gb).await?;
let mut expr = self.bind_expr(gb)?;
self.visit_group_by_expr(select_list, &mut expr);
}
Ok(())
}

pub async fn extract_having_orderby_aggregate(
pub fn extract_having_orderby_aggregate(
&mut self,
having: &Option<Expr>,
orderbys: &[OrderByExpr],
) -> Result<(Option<ScalarExpression>, Option<Vec<SortField>>), BindError> {
// Extract having expression.
let return_having = if let Some(having) = having {
let mut having = self.bind_expr(having).await?;
let mut having = self.bind_expr(having)?;
self.visit_column_agg_expr(&mut having, false)?;

Some(having)
Expand All @@ -72,7 +71,7 @@ impl<S: Storage> Binder<S> {
asc,
nulls_first,
} = orderby;
let mut expr = self.bind_expr(expr).await?;
let mut expr = self.bind_expr(expr)?;
self.visit_column_agg_expr(&mut expr, false)?;

return_orderby.push(SortField::new(
Expand Down Expand Up @@ -156,14 +155,14 @@ impl<S: Storage> Binder<S> {
/// e.g. SELECT a,count(b) FROM t GROUP BY a. it's ok.
/// SELECT a,b FROM t GROUP BY a. it's error.
/// SELECT a,count(b) FROM t GROUP BY b. it's error.
async fn validate_groupby_illegal_column(
fn validate_groupby_illegal_column(
&mut self,
select_items: &[ScalarExpression],
groupby: &[Expr],
) -> Result<(), BindError> {
let mut group_raw_exprs = vec![];
for expr in groupby {
let expr = self.bind_expr(expr).await?;
let expr = self.bind_expr(expr)?;

if let ScalarExpression::Alias { alias, .. } = expr {
let alias_expr = select_items.iter().find(|column| {
Expand Down
8 changes: 4 additions & 4 deletions src/binder/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ impl std::fmt::Display for FileFormat {

impl FromStr for ExtSource {
type Err = ();
fn from_str(_s: &str) -> std::result::Result<Self, Self::Err> {
fn from_str(_s: &str) -> Result<Self, Self::Err> {
Err(())
}
}

impl<S: Storage> Binder<S> {
pub(super) async fn bind_copy(
impl<'a, T: Transaction> Binder<'a, T> {
pub(super) fn bind_copy(
&mut self,
source: CopySource,
to: bool,
Expand All @@ -69,7 +69,7 @@ impl<S: Storage> Binder<S> {
}
};

if let Some(table) = self.context.storage.table(&table_name.to_string()).await {
if let Some(table) = self.context.transaction.table(&table_name.to_string()) {
let cols = table.all_columns();
let ext_source = ExtSource {
path: match target {
Expand Down
17 changes: 11 additions & 6 deletions src/binder/create_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ use crate::catalog::ColumnCatalog;
use crate::planner::operator::create_table::CreateTableOperator;
use crate::planner::operator::Operator;
use crate::planner::LogicalPlan;
use crate::storage::Storage;
use crate::storage::Transaction;

impl<S: Storage> Binder<S> {
impl<'a, T: Transaction> Binder<'a, T> {
// TODO: TableConstraint
pub(crate) fn bind_create_table(
&mut self,
Expand Down Expand Up @@ -60,19 +60,22 @@ mod tests {
use super::*;
use crate::binder::BinderContext;
use crate::catalog::ColumnDesc;
use crate::execution::ExecutorError;
use crate::storage::kip::KipStorage;
use crate::storage::Storage;
use crate::types::LogicalType;
use tempfile::TempDir;

#[tokio::test]
async fn test_create_bind() {
async fn test_create_bind() -> Result<(), ExecutorError> {
let temp_dir = TempDir::new().expect("unable to create temporary working directory");
let storage = KipStorage::new(temp_dir.path()).await.unwrap();
let storage = KipStorage::new(temp_dir.path()).await?;
let transaction = storage.transaction().await?;

let sql = "create table t1 (id int primary key, name varchar(10) null)";
let binder = Binder::new(BinderContext::new(storage));
let binder = Binder::new(BinderContext::new(&transaction));
let stmt = crate::parser::parse_sql(sql).unwrap();
let plan1 = binder.bind(&stmt[0]).await.unwrap();
let plan1 = binder.bind(&stmt[0]).unwrap();

match plan1.operator {
Operator::CreateTable(op) => {
Expand All @@ -92,5 +95,7 @@ mod tests {
}
_ => unreachable!(),
}

Ok(())
}
}
10 changes: 5 additions & 5 deletions src/binder/delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@ use crate::binder::{lower_case_name, split_name, BindError, Binder};
use crate::planner::operator::delete::DeleteOperator;
use crate::planner::operator::Operator;
use crate::planner::LogicalPlan;
use crate::storage::Storage;
use crate::storage::Transaction;
use sqlparser::ast::{Expr, TableFactor, TableWithJoins};

impl<S: Storage> Binder<S> {
pub(crate) async fn bind_delete(
impl<'a, T: Transaction> Binder<'a, T> {
pub(crate) fn bind_delete(
&mut self,
from: &TableWithJoins,
selection: &Option<Expr>,
) -> Result<LogicalPlan, BindError> {
if let TableFactor::Table { name, .. } = &from.relation {
let name = lower_case_name(name);
let (_, name) = split_name(&name)?;
let (table_name, mut plan) = self._bind_single_table_ref(None, name).await?;
let (table_name, mut plan) = self._bind_single_table_ref(None, name)?;

if let Some(predicate) = selection {
plan = self.bind_where(plan, predicate).await?;
plan = self.bind_where(plan, predicate)?;
}

Ok(LogicalPlan {
Expand Down
4 changes: 2 additions & 2 deletions src/binder/distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use crate::binder::Binder;
use crate::expression::ScalarExpression;
use crate::planner::operator::aggregate::AggregateOperator;
use crate::planner::LogicalPlan;
use crate::storage::Storage;
use crate::storage::Transaction;

impl<S: Storage> Binder<S> {
impl<'a, T: Transaction> Binder<'a, T> {
pub fn bind_distinct(
&mut self,
children: LogicalPlan,
Expand Down
4 changes: 2 additions & 2 deletions src/binder/drop_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ use crate::binder::{lower_case_name, split_name, BindError, Binder};
use crate::planner::operator::drop_table::DropTableOperator;
use crate::planner::operator::Operator;
use crate::planner::LogicalPlan;
use crate::storage::Storage;
use crate::storage::Transaction;
use sqlparser::ast::ObjectName;
use std::sync::Arc;

impl<S: Storage> Binder<S> {
impl<'a, T: Transaction> Binder<'a, T> {
pub(crate) fn bind_drop_table(&mut self, name: &ObjectName) -> Result<LogicalPlan, BindError> {
let name = lower_case_name(&name);
let (_, name) = split_name(&name)?;
Expand Down
42 changes: 17 additions & 25 deletions src/binder/expr.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::binder::BindError;
use crate::expression::agg::AggKind;
use async_recursion::async_recursion;
use itertools::Itertools;
use sqlparser::ast::{
BinaryOperator, Expr, Function, FunctionArg, FunctionArgExpr, Ident, UnaryOperator,
Expand All @@ -10,35 +9,29 @@ use std::sync::Arc;

use super::Binder;
use crate::expression::ScalarExpression;
use crate::storage::Storage;
use crate::storage::Transaction;
use crate::types::value::DataValue;
use crate::types::LogicalType;

impl<S: Storage> Binder<S> {
#[async_recursion]
pub(crate) async fn bind_expr(&mut self, expr: &Expr) -> Result<ScalarExpression, BindError> {
impl<'a, T: Transaction> Binder<'a, T> {
pub(crate) fn bind_expr(&mut self, expr: &Expr) -> Result<ScalarExpression, BindError> {
match expr {
Expr::Identifier(ident) => {
self.bind_column_ref_from_identifiers(slice::from_ref(ident), None)
.await
}
Expr::CompoundIdentifier(idents) => {
self.bind_column_ref_from_identifiers(idents, None).await
}
Expr::BinaryOp { left, right, op } => {
self.bind_binary_op_internal(left, right, op).await
}
Expr::CompoundIdentifier(idents) => self.bind_column_ref_from_identifiers(idents, None),
Expr::BinaryOp { left, right, op } => self.bind_binary_op_internal(left, right, op),
Expr::Value(v) => Ok(ScalarExpression::Constant(Arc::new(v.into()))),
Expr::Function(func) => self.bind_agg_call(func).await,
Expr::Nested(expr) => self.bind_expr(expr).await,
Expr::UnaryOp { expr, op } => self.bind_unary_op_internal(expr, op).await,
Expr::Function(func) => self.bind_agg_call(func),
Expr::Nested(expr) => self.bind_expr(expr),
Expr::UnaryOp { expr, op } => self.bind_unary_op_internal(expr, op),
_ => {
todo!()
}
}
}

pub async fn bind_column_ref_from_identifiers(
pub fn bind_column_ref_from_identifiers(
&mut self,
idents: &[Ident],
bind_table_name: Option<&String>,
Expand Down Expand Up @@ -66,9 +59,8 @@ impl<S: Storage> Binder<S> {
if let Some(table) = table_name.or(bind_table_name) {
let table_catalog = self
.context
.storage
.transaction
.table(table)
.await
.ok_or_else(|| BindError::InvalidTable(table.to_string()))?;

let column_catalog = table_catalog
Expand Down Expand Up @@ -100,14 +92,14 @@ impl<S: Storage> Binder<S> {
}
}

async fn bind_binary_op_internal(
fn bind_binary_op_internal(
&mut self,
left: &Expr,
right: &Expr,
op: &BinaryOperator,
) -> Result<ScalarExpression, BindError> {
let left_expr = Box::new(self.bind_expr(left).await?);
let right_expr = Box::new(self.bind_expr(right).await?);
let left_expr = Box::new(self.bind_expr(left)?);
let right_expr = Box::new(self.bind_expr(right)?);

let ty = match op {
BinaryOperator::Plus
Expand Down Expand Up @@ -137,12 +129,12 @@ impl<S: Storage> Binder<S> {
})
}

async fn bind_unary_op_internal(
fn bind_unary_op_internal(
&mut self,
expr: &Expr,
op: &UnaryOperator,
) -> Result<ScalarExpression, BindError> {
let expr = Box::new(self.bind_expr(expr).await?);
let expr = Box::new(self.bind_expr(expr)?);
let ty = if let UnaryOperator::Not = op {
LogicalType::Boolean
} else {
Expand All @@ -156,7 +148,7 @@ impl<S: Storage> Binder<S> {
})
}

async fn bind_agg_call(&mut self, func: &Function) -> Result<ScalarExpression, BindError> {
fn bind_agg_call(&mut self, func: &Function) -> Result<ScalarExpression, BindError> {
let mut args = Vec::with_capacity(func.args.len());

for arg in func.args.iter() {
Expand All @@ -165,7 +157,7 @@ impl<S: Storage> Binder<S> {
FunctionArg::Unnamed(arg) => arg,
};
match arg_expr {
FunctionArgExpr::Expr(expr) => args.push(self.bind_expr(expr).await?),
FunctionArgExpr::Expr(expr) => args.push(self.bind_expr(expr)?),
FunctionArgExpr::Wildcard => args.push(Self::wildcard_expr()),
_ => todo!(),
}
Expand Down
21 changes: 9 additions & 12 deletions src/binder/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ use crate::planner::operator::insert::InsertOperator;
use crate::planner::operator::values::ValuesOperator;
use crate::planner::operator::Operator;
use crate::planner::LogicalPlan;
use crate::storage::Storage;
use crate::storage::Transaction;
use crate::types::value::{DataValue, ValueRef};
use sqlparser::ast::{Expr, Ident, ObjectName};
use std::slice;
use std::sync::Arc;

impl<S: Storage> Binder<S> {
pub(crate) async fn bind_insert(
impl<'a, T: Transaction> Binder<'a, T> {
pub(crate) fn bind_insert(
&mut self,
name: ObjectName,
idents: &[Ident],
Expand All @@ -24,21 +24,18 @@ impl<S: Storage> Binder<S> {
let (_, name) = split_name(&name)?;
let table_name = Arc::new(name.to_string());

if let Some(table) = self.context.storage.table(&table_name).await {
if let Some(table) = self.context.transaction.table(&table_name) {
let mut columns = Vec::new();

if idents.is_empty() {
columns = table.all_columns();
} else {
let bind_table_name = Some(table_name.to_string());
for ident in idents {
match self
.bind_column_ref_from_identifiers(
slice::from_ref(ident),
bind_table_name.as_ref(),
)
.await?
{
match self.bind_column_ref_from_identifiers(
slice::from_ref(ident),
bind_table_name.as_ref(),
)? {
ScalarExpression::ColumnRef(catalog) => columns.push(catalog),
_ => unreachable!(),
}
Expand All @@ -50,7 +47,7 @@ impl<S: Storage> Binder<S> {
let mut row = Vec::with_capacity(expr_row.len());

for (i, expr) in expr_row.into_iter().enumerate() {
match &self.bind_expr(expr).await? {
match &self.bind_expr(expr)? {
ScalarExpression::Constant(value) => {
// Check if the value length is too long
value.check_len(columns[i].datatype())?;
Expand Down
Loading

0 comments on commit 8786f0f

Please sign in to comment.