Skip to content

Commit

Permalink
feat(expr): support capturing context in expression (#12747)
Browse files Browse the repository at this point in the history
Signed-off-by: TennyZhuang <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Noel Kwan <[email protected]>
  • Loading branch information
3 people authored Oct 17, 2023
1 parent 61a5bd5 commit 7c37573
Show file tree
Hide file tree
Showing 19 changed files with 461 additions and 97 deletions.
1 change: 1 addition & 0 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ message ExprNode {

// Adminitration functions
COL_DESCRIPTION = 2100;
CAST_REGCLASS = 2101;
}
Type function_type = 1;
data.DataType return_type = 3;
Expand Down
18 changes: 16 additions & 2 deletions src/expr/core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,20 @@ use thiserror::Error;
/// A specialized Result type for expression operations.
pub type Result<T> = std::result::Result<T, ExprError>;

pub struct ContextUnavailable(&'static str);

impl ContextUnavailable {
pub fn new(field: &'static str) -> Self {
Self(field)
}
}

impl From<ContextUnavailable> for ExprError {
fn from(e: ContextUnavailable) -> Self {
ExprError::Context(e.0)
}
}

/// The error type for expression operations.
#[derive(Error, Debug)]
pub enum ExprError {
Expand Down Expand Up @@ -71,8 +85,8 @@ pub enum ExprError {
#[error("not a constant")]
NotConstant,

#[error("Context not found")]
Context,
#[error("Context {0} not found")]
Context(&'static str),

#[error("field name must not be null")]
FieldNameNull,
Expand Down
2 changes: 1 addition & 1 deletion src/expr/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ pub mod sig;
pub mod table_function;
pub mod window_function;

pub use error::{ExprError, Result};
pub use error::{ContextUnavailable, ExprError, Result};
pub use risingwave_common::{bail, ensure};
pub use risingwave_expr_macro::*;
2 changes: 1 addition & 1 deletion src/expr/impl/src/scalar/proctime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use risingwave_expr::{function, ExprError, Result};
/// Get the processing time in Timestamptz scalar from the task-local epoch.
#[function("proctime() -> timestamptz", volatile)]
fn proctime() -> Result<Timestamptz> {
let epoch = epoch::task_local::curr_epoch().ok_or(ExprError::Context)?;
let epoch = epoch::task_local::curr_epoch().ok_or(ExprError::Context("EPOCH"))?;
Ok(epoch.as_timestamptz())
}

Expand Down
206 changes: 206 additions & 0 deletions src/expr/macro/src/context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
// Copyright 2023 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use itertools::Itertools;
use proc_macro2::TokenStream;
use quote::{quote, quote_spanned, ToTokens};
use syn::parse::{Parse, ParseStream};
use syn::{Error, FnArg, Ident, ItemFn, Result, Token, Type, Visibility};

/// See [`super::define_context!`].
#[derive(Debug, Clone)]
pub(super) struct DefineContextField {
vis: Visibility,
name: Ident,
ty: Type,
}

/// See [`super::define_context!`].
#[derive(Debug, Clone)]
pub(super) struct DefineContextAttr {
fields: Vec<DefineContextField>,
}

impl Parse for DefineContextField {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let vis: Visibility = input.parse()?;
let name: Ident = input.parse()?;
input.parse::<Token![:]>()?;
let ty: Type = input.parse()?;

Ok(Self { vis, name, ty })
}
}

impl Parse for DefineContextAttr {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let fields = input.parse_terminated(DefineContextField::parse, Token![,])?;
Ok(Self {
fields: fields.into_iter().collect(),
})
}
}

impl DefineContextField {
pub(super) fn gen(self) -> Result<TokenStream> {
let Self { vis, name, ty } = self;

{
let name_s = name.to_string();
if name_s.to_uppercase() != name_s {
return Err(Error::new_spanned(
name,
"the name of context variable should be uppercase",
));
}
}

Ok(quote! {
#[allow(non_snake_case)]
pub mod #name {
use super::*;
pub type Type = #ty;

tokio::task_local! {
static LOCAL_KEY: #ty;
}

#vis fn try_with<F, R>(f: F) -> Result<R, risingwave_expr::ExprError>
where
F: FnOnce(&#ty) -> R
{
LOCAL_KEY.try_with(f).map_err(|_| risingwave_expr::ContextUnavailable::new(stringify!(#name))).map_err(Into::into)
}

pub fn scope<F>(value: #ty, f: F) -> tokio::task::futures::TaskLocalFuture<#ty, F>
where
F: std::future::Future
{
LOCAL_KEY.scope(value, f)
}

pub fn sync_scope<F, R>(value: #ty, f: F) -> R
where
F: FnOnce() -> R
{
LOCAL_KEY.sync_scope(value, f)
}
}
})
}
}

impl DefineContextAttr {
pub(super) fn gen(self) -> Result<TokenStream> {
let generated_fields: Vec<TokenStream> = self
.fields
.into_iter()
.map(DefineContextField::gen)
.try_collect()?;
Ok(quote! {
#(#generated_fields)*
})
}
}

pub struct CaptureContextAttr {
/// The context variables which are captured.
captures: Vec<Ident>,
}

impl Parse for CaptureContextAttr {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let captures = input.parse_terminated(Ident::parse, Token![,])?;
Ok(Self {
captures: captures.into_iter().collect(),
})
}
}

pub(super) fn generate_captured_function(
attr: CaptureContextAttr,
mut user_fn: ItemFn,
) -> Result<TokenStream> {
let CaptureContextAttr { captures } = attr;
let orig_user_fn = user_fn.clone();

let sig = &mut user_fn.sig;

// Modify the name.
{
let new_name = format!("{}_captured", sig.ident);
let new_name = Ident::new(&new_name, sig.ident.span());
sig.ident = new_name;
}

// Modify the inputs of sig.
let inputs = &mut sig.inputs;
if inputs.len() < captures.len() {
return Err(syn::Error::new_spanned(
inputs,
format!("expected at least {} inputs", captures.len()),
));
}

let (captured_inputs, remained_inputs) = {
let mut inputs = inputs.iter().cloned();
let inputs = inputs.by_ref();
let captured_inputs = inputs.take(captures.len()).collect_vec();
let remained_inputs = inputs.collect_vec();
(captured_inputs, remained_inputs)
};
*inputs = remained_inputs.into_iter().collect();

// Modify the body
let body = &mut user_fn.block;
let new_body = {
let mut scoped = quote! {
// TODO: We can call the old function directly here.
#body
};

#[allow(clippy::disallowed_methods)]
for (context, arg) in captures.into_iter().zip(captured_inputs.into_iter()) {
let FnArg::Typed(arg) = arg else {
return Err(syn::Error::new_spanned(
arg,
"receiver is not allowed in captured function",
));
};
let name = arg.pat.into_token_stream();
scoped = quote_spanned! { context.span()=>
// TODO: Can we add an assertion here that `&<<#context::Type> as Deref>::Target` is same as `#arg.ty`?
#context::try_with(|#name| {
#scoped
}).flatten()
}
}
scoped
};
let new_user_fn = {
let vis = user_fn.vis;
let sig = user_fn.sig;
quote! {
#vis #sig {
{#new_body}.map_err(Into::into)
}
}
};

Ok(quote! {
#[allow(dead_code)]
#orig_user_fn
#new_user_fn
})
}
37 changes: 36 additions & 1 deletion src/expr/macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
#![feature(lint_reasons)]
#![feature(let_chains)]

use context::DefineContextAttr;
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use syn::{Error, Result};
use syn::{Error, ItemFn, Result};

use crate::context::{generate_captured_function, CaptureContextAttr};

mod context;
mod gen;
mod parse;
mod types;
Expand Down Expand Up @@ -606,3 +610,34 @@ impl UserFunctionAttr {
&& self.return_type_kind == ReturnTypeKind::T
}
}

/// Define the context variables which can be used by risingwave expressions.
#[proc_macro]
pub fn define_context(def: TokenStream) -> TokenStream {
fn inner(def: TokenStream) -> Result<TokenStream2> {
let attr: DefineContextAttr = syn::parse(def)?;
attr.gen()
}

match inner(def) {
Ok(tokens) => tokens.into(),
Err(e) => e.to_compile_error().into(),
}
}

/// Capture the context from the local context to the function impl.
/// TODO: The macro will be merged to [`#[function(.., capture_context(..))]`](macro@function) later.
///
/// Currently, we should use the macro separately with a simple wrapper.
#[proc_macro_attribute]
pub fn capture_context(attr: TokenStream, item: TokenStream) -> TokenStream {
fn inner(attr: TokenStream, item: TokenStream) -> Result<TokenStream2> {
let attr: CaptureContextAttr = syn::parse(attr)?;
let user_fn: ItemFn = syn::parse(item)?;
generate_captured_function(attr, user_fn)
}
match inner(attr, item) {
Ok(tokens) => tokens.into(),
Err(e) => e.to_compile_error().into(),
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
expected_outputs:
- batch_plan
- logical_plan
- sql: |
select ('pg' || '_namespace')::regclass
expected_outputs:
- batch_plan
- logical_plan
- sql: |
select 'boolin'::regproc
expected_outputs:
Expand Down
14 changes: 12 additions & 2 deletions src/frontend/planner_test/tests/testdata/output/pg_catalog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,19 @@
- sql: |
select 'pg_namespace'::regclass
logical_plan: |-
LogicalProject { exprs: [2:Int32] }
LogicalProject { exprs: [CastRegclass('pg_namespace':Varchar) as $expr1] }
└─LogicalValues { rows: [[]], schema: Schema { fields: [] } }
batch_plan: 'BatchValues { rows: [[2:Int32]] }'
batch_plan: |-
BatchProject { exprs: [CastRegclass('pg_namespace':Varchar) as $expr1] }
└─BatchValues { rows: [[]] }
- sql: |
select ('pg' || '_namespace')::regclass
logical_plan: |-
LogicalProject { exprs: [CastRegclass(ConcatOp('pg':Varchar, '_namespace':Varchar)) as $expr1] }
└─LogicalValues { rows: [[]], schema: Schema { fields: [] } }
batch_plan: |-
BatchProject { exprs: [CastRegclass('pg_namespace':Varchar) as $expr1] }
└─BatchValues { rows: [[]] }
- sql: |
select 'boolin'::regproc
logical_plan: |-
Expand Down
38 changes: 12 additions & 26 deletions src/frontend/src/binder/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -523,32 +523,18 @@ impl Binder {
// TODO: Add generic expr support when needed
AstDataType::Regclass => {
let input = self.bind_expr_inner(expr)?;
let class_name = match &input {
ExprImpl::Literal(literal)
if literal.return_type() == DataType::Varchar
&& let Some(scalar) = literal.get_data() =>
{
match scalar {
risingwave_common::types::ScalarImpl::Utf8(s) => s,
_ => {
return Err(ErrorCode::BindError(
"Unsupported input type".to_string(),
)
.into())
}
}
}
ExprImpl::Literal(literal) if literal.return_type().is_int() => {
return Ok(ExprImpl::Literal(literal.clone()))
}
_ => {
return Err(
ErrorCode::BindError("Unsupported input type".to_string()).into()
)
}
};
self.resolve_regclass(class_name)
.map(|id| ExprImpl::literal_int(id as i32))
match input.return_type() {
DataType::Varchar => Ok(ExprImpl::FunctionCall(Box::new(
FunctionCall::new_unchecked(
ExprType::CastRegclass,
vec![input],
DataType::Int32,
),
))),
DataType::Int32 => Ok(input),
dt if dt.is_int() => Ok(input.cast_explicit(DataType::Int32)?),
_ => Err(ErrorCode::BindError("Unsupported input type".to_string()).into()),
}
}
AstDataType::Regproc => {
let lhs = self.bind_expr_inner(expr)?;
Expand Down
Loading

0 comments on commit 7c37573

Please sign in to comment.