Skip to content

Commit

Permalink
feat: support BETWEEN on Where (#133)
Browse files Browse the repository at this point in the history
* style: pass clippy

* feat: support `BETWEEN` on `Where`

* code fmt
  • Loading branch information
KKould authored Feb 9, 2024
1 parent bb9ccef commit 4b4c8e6
Show file tree
Hide file tree
Showing 33 changed files with 293 additions and 120 deletions.
11 changes: 4 additions & 7 deletions src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl SimpleQueryHandler for SessionBackend {
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
guard.replace(transaction);

Ok(vec![Response::Execution(Tag::new("OK").into())])
Ok(vec![Response::Execution(Tag::new("OK"))])
}
"COMMIT;" | "COMMIT" | "COMMIT WORK;" | "COMMIT WORK" => {
let mut guard = self.tx.lock().await;
Expand All @@ -124,7 +124,7 @@ impl SimpleQueryHandler for SessionBackend {
.await
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

Ok(vec![Response::Execution(Tag::new("OK").into())])
Ok(vec![Response::Execution(Tag::new("OK"))])
} else {
Err(PgWireError::ApiError(Box::new(
DatabaseError::NoTransactionBegin,
Expand All @@ -141,7 +141,7 @@ impl SimpleQueryHandler for SessionBackend {
}
drop(guard.take());

Ok(vec![Response::Execution(Tag::new("OK").into())])
Ok(vec![Response::Execution(Tag::new("OK"))])
}
_ => {
let mut guard = self.tx.lock().await;
Expand Down Expand Up @@ -210,10 +210,7 @@ fn encode_tuples<'a>(tuples: Vec<Tuple>) -> PgWireResult<QueryResponse<'a>> {
results.push(encoder.finish());
}

Ok(QueryResponse::new(
schema,
stream::iter(results.into_iter()),
))
Ok(QueryResponse::new(schema, stream::iter(results)))
}

fn into_pg_type(data_type: &LogicalType) -> PgWireResult<Type> {
Expand Down
22 changes: 21 additions & 1 deletion src/binder/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,16 @@ impl<'a, T: Transaction> Binder<'a, T> {
self.visit_column_agg_expr(arg)?;
}
}
ScalarExpression::Between {
expr,
left_expr,
right_expr,
..
} => {
self.visit_column_agg_expr(expr)?;
self.visit_column_agg_expr(left_expr)?;
self.visit_column_agg_expr(right_expr)?;
}
ScalarExpression::Constant(_) | ScalarExpression::ColumnRef { .. } => {}
}

Expand Down Expand Up @@ -257,7 +267,17 @@ impl<'a, T: Transaction> Binder<'a, T> {
self.validate_having_orderby(right_expr)?;
Ok(())
}

ScalarExpression::Between {
expr,
left_expr,
right_expr,
..
} => {
self.validate_having_orderby(expr)?;
self.validate_having_orderby(left_expr)?;
self.validate_having_orderby(right_expr)?;
Ok(())
}
ScalarExpression::Constant(_) => Ok(()),
}
}
Expand Down
11 changes: 11 additions & 0 deletions src/binder/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,17 @@ impl<'a, T: Transaction> Binder<'a, T> {

Ok(ScalarExpression::Constant(Arc::new(value)))
}
Expr::Between {
expr,
negated,
low,
high,
} => Ok(ScalarExpression::Between {
negated: *negated,
expr: Box::new(self.bind_expr(expr)?),
left_expr: Box::new(self.bind_expr(low)?),
right_expr: Box::new(self.bind_expr(high)?),
}),
_ => {
todo!()
}
Expand Down
4 changes: 2 additions & 2 deletions src/binder/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
ScalarExpression::Constant(dv) => match dv.as_ref() {
DataValue::Int32(Some(v)) if *v >= 0 => limit = Some(*v as usize),
DataValue::Int64(Some(v)) if *v >= 0 => limit = Some(*v as usize),
_ => return Err(DatabaseError::from(DatabaseError::InvalidType)),
_ => return Err(DatabaseError::InvalidType),
},
_ => {
return Err(DatabaseError::InvalidColumn(
Expand All @@ -424,7 +424,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
ScalarExpression::Constant(dv) => match dv.as_ref() {
DataValue::Int32(Some(v)) if *v > 0 => offset = Some(*v as usize),
DataValue::Int64(Some(v)) if *v > 0 => offset = Some(*v as usize),
_ => return Err(DatabaseError::from(DatabaseError::InvalidType)),
_ => return Err(DatabaseError::InvalidType),
},
_ => {
return Err(DatabaseError::InvalidColumn(
Expand Down
9 changes: 4 additions & 5 deletions src/catalog/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl TableCatalog {
}

pub(crate) fn clone_columns(&self) -> Vec<ColumnRef> {
self.columns.values().map(Arc::clone).collect()
self.columns.values().cloned().collect()
}

pub(crate) fn columns_with_id(&self) -> Iter<'_, ColumnId, ColumnRef> {
Expand All @@ -66,17 +66,16 @@ impl TableCatalog {

pub(crate) fn primary_key(&self) -> Result<(usize, &ColumnRef), DatabaseError> {
self.columns
.iter()
.map(|(_, column)| column)
.values()
.enumerate()
.find(|(_, column)| column.desc.is_primary)
.ok_or(DatabaseError::PrimaryKeyNotFound)
}

pub(crate) fn types(&self) -> Vec<LogicalType> {
self.columns
.iter()
.map(|(_, column)| *column.datatype())
.values()
.map(|column| *column.datatype())
.collect_vec()
}

Expand Down
2 changes: 1 addition & 1 deletion src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ impl<S: Storage> DBTransaction<S> {
let (plan, _) = Database::<S>::build_plan::<T, S::TransactionType>(sql, &self.inner)?;
let mut stream = build_write(plan, &mut self.inner);

Ok(try_collect(&mut stream).await?)
try_collect(&mut stream).await
}

pub async fn commit(self) -> Result<(), DatabaseError> {
Expand Down
2 changes: 1 addition & 1 deletion src/execution/codegen/dql/aggregate/hash_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ impl UserData for HashAggStatus {
Ok(())
});
methods.add_method_mut("to_tuples", |_, agg_status, ()| {
Ok(agg_status.to_tuples().unwrap())
Ok(agg_status.as_tuples().unwrap())
});
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/execution/volcano/ddl/add_column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl AddColumn {
column,
if_not_exists,
} = &self.op;
let mut unique_values = column.desc().is_unique.then(|| Vec::new());
let mut unique_values = column.desc().is_unique.then(Vec::new);
let mut tuple_columns = None;
let mut tuples = Vec::new();

Expand Down Expand Up @@ -78,7 +78,7 @@ impl AddColumn {
id: unique_meta.id,
column_values: vec![value],
};
transaction.add_index(&table_name, index, vec![tuple_id], true)?;
transaction.add_index(table_name, index, vec![tuple_id], true)?;
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/execution/volcano/dml/analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ impl fmt::Display for AnalyzeOperator {
let columns = self
.columns
.iter()
.map(|column| format!("{}", column.name()))
.map(|column| column.name().to_string())
.join(", ");

write!(f, "Analyze {} -> [{}]", self.table_name, columns)?;
Expand Down
11 changes: 5 additions & 6 deletions src/execution/volcano/dql/aggregate/hash_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,10 @@ impl HashAggStatus {
Ok(())
}

pub(crate) fn to_tuples(&mut self) -> Result<Vec<Tuple>, DatabaseError> {
let group_columns = Arc::new(mem::replace(&mut self.group_columns, vec![]));
pub(crate) fn as_tuples(&mut self) -> Result<Vec<Tuple>, DatabaseError> {
let group_columns = Arc::new(mem::take(&mut self.group_columns));

Ok(self
.group_hash_accs
self.group_hash_accs
.drain()
.map(|(group_keys, accs)| {
// Tips: Accumulator First
Expand All @@ -129,7 +128,7 @@ impl HashAggStatus {
values,
})
})
.try_collect()?)
.try_collect()
}
}

Expand All @@ -149,7 +148,7 @@ impl HashAggExecutor {
agg_status.update(tuple?)?;
}

for tuple in agg_status.to_tuples()? {
for tuple in agg_status.as_tuples()? {
yield tuple;
}
}
Expand Down
14 changes: 5 additions & 9 deletions src/execution/volcano/dql/join/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,7 @@ impl HashJoinStatus {
let _ = mem::replace(left_init_flag, true);
}

build_map
.entry(hash)
.or_insert_with(|| Vec::new())
.push(tuple);
build_map.entry(hash).or_insert_with(Vec::new).push(tuple);

Ok(())
}
Expand All @@ -134,7 +131,7 @@ impl HashJoinStatus {
} = self;

let right_cols_len = tuple.columns.len();
let hash = Self::hash_row(&on_right_keys, &hash_random_state, &tuple)?;
let hash = Self::hash_row(on_right_keys, hash_random_state, &tuple)?;

if !*right_init_flag {
Self::columns_filling(&tuple, join_columns, *right_force_nullable);
Expand Down Expand Up @@ -240,14 +237,14 @@ impl HashJoinStatus {
build_map
.drain()
.filter(|(hash, _)| !used_set.contains(hash))
.map(|(_, mut tuples)| {
.flat_map(|(_, mut tuples)| {
for Tuple {
values,
columns,
id,
} in tuples.iter_mut()
{
let _ = mem::replace(id, None);
let _ = id.take();
let (right_values, full_columns) = buf.get_or_insert_with(|| {
let (right_values, mut right_columns): (
Vec<ValueRef>,
Expand All @@ -269,10 +266,9 @@ impl HashJoinStatus {
}
tuples
})
.flatten()
.collect_vec()
})
.unwrap_or_else(|| vec![])
.unwrap_or_else(Vec::new)
}

fn columns_filling(tuple: &Tuple, join_columns: &mut Vec<ColumnRef>, force_nullable: bool) {
Expand Down
6 changes: 1 addition & 5 deletions src/execution/volcano/dql/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,7 @@ pub(crate) fn radix_sort<T>(mut tuples: Vec<(T, Vec<u8>)>) -> Vec<T> {
temp_buckets[index as usize].push((t, bytes));
}

tuples = temp_buckets
.iter_mut()
.map(|group| mem::replace(group, vec![]))
.flatten()
.collect_vec();
tuples = temp_buckets.iter_mut().flat_map(mem::take).collect_vec();
}
return tuples.into_iter().map(|(tuple, _)| tuple).collect_vec();
}
Expand Down
34 changes: 33 additions & 1 deletion src/expression/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::types::tuple::Tuple;
use crate::types::value::{DataValue, ValueRef};
use itertools::Itertools;
use lazy_static::lazy_static;
use std::cmp::Ordering;
use std::sync::Arc;

lazy_static! {
Expand Down Expand Up @@ -68,9 +69,17 @@ impl ScalarExpression {
negated,
} => {
let value = expr.eval(tuple)?;
if value.is_null() {
return Ok(Arc::new(DataValue::Boolean(None)));
}
let mut is_in = false;
for arg in args {
if arg.eval(tuple)? == value {
let arg_value = arg.eval(tuple)?;

if arg_value.is_null() {
return Ok(Arc::new(DataValue::Boolean(None)));
}
if arg_value == value {
is_in = true;
break;
}
Expand All @@ -92,6 +101,29 @@ impl ScalarExpression {

Ok(value)
}
ScalarExpression::Between {
expr,
left_expr,
right_expr,
negated,
} => {
let value = expr.eval(tuple)?;
let left = left_expr.eval(tuple)?;
let right = right_expr.eval(tuple)?;

let mut is_between = match (
value.partial_cmp(&left).map(Ordering::is_ge),
value.partial_cmp(&right).map(Ordering::is_le),
) {
(Some(true), Some(true)) => true,
(None, _) | (_, None) => return Ok(Arc::new(DataValue::Boolean(None))),
_ => false,
};
if *negated {
is_between = !is_between;
}
Ok(Arc::new(DataValue::Boolean(Some(is_between))))
}
}
}

Expand Down
Loading

0 comments on commit 4b4c8e6

Please sign in to comment.