Skip to content

Commit

Permalink
refactor(expr): separate user-facing rw_vnode from the internal one (
Browse files Browse the repository at this point in the history
…#18815)

Signed-off-by: Bugen Zhao <[email protected]>
  • Loading branch information
BugenZhao authored Oct 10, 2024
1 parent b4976e0 commit c68aa5e
Show file tree
Hide file tree
Showing 11 changed files with 120 additions and 44 deletions.
20 changes: 0 additions & 20 deletions e2e_test/batch/functions/internal.slt.part

This file was deleted.

45 changes: 45 additions & 0 deletions e2e_test/batch/functions/vnode.slt.part
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
query error takes at least 2 arguments \(0 given\)
select rw_vnode();

query error takes at least 2 arguments \(1 given\)
select rw_vnode(256);

query I
select rw_vnode(256, 114, 514);
----
97

query I
select rw_vnode(4096, 114, 514);
----
1377

# VirtualNode::MAX_COUNT
query I
select rw_vnode(32768, 114, 514);
----
21857

query error the first argument \(vnode count\) must not be NULL
select rw_vnode(NULL, 114, 514);

query error the first argument \(vnode count\) must be in range 1..=32768
select rw_vnode(0, 114, 514);

query error the first argument \(vnode count\) must be in range 1..=32768
select rw_vnode(32769, 114, 514);

statement ok
create table vnodes (vnode int);

statement ok
insert into vnodes values (256), (4096);

statement ok
flush;

query error the first argument \(vnode count\) must be a constant
select rw_vnode(vnode, 114, 514) from vnodes;

statement ok
drop table vnodes;
1 change: 1 addition & 0 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ message ExprNode {
// Internal functions
VNODE = 1101;
TEST_PAID_TIER = 1102;
VNODE_USER = 1103;
// Non-deterministic functions
PROCTIME = 2023;
PG_SLEEP = 2024;
Expand Down
11 changes: 5 additions & 6 deletions src/expr/core/src/expr_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

use std::future::Future;

use risingwave_common::hash::VirtualNode;
use risingwave_expr::{define_context, Result as ExprResult};
use risingwave_pb::plan_common::ExprContext;

Expand All @@ -30,11 +29,11 @@ pub fn capture_expr_context() -> ExprResult<ExprContext> {
Ok(ExprContext { time_zone })
}

/// Get the vnode count from the context, or [`VirtualNode::COUNT_FOR_COMPAT`] if not set.
// TODO(var-vnode): the only case where this is not set is for batch queries, is it still
// necessary to support `rw_vnode` expression in batch queries?
pub fn vnode_count() -> usize {
VNODE_COUNT::try_with(|&x| x).unwrap_or(VirtualNode::COUNT_FOR_COMPAT)
/// Get the vnode count from the context.
///
/// Always returns `Ok` in streaming mode and `Err` in batch mode.
pub fn vnode_count() -> ExprResult<usize> {
VNODE_COUNT::try_with(|&x| x)
}

pub async fn expr_context_scope<Fut>(expr_context: ExprContext, future: Fut) -> Fut::Output
Expand Down
53 changes: 49 additions & 4 deletions src/expr/impl/src/scalar/vnode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,21 @@

use std::sync::Arc;

use anyhow::Context;
use itertools::Itertools;
use risingwave_common::array::{ArrayBuilder, ArrayImpl, ArrayRef, DataChunk, I16ArrayBuilder};
use risingwave_common::hash::VirtualNode;
use risingwave_common::row::OwnedRow;
use risingwave_common::types::{DataType, Datum};
use risingwave_expr::expr::{BoxedExpression, Expression};
use risingwave_expr::expr_context::vnode_count;
use risingwave_expr::{build_function, Result};
use risingwave_expr::{build_function, expr_context, Result};

#[derive(Debug)]
struct VnodeExpression {
/// `Some` if it's from the first argument of user-facing function `VnodeUser` (`rw_vnode`),
/// `None` if it's from the internal function `Vnode`.
vnode_count: Option<usize>,

/// A list of expressions to get the distribution key columns. Typically `InputRef`.
children: Vec<BoxedExpression>,

Expand All @@ -36,6 +41,36 @@ struct VnodeExpression {
#[build_function("vnode(...) -> int2")]
fn build(_: DataType, children: Vec<BoxedExpression>) -> Result<BoxedExpression> {
Ok(Box::new(VnodeExpression {
vnode_count: None,
all_indices: (0..children.len()).collect(),
children,
}))
}

#[build_function("vnode_user(...) -> int2")]
fn build_user(_: DataType, children: Vec<BoxedExpression>) -> Result<BoxedExpression> {
let mut children = children.into_iter();

let vnode_count = children
.next()
.unwrap() // always exist, argument number enforced in binder
.eval_const() // required to be constant
.context("the first argument (vnode count) must be a constant")?
.context("the first argument (vnode count) must not be NULL")?
.into_int32(); // always int32, casted during type inference

if !(1i32..=VirtualNode::MAX_COUNT as i32).contains(&vnode_count) {
return Err(anyhow::anyhow!(
"the first argument (vnode count) must be in range 1..={}",
VirtualNode::MAX_COUNT
)
.into());
}

let children = children.collect_vec();

Ok(Box::new(VnodeExpression {
vnode_count: Some(vnode_count.try_into().unwrap()),
all_indices: (0..children.len()).collect(),
children,
}))
Expand All @@ -54,7 +89,7 @@ impl Expression for VnodeExpression {
}
let input = DataChunk::new(arrays, input.visibility().clone());

let vnodes = VirtualNode::compute_chunk(&input, &self.all_indices, vnode_count());
let vnodes = VirtualNode::compute_chunk(&input, &self.all_indices, self.vnode_count()?);
let mut builder = I16ArrayBuilder::new(input.capacity());
vnodes
.into_iter()
Expand All @@ -70,13 +105,23 @@ impl Expression for VnodeExpression {
let input = OwnedRow::new(datums);

Ok(Some(
VirtualNode::compute_row(input, &self.all_indices, vnode_count())
VirtualNode::compute_row(input, &self.all_indices, self.vnode_count()?)
.to_scalar()
.into(),
))
}
}

impl VnodeExpression {
fn vnode_count(&self) -> Result<usize> {
if let Some(vnode_count) = self.vnode_count {
Ok(vnode_count)
} else {
expr_context::vnode_count()
}
}
}

#[cfg(test)]
mod tests {
use risingwave_common::array::{DataChunk, DataChunkTestExt};
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/planner_test/tests/testdata/input/cse_expr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
- stream_plan
- name: Common sub expression shouldn't extract impure function
sql: |
create table t(v1 varchar, v2 int, v3 int);
select rw_vnode(v2) + 1 as vnode, rw_vnode(v2) + 1 as vnode2, v2 + 1 x, v2 + 1 y from t;
create table t(v1 varchar, v2 double, v3 double);
select pg_sleep(v2) + 1 as a, pg_sleep(v2) + 1 as b, v2 + 1 x, v2 + 1 y from t;
expected_outputs:
- batch_plan
- stream_plan
Expand Down
14 changes: 7 additions & 7 deletions src/frontend/planner_test/tests/testdata/output/cse_expr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,17 @@
└─StreamTableScan { table: t, columns: [t.v1, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) }
- name: Common sub expression shouldn't extract impure function
sql: |
create table t(v1 varchar, v2 int, v3 int);
select rw_vnode(v2) + 1 as vnode, rw_vnode(v2) + 1 as vnode2, v2 + 1 x, v2 + 1 y from t;
create table t(v1 varchar, v2 double, v3 double);
select pg_sleep(v2) + 1 as a, pg_sleep(v2) + 1 as b, v2 + 1 x, v2 + 1 y from t;
batch_plan: |-
BatchExchange { order: [], dist: Single }
└─BatchProject { exprs: [(Vnode(t.v2) + 1:Int32) as $expr2, (Vnode(t.v2) + 1:Int32) as $expr3, $expr1, $expr1] }
└─BatchProject { exprs: [t.v2, (t.v2 + 1:Int32) as $expr1] }
└─BatchProject { exprs: [(PgSleep(t.v2) + 1:Int32) as $expr2, (PgSleep(t.v2) + 1:Int32) as $expr3, $expr1, $expr1] }
└─BatchProject { exprs: [t.v2, (t.v2 + 1:Float64) as $expr1] }
└─BatchScan { table: t, columns: [t.v2], distribution: SomeShard }
stream_plan: |-
StreamMaterialize { columns: [vnode, vnode2, x, y, t._row_id(hidden)], stream_key: [t._row_id], pk_columns: [t._row_id], pk_conflict: NoCheck }
└─StreamProject { exprs: [(Vnode(t.v2) + 1:Int32) as $expr2, (Vnode(t.v2) + 1:Int32) as $expr3, $expr1, $expr1, t._row_id] }
└─StreamProject { exprs: [t.v2, (t.v2 + 1:Int32) as $expr1, t._row_id] }
StreamMaterialize { columns: [a, b, x, y, t._row_id(hidden)], stream_key: [t._row_id], pk_columns: [t._row_id], pk_conflict: NoCheck }
└─StreamProject { exprs: [(PgSleep(t.v2) + 1:Int32) as $expr2, (PgSleep(t.v2) + 1:Int32) as $expr3, $expr1, $expr1, t._row_id] }
└─StreamProject { exprs: [t.v2, (t.v2 + 1:Float64) as $expr1, t._row_id] }
└─StreamTableScan { table: t, columns: [t.v2, t._row_id], stream_scan_type: ArrangementBackfill, stream_key: [t._row_id], pk: [_row_id], dist: UpstreamHashShard(t._row_id) }
- name: Common sub expression shouldn't extract const
sql: |
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/binder/expr/function/builtin_scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ impl Binder {
("pg_is_in_recovery", raw_call(ExprType::PgIsInRecovery)),
("rw_recovery_status", raw_call(ExprType::RwRecoveryStatus)),
// internal
("rw_vnode", raw_call(ExprType::Vnode)),
("rw_vnode", raw_call(ExprType::VnodeUser)),
("rw_test_paid_tier", raw_call(ExprType::TestPaidTier)), // for testing purposes
// TODO: choose which pg version we should return.
("version", raw_literal(ExprImpl::literal_varchar(current_cluster_version()))),
Expand Down
5 changes: 3 additions & 2 deletions src/frontend/src/expr/pure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,8 @@ impl ExprVisitor for ImpureAnalyzer {
| Type::MapContains
| Type::MapDelete
| Type::MapInsert
| Type::MapLength =>
| Type::MapLength
| Type::VnodeUser =>
// expression output is deterministic(same result for the same input)
{
func_call
Expand All @@ -270,7 +271,7 @@ impl ExprVisitor for ImpureAnalyzer {
.for_each(|expr| self.visit_expr(expr));
}
// expression output is not deterministic
Type::Vnode
Type::Vnode // obtain vnode count from the context
| Type::TestPaidTier
| Type::Proctime
| Type::PgSleep
Expand Down
8 changes: 6 additions & 2 deletions src/frontend/src/expr/type_inference/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -660,8 +660,12 @@ fn infer_type_for_special(
.into()),
}
}
ExprType::Vnode => {
ensure_arity!("vnode", 1 <= | inputs |);
// internal use only
ExprType::Vnode => Ok(Some(VirtualNode::RW_TYPE)),
// user-facing `rw_vnode`
ExprType::VnodeUser => {
ensure_arity!("rw_vnode", 2 <= | inputs |);
inputs[0].cast_explicit_mut(DataType::Int32)?; // vnode count
Ok(Some(VirtualNode::RW_TYPE))
}
ExprType::Greatest | ExprType::Least => {
Expand Down
1 change: 1 addition & 0 deletions src/frontend/src/optimizer/plan_expr_visitor/strong.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ impl Strong {
| ExprType::MapInsert
| ExprType::MapLength
| ExprType::Vnode
| ExprType::VnodeUser
| ExprType::TestPaidTier
| ExprType::Proctime
| ExprType::PgSleep
Expand Down

0 comments on commit c68aa5e

Please sign in to comment.