Skip to content

Commit

Permalink
refactor(expr): allow defining context visibility using restricted re…
Browse files Browse the repository at this point in the history
…lative path (#12919)

Signed-off-by: TennyZhuang <[email protected]>
  • Loading branch information
TennyZhuang authored Oct 17, 2023
1 parent 34ec260 commit c256098
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/expr/macro/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ proc-macro = true
itertools = "0.11"
proc-macro2 = "1"
quote = "1"
syn = "2"
syn = { version = "2", features = ["full", "extra-traits"] }

[lints]
workspace = true
5 changes: 5 additions & 0 deletions src/expr/macro/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ use quote::{quote, quote_spanned, ToTokens};
use syn::parse::{Parse, ParseStream};
use syn::{Error, FnArg, Ident, ItemFn, Result, Token, Type, Visibility};

use crate::utils::extend_vis_with_super;

/// See [`super::define_context!`].
#[derive(Debug, Clone)]
pub(super) struct DefineContextField {
Expand Down Expand Up @@ -56,6 +58,9 @@ impl DefineContextField {
pub(super) fn gen(self) -> Result<TokenStream> {
let Self { vis, name, ty } = self;

// We create a sub mod, so we need to extend the vis of getter.
let vis: Visibility = extend_vis_with_super(vis);

{
let name_s = name.to_string();
if name_s.to_uppercase() != name_s {
Expand Down
70 changes: 70 additions & 0 deletions src/expr/macro/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use proc_macro2::Ident;
use syn::spanned::Spanned;
use syn::{Token, VisRestricted, Visibility};

/// Convert a string from `snake_case` to `CamelCase`.
pub fn to_camel_case(input: &str) -> String {
input
Expand All @@ -27,3 +31,69 @@ pub fn to_camel_case(input: &str) -> String {
})
.collect()
}

pub(crate) fn extend_vis_with_super(vis: Visibility) -> Visibility {
let Visibility::Restricted(vis) = vis else {
return vis;
};
let VisRestricted {
pub_token,
paren_token,
mut in_token,
mut path,
} = vis;
let first_segment = path.segments.first_mut().unwrap();
if first_segment.ident == "self" {
*first_segment = Ident::new("super", first_segment.span()).into();
} else if first_segment.ident == "super" {
let span = first_segment.span();
path.segments.insert(0, Ident::new("super", span).into());
in_token.get_or_insert(Token![in](in_token.span()));
}
Visibility::Restricted(VisRestricted {
pub_token,
paren_token,
in_token,
path,
})
}

#[cfg(test)]
mod tests {
use quote::ToTokens;
use syn::Visibility;

use crate::utils::extend_vis_with_super;

#[test]
fn test_extend_vis_with_super() {
let cases = [
("pub", "pub"),
("pub(crate)", "pub(crate)"),
("pub(self)", "pub(super)"),
("pub(super)", "pub(in super::super)"),
("pub(in self)", "pub(in super)"),
(
"pub(in self::context::data)",
"pub(in super::context::data)",
),
(
"pub(in super::context::data)",
"pub(in super::super::context::data)",
),
("pub(in crate::func::impl_)", "pub(in crate::func::impl_)"),
(
"pub(in ::risingwave_expr::func::impl_)",
"pub(in ::risingwave_expr::func::impl_)",
),
];
for (input, expected) in cases {
let input: Visibility = syn::parse_str(input).unwrap();
let expected: Visibility = syn::parse_str(expected).unwrap();
let output = extend_vis_with_super(input);
let expected = expected.into_token_stream().to_string();
let output = output.into_token_stream().to_string();
assert_eq!(expected, output);
}
}
}
8 changes: 4 additions & 4 deletions src/frontend/src/expr/function_impl/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use risingwave_expr::define_context;
use crate::session::AuthContext;

define_context! {
pub(in crate::expr::function_impl) CATALOG_READER: crate::catalog::CatalogReader,
pub(in crate::expr::function_impl) AUTH_CONTEXT: Arc<AuthContext>,
pub(in crate::expr::function_impl) DB_NAME: String,
pub(in crate::expr::function_impl) SEARCH_PATH: SearchPath,
pub(super) CATALOG_READER: crate::catalog::CatalogReader,
pub(super) AUTH_CONTEXT: Arc<AuthContext>,
pub(super) DB_NAME: String,
pub(super) SEARCH_PATH: SearchPath,
}

0 comments on commit c256098

Please sign in to comment.