From 765a699aa146e72d6e71bc8b664e2df4d4a181ce Mon Sep 17 00:00:00 2001 From: Ggiggle <47661277+Ggiggle@users.noreply.github.com> Date: Tue, 5 Nov 2024 14:36:00 +0800 Subject: [PATCH] fix(pilota-build): keep the leading underscore in idl method name (#282) --- Cargo.lock | 2 +- pilota-build/Cargo.toml | 2 +- pilota-build/src/parser/thrift/mod.rs | 31 +- pilota-build/test_data/thrift/underscore.rs | 583 ++++++++++++++++++ .../test_data/thrift/underscore.thrift | 5 + 5 files changed, 608 insertions(+), 15 deletions(-) create mode 100644 pilota-build/test_data/thrift/underscore.rs create mode 100644 pilota-build/test_data/thrift/underscore.thrift diff --git a/Cargo.lock b/Cargo.lock index e9c368eb..4a7cbab4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -845,7 +845,7 @@ dependencies = [ [[package]] name = "pilota-build" -version = "0.11.22" +version = "0.11.23" dependencies = [ "ahash", "anyhow", diff --git a/pilota-build/Cargo.toml b/pilota-build/Cargo.toml index 57643a49..a878c4b8 100644 --- a/pilota-build/Cargo.toml +++ b/pilota-build/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pilota-build" -version = "0.11.22" +version = "0.11.23" edition = "2021" description = "Compile thrift and protobuf idl into rust code at compile-time." documentation = "https://docs.rs/pilota-build" diff --git a/pilota-build/src/parser/thrift/mod.rs b/pilota-build/src/parser/thrift/mod.rs index a8e73ab5..9564202f 100644 --- a/pilota-build/src/parser/thrift/mod.rs +++ b/pilota-build/src/parser/thrift/mod.rs @@ -12,11 +12,11 @@ use thrift_parser::Annotations; use crate::{ index::Idx, - ir, - ir::{Arg, Enum, EnumVariant, FieldKind, File, Item, ItemKind, Path}, + ir::{self, Arg, Enum, EnumVariant, FieldKind, File, Item, ItemKind, Path}, symbol::{EnumRepr, FileId, Ident}, tags::{Annotation, PilotaName, RustWrapperArc, Tags}, util::error_abort, + IdentName, }; #[salsa::query_group(SourceDatabaseStorage)] @@ -133,15 +133,16 @@ impl ThriftLower { .get::() .map(|name| name.0.to_string()) .unwrap_or_else(|| func.name.to_string()); + let ident = Ident::from(name.clone()); function_names - .entry(name.to_upper_camel_case()) + .entry(ident.to_upper_camel_case()) .or_default() .push(name); }); let function_name_duplicates = function_names .iter() .filter(|(_, v)| v.len() > 1) - .map(|(k, _)| k) + .map(|(k, _)| k.as_str()) .collect::>(); let kind = ir::ItemKind::Service(ir::Service { @@ -188,12 +189,14 @@ impl ThriftLower { let tags = self.extract_tags(&f.annotations); let name = tags .get::() - .map(|name| name.0.to_string()) - .unwrap_or_else(|| f.name.to_string()); - let method_name = if function_name_duplicates.contains(&name.to_upper_camel_case()) { + .map(|name| name.0.clone()) + .unwrap_or_else(|| FastStr::new(f.name.0.clone())); + + let upper_camel_ident = f.name.as_str().upper_camel_ident(); + let method_name = if function_name_duplicates.contains(upper_camel_ident.as_str()) { name } else { - name.to_upper_camel_case() + upper_camel_ident }; let name: Ident = format!("{}{}ResultRecv", service_name, method_name).into(); @@ -292,17 +295,19 @@ impl ThriftLower { &self, service_name: &String, method: &thrift_parser::Function, - function_name_duplicates: &FxHashSet<&String>, + function_name_duplicates: &FxHashSet<&str>, ) -> ir::Method { let tags = self.extract_tags(&method.annotations); let name = tags .get::() - .map(|name| name.0.to_string()) - .unwrap_or_else(|| method.name.to_string()); - let method_name = if function_name_duplicates.contains(&name.to_upper_camel_case()) { + .map(|name| name.0.clone()) + .unwrap_or_else(|| FastStr::new(method.name.0.clone())); + + let upper_camel_ident = method.name.as_str().upper_camel_ident(); + let method_name = if function_name_duplicates.contains(upper_camel_ident.as_str()) { name } else { - name.to_upper_camel_case() + upper_camel_ident }; ir::Method { diff --git a/pilota-build/test_data/thrift/underscore.rs b/pilota-build/test_data/thrift/underscore.rs new file mode 100644 index 00000000..fa87a463 --- /dev/null +++ b/pilota-build/test_data/thrift/underscore.rs @@ -0,0 +1,583 @@ +pub mod underscore { + #![allow(warnings, clippy::all)] + + pub mod underscore { + #[derive(PartialOrd, Hash, Eq, Ord, Debug, ::pilota::derivative::Derivative)] + #[derivative(Default)] + #[derive(Clone, PartialEq)] + pub enum Test_UnderscoredResultRecv { + #[derivative(Default)] + Ok(::pilota::FastStr), + } + + impl ::pilota::thrift::Message for Test_UnderscoredResultRecv { + fn encode( + &self, + __protocol: &mut T, + ) -> ::std::result::Result<(), ::pilota::thrift::ThriftException> { + #[allow(unused_imports)] + use ::pilota::thrift::TOutputProtocolExt; + __protocol.write_struct_begin(&::pilota::thrift::TStructIdentifier { + name: "Test_UnderscoredResultRecv", + })?; + match self { + Test_UnderscoredResultRecv::Ok(ref value) => { + __protocol.write_faststr_field(0, (value).clone())?; + } + } + __protocol.write_field_stop()?; + __protocol.write_struct_end()?; + ::std::result::Result::Ok(()) + } + + fn decode( + __protocol: &mut T, + ) -> ::std::result::Result { + #[allow(unused_imports)] + use ::pilota::{thrift::TLengthProtocolExt, Buf}; + let mut ret = None; + __protocol.read_struct_begin()?; + loop { + let field_ident = __protocol.read_field_begin()?; + if field_ident.field_type == ::pilota::thrift::TType::Stop { + __protocol.field_stop_len(); + break; + } else { + __protocol.field_begin_len(field_ident.field_type, field_ident.id); + } + match field_ident.id { + Some(0) => { + if ret.is_none() { + let field_ident = __protocol.read_faststr()?; + __protocol.faststr_len(&field_ident); + ret = Some(Test_UnderscoredResultRecv::Ok(field_ident)); + } else { + return ::std::result::Result::Err( + ::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "received multiple fields for union from remote Message", + ), + ); + } + } + _ => { + __protocol.skip(field_ident.field_type)?; + } + } + } + __protocol.read_field_end()?; + __protocol.read_struct_end()?; + if let Some(ret) = ret { + ::std::result::Result::Ok(ret) + } else { + ::std::result::Result::Err(::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "received empty union from remote Message", + )) + } + } + + fn decode_async<'a, T: ::pilota::thrift::TAsyncInputProtocol>( + __protocol: &'a mut T, + ) -> ::std::pin::Pin< + ::std::boxed::Box< + dyn ::std::future::Future< + Output = ::std::result::Result, + > + Send + + 'a, + >, + > { + ::std::boxed::Box::pin(async move { + let mut ret = None; + __protocol.read_struct_begin().await?; + loop { + let field_ident = __protocol.read_field_begin().await?; + if field_ident.field_type == ::pilota::thrift::TType::Stop { + break; + } else { + } + match field_ident.id { + Some(0) => { + if ret.is_none() { + let field_ident = __protocol.read_faststr().await?; + + ret = Some(Test_UnderscoredResultRecv::Ok(field_ident)); + } else { + return ::std::result::Result::Err(::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "received multiple fields for union from remote Message" + )); + } + } + _ => { + __protocol.skip(field_ident.field_type).await?; + } + } + } + __protocol.read_field_end().await?; + __protocol.read_struct_end().await?; + if let Some(ret) = ret { + ::std::result::Result::Ok(ret) + } else { + ::std::result::Result::Err(::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "received empty union from remote Message", + )) + } + }) + } + + fn size(&self, __protocol: &mut T) -> usize { + #[allow(unused_imports)] + use ::pilota::thrift::TLengthProtocolExt; + __protocol.struct_begin_len(&::pilota::thrift::TStructIdentifier { + name: "Test_UnderscoredResultRecv", + }) + match self { + Test_UnderscoredResultRecv::Ok(ref value) => { + __protocol.faststr_field_len(Some(0), value) + } + } + __protocol.field_stop_len() + + __protocol.struct_end_len() + } + } + #[derive(PartialOrd, Hash, Eq, Ord, Debug, Default, Clone, PartialEq)] + pub struct Test_UnderscoredArgsSend { + pub param: ::pilota::FastStr, + } + impl ::pilota::thrift::Message for Test_UnderscoredArgsSend { + fn encode( + &self, + __protocol: &mut T, + ) -> ::std::result::Result<(), ::pilota::thrift::ThriftException> { + #[allow(unused_imports)] + use ::pilota::thrift::TOutputProtocolExt; + let struct_ident = ::pilota::thrift::TStructIdentifier { + name: "Test_UnderscoredArgsSend", + }; + + __protocol.write_struct_begin(&struct_ident)?; + __protocol.write_faststr_field(1, (&self.param).clone())?; + __protocol.write_field_stop()?; + __protocol.write_struct_end()?; + ::std::result::Result::Ok(()) + } + + fn decode( + __protocol: &mut T, + ) -> ::std::result::Result { + #[allow(unused_imports)] + use ::pilota::{thrift::TLengthProtocolExt, Buf}; + + let mut var_1 = None; + + let mut __pilota_decoding_field_id = None; + + __protocol.read_struct_begin()?; + if let ::std::result::Result::Err(mut err) = (|| { + loop { + let field_ident = __protocol.read_field_begin()?; + if field_ident.field_type == ::pilota::thrift::TType::Stop { + __protocol.field_stop_len(); + break; + } else { + __protocol.field_begin_len(field_ident.field_type, field_ident.id); + } + __pilota_decoding_field_id = field_ident.id; + match field_ident.id { + Some(1) + if field_ident.field_type == ::pilota::thrift::TType::Binary => + { + var_1 = Some(__protocol.read_faststr()?); + } + _ => { + __protocol.skip(field_ident.field_type)?; + } + } + + __protocol.read_field_end()?; + __protocol.field_end_len(); + } + ::std::result::Result::Ok::<_, ::pilota::thrift::ThriftException>(()) + })() { + if let Some(field_id) = __pilota_decoding_field_id { + err.prepend_msg(&format!("decode struct `Test_UnderscoredArgsSend` field(#{}) failed, caused by: ", field_id)); + } + return ::std::result::Result::Err(err); + }; + __protocol.read_struct_end()?; + + let Some(var_1) = var_1 else { + return ::std::result::Result::Err(::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "field param is required".to_string(), + )); + }; + + let data = Self { param: var_1 }; + ::std::result::Result::Ok(data) + } + + fn decode_async<'a, T: ::pilota::thrift::TAsyncInputProtocol>( + __protocol: &'a mut T, + ) -> ::std::pin::Pin< + ::std::boxed::Box< + dyn ::std::future::Future< + Output = ::std::result::Result, + > + Send + + 'a, + >, + > { + ::std::boxed::Box::pin(async move { + let mut var_1 = None; + + let mut __pilota_decoding_field_id = None; + + __protocol.read_struct_begin().await?; + if let ::std::result::Result::Err(mut err) = async { + loop { + let field_ident = __protocol.read_field_begin().await?; + if field_ident.field_type == ::pilota::thrift::TType::Stop { + break; + } else { + } + __pilota_decoding_field_id = field_ident.id; + match field_ident.id { + Some(1) + if field_ident.field_type + == ::pilota::thrift::TType::Binary => + { + var_1 = Some(__protocol.read_faststr().await?); + } + _ => { + __protocol.skip(field_ident.field_type).await?; + } + } + + __protocol.read_field_end().await?; + } + ::std::result::Result::Ok::<_, ::pilota::thrift::ThriftException>(()) + } + .await + { + if let Some(field_id) = __pilota_decoding_field_id { + err.prepend_msg(&format!("decode struct `Test_UnderscoredArgsSend` field(#{}) failed, caused by: ", field_id)); + } + return ::std::result::Result::Err(err); + }; + __protocol.read_struct_end().await?; + + let Some(var_1) = var_1 else { + return ::std::result::Result::Err( + ::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "field param is required".to_string(), + ), + ); + }; + + let data = Self { param: var_1 }; + ::std::result::Result::Ok(data) + }) + } + + fn size(&self, __protocol: &mut T) -> usize { + #[allow(unused_imports)] + use ::pilota::thrift::TLengthProtocolExt; + __protocol.struct_begin_len(&::pilota::thrift::TStructIdentifier { + name: "Test_UnderscoredArgsSend", + }) + __protocol.faststr_field_len(Some(1), &self.param) + + __protocol.field_stop_len() + + __protocol.struct_end_len() + } + } + #[derive(PartialOrd, Hash, Eq, Ord, Debug, Default, Clone, PartialEq)] + pub struct Test_UnderscoredArgsRecv { + pub param: ::pilota::FastStr, + } + impl ::pilota::thrift::Message for Test_UnderscoredArgsRecv { + fn encode( + &self, + __protocol: &mut T, + ) -> ::std::result::Result<(), ::pilota::thrift::ThriftException> { + #[allow(unused_imports)] + use ::pilota::thrift::TOutputProtocolExt; + let struct_ident = ::pilota::thrift::TStructIdentifier { + name: "Test_UnderscoredArgsRecv", + }; + + __protocol.write_struct_begin(&struct_ident)?; + __protocol.write_faststr_field(1, (&self.param).clone())?; + __protocol.write_field_stop()?; + __protocol.write_struct_end()?; + ::std::result::Result::Ok(()) + } + + fn decode( + __protocol: &mut T, + ) -> ::std::result::Result { + #[allow(unused_imports)] + use ::pilota::{thrift::TLengthProtocolExt, Buf}; + + let mut var_1 = None; + + let mut __pilota_decoding_field_id = None; + + __protocol.read_struct_begin()?; + if let ::std::result::Result::Err(mut err) = (|| { + loop { + let field_ident = __protocol.read_field_begin()?; + if field_ident.field_type == ::pilota::thrift::TType::Stop { + __protocol.field_stop_len(); + break; + } else { + __protocol.field_begin_len(field_ident.field_type, field_ident.id); + } + __pilota_decoding_field_id = field_ident.id; + match field_ident.id { + Some(1) + if field_ident.field_type == ::pilota::thrift::TType::Binary => + { + var_1 = Some(__protocol.read_faststr()?); + } + _ => { + __protocol.skip(field_ident.field_type)?; + } + } + + __protocol.read_field_end()?; + __protocol.field_end_len(); + } + ::std::result::Result::Ok::<_, ::pilota::thrift::ThriftException>(()) + })() { + if let Some(field_id) = __pilota_decoding_field_id { + err.prepend_msg(&format!("decode struct `Test_UnderscoredArgsRecv` field(#{}) failed, caused by: ", field_id)); + } + return ::std::result::Result::Err(err); + }; + __protocol.read_struct_end()?; + + let Some(var_1) = var_1 else { + return ::std::result::Result::Err(::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "field param is required".to_string(), + )); + }; + + let data = Self { param: var_1 }; + ::std::result::Result::Ok(data) + } + + fn decode_async<'a, T: ::pilota::thrift::TAsyncInputProtocol>( + __protocol: &'a mut T, + ) -> ::std::pin::Pin< + ::std::boxed::Box< + dyn ::std::future::Future< + Output = ::std::result::Result, + > + Send + + 'a, + >, + > { + ::std::boxed::Box::pin(async move { + let mut var_1 = None; + + let mut __pilota_decoding_field_id = None; + + __protocol.read_struct_begin().await?; + if let ::std::result::Result::Err(mut err) = async { + loop { + let field_ident = __protocol.read_field_begin().await?; + if field_ident.field_type == ::pilota::thrift::TType::Stop { + break; + } else { + } + __pilota_decoding_field_id = field_ident.id; + match field_ident.id { + Some(1) + if field_ident.field_type + == ::pilota::thrift::TType::Binary => + { + var_1 = Some(__protocol.read_faststr().await?); + } + _ => { + __protocol.skip(field_ident.field_type).await?; + } + } + + __protocol.read_field_end().await?; + } + ::std::result::Result::Ok::<_, ::pilota::thrift::ThriftException>(()) + } + .await + { + if let Some(field_id) = __pilota_decoding_field_id { + err.prepend_msg(&format!("decode struct `Test_UnderscoredArgsRecv` field(#{}) failed, caused by: ", field_id)); + } + return ::std::result::Result::Err(err); + }; + __protocol.read_struct_end().await?; + + let Some(var_1) = var_1 else { + return ::std::result::Result::Err( + ::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "field param is required".to_string(), + ), + ); + }; + + let data = Self { param: var_1 }; + ::std::result::Result::Ok(data) + }) + } + + fn size(&self, __protocol: &mut T) -> usize { + #[allow(unused_imports)] + use ::pilota::thrift::TLengthProtocolExt; + __protocol.struct_begin_len(&::pilota::thrift::TStructIdentifier { + name: "Test_UnderscoredArgsRecv", + }) + __protocol.faststr_field_len(Some(1), &self.param) + + __protocol.field_stop_len() + + __protocol.struct_end_len() + } + } + #[derive(PartialOrd, Hash, Eq, Ord, Debug, ::pilota::derivative::Derivative)] + #[derivative(Default)] + #[derive(Clone, PartialEq)] + pub enum Test_UnderscoredResultSend { + #[derivative(Default)] + Ok(::pilota::FastStr), + } + + impl ::pilota::thrift::Message for Test_UnderscoredResultSend { + fn encode( + &self, + __protocol: &mut T, + ) -> ::std::result::Result<(), ::pilota::thrift::ThriftException> { + #[allow(unused_imports)] + use ::pilota::thrift::TOutputProtocolExt; + __protocol.write_struct_begin(&::pilota::thrift::TStructIdentifier { + name: "Test_UnderscoredResultSend", + })?; + match self { + Test_UnderscoredResultSend::Ok(ref value) => { + __protocol.write_faststr_field(0, (value).clone())?; + } + } + __protocol.write_field_stop()?; + __protocol.write_struct_end()?; + ::std::result::Result::Ok(()) + } + + fn decode( + __protocol: &mut T, + ) -> ::std::result::Result { + #[allow(unused_imports)] + use ::pilota::{thrift::TLengthProtocolExt, Buf}; + let mut ret = None; + __protocol.read_struct_begin()?; + loop { + let field_ident = __protocol.read_field_begin()?; + if field_ident.field_type == ::pilota::thrift::TType::Stop { + __protocol.field_stop_len(); + break; + } else { + __protocol.field_begin_len(field_ident.field_type, field_ident.id); + } + match field_ident.id { + Some(0) => { + if ret.is_none() { + let field_ident = __protocol.read_faststr()?; + __protocol.faststr_len(&field_ident); + ret = Some(Test_UnderscoredResultSend::Ok(field_ident)); + } else { + return ::std::result::Result::Err( + ::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "received multiple fields for union from remote Message", + ), + ); + } + } + _ => { + __protocol.skip(field_ident.field_type)?; + } + } + } + __protocol.read_field_end()?; + __protocol.read_struct_end()?; + if let Some(ret) = ret { + ::std::result::Result::Ok(ret) + } else { + ::std::result::Result::Err(::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "received empty union from remote Message", + )) + } + } + + fn decode_async<'a, T: ::pilota::thrift::TAsyncInputProtocol>( + __protocol: &'a mut T, + ) -> ::std::pin::Pin< + ::std::boxed::Box< + dyn ::std::future::Future< + Output = ::std::result::Result, + > + Send + + 'a, + >, + > { + ::std::boxed::Box::pin(async move { + let mut ret = None; + __protocol.read_struct_begin().await?; + loop { + let field_ident = __protocol.read_field_begin().await?; + if field_ident.field_type == ::pilota::thrift::TType::Stop { + break; + } else { + } + match field_ident.id { + Some(0) => { + if ret.is_none() { + let field_ident = __protocol.read_faststr().await?; + + ret = Some(Test_UnderscoredResultSend::Ok(field_ident)); + } else { + return ::std::result::Result::Err(::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "received multiple fields for union from remote Message" + )); + } + } + _ => { + __protocol.skip(field_ident.field_type).await?; + } + } + } + __protocol.read_field_end().await?; + __protocol.read_struct_end().await?; + if let Some(ret) = ret { + ::std::result::Result::Ok(ret) + } else { + ::std::result::Result::Err(::pilota::thrift::new_protocol_exception( + ::pilota::thrift::ProtocolExceptionKind::InvalidData, + "received empty union from remote Message", + )) + } + }) + } + + fn size(&self, __protocol: &mut T) -> usize { + #[allow(unused_imports)] + use ::pilota::thrift::TLengthProtocolExt; + __protocol.struct_begin_len(&::pilota::thrift::TStructIdentifier { + name: "Test_UnderscoredResultSend", + }) + match self { + Test_UnderscoredResultSend::Ok(ref value) => { + __protocol.faststr_field_len(Some(0), value) + } + } + __protocol.field_stop_len() + + __protocol.struct_end_len() + } + } + pub trait Test {} + } +} diff --git a/pilota-build/test_data/thrift/underscore.thrift b/pilota-build/test_data/thrift/underscore.thrift new file mode 100644 index 00000000..dd335b5b --- /dev/null +++ b/pilota-build/test_data/thrift/underscore.thrift @@ -0,0 +1,5 @@ +service Test { + string _underscored( + 1: string param + ) +}