diff --git a/proto/expr.proto b/proto/expr.proto index 1271249b9fdcf..2086554e78975 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -229,6 +229,7 @@ message ExprNode { // Adminitration functions COL_DESCRIPTION = 2100; + CAST_REGCLASS = 2101; } Type function_type = 1; data.DataType return_type = 3; diff --git a/src/expr/core/src/error.rs b/src/expr/core/src/error.rs index a6bb01e45bf54..7204b851e5b2b 100644 --- a/src/expr/core/src/error.rs +++ b/src/expr/core/src/error.rs @@ -21,6 +21,20 @@ use thiserror::Error; /// A specialized Result type for expression operations. pub type Result = std::result::Result; +pub struct ContextUnavailable(&'static str); + +impl ContextUnavailable { + pub fn new(field: &'static str) -> Self { + Self(field) + } +} + +impl From for ExprError { + fn from(e: ContextUnavailable) -> Self { + ExprError::Context(e.0) + } +} + /// The error type for expression operations. #[derive(Error, Debug)] pub enum ExprError { @@ -71,8 +85,8 @@ pub enum ExprError { #[error("not a constant")] NotConstant, - #[error("Context not found")] - Context, + #[error("Context {0} not found")] + Context(&'static str), #[error("field name must not be null")] FieldNameNull, diff --git a/src/expr/core/src/lib.rs b/src/expr/core/src/lib.rs index 32b538ed084eb..c2f46d5632274 100644 --- a/src/expr/core/src/lib.rs +++ b/src/expr/core/src/lib.rs @@ -33,6 +33,6 @@ pub mod sig; pub mod table_function; pub mod window_function; -pub use error::{ExprError, Result}; +pub use error::{ContextUnavailable, ExprError, Result}; pub use risingwave_common::{bail, ensure}; pub use risingwave_expr_macro::*; diff --git a/src/expr/impl/src/scalar/proctime.rs b/src/expr/impl/src/scalar/proctime.rs index 4271fee8ca58f..659a64f4e0c7b 100644 --- a/src/expr/impl/src/scalar/proctime.rs +++ b/src/expr/impl/src/scalar/proctime.rs @@ -19,7 +19,7 @@ use risingwave_expr::{function, ExprError, Result}; /// Get the processing time in Timestamptz scalar from the task-local epoch. #[function("proctime() -> timestamptz", volatile)] fn proctime() -> Result { - let epoch = epoch::task_local::curr_epoch().ok_or(ExprError::Context)?; + let epoch = epoch::task_local::curr_epoch().ok_or(ExprError::Context("EPOCH"))?; Ok(epoch.as_timestamptz()) } diff --git a/src/expr/macro/src/context.rs b/src/expr/macro/src/context.rs new file mode 100644 index 0000000000000..152b59761492c --- /dev/null +++ b/src/expr/macro/src/context.rs @@ -0,0 +1,206 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use itertools::Itertools; +use proc_macro2::TokenStream; +use quote::{quote, quote_spanned, ToTokens}; +use syn::parse::{Parse, ParseStream}; +use syn::{Error, FnArg, Ident, ItemFn, Result, Token, Type, Visibility}; + +/// See [`super::define_context!`]. +#[derive(Debug, Clone)] +pub(super) struct DefineContextField { + vis: Visibility, + name: Ident, + ty: Type, +} + +/// See [`super::define_context!`]. +#[derive(Debug, Clone)] +pub(super) struct DefineContextAttr { + fields: Vec, +} + +impl Parse for DefineContextField { + fn parse(input: ParseStream<'_>) -> Result { + let vis: Visibility = input.parse()?; + let name: Ident = input.parse()?; + input.parse::()?; + let ty: Type = input.parse()?; + + Ok(Self { vis, name, ty }) + } +} + +impl Parse for DefineContextAttr { + fn parse(input: ParseStream<'_>) -> Result { + let fields = input.parse_terminated(DefineContextField::parse, Token![,])?; + Ok(Self { + fields: fields.into_iter().collect(), + }) + } +} + +impl DefineContextField { + pub(super) fn gen(self) -> Result { + let Self { vis, name, ty } = self; + + { + let name_s = name.to_string(); + if name_s.to_uppercase() != name_s { + return Err(Error::new_spanned( + name, + "the name of context variable should be uppercase", + )); + } + } + + Ok(quote! { + #[allow(non_snake_case)] + pub mod #name { + use super::*; + pub type Type = #ty; + + tokio::task_local! { + static LOCAL_KEY: #ty; + } + + #vis fn try_with(f: F) -> Result + where + F: FnOnce(&#ty) -> R + { + LOCAL_KEY.try_with(f).map_err(|_| risingwave_expr::ContextUnavailable::new(stringify!(#name))).map_err(Into::into) + } + + pub fn scope(value: #ty, f: F) -> tokio::task::futures::TaskLocalFuture<#ty, F> + where + F: std::future::Future + { + LOCAL_KEY.scope(value, f) + } + + pub fn sync_scope(value: #ty, f: F) -> R + where + F: FnOnce() -> R + { + LOCAL_KEY.sync_scope(value, f) + } + } + }) + } +} + +impl DefineContextAttr { + pub(super) fn gen(self) -> Result { + let generated_fields: Vec = self + .fields + .into_iter() + .map(DefineContextField::gen) + .try_collect()?; + Ok(quote! { + #(#generated_fields)* + }) + } +} + +pub struct CaptureContextAttr { + /// The context variables which are captured. + captures: Vec, +} + +impl Parse for CaptureContextAttr { + fn parse(input: ParseStream<'_>) -> Result { + let captures = input.parse_terminated(Ident::parse, Token![,])?; + Ok(Self { + captures: captures.into_iter().collect(), + }) + } +} + +pub(super) fn generate_captured_function( + attr: CaptureContextAttr, + mut user_fn: ItemFn, +) -> Result { + let CaptureContextAttr { captures } = attr; + let orig_user_fn = user_fn.clone(); + + let sig = &mut user_fn.sig; + + // Modify the name. + { + let new_name = format!("{}_captured", sig.ident); + let new_name = Ident::new(&new_name, sig.ident.span()); + sig.ident = new_name; + } + + // Modify the inputs of sig. + let inputs = &mut sig.inputs; + if inputs.len() < captures.len() { + return Err(syn::Error::new_spanned( + inputs, + format!("expected at least {} inputs", captures.len()), + )); + } + + let (captured_inputs, remained_inputs) = { + let mut inputs = inputs.iter().cloned(); + let inputs = inputs.by_ref(); + let captured_inputs = inputs.take(captures.len()).collect_vec(); + let remained_inputs = inputs.collect_vec(); + (captured_inputs, remained_inputs) + }; + *inputs = remained_inputs.into_iter().collect(); + + // Modify the body + let body = &mut user_fn.block; + let new_body = { + let mut scoped = quote! { + // TODO: We can call the old function directly here. + #body + }; + + #[allow(clippy::disallowed_methods)] + for (context, arg) in captures.into_iter().zip(captured_inputs.into_iter()) { + let FnArg::Typed(arg) = arg else { + return Err(syn::Error::new_spanned( + arg, + "receiver is not allowed in captured function", + )); + }; + let name = arg.pat.into_token_stream(); + scoped = quote_spanned! { context.span()=> + // TODO: Can we add an assertion here that `&<<#context::Type> as Deref>::Target` is same as `#arg.ty`? + #context::try_with(|#name| { + #scoped + }).flatten() + } + } + scoped + }; + let new_user_fn = { + let vis = user_fn.vis; + let sig = user_fn.sig; + quote! { + #vis #sig { + {#new_body}.map_err(Into::into) + } + } + }; + + Ok(quote! { + #[allow(dead_code)] + #orig_user_fn + #new_user_fn + }) +} diff --git a/src/expr/macro/src/lib.rs b/src/expr/macro/src/lib.rs index 24760d06f4341..363fc958b557d 100644 --- a/src/expr/macro/src/lib.rs +++ b/src/expr/macro/src/lib.rs @@ -15,10 +15,14 @@ #![feature(lint_reasons)] #![feature(let_chains)] +use context::DefineContextAttr; use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; -use syn::{Error, Result}; +use syn::{Error, ItemFn, Result}; +use crate::context::{generate_captured_function, CaptureContextAttr}; + +mod context; mod gen; mod parse; mod types; @@ -606,3 +610,34 @@ impl UserFunctionAttr { && self.return_type_kind == ReturnTypeKind::T } } + +/// Define the context variables which can be used by risingwave expressions. +#[proc_macro] +pub fn define_context(def: TokenStream) -> TokenStream { + fn inner(def: TokenStream) -> Result { + let attr: DefineContextAttr = syn::parse(def)?; + attr.gen() + } + + match inner(def) { + Ok(tokens) => tokens.into(), + Err(e) => e.to_compile_error().into(), + } +} + +/// Capture the context from the local context to the function impl. +/// TODO: The macro will be merged to [`#[function(.., capture_context(..))]`](macro@function) later. +/// +/// Currently, we should use the macro separately with a simple wrapper. +#[proc_macro_attribute] +pub fn capture_context(attr: TokenStream, item: TokenStream) -> TokenStream { + fn inner(attr: TokenStream, item: TokenStream) -> Result { + let attr: CaptureContextAttr = syn::parse(attr)?; + let user_fn: ItemFn = syn::parse(item)?; + generate_captured_function(attr, user_fn) + } + match inner(attr, item) { + Ok(tokens) => tokens.into(), + Err(e) => e.to_compile_error().into(), + } +} diff --git a/src/frontend/planner_test/tests/testdata/input/pg_catalog.yaml b/src/frontend/planner_test/tests/testdata/input/pg_catalog.yaml index 7560d90aae8e9..7f13113afb086 100644 --- a/src/frontend/planner_test/tests/testdata/input/pg_catalog.yaml +++ b/src/frontend/planner_test/tests/testdata/input/pg_catalog.yaml @@ -23,6 +23,11 @@ expected_outputs: - batch_plan - logical_plan +- sql: | + select ('pg' || '_namespace')::regclass + expected_outputs: + - batch_plan + - logical_plan - sql: | select 'boolin'::regproc expected_outputs: diff --git a/src/frontend/planner_test/tests/testdata/output/pg_catalog.yaml b/src/frontend/planner_test/tests/testdata/output/pg_catalog.yaml index 0ce056fa1c48a..5c5a88fb472b6 100644 --- a/src/frontend/planner_test/tests/testdata/output/pg_catalog.yaml +++ b/src/frontend/planner_test/tests/testdata/output/pg_catalog.yaml @@ -221,9 +221,19 @@ - sql: | select 'pg_namespace'::regclass logical_plan: |- - LogicalProject { exprs: [2:Int32] } + LogicalProject { exprs: [CastRegclass('pg_namespace':Varchar) as $expr1] } └─LogicalValues { rows: [[]], schema: Schema { fields: [] } } - batch_plan: 'BatchValues { rows: [[2:Int32]] }' + batch_plan: |- + BatchProject { exprs: [CastRegclass('pg_namespace':Varchar) as $expr1] } + └─BatchValues { rows: [[]] } +- sql: | + select ('pg' || '_namespace')::regclass + logical_plan: |- + LogicalProject { exprs: [CastRegclass(ConcatOp('pg':Varchar, '_namespace':Varchar)) as $expr1] } + └─LogicalValues { rows: [[]], schema: Schema { fields: [] } } + batch_plan: |- + BatchProject { exprs: [CastRegclass('pg_namespace':Varchar) as $expr1] } + └─BatchValues { rows: [[]] } - sql: | select 'boolin'::regproc logical_plan: |- diff --git a/src/frontend/src/binder/expr/mod.rs b/src/frontend/src/binder/expr/mod.rs index e9f10f572763a..6da590c2d315d 100644 --- a/src/frontend/src/binder/expr/mod.rs +++ b/src/frontend/src/binder/expr/mod.rs @@ -523,32 +523,18 @@ impl Binder { // TODO: Add generic expr support when needed AstDataType::Regclass => { let input = self.bind_expr_inner(expr)?; - let class_name = match &input { - ExprImpl::Literal(literal) - if literal.return_type() == DataType::Varchar - && let Some(scalar) = literal.get_data() => - { - match scalar { - risingwave_common::types::ScalarImpl::Utf8(s) => s, - _ => { - return Err(ErrorCode::BindError( - "Unsupported input type".to_string(), - ) - .into()) - } - } - } - ExprImpl::Literal(literal) if literal.return_type().is_int() => { - return Ok(ExprImpl::Literal(literal.clone())) - } - _ => { - return Err( - ErrorCode::BindError("Unsupported input type".to_string()).into() - ) - } - }; - self.resolve_regclass(class_name) - .map(|id| ExprImpl::literal_int(id as i32)) + match input.return_type() { + DataType::Varchar => Ok(ExprImpl::FunctionCall(Box::new( + FunctionCall::new_unchecked( + ExprType::CastRegclass, + vec![input], + DataType::Int32, + ), + ))), + DataType::Int32 => Ok(input), + dt if dt.is_int() => Ok(input.cast_explicit(DataType::Int32)?), + _ => Err(ErrorCode::BindError("Unsupported input type".to_string()).into()), + } } AstDataType::Regproc => { let lhs = self.bind_expr_inner(expr)?; diff --git a/src/frontend/src/binder/relation/table_or_source.rs b/src/frontend/src/binder/relation/table_or_source.rs index 480fcd20faf36..b05b5db42b300 100644 --- a/src/frontend/src/binder/relation/table_or_source.rs +++ b/src/frontend/src/binder/relation/table_or_source.rs @@ -21,7 +21,6 @@ use risingwave_common::error::{ErrorCode, Result, RwError}; use risingwave_common::session_config::USER_NAME_WILD_CARD; use risingwave_sqlparser::ast::{Statement, TableAlias}; use risingwave_sqlparser::parser::Parser; -use risingwave_sqlparser::tokenizer::{Token, Tokenizer}; use super::BoundShare; use crate::binder::relation::BoundSubquery; @@ -377,44 +376,4 @@ impl Binder { Ok(table) } - - pub(crate) fn resolve_regclass(&self, class_name: &str) -> Result { - let obj = Self::parse_object_name(class_name)?; - - if obj.0.len() == 1 { - let class_name = obj.0[0].real_value(); - let schema_path = SchemaPath::Path(&self.search_path, &self.auth_context.user_name); - Ok(self - .catalog - .get_id_by_class_name(&self.db_name, schema_path, &class_name)?) - } else { - let schema = obj.0[0].real_value(); - let class_name = obj.0[1].real_value(); - let schema_path = SchemaPath::Name(&schema); - Ok(self - .catalog - .get_id_by_class_name(&self.db_name, schema_path, &class_name)?) - } - } - - /// Attempt to parse the value of a varchar Literal into an - /// [`ObjectName`](risingwave_sqlparser::ast::ObjectName). - fn parse_object_name(name: &str) -> Result { - // We use the full parser here because this function needs to accept every legal way - // of identifying an object in PG SQL as a valid value for the varchar - // literal. For example: 'foo', 'public.foo', '"my table"', and - // '"my schema".foo' must all work as values passed pg_table_size. - let mut tokenizer = Tokenizer::new(name); - let tokens = tokenizer - .tokenize_with_location() - .map_err(|e| ErrorCode::BindError(e.to_string()))?; - let mut parser = Parser::new(tokens); - let object = parser - .parse_object_name() - .map_err(|e| ErrorCode::BindError(e.to_string()))?; - if parser.next_token().token != Token::EOF { - Err(ErrorCode::BindError("Invalid name syntax".to_string()))? - } - Ok(object) - } } diff --git a/src/frontend/src/catalog/mod.rs b/src/frontend/src/catalog/mod.rs index 211f8ff1bd07b..ad4e3ae18c954 100644 --- a/src/frontend/src/catalog/mod.rs +++ b/src/frontend/src/catalog/mod.rs @@ -37,6 +37,7 @@ pub(crate) mod system_catalog; pub(crate) mod table_catalog; pub(crate) mod view_catalog; +pub(crate) use catalog_service::CatalogReader; pub use index_catalog::IndexCatalog; pub use table_catalog::TableCatalog; diff --git a/src/frontend/src/expr/function_impl/cast_regclass.rs b/src/frontend/src/expr/function_impl/cast_regclass.rs new file mode 100644 index 0000000000000..e0f8670d791fb --- /dev/null +++ b/src/frontend/src/expr/function_impl/cast_regclass.rs @@ -0,0 +1,102 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use risingwave_common::session_config::SearchPath; +use risingwave_expr::{capture_context, function, ExprError}; +use risingwave_sqlparser::parser::{Parser, ParserError}; +use risingwave_sqlparser::tokenizer::{Token, Tokenizer}; +use thiserror::Error; + +use super::context::{AUTH_CONTEXT, CATALOG_READER, DB_NAME, SEARCH_PATH}; +use crate::catalog::root_catalog::SchemaPath; +use crate::catalog::{CatalogError, CatalogReader}; +use crate::session::AuthContext; + +#[derive(Error, Debug)] +enum ResolveRegclassError { + #[error("parse object name failed: {0}")] + Parser(#[from] ParserError), + #[error("catalog error: {0}")] + Catalog(#[from] CatalogError), +} + +impl From for ExprError { + fn from(e: ResolveRegclassError) -> Self { + match e { + ResolveRegclassError::Parser(e) => ExprError::Parse(e.to_string().into_boxed_str()), + ResolveRegclassError::Catalog(e) => ExprError::InvalidParam { + name: "name", + reason: e.to_string().into_boxed_str(), + }, + } + } +} + +#[capture_context(CATALOG_READER, AUTH_CONTEXT, SEARCH_PATH, DB_NAME)] +fn resolve_regclass_impl( + catalog: &CatalogReader, + auth_context: &AuthContext, + search_path: &SearchPath, + db_name: &str, + class_name: &str, +) -> Result { + resolve_regclass_inner(catalog, auth_context, search_path, db_name, class_name) + .map_err(Into::into) +} + +fn resolve_regclass_inner( + catalog: &CatalogReader, + auth_context: &AuthContext, + search_path: &SearchPath, + db_name: &str, + class_name: &str, +) -> Result { + let obj = parse_object_name(class_name)?; + + if obj.0.len() == 1 { + let class_name = obj.0[0].real_value(); + let schema_path = SchemaPath::Path(search_path, &auth_context.user_name); + Ok(catalog + .read_guard() + .get_id_by_class_name(db_name, schema_path, &class_name)?) + } else { + let schema = obj.0[0].real_value(); + let class_name = obj.0[1].real_value(); + let schema_path = SchemaPath::Name(&schema); + Ok(catalog + .read_guard() + .get_id_by_class_name(db_name, schema_path, &class_name)?) + } +} + +fn parse_object_name(name: &str) -> Result { + // We use the full parser here because this function needs to accept every legal way + // of identifying an object in PG SQL as a valid value for the varchar + // literal. For example: 'foo', 'public.foo', '"my table"', and + // '"my schema".foo' must all work as values passed pg_table_size. + let mut tokenizer = Tokenizer::new(name); + let tokens = tokenizer + .tokenize_with_location() + .map_err(ParserError::from)?; + let mut parser = Parser::new(tokens); + let object = parser.parse_object_name()?; + parser.expect_token(&Token::EOF)?; + Ok(object) +} + +#[function("cast_regclass(varchar) -> int4")] +fn cast_regclass(class_name: &str) -> Result { + let oid = resolve_regclass_impl_captured(class_name)?; + Ok(oid as i32) +} diff --git a/src/frontend/src/expr/function_impl/context.rs b/src/frontend/src/expr/function_impl/context.rs new file mode 100644 index 0000000000000..e3fb5f05191ef --- /dev/null +++ b/src/frontend/src/expr/function_impl/context.rs @@ -0,0 +1,27 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use risingwave_common::session_config::SearchPath; +use risingwave_expr::define_context; + +use crate::session::AuthContext; + +define_context! { + pub(in crate::expr::function_impl) CATALOG_READER: crate::catalog::CatalogReader, + pub(in crate::expr::function_impl) AUTH_CONTEXT: Arc, + pub(in crate::expr::function_impl) DB_NAME: String, + pub(in crate::expr::function_impl) SEARCH_PATH: SearchPath, +} diff --git a/src/frontend/src/expr/function_impl/mod.rs b/src/frontend/src/expr/function_impl/mod.rs index 33d402b4bb6af..1f31b7f307dac 100644 --- a/src/frontend/src/expr/function_impl/mod.rs +++ b/src/frontend/src/expr/function_impl/mod.rs @@ -12,4 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod cast_regclass; mod col_description; +pub mod context; diff --git a/src/frontend/src/expr/mod.rs b/src/frontend/src/expr/mod.rs index f7dc01e2eef35..ba36ad0514d28 100644 --- a/src/frontend/src/expr/mod.rs +++ b/src/frontend/src/expr/mod.rs @@ -44,7 +44,7 @@ pub use order_by_expr::{OrderBy, OrderByExpr}; mod expr_mutator; mod expr_rewriter; mod expr_visitor; -mod function_impl; +pub mod function_impl; mod session_timezone; mod type_inference; mod utils; diff --git a/src/frontend/src/expr/pure.rs b/src/frontend/src/expr/pure.rs index 5fb96675cfe3c..71b3a2e20f475 100644 --- a/src/frontend/src/expr/pure.rs +++ b/src/frontend/src/expr/pure.rs @@ -215,7 +215,8 @@ impl ExprVisitor for ImpureAnalyzer { | expr_node::Type::PgSleep | expr_node::Type::PgSleepFor | expr_node::Type::PgSleepUntil - | expr_node::Type::ColDescription => true, + | expr_node::Type::ColDescription + | expr_node::Type::CastRegclass => true, } } } diff --git a/src/frontend/src/handler/query.rs b/src/frontend/src/handler/query.rs index 79289071dd889..e11562bccb467 100644 --- a/src/frontend/src/handler/query.rs +++ b/src/frontend/src/handler/query.rs @@ -464,18 +464,12 @@ async fn distribute_execute( #[expect(clippy::unused_async)] async fn local_execute(session: Arc, query: Query) -> Result { let front_env = session.env(); + // TODO: if there's no table scan, we don't need to acquire snapshot. let snapshot = session.pinned_snapshot(); // TODO: Passing sql here - let execution = LocalQueryExecution::new( - query, - front_env.clone(), - "", - snapshot, - session.auth_context(), - session.reset_cancel_query_flag(), - ); + let execution = LocalQueryExecution::new(query, front_env.clone(), "", snapshot, session); Ok(execution.stream_rows()) } diff --git a/src/frontend/src/lib.rs b/src/frontend/src/lib.rs index caf4b02c29c9c..0a036b8e96233 100644 --- a/src/frontend/src/lib.rs +++ b/src/frontend/src/lib.rs @@ -33,6 +33,7 @@ #![feature(type_alias_impl_trait)] #![feature(impl_trait_in_assoc_type)] #![feature(async_fn_in_trait)] +#![feature(result_flattening)] #![recursion_limit = "256"] #[cfg(test)] diff --git a/src/frontend/src/scheduler/local.rs b/src/frontend/src/scheduler/local.rs index f3906ffbcc755..28cfa25b70bf1 100644 --- a/src/frontend/src/scheduler/local.rs +++ b/src/frontend/src/scheduler/local.rs @@ -52,7 +52,7 @@ use crate::scheduler::plan_fragmenter::{ExecutionPlanNode, Query, StageId}; use crate::scheduler::task_context::FrontendBatchTaskContext; use crate::scheduler::worker_node_manager::WorkerNodeSelector; use crate::scheduler::{ReadSnapshot, SchedulerError, SchedulerResult}; -use crate::session::{AuthContext, FrontendEnv}; +use crate::session::{AuthContext, FrontendEnv, SessionImpl}; pub type LocalQueryStream = ReceiverStream>; @@ -63,8 +63,7 @@ pub struct LocalQueryExecution { // The snapshot will be released when LocalQueryExecution is dropped. // TODO snapshot: ReadSnapshot, - auth_context: Arc, - shutdown_rx: ShutdownToken, + session: Arc, worker_node_manager: WorkerNodeSelector, } @@ -74,8 +73,7 @@ impl LocalQueryExecution { front_env: FrontendEnv, sql: S, snapshot: ReadSnapshot, - auth_context: Arc, - shutdown_rx: ShutdownToken, + session: Arc, ) -> Self { let sql = sql.into(); let worker_node_manager = WorkerNodeSelector::new( @@ -88,18 +86,24 @@ impl LocalQueryExecution { query, front_env, snapshot, - auth_context, - shutdown_rx, + session, worker_node_manager, } } + fn auth_context(&self) -> Arc { + self.session.auth_context() + } + + fn shutdown_rx(&self) -> ShutdownToken { + self.session.reset_cancel_query_flag() + } + #[try_stream(ok = DataChunk, error = RwError)] pub async fn run_inner(self) { debug!(%self.query.query_id, self.sql, "Starting to run query"); - let context = - FrontendBatchTaskContext::new(self.front_env.clone(), self.auth_context.clone()); + let context = FrontendBatchTaskContext::new(self.front_env.clone(), self.auth_context()); let task_id = TaskId { query_id: self.query.query_id.id.clone(), @@ -115,7 +119,7 @@ impl LocalQueryExecution { &task_id, context, self.snapshot.batch_query_epoch(), - self.shutdown_rx.clone(), + self.shutdown_rx().clone(), ); let executor = executor.build().await?; @@ -137,9 +141,14 @@ impl LocalQueryExecution { pub fn stream_rows(self) -> LocalQueryStream { let compute_runtime = self.front_env.compute_runtime(); let (sender, receiver) = mpsc::channel(10); - let shutdown_rx = self.shutdown_rx.clone(); + let shutdown_rx = self.shutdown_rx().clone(); + + let catalog_reader = self.front_env.catalog_reader().clone(); + let auth_context = self.session.auth_context().clone(); + let db_name = self.session.database().to_string(); + let search_path = self.session.config().get_search_path().clone(); - compute_runtime.spawn(async move { + let exec = async move { let mut data_stream = self.run().map(|r| r.map_err(|e| Box::new(e) as BoxedError)); while let Some(mut r) = data_stream.next().await { // append a query cancelled error if the query is cancelled. @@ -151,7 +160,18 @@ impl LocalQueryExecution { return; } } - }); + }; + + use crate::expr::function_impl::context::{ + AUTH_CONTEXT, CATALOG_READER, DB_NAME, SEARCH_PATH, + }; + + let exec = async move { CATALOG_READER::scope(catalog_reader, exec).await }; + let exec = async move { DB_NAME::scope(db_name, exec).await }; + let exec = async move { SEARCH_PATH::scope(search_path, exec).await }; + let exec = async move { AUTH_CONTEXT::scope(auth_context, exec).await }; + + compute_runtime.spawn(exec); ReceiverStream::new(receiver) }