From d932d1ea61de2667810a5cce7a16b3bcc8f74757 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Thu, 12 Sep 2024 14:31:29 +0800 Subject: [PATCH] vnode expr context Signed-off-by: Bugen Zhao --- src/expr/core/src/expr_context.rs | 7 ++++++ src/expr/impl/src/scalar/vnode.rs | 7 +++--- src/stream/src/executor/actor.rs | 42 +++++++++++++++++++++---------- 3 files changed, 39 insertions(+), 17 deletions(-) diff --git a/src/expr/core/src/expr_context.rs b/src/expr/core/src/expr_context.rs index 27a888118e318..547aef01fef18 100644 --- a/src/expr/core/src/expr_context.rs +++ b/src/expr/core/src/expr_context.rs @@ -14,6 +14,7 @@ use std::future::Future; +use risingwave_common::hash::VirtualNode; use risingwave_expr::{define_context, Result as ExprResult}; use risingwave_pb::plan_common::ExprContext; @@ -21,6 +22,7 @@ use risingwave_pb::plan_common::ExprContext; define_context! { pub TIME_ZONE: String, pub FRAGMENT_ID: u32, + pub VNODE_COUNT: usize, } pub fn capture_expr_context() -> ExprResult { @@ -28,6 +30,11 @@ pub fn capture_expr_context() -> ExprResult { Ok(ExprContext { time_zone }) } +/// Get the vnode count from the context, or [`VirtualNode::COUNT`] if not set. +pub fn vnode_count() -> usize { + VNODE_COUNT::try_with(|&x| x).unwrap_or(VirtualNode::COUNT) +} + pub async fn expr_context_scope(expr_context: ExprContext, future: Fut) -> Fut::Output where Fut: Future, diff --git a/src/expr/impl/src/scalar/vnode.rs b/src/expr/impl/src/scalar/vnode.rs index edd4caa39970e..960d71ca809d8 100644 --- a/src/expr/impl/src/scalar/vnode.rs +++ b/src/expr/impl/src/scalar/vnode.rs @@ -19,6 +19,7 @@ use risingwave_common::hash::VirtualNode; use risingwave_common::row::OwnedRow; use risingwave_common::types::{DataType, Datum}; use risingwave_expr::expr::{BoxedExpression, Expression}; +use risingwave_expr::expr_context::vnode_count; use risingwave_expr::{build_function, Result}; #[derive(Debug)] @@ -43,8 +44,7 @@ impl Expression for VnodeExpression { } async fn eval(&self, input: &DataChunk) -> Result { - // TODO(var-vnode): get vnode count from context - let vnodes = VirtualNode::compute_chunk(input, &self.dist_key_indices, VirtualNode::COUNT); + let vnodes = VirtualNode::compute_chunk(input, &self.dist_key_indices, vnode_count()); let mut builder = I16ArrayBuilder::new(input.capacity()); vnodes .into_iter() @@ -53,9 +53,8 @@ impl Expression for VnodeExpression { } async fn eval_row(&self, input: &OwnedRow) -> Result { - // TODO(var-vnode): get vnode count from context Ok(Some( - VirtualNode::compute_row(input, &self.dist_key_indices, VirtualNode::COUNT) + VirtualNode::compute_row(input, &self.dist_key_indices, vnode_count()) .to_scalar() .into(), )) diff --git a/src/stream/src/executor/actor.rs b/src/stream/src/executor/actor.rs index 4e56e3b0c2262..1fb13d8518386 100644 --- a/src/stream/src/executor/actor.rs +++ b/src/stream/src/executor/actor.rs @@ -19,13 +19,16 @@ use std::sync::{Arc, LazyLock}; use anyhow::anyhow; use await_tree::InstrumentAwait; use futures::future::join_all; +use futures::FutureExt; use hytra::TrAdder; +use risingwave_common::bitmap::Bitmap; use risingwave_common::catalog::TableId; use risingwave_common::config::StreamingConfig; +use risingwave_common::hash::VirtualNode; use risingwave_common::log::LogSuppresser; use risingwave_common::metrics::{IntGaugeExt, GLOBAL_ERROR_METRICS}; use risingwave_common::util::epoch::EpochPair; -use risingwave_expr::expr_context::{expr_context_scope, FRAGMENT_ID}; +use risingwave_expr::expr_context::{expr_context_scope, FRAGMENT_ID, VNODE_COUNT}; use risingwave_expr::ExprError; use risingwave_pb::plan_common::ExprContext; use risingwave_pb::stream_plan::PbStreamActor; @@ -44,6 +47,7 @@ use crate::task::{ActorId, LocalBarrierManager}; pub struct ActorContext { pub id: ActorId, pub fragment_id: u32, + pub vnode_count: usize, pub mview_definition: String, // TODO(eric): these seem to be useless now? @@ -71,6 +75,7 @@ impl ActorContext { Arc::new(Self { id, fragment_id: 0, + vnode_count: VirtualNode::COUNT_FOR_TEST, mview_definition: "".to_string(), cur_mem_val: Arc::new(0.into()), last_mem_val: Arc::new(0.into()), @@ -97,6 +102,9 @@ impl ActorContext { id: stream_actor.actor_id, fragment_id: stream_actor.fragment_id, mview_definition: stream_actor.mview_definition.clone(), + vnode_count: (stream_actor.vnode_bitmap.as_ref()) + // TODO(var-vnode): use 1 for singleton fragment + .map_or(VirtualNode::COUNT, |b| Bitmap::from(b).len()), cur_mem_val: Arc::new(0.into()), last_mem_val: Arc::new(0.into()), total_mem_val, @@ -177,18 +185,26 @@ where #[inline(always)] pub async fn run(mut self) -> StreamResult<()> { - FRAGMENT_ID::scope( - self.actor_context.fragment_id, - expr_context_scope(self.expr_context.clone(), async move { - tokio::join!( - // Drive the subtasks concurrently. - join_all(std::mem::take(&mut self.subtasks)), - self.run_consumer(), - ) - .1 - }), - ) - .await + let expr_context = self.expr_context.clone(); + let fragment_id = self.actor_context.fragment_id; + let vnode_count = self.actor_context.vnode_count; + + let run = async move { + tokio::join!( + // Drive the subtasks concurrently. + join_all(std::mem::take(&mut self.subtasks)), + self.run_consumer(), + ) + .1 + } + .boxed(); + + // Attach contexts to the future. + let run = expr_context_scope(expr_context, run); + let run = FRAGMENT_ID::scope(fragment_id, run); + let run = VNODE_COUNT::scope(vnode_count, run); + + run.await } async fn run_consumer(self) -> StreamResult<()> {