From 40aa7b29baaa21943ed2de02b5ec879dcd6cc9ab Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Mon, 19 Feb 2024 16:49:49 +0800 Subject: [PATCH] support name style for Fields Signed-off-by: Runji Wang --- src/common/fields-derive/src/lib.rs | 58 ++++++++++++++++++++++++- src/frontend/src/handler/cancel_job.rs | 24 +++++----- src/frontend/src/handler/explain.rs | 30 +++++-------- src/frontend/src/handler/show.rs | 29 ++++++------- src/frontend/src/handler/transaction.rs | 1 + src/frontend/src/handler/variable.rs | 3 ++ 6 files changed, 96 insertions(+), 49 deletions(-) diff --git a/src/common/fields-derive/src/lib.rs b/src/common/fields-derive/src/lib.rs index 86fa229a5adc..b38f57975168 100644 --- a/src/common/fields-derive/src/lib.rs +++ b/src/common/fields-derive/src/lib.rs @@ -16,7 +16,7 @@ use proc_macro2::TokenStream; use quote::quote; use syn::{Data, DeriveInput, Result}; -#[proc_macro_derive(Fields, attributes(primary_key))] +#[proc_macro_derive(Fields, attributes(primary_key, fields))] pub fn fields(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream { inner(tokens.into()).into() } @@ -46,6 +46,16 @@ fn gen(tokens: TokenStream) -> Result { )); }; + let style = get_style(&input); + if let Some(style) = &style { + if !["Title Case", "TITLE CASE", "snake_case"].contains(&style.value().as_str()) { + return Err(syn::Error::new_spanned( + style, + "only `Title Case`, `TITLE CASE`, and `snake_case` are supported", + )); + } + } + let fields_rw: Vec = struct_ .fields .iter() @@ -55,6 +65,12 @@ fn gen(tokens: TokenStream) -> Result { if name.starts_with("r#") { name = name[2..].to_string(); } + // cast style + match style.as_ref().map_or(String::new(), |f| f.value()).as_str() { + "Title Case" => name = to_title_case(&name), + "TITLE CASE" => name = to_title_case(&name).to_uppercase(), + _ => {} + } let ty = &field.ty; quote! { (#name, <#ty as ::risingwave_common::types::WithDataType>::default_data_type()) @@ -132,6 +148,46 @@ fn get_primary_key(input: &syn::DeriveInput) -> Option> { None } +/// Get name style from `#[fields(style = "xxx")]` attribute. +fn get_style(input: &syn::DeriveInput) -> Option { + let style = input.attrs.iter().find_map(|attr| match &attr.meta { + syn::Meta::List(list) if list.path.is_ident("fields") => { + let name_value: syn::MetaNameValue = syn::parse2(list.tokens.clone()).ok()?; + if name_value.path.is_ident("style") { + Some(name_value.value) + } else { + None + } + } + _ => None, + })?; + match style { + syn::Expr::Lit(lit) => match lit.lit { + syn::Lit::Str(s) => Some(s), + _ => None, + }, + _ => None, + } +} + +/// Convert `snake_case` to `Title Case`. +fn to_title_case(s: &str) -> String { + let mut title = String::new(); + let mut next_upper = true; + for c in s.chars() { + if c == '_' { + title.push(' '); + next_upper = true; + } else if next_upper { + title.push(c.to_uppercase().next().unwrap()); + next_upper = false; + } else { + title.push(c); + } + } + title +} + #[cfg(test)] mod tests { use indoc::indoc; diff --git a/src/frontend/src/handler/cancel_job.rs b/src/frontend/src/handler/cancel_job.rs index f124a2a030bd..0f4358373a45 100644 --- a/src/frontend/src/handler/cancel_job.rs +++ b/src/frontend/src/handler/cancel_job.rs @@ -12,14 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use itertools::Itertools; -use pgwire::pg_field_descriptor::PgFieldDescriptor; use pgwire::pg_response::{PgResponse, StatementType}; -use pgwire::types::Row; -use risingwave_common::types::DataType; +use risingwave_common::types::Fields; use risingwave_pb::meta::cancel_creating_jobs_request::{CreatingJobIds, PbJobs}; use risingwave_sqlparser::ast::JobIdents; +use super::RwPgResponseBuilderExt; use crate::error::Result; use crate::handler::{HandlerArgs, RwPgResponse}; @@ -36,16 +34,14 @@ pub(super) async fn handle_cancel( .await?; let rows = canceled_jobs .into_iter() - .map(|id| Row::new(vec![Some(id.to_string().into())])) - .collect_vec(); + .map(|id| CancelRow { id: id as i32 }); Ok(PgResponse::builder(StatementType::CANCEL_COMMAND) - .values( - rows.into(), - vec![PgFieldDescriptor::new( - "Id".to_string(), - DataType::Varchar.to_oid(), - DataType::Varchar.type_len(), - )], - ) + .rows(rows) .into()) } + +#[derive(Fields)] +#[fields(style = "Title Case")] +struct CancelRow { + id: i32, +} diff --git a/src/frontend/src/handler/explain.rs b/src/frontend/src/handler/explain.rs index c25bf7678bd0..b966cca8f50c 100644 --- a/src/frontend/src/handler/explain.rs +++ b/src/frontend/src/handler/explain.rs @@ -12,12 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use itertools::Itertools; -use pgwire::pg_field_descriptor::PgFieldDescriptor; use pgwire::pg_response::{PgResponse, StatementType}; -use pgwire::types::Row; use risingwave_common::bail_not_implemented; -use risingwave_common::types::DataType; +use risingwave_common::types::Fields; use risingwave_sqlparser::ast::{ExplainOptions, ExplainType, Statement}; use thiserror_ext::AsReport; @@ -27,7 +24,7 @@ use super::create_sink::{gen_sink_plan, get_partition_compute_info}; use super::create_table::ColumnIdGenerator; use super::query::gen_batch_plan_by_statement; use super::util::SourceSchemaCompatExt; -use super::RwPgResponse; +use super::{RwPgResponse, RwPgResponseBuilderExt}; use crate::error::{ErrorCode, Result}; use crate::handler::create_table::handle_create_table_plan; use crate::handler::HandlerArgs; @@ -254,20 +251,17 @@ pub async fn handle_explain( } } - let rows = blocks - .iter() - .flat_map(|b| b.lines().map(|l| l.to_owned())) - .map(|l| Row::new(vec![Some(l.into())])) - .collect_vec(); + let rows = blocks.iter().flat_map(|b| b.lines()).map(|l| ExplainRow { + query_plan: l.into(), + }); Ok(PgResponse::builder(StatementType::EXPLAIN) - .values( - rows.into(), - vec![PgFieldDescriptor::new( - "QUERY PLAN".to_owned(), - DataType::Varchar.to_oid(), - DataType::Varchar.type_len(), - )], - ) + .rows(rows) .into()) } + +#[derive(Fields)] +#[fields(style = "TITLE CASE")] +struct ExplainRow { + query_plan: String, +} diff --git a/src/frontend/src/handler/show.rs b/src/frontend/src/handler/show.rs index 669e711507a4..92f082a61dcb 100644 --- a/src/frontend/src/handler/show.rs +++ b/src/frontend/src/handler/show.rs @@ -19,7 +19,6 @@ use pgwire::pg_field_descriptor::PgFieldDescriptor; use pgwire::pg_protocol::truncated_fmt; use pgwire::pg_response::{PgResponse, StatementType}; use pgwire::pg_server::Session; -use pgwire::types::Row; use risingwave_common::bail_not_implemented; use risingwave_common::catalog::{ColumnCatalog, ColumnDesc, DEFAULT_SCHEMA_NAME}; use risingwave_common::types::{DataType, Fields}; @@ -108,11 +107,13 @@ fn schema_or_default(schema: &Option) -> String { } #[derive(Fields)] +#[fields(style = "Title Case")] struct ShowObjectRow { name: String, } #[derive(Fields)] +#[fields(style = "Title Case")] pub struct ShowColumnRow { pub name: String, pub r#type: String, @@ -143,6 +144,7 @@ impl ShowColumnRow { } #[derive(Fields)] +#[fields(style = "Title Case")] struct ShowConnectionRow { name: String, r#type: String, @@ -150,6 +152,7 @@ struct ShowConnectionRow { } #[derive(Fields)] +#[fields(style = "Title Case")] struct ShowFunctionRow { name: String, arguments: String, @@ -159,6 +162,7 @@ struct ShowFunctionRow { } #[derive(Fields)] +#[fields(style = "Title Case")] struct ShowIndexRow { name: String, on: String, @@ -182,6 +186,7 @@ impl From> for ShowIndexRow { } #[derive(Fields)] +#[fields(style = "Title Case")] struct ShowClusterRow { addr: String, state: String, @@ -192,6 +197,7 @@ struct ShowClusterRow { } #[derive(Fields)] +#[fields(style = "Title Case")] struct ShowJobRow { id: i64, statement: String, @@ -199,6 +205,7 @@ struct ShowJobRow { } #[derive(Fields)] +#[fields(style = "Title Case")] struct ShowProcessListRow { id: String, user: String, @@ -209,6 +216,7 @@ struct ShowProcessListRow { } #[derive(Fields)] +#[fields(style = "Title Case")] struct ShowCreateObjectRow { name: String, create_sql: String, @@ -502,21 +510,10 @@ pub fn handle_show_create_object( let name = format!("{}.{}", schema_name, object_name); Ok(PgResponse::builder(StatementType::SHOW_COMMAND) - .values( - vec![Row::new(vec![Some(name.into()), Some(sql.into())])].into(), - vec![ - PgFieldDescriptor::new( - "Name".to_owned(), - DataType::Varchar.to_oid(), - DataType::Varchar.type_len(), - ), - PgFieldDescriptor::new( - "Create Sql".to_owned(), - DataType::Varchar.to_oid(), - DataType::Varchar.type_len(), - ), - ], - ) + .rows([ShowCreateObjectRow { + name, + create_sql: sql, + }]) .into()) } diff --git a/src/frontend/src/handler/transaction.rs b/src/frontend/src/handler/transaction.rs index 20116a59c4aa..8ab7af36c29c 100644 --- a/src/frontend/src/handler/transaction.rs +++ b/src/frontend/src/handler/transaction.rs @@ -118,6 +118,7 @@ pub async fn handle_set( } #[derive(Fields)] +#[fields(style = "Title Case")] struct ShowVariableRow { name: String, } diff --git a/src/frontend/src/handler/variable.rs b/src/frontend/src/handler/variable.rs index 736dc7d78bb4..96fd232215cc 100644 --- a/src/frontend/src/handler/variable.rs +++ b/src/frontend/src/handler/variable.rs @@ -177,11 +177,13 @@ pub fn infer_show_variable(name: &str) -> Vec { } #[derive(Fields)] +#[fields(style = "Title Case")] struct ShowVariableRow { name: String, } #[derive(Fields)] +#[fields(style = "Title Case")] struct ShowVariableAllRow { name: String, setting: String, @@ -189,6 +191,7 @@ struct ShowVariableAllRow { } #[derive(Fields)] +#[fields(style = "Title Case")] struct ShowVariableParamsRow { name: String, value: String,