From c25609832c7d6ce6f3f59c94fca2cc406360c5b2 Mon Sep 17 00:00:00 2001 From: TennyZhuang Date: Tue, 17 Oct 2023 22:19:36 +0800 Subject: [PATCH] refactor(expr): allow defining context visibility using restricted relative path (#12919) Signed-off-by: TennyZhuang --- src/expr/macro/Cargo.toml | 2 +- src/expr/macro/src/context.rs | 5 ++ src/expr/macro/src/utils.rs | 70 +++++++++++++++++++ .../src/expr/function_impl/context.rs | 8 +-- 4 files changed, 80 insertions(+), 5 deletions(-) diff --git a/src/expr/macro/Cargo.toml b/src/expr/macro/Cargo.toml index c73d9c723dd69..bf761b142061f 100644 --- a/src/expr/macro/Cargo.toml +++ b/src/expr/macro/Cargo.toml @@ -11,7 +11,7 @@ proc-macro = true itertools = "0.11" proc-macro2 = "1" quote = "1" -syn = "2" +syn = { version = "2", features = ["full", "extra-traits"] } [lints] workspace = true diff --git a/src/expr/macro/src/context.rs b/src/expr/macro/src/context.rs index 152b59761492c..e55c5adee6de2 100644 --- a/src/expr/macro/src/context.rs +++ b/src/expr/macro/src/context.rs @@ -18,6 +18,8 @@ use quote::{quote, quote_spanned, ToTokens}; use syn::parse::{Parse, ParseStream}; use syn::{Error, FnArg, Ident, ItemFn, Result, Token, Type, Visibility}; +use crate::utils::extend_vis_with_super; + /// See [`super::define_context!`]. #[derive(Debug, Clone)] pub(super) struct DefineContextField { @@ -56,6 +58,9 @@ impl DefineContextField { pub(super) fn gen(self) -> Result { let Self { vis, name, ty } = self; + // We create a sub mod, so we need to extend the vis of getter. + let vis: Visibility = extend_vis_with_super(vis); + { let name_s = name.to_string(); if name_s.to_uppercase() != name_s { diff --git a/src/expr/macro/src/utils.rs b/src/expr/macro/src/utils.rs index 788d09857cc93..74fddf4680db9 100644 --- a/src/expr/macro/src/utils.rs +++ b/src/expr/macro/src/utils.rs @@ -12,6 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +use proc_macro2::Ident; +use syn::spanned::Spanned; +use syn::{Token, VisRestricted, Visibility}; + /// Convert a string from `snake_case` to `CamelCase`. pub fn to_camel_case(input: &str) -> String { input @@ -27,3 +31,69 @@ pub fn to_camel_case(input: &str) -> String { }) .collect() } + +pub(crate) fn extend_vis_with_super(vis: Visibility) -> Visibility { + let Visibility::Restricted(vis) = vis else { + return vis; + }; + let VisRestricted { + pub_token, + paren_token, + mut in_token, + mut path, + } = vis; + let first_segment = path.segments.first_mut().unwrap(); + if first_segment.ident == "self" { + *first_segment = Ident::new("super", first_segment.span()).into(); + } else if first_segment.ident == "super" { + let span = first_segment.span(); + path.segments.insert(0, Ident::new("super", span).into()); + in_token.get_or_insert(Token![in](in_token.span())); + } + Visibility::Restricted(VisRestricted { + pub_token, + paren_token, + in_token, + path, + }) +} + +#[cfg(test)] +mod tests { + use quote::ToTokens; + use syn::Visibility; + + use crate::utils::extend_vis_with_super; + + #[test] + fn test_extend_vis_with_super() { + let cases = [ + ("pub", "pub"), + ("pub(crate)", "pub(crate)"), + ("pub(self)", "pub(super)"), + ("pub(super)", "pub(in super::super)"), + ("pub(in self)", "pub(in super)"), + ( + "pub(in self::context::data)", + "pub(in super::context::data)", + ), + ( + "pub(in super::context::data)", + "pub(in super::super::context::data)", + ), + ("pub(in crate::func::impl_)", "pub(in crate::func::impl_)"), + ( + "pub(in ::risingwave_expr::func::impl_)", + "pub(in ::risingwave_expr::func::impl_)", + ), + ]; + for (input, expected) in cases { + let input: Visibility = syn::parse_str(input).unwrap(); + let expected: Visibility = syn::parse_str(expected).unwrap(); + let output = extend_vis_with_super(input); + let expected = expected.into_token_stream().to_string(); + let output = output.into_token_stream().to_string(); + assert_eq!(expected, output); + } + } +} diff --git a/src/frontend/src/expr/function_impl/context.rs b/src/frontend/src/expr/function_impl/context.rs index e3fb5f05191ef..13a7175fabb54 100644 --- a/src/frontend/src/expr/function_impl/context.rs +++ b/src/frontend/src/expr/function_impl/context.rs @@ -20,8 +20,8 @@ use risingwave_expr::define_context; use crate::session::AuthContext; define_context! { - pub(in crate::expr::function_impl) CATALOG_READER: crate::catalog::CatalogReader, - pub(in crate::expr::function_impl) AUTH_CONTEXT: Arc, - pub(in crate::expr::function_impl) DB_NAME: String, - pub(in crate::expr::function_impl) SEARCH_PATH: SearchPath, + pub(super) CATALOG_READER: crate::catalog::CatalogReader, + pub(super) AUTH_CONTEXT: Arc, + pub(super) DB_NAME: String, + pub(super) SEARCH_PATH: SearchPath, }