Skip to content

Commit

Permalink
refactor(frontend): use #[derive(Fields)] in statement handlers (#1…
Browse files Browse the repository at this point in the history
…5130)

Signed-off-by: Runji Wang <[email protected]>
  • Loading branch information
wangrunji0408 authored Feb 23, 2024
1 parent 316f180 commit 91d97ac
Show file tree
Hide file tree
Showing 12 changed files with 449 additions and 623 deletions.
58 changes: 57 additions & 1 deletion src/common/fields-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use proc_macro2::TokenStream;
use quote::quote;
use syn::{Data, DeriveInput, Result};

#[proc_macro_derive(Fields, attributes(primary_key))]
#[proc_macro_derive(Fields, attributes(primary_key, fields))]
pub fn fields(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
inner(tokens.into()).into()
}
Expand Down Expand Up @@ -46,6 +46,16 @@ fn gen(tokens: TokenStream) -> Result<TokenStream> {
));
};

let style = get_style(&input);
if let Some(style) = &style {
if !["Title Case", "TITLE CASE", "snake_case"].contains(&style.value().as_str()) {
return Err(syn::Error::new_spanned(
style,
"only `Title Case`, `TITLE CASE`, and `snake_case` are supported",
));
}
}

let fields_rw: Vec<TokenStream> = struct_
.fields
.iter()
Expand All @@ -55,6 +65,12 @@ fn gen(tokens: TokenStream) -> Result<TokenStream> {
if name.starts_with("r#") {
name = name[2..].to_string();
}
// cast style
match style.as_ref().map_or(String::new(), |f| f.value()).as_str() {
"Title Case" => name = to_title_case(&name),
"TITLE CASE" => name = to_title_case(&name).to_uppercase(),
_ => {}
}
let ty = &field.ty;
quote! {
(#name, <#ty as ::risingwave_common::types::WithDataType>::default_data_type())
Expand Down Expand Up @@ -132,6 +148,46 @@ fn get_primary_key(input: &syn::DeriveInput) -> Option<Vec<usize>> {
None
}

/// Get name style from `#[fields(style = "xxx")]` attribute.
fn get_style(input: &syn::DeriveInput) -> Option<syn::LitStr> {
let style = input.attrs.iter().find_map(|attr| match &attr.meta {
syn::Meta::List(list) if list.path.is_ident("fields") => {
let name_value: syn::MetaNameValue = syn::parse2(list.tokens.clone()).ok()?;
if name_value.path.is_ident("style") {
Some(name_value.value)
} else {
None
}
}
_ => None,
})?;
match style {
syn::Expr::Lit(lit) => match lit.lit {
syn::Lit::Str(s) => Some(s),
_ => None,
},
_ => None,
}
}

/// Convert `snake_case` to `Title Case`.
fn to_title_case(s: &str) -> String {
let mut title = String::new();
let mut next_upper = true;
for c in s.chars() {
if c == '_' {
title.push(' ');
next_upper = true;
} else if next_upper {
title.push(c.to_uppercase().next().unwrap());
next_upper = false;
} else {
title.push(c);
}
}
title
}

#[cfg(test)]
mod tests {
use indoc::indoc;
Expand Down
24 changes: 10 additions & 14 deletions src/frontend/src/handler/cancel_job.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use itertools::Itertools;
use pgwire::pg_field_descriptor::PgFieldDescriptor;
use pgwire::pg_response::{PgResponse, StatementType};
use pgwire::types::Row;
use risingwave_common::types::DataType;
use risingwave_common::types::Fields;
use risingwave_pb::meta::cancel_creating_jobs_request::{CreatingJobIds, PbJobs};
use risingwave_sqlparser::ast::JobIdents;

use super::RwPgResponseBuilderExt;
use crate::error::Result;
use crate::handler::{HandlerArgs, RwPgResponse};

Expand All @@ -36,16 +34,14 @@ pub(super) async fn handle_cancel(
.await?;
let rows = canceled_jobs
.into_iter()
.map(|id| Row::new(vec![Some(id.to_string().into())]))
.collect_vec();
.map(|id| CancelRow { id: id.to_string() });
Ok(PgResponse::builder(StatementType::CANCEL_COMMAND)
.values(
rows.into(),
vec![PgFieldDescriptor::new(
"Id".to_string(),
DataType::Varchar.to_oid(),
DataType::Varchar.type_len(),
)],
)
.rows(rows)
.into())
}

#[derive(Fields)]
#[fields(style = "Title Case")]
struct CancelRow {
id: String,
}
118 changes: 46 additions & 72 deletions src/frontend/src/handler/describe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,16 @@ use std::fmt::Display;
use itertools::Itertools;
use pgwire::pg_field_descriptor::PgFieldDescriptor;
use pgwire::pg_response::{PgResponse, StatementType};
use pgwire::types::Row;
use risingwave_common::catalog::{ColumnCatalog, ColumnDesc};
use risingwave_common::types::DataType;
use risingwave_common::types::Fields;
use risingwave_sqlparser::ast::{display_comma_separated, ObjectName};

use super::RwPgResponse;
use super::show::ShowColumnRow;
use super::{fields_to_descriptors, RwPgResponse};
use crate::binder::{Binder, Relation};
use crate::catalog::CatalogError;
use crate::error::Result;
use crate::handler::util::col_descs_to_rows;
use crate::handler::HandlerArgs;
use crate::handler::{HandlerArgs, RwPgResponseBuilderExt};

pub fn handle_describe(handler_args: HandlerArgs, object_name: ObjectName) -> Result<RwPgResponse> {
let session = handler_args.session;
Expand Down Expand Up @@ -156,7 +155,10 @@ pub fn handle_describe(handler_args: HandlerArgs, object_name: ObjectName) -> Re
};

// Convert all column descs to rows
let mut rows = col_descs_to_rows(columns);
let mut rows = columns
.into_iter()
.flat_map(ShowColumnRow::from_catalog)
.collect_vec();

fn concat<T>(display_elems: impl IntoIterator<Item = T>) -> String
where
Expand All @@ -170,96 +172,68 @@ pub fn handle_describe(handler_args: HandlerArgs, object_name: ObjectName) -> Re

// Convert primary key to rows
if !pk_columns.is_empty() {
rows.push(Row::new(vec![
Some("primary key".into()),
Some(concat(pk_columns.iter().map(|x| &x.name)).into()),
None, // Is Hidden
None, // Description
]));
rows.push(ShowColumnRow {
name: "primary key".into(),
r#type: concat(pk_columns.iter().map(|x| &x.name)),
is_hidden: None,
description: None,
});
}

// Convert distribution keys to rows
if !dist_columns.is_empty() {
rows.push(Row::new(vec![
Some("distribution key".into()),
Some(concat(dist_columns.iter().map(|x| &x.name)).into()),
None, // Is Hidden
None, // Description
]));
rows.push(ShowColumnRow {
name: "distribution key".into(),
r#type: concat(dist_columns.iter().map(|x| &x.name)),
is_hidden: None,
description: None,
});
}

// Convert all indexes to rows
rows.extend(indices.iter().map(|index| {
let index_display = index.display();

Row::new(vec![
Some(index.name.clone().into()),
if index_display.include_columns.is_empty() {
Some(
format!(
"index({}) distributed by({})",
display_comma_separated(&index_display.index_columns_with_ordering),
display_comma_separated(&index_display.distributed_by_columns),
)
.into(),
ShowColumnRow {
name: index.name.clone(),
r#type: if index_display.include_columns.is_empty() {
format!(
"index({}) distributed by({})",
display_comma_separated(&index_display.index_columns_with_ordering),
display_comma_separated(&index_display.distributed_by_columns),
)
} else {
Some(
format!(
"index({}) include({}) distributed by({})",
display_comma_separated(&index_display.index_columns_with_ordering),
display_comma_separated(&index_display.include_columns),
display_comma_separated(&index_display.distributed_by_columns),
)
.into(),
format!(
"index({}) include({}) distributed by({})",
display_comma_separated(&index_display.index_columns_with_ordering),
display_comma_separated(&index_display.include_columns),
display_comma_separated(&index_display.distributed_by_columns),
)
},
// Is Hidden
None,
// Description
is_hidden: None,
// TODO: index description
None,
])
description: None,
}
}));

rows.push(Row::new(vec![
Some("table description".into()),
Some(relname.into()),
None, // Is Hidden
description.map(Into::into), // Description
]));
rows.push(ShowColumnRow {
name: "table description".into(),
r#type: relname,
is_hidden: None,
description: description.map(Into::into),
});

// TODO: table name and description as title of response
// TODO: recover the original user statement
Ok(PgResponse::builder(StatementType::DESCRIBE)
.values(
rows.into(),
vec![
PgFieldDescriptor::new(
"Name".to_owned(),
DataType::Varchar.to_oid(),
DataType::Varchar.type_len(),
),
PgFieldDescriptor::new(
"Type".to_owned(),
DataType::Varchar.to_oid(),
DataType::Varchar.type_len(),
),
PgFieldDescriptor::new(
"Is Hidden".to_owned(),
DataType::Varchar.to_oid(),
DataType::Varchar.type_len(),
),
PgFieldDescriptor::new(
"Description".to_owned(),
DataType::Varchar.to_oid(),
DataType::Varchar.type_len(),
),
],
)
.rows(rows)
.into())
}

pub fn infer_describe() -> Vec<PgFieldDescriptor> {
fields_to_descriptors(ShowColumnRow::fields())
}

#[cfg(test)]
mod tests {
use std::collections::HashMap;
Expand Down
30 changes: 12 additions & 18 deletions src/frontend/src/handler/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use itertools::Itertools;
use pgwire::pg_field_descriptor::PgFieldDescriptor;
use pgwire::pg_response::{PgResponse, StatementType};
use pgwire::types::Row;
use risingwave_common::bail_not_implemented;
use risingwave_common::types::DataType;
use risingwave_common::types::Fields;
use risingwave_sqlparser::ast::{ExplainOptions, ExplainType, Statement};
use thiserror_ext::AsReport;

Expand All @@ -27,7 +24,7 @@ use super::create_sink::{gen_sink_plan, get_partition_compute_info};
use super::create_table::ColumnIdGenerator;
use super::query::gen_batch_plan_by_statement;
use super::util::SourceSchemaCompatExt;
use super::RwPgResponse;
use super::{RwPgResponse, RwPgResponseBuilderExt};
use crate::error::{ErrorCode, Result};
use crate::handler::create_table::handle_create_table_plan;
use crate::handler::HandlerArgs;
Expand Down Expand Up @@ -254,20 +251,17 @@ pub async fn handle_explain(
}
}

let rows = blocks
.iter()
.flat_map(|b| b.lines().map(|l| l.to_owned()))
.map(|l| Row::new(vec![Some(l.into())]))
.collect_vec();
let rows = blocks.iter().flat_map(|b| b.lines()).map(|l| ExplainRow {
query_plan: l.into(),
});

Ok(PgResponse::builder(StatementType::EXPLAIN)
.values(
rows.into(),
vec![PgFieldDescriptor::new(
"QUERY PLAN".to_owned(),
DataType::Varchar.to_oid(),
DataType::Varchar.type_len(),
)],
)
.rows(rows)
.into())
}

#[derive(Fields)]
#[fields(style = "TITLE CASE")]
struct ExplainRow {
query_plan: String,
}
Loading

0 comments on commit 91d97ac

Please sign in to comment.