diff --git a/src/connector/src/parser/protobuf/parser.rs b/src/connector/src/parser/protobuf/parser.rs index bb72470c61336..8e3d5f384242d 100644 --- a/src/connector/src/parser/protobuf/parser.rs +++ b/src/connector/src/parser/protobuf/parser.rs @@ -123,8 +123,13 @@ impl ProtobufParserConfig { pub fn map_to_columns(&self) -> Result> { let mut columns = Vec::with_capacity(self.message_descriptor.fields().len()); let mut index = 0; + let mut parse_trace: Vec = vec![]; for field in self.message_descriptor.fields() { - columns.push(Self::pb_field_to_col_desc(&field, &mut index)?); + columns.push(Self::pb_field_to_col_desc( + &field, + &mut index, + &mut parse_trace, + )?); } Ok(columns) @@ -134,14 +139,15 @@ impl ProtobufParserConfig { fn pb_field_to_col_desc( field_descriptor: &FieldDescriptor, index: &mut i32, + parse_trace: &mut Vec, ) -> Result { - let field_type = protobuf_type_mapping(field_descriptor)?; + let field_type = protobuf_type_mapping(field_descriptor, parse_trace)?; if let Kind::Message(m) = field_descriptor.kind() { let field_descs = if let DataType::List { .. } = field_type { vec![] } else { m.fields() - .map(|f| Self::pb_field_to_col_desc(&f, index)) + .map(|f| Self::pb_field_to_col_desc(&f, index, parse_trace)) .collect::>>()? }; *index += 1; @@ -219,6 +225,20 @@ impl ProtobufParser { } } +fn detect_loop_and_push(trace: &mut Vec, fd: &FieldDescriptor) -> Result<()> { + let identifier = format!("{}({})", fd.name(), fd.full_name()); + if trace.iter().any(|s| s == identifier.as_str()) { + return Err(RwError::from(ProtocolError(format!( + "circular reference detected: {}, conflict with {}, kind {:?}", + trace.iter().join("->"), + identifier, + fd.kind(), + )))); + } + trace.push(identifier); + Ok(()) +} + impl ByteStreamSourceParser for ProtobufParser { fn columns(&self) -> &[SourceColumnDesc] { &self.rw_columns @@ -302,7 +322,11 @@ fn from_protobuf_value(field_desc: &FieldDescriptor, value: &Value) -> Result Result { +fn protobuf_type_mapping( + field_descriptor: &FieldDescriptor, + parse_trace: &mut Vec, +) -> Result { + detect_loop_and_push(parse_trace, field_descriptor)?; let field_type = field_descriptor.kind(); let mut t = match field_type { Kind::Bool => DataType::Boolean, @@ -317,7 +341,7 @@ fn protobuf_type_mapping(field_descriptor: &FieldDescriptor) -> Result Kind::Message(m) => { let fields = m .fields() - .map(|f| protobuf_type_mapping(&f)) + .map(|f| protobuf_type_mapping(&f, parse_trace)) .collect::>>()?; let field_names = m.fields().map(|f| f.name().to_string()).collect_vec(); DataType::new_struct(fields, field_names) @@ -334,6 +358,7 @@ fn protobuf_type_mapping(field_descriptor: &FieldDescriptor) -> Result if field_descriptor.cardinality() == Cardinality::Repeated { t = DataType::List(Box::new(t)) } + _ = parse_trace.pop(); Ok(t) } @@ -456,4 +481,21 @@ mod test { ); Ok(()) } + + #[tokio::test] + async fn test_refuse_recursive_proto_message() { + let location = schema_dir() + "/proto_recursive/recursive.pb"; + let message_name = "recursive.ComplexRecursiveMessage"; + let conf = ProtobufParserConfig::new(&HashMap::new(), &location, message_name, false) + .await + .unwrap(); + let columns = conf.map_to_columns(); + // expect error message: + // "Err(Protocol error: circular reference detected: + // parent(recursive.ComplexRecursiveMessage.parent)->siblings(recursive. + // ComplexRecursiveMessage.Parent.siblings), conflict with + // parent(recursive.ComplexRecursiveMessage.parent), kind + // recursive.ComplexRecursiveMessage.Parent" + assert!(columns.is_err()); + } } diff --git a/src/connector/src/test_data/proto_recursive/recursive.pb b/src/connector/src/test_data/proto_recursive/recursive.pb new file mode 100644 index 0000000000000..85912baa4e804 --- /dev/null +++ b/src/connector/src/test_data/proto_recursive/recursive.pb @@ -0,0 +1,20 @@ + +‡ +recursive.proto recursive"à +ComplexRecursiveMessage + node_name ( RnodeName +node_id (RnodeIdM + +attributes ( 2-.recursive.ComplexRecursiveMessage.AttributesR +attributesA +parent ( 2).recursive.ComplexRecursiveMessage.ParentRparent> +children ( 2".recursive.ComplexRecursiveMessageRchildren4 + +Attributes +key ( Rkey +value ( Rvalue† +Parent + parent_name ( R +parentName + parent_id (RparentId> +siblings ( 2".recursive.ComplexRecursiveMessageRsiblingsbproto3 \ No newline at end of file diff --git a/src/connector/src/test_data/proto_recursive/recursive.proto b/src/connector/src/test_data/proto_recursive/recursive.proto new file mode 100644 index 0000000000000..657de75532ba5 --- /dev/null +++ b/src/connector/src/test_data/proto_recursive/recursive.proto @@ -0,0 +1,25 @@ +syntax = "proto3"; + +package recursive; + +message ComplexRecursiveMessage { + string node_name = 1; + int32 node_id = 2; + + message Attributes { + string key = 1; + string value = 2; + } + + repeated Attributes attributes = 3; + + message Parent { + string parent_name = 1; + int32 parent_id = 2; + repeated ComplexRecursiveMessage siblings = 3; + } + + Parent parent = 4; + repeated ComplexRecursiveMessage children = 5; +} +