From ef7eb1d48ef2f43701455e63b8aa2737a4166f10 Mon Sep 17 00:00:00 2001 From: xiangjinwu <17769960+xiangjinwu@users.noreply.github.com> Date: Fri, 8 Mar 2024 14:16:04 +0800 Subject: [PATCH] refactor(connector): share schema registry loader between avro and protobuf (#14642) --- src/connector/src/parser/protobuf/mod.rs | 1 - src/connector/src/parser/protobuf/parser.rs | 40 ++++---- .../src/parser/protobuf/schema_resolver.rs | 95 ------------------ src/connector/src/schema/avro.rs | 94 +++++------------- src/connector/src/schema/loader.rs | 95 ++++++++++++++++++ src/connector/src/schema/mod.rs | 3 + src/connector/src/schema/protobuf.rs | 97 ++++++++++++++++++- 7 files changed, 237 insertions(+), 188 deletions(-) delete mode 100644 src/connector/src/parser/protobuf/schema_resolver.rs create mode 100644 src/connector/src/schema/loader.rs diff --git a/src/connector/src/parser/protobuf/mod.rs b/src/connector/src/parser/protobuf/mod.rs index c6a7e23357b0d..bfcb0adfe1a18 100644 --- a/src/connector/src/parser/protobuf/mod.rs +++ b/src/connector/src/parser/protobuf/mod.rs @@ -14,7 +14,6 @@ mod parser; pub use parser::*; -mod schema_resolver; #[rustfmt::skip] #[cfg(test)] diff --git a/src/connector/src/parser/protobuf/parser.rs b/src/connector/src/parser/protobuf/parser.rs index b09cc24ea59e1..e9ae317ca5ebc 100644 --- a/src/connector/src/parser/protobuf/parser.rs +++ b/src/connector/src/parser/protobuf/parser.rs @@ -17,8 +17,8 @@ use std::sync::Arc; use anyhow::Context; use itertools::Itertools; use prost_reflect::{ - Cardinality, DescriptorPool, DynamicMessage, FieldDescriptor, Kind, MessageDescriptor, - ReflectMessage, Value, + Cardinality, DescriptorPool, DynamicMessage, FieldDescriptor, FileDescriptor, Kind, + MessageDescriptor, ReflectMessage, Value, }; use risingwave_common::array::{ListValue, StructValue}; use risingwave_common::types::{DataType, Datum, Decimal, JsonbVal, ScalarImpl, F32, F64}; @@ -27,7 +27,6 @@ use risingwave_pb::plan_common::{AdditionalColumn, ColumnDesc, ColumnDescVersion use thiserror::Error; use thiserror_ext::{AsReport, Macro}; -use super::schema_resolver::*; use crate::error::ConnectorResult; use crate::parser::unified::protobuf::ProtobufAccess; use crate::parser::unified::{ @@ -35,9 +34,8 @@ use crate::parser::unified::{ }; use crate::parser::util::bytes_from_url; use crate::parser::{AccessBuilder, EncodingProperties}; -use crate::schema::schema_registry::{ - extract_schema_id, get_subject_by_strategy, handle_sr_list, Client, WireFormatError, -}; +use crate::schema::schema_registry::{extract_schema_id, handle_sr_list, Client, WireFormatError}; +use crate::schema::SchemaLoader; #[derive(Debug)] pub struct ProtobufAccessBuilder { @@ -100,25 +98,27 @@ impl ProtobufParserConfig { // https://docs.confluent.io/platform/7.5/control-center/topics/schema.html#c3-schemas-best-practices-key-value-pairs bail!("protobuf key is not supported"); } - let schema_bytes = if protobuf_config.use_schema_registry { - let schema_value = get_subject_by_strategy( - &protobuf_config.name_strategy, - protobuf_config.topic.as_str(), - Some(message_name.as_ref()), - false, - )?; - tracing::debug!("infer value subject {schema_value}"); - + let pool = if protobuf_config.use_schema_registry { let client = Client::new(url, &protobuf_config.client_config)?; - compile_file_descriptor_from_schema_registry(schema_value.as_str(), &client).await? + let loader = SchemaLoader { + client, + name_strategy: protobuf_config.name_strategy, + topic: protobuf_config.topic, + key_record_name: None, + val_record_name: Some(message_name.clone()), + }; + let (_schema_id, root_file_descriptor) = loader + .load_val_schema::() + .await + .context("load schema failed")?; + root_file_descriptor.parent_pool().clone() } else { let url = url.first().unwrap(); - bytes_from_url(url, protobuf_config.aws_auth_props.as_ref()).await? + let schema_bytes = bytes_from_url(url, protobuf_config.aws_auth_props.as_ref()).await?; + DescriptorPool::decode(schema_bytes.as_slice()) + .with_context(|| format!("cannot build descriptor pool from schema `{location}`"))? }; - let pool = DescriptorPool::decode(schema_bytes.as_slice()) - .with_context(|| format!("cannot build descriptor pool from schema `{}`", location))?; - let message_descriptor = pool.get_message_by_name(message_name).with_context(|| { format!( "cannot find message `{}` in schema `{}`", diff --git a/src/connector/src/parser/protobuf/schema_resolver.rs b/src/connector/src/parser/protobuf/schema_resolver.rs deleted file mode 100644 index 828843842c785..0000000000000 --- a/src/connector/src/parser/protobuf/schema_resolver.rs +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::iter; -use std::path::Path; - -use anyhow::Context; -use itertools::Itertools; -use protobuf_native::compiler::{ - SimpleErrorCollector, SourceTreeDescriptorDatabase, VirtualSourceTree, -}; -use protobuf_native::MessageLite; - -use crate::error::ConnectorResult; -use crate::schema::schema_registry::Client; - -macro_rules! embed_wkts { - [$( $path:literal ),+ $(,)?] => { - &[$( - ( - concat!("google/protobuf/", $path), - include_bytes!(concat!(env!("PROTO_INCLUDE"), "/google/protobuf/", $path)).as_slice(), - ) - ),+] - }; -} -const WELL_KNOWN_TYPES: &[(&str, &[u8])] = embed_wkts![ - "any.proto", - "api.proto", - "compiler/plugin.proto", - "descriptor.proto", - "duration.proto", - "empty.proto", - "field_mask.proto", - "source_context.proto", - "struct.proto", - "timestamp.proto", - "type.proto", - "wrappers.proto", -]; - -// Pull protobuf schema and all it's deps from the confluent schema registry, -// and compile then into one file descriptor -pub(super) async fn compile_file_descriptor_from_schema_registry( - subject_name: &str, - client: &Client, -) -> ConnectorResult> { - let (primary_subject, dependency_subjects) = client - .get_subject_and_references(subject_name) - .await - .with_context(|| format!("failed to resolve subject `{subject_name}`"))?; - - // Compile .proto files into a file descriptor set. - let mut source_tree = VirtualSourceTree::new(); - for subject in iter::once(&primary_subject).chain(dependency_subjects.iter()) { - source_tree.as_mut().add_file( - Path::new(&subject.name), - subject.schema.content.as_bytes().to_vec(), - ); - } - for (path, bytes) in WELL_KNOWN_TYPES { - source_tree - .as_mut() - .add_file(Path::new(path), bytes.to_vec()); - } - - let mut error_collector = SimpleErrorCollector::new(); - // `db` needs to be dropped before we can iterate on `error_collector`. - let fds = { - let mut db = SourceTreeDescriptorDatabase::new(source_tree.as_mut()); - db.as_mut().record_errors_to(error_collector.as_mut()); - db.as_mut() - .build_file_descriptor_set(&[Path::new(&primary_subject.name)]) - } - .with_context(|| { - format!( - "build_file_descriptor_set failed. Errors:\n{}", - error_collector.as_mut().join("\n") - ) - })?; - - let serialized = fds.serialize().context("serialize descriptor set failed")?; - Ok(serialized) -} diff --git a/src/connector/src/schema/avro.rs b/src/connector/src/schema/avro.rs index 553f6e71efca9..22c5fb4acadd1 100644 --- a/src/connector/src/schema/avro.rs +++ b/src/connector/src/schema/avro.rs @@ -16,89 +16,41 @@ use std::collections::BTreeMap; use std::sync::Arc; use apache_avro::Schema as AvroSchema; -use risingwave_pb::catalog::PbSchemaRegistryNameStrategy; -use super::schema_registry::{ - get_subject_by_strategy, handle_sr_list, name_strategy_from_str, Client, ConfluentSchema, - SchemaRegistryAuth, -}; -use super::{ - invalid_option_error, InvalidOptionError, SchemaFetchError, KEY_MESSAGE_NAME_KEY, - MESSAGE_NAME_KEY, NAME_STRATEGY_KEY, SCHEMA_REGISTRY_KEY, -}; +use super::loader::{LoadedSchema, SchemaLoader}; +use super::schema_registry::Subject; +use super::SchemaFetchError; pub struct SchemaWithId { pub schema: Arc, pub id: i32, } -impl TryFrom for SchemaWithId { - type Error = SchemaFetchError; - - fn try_from(fetched: ConfluentSchema) -> Result { - let parsed = AvroSchema::parse_str(&fetched.content) - .map_err(|e| SchemaFetchError::SchemaCompile(e.into()))?; - Ok(Self { - schema: Arc::new(parsed), - id: fetched.id, - }) - } -} - /// Schema registry only pub async fn fetch_schema( format_options: &BTreeMap, topic: &str, ) -> Result<(SchemaWithId, SchemaWithId), SchemaFetchError> { - let schema_location = format_options - .get(SCHEMA_REGISTRY_KEY) - .ok_or_else(|| invalid_option_error!("{SCHEMA_REGISTRY_KEY} required"))? - .clone(); - let client_config = format_options.into(); - let name_strategy = format_options - .get(NAME_STRATEGY_KEY) - .map(|s| { - name_strategy_from_str(s) - .ok_or_else(|| invalid_option_error!("unrecognized strategy {s}")) - }) - .transpose()? - .unwrap_or_default(); - let key_record_name = format_options - .get(KEY_MESSAGE_NAME_KEY) - .map(std::ops::Deref::deref); - let val_record_name = format_options - .get(MESSAGE_NAME_KEY) - .map(std::ops::Deref::deref); - - let (key_schema, val_schema) = fetch_schema_inner( - &schema_location, - &client_config, - &name_strategy, - topic, - key_record_name, - val_record_name, - ) - .await?; - - Ok((key_schema.try_into()?, val_schema.try_into()?)) + let loader = SchemaLoader::from_format_options(topic, format_options)?; + + let (key_id, key_avro) = loader.load_key_schema().await?; + let (val_id, val_avro) = loader.load_val_schema().await?; + + Ok(( + SchemaWithId { + id: key_id, + schema: Arc::new(key_avro), + }, + SchemaWithId { + id: val_id, + schema: Arc::new(val_avro), + }, + )) } -async fn fetch_schema_inner( - schema_location: &str, - client_config: &SchemaRegistryAuth, - name_strategy: &PbSchemaRegistryNameStrategy, - topic: &str, - key_record_name: Option<&str>, - val_record_name: Option<&str>, -) -> Result<(ConfluentSchema, ConfluentSchema), SchemaFetchError> { - let urls = handle_sr_list(schema_location)?; - let client = Client::new(urls, client_config)?; - - let key_subject = get_subject_by_strategy(name_strategy, topic, key_record_name, true)?; - let key_schema = client.get_schema_by_subject(&key_subject).await?; - - let val_subject = get_subject_by_strategy(name_strategy, topic, val_record_name, false)?; - let val_schema = client.get_schema_by_subject(&val_subject).await?; - - Ok((key_schema, val_schema)) +impl LoadedSchema for AvroSchema { + fn compile(primary: Subject, _: Vec) -> Result { + AvroSchema::parse_str(&primary.schema.content) + .map_err(|e| SchemaFetchError::SchemaCompile(e.into())) + } } diff --git a/src/connector/src/schema/loader.rs b/src/connector/src/schema/loader.rs new file mode 100644 index 0000000000000..a50d8cced575b --- /dev/null +++ b/src/connector/src/schema/loader.rs @@ -0,0 +1,95 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::BTreeMap; + +use risingwave_pb::catalog::PbSchemaRegistryNameStrategy; + +use super::schema_registry::{ + get_subject_by_strategy, handle_sr_list, name_strategy_from_str, Client, Subject, +}; +use super::{invalid_option_error, InvalidOptionError, SchemaFetchError}; + +const MESSAGE_NAME_KEY: &str = "message"; +const KEY_MESSAGE_NAME_KEY: &str = "key.message"; +const SCHEMA_LOCATION_KEY: &str = "schema.location"; +const SCHEMA_REGISTRY_KEY: &str = "schema.registry"; +const NAME_STRATEGY_KEY: &str = "schema.registry.name.strategy"; + +pub struct SchemaLoader { + pub client: Client, + pub name_strategy: PbSchemaRegistryNameStrategy, + pub topic: String, + pub key_record_name: Option, + pub val_record_name: Option, +} + +impl SchemaLoader { + pub fn from_format_options( + topic: &str, + format_options: &BTreeMap, + ) -> Result { + let schema_location = format_options + .get(SCHEMA_REGISTRY_KEY) + .ok_or_else(|| invalid_option_error!("{SCHEMA_REGISTRY_KEY} required"))?; + let client_config = format_options.into(); + let urls = handle_sr_list(schema_location)?; + let client = Client::new(urls, &client_config)?; + + let name_strategy = format_options + .get(NAME_STRATEGY_KEY) + .map(|s| { + name_strategy_from_str(s) + .ok_or_else(|| invalid_option_error!("unrecognized strategy {s}")) + }) + .transpose()? + .unwrap_or_default(); + let key_record_name = format_options.get(KEY_MESSAGE_NAME_KEY).cloned(); + let val_record_name = format_options.get(MESSAGE_NAME_KEY).cloned(); + + Ok(Self { + client, + name_strategy, + topic: topic.into(), + key_record_name, + val_record_name, + }) + } + + async fn load_schema( + &self, + record: Option<&str>, + ) -> Result<(i32, Out), SchemaFetchError> { + let subject = get_subject_by_strategy(&self.name_strategy, &self.topic, record, IS_KEY)?; + let (primary_subject, dependency_subjects) = + self.client.get_subject_and_references(&subject).await?; + let schema_id = primary_subject.schema.id; + let out = Out::compile(primary_subject, dependency_subjects)?; + Ok((schema_id, out)) + } + + pub async fn load_key_schema(&self) -> Result<(i32, Out), SchemaFetchError> { + self.load_schema::(self.key_record_name.as_deref()) + .await + } + + pub async fn load_val_schema(&self) -> Result<(i32, Out), SchemaFetchError> { + self.load_schema::(self.val_record_name.as_deref()) + .await + } +} + +pub trait LoadedSchema: Sized { + fn compile(primary: Subject, references: Vec) -> Result; +} diff --git a/src/connector/src/schema/mod.rs b/src/connector/src/schema/mod.rs index 8d2a9ae780572..28151e60895e9 100644 --- a/src/connector/src/schema/mod.rs +++ b/src/connector/src/schema/mod.rs @@ -15,9 +15,12 @@ use crate::error::ConnectorError; pub mod avro; +mod loader; pub mod protobuf; pub mod schema_registry; +pub use loader::SchemaLoader; + const MESSAGE_NAME_KEY: &str = "message"; const KEY_MESSAGE_NAME_KEY: &str = "key.message"; const SCHEMA_LOCATION_KEY: &str = "schema.location"; diff --git a/src/connector/src/schema/protobuf.rs b/src/connector/src/schema/protobuf.rs index 0d6116121977d..6229e2e3e223a 100644 --- a/src/connector/src/schema/protobuf.rs +++ b/src/connector/src/schema/protobuf.rs @@ -14,8 +14,11 @@ use std::collections::BTreeMap; -use prost_reflect::MessageDescriptor; +use itertools::Itertools as _; +use prost_reflect::{DescriptorPool, FileDescriptor, MessageDescriptor}; +use super::loader::LoadedSchema; +use super::schema_registry::Subject; use super::{ invalid_option_error, InvalidOptionError, SchemaFetchError, MESSAGE_NAME_KEY, SCHEMA_LOCATION_KEY, @@ -58,3 +61,95 @@ pub async fn fetch_descriptor( .map_err(SchemaFetchError::YetToMigrate)?; Ok(conf.message_descriptor) } + +impl LoadedSchema for FileDescriptor { + fn compile(primary: Subject, references: Vec) -> Result { + let primary_name = primary.name.clone(); + match compile_pb(primary, references) { + Err(e) => Err(SchemaFetchError::SchemaCompile(e.into())), + Ok(b) => { + let pool = DescriptorPool::decode(b.as_slice()) + .map_err(|e| SchemaFetchError::SchemaCompile(e.into()))?; + pool.get_file_by_name(&primary_name).ok_or_else(|| { + SchemaFetchError::SchemaCompile( + anyhow::anyhow!("{primary_name} lost after compilation").into(), + ) + }) + } + } + } +} + +macro_rules! embed_wkts { + [$( $path:literal ),+ $(,)?] => { + &[$( + ( + concat!("google/protobuf/", $path), + include_bytes!(concat!(env!("PROTO_INCLUDE"), "/google/protobuf/", $path)).as_slice(), + ) + ),+] + }; +} +const WELL_KNOWN_TYPES: &[(&str, &[u8])] = embed_wkts![ + "any.proto", + "api.proto", + "compiler/plugin.proto", + "descriptor.proto", + "duration.proto", + "empty.proto", + "field_mask.proto", + "source_context.proto", + "struct.proto", + "timestamp.proto", + "type.proto", + "wrappers.proto", +]; + +#[derive(Debug, thiserror::Error)] +pub enum PbCompileError { + #[error("build_file_descriptor_set failed\n{}", errs.iter().map(|e| format!("\t{e}")).join("\n"))] + Build { + errs: Vec, + }, + #[error("serialize descriptor set failed")] + Serialize, +} + +pub fn compile_pb( + primary_subject: Subject, + dependency_subjects: Vec, +) -> Result, PbCompileError> { + use std::iter; + use std::path::Path; + + use protobuf_native::compiler::{ + SimpleErrorCollector, SourceTreeDescriptorDatabase, VirtualSourceTree, + }; + use protobuf_native::MessageLite; + + let mut source_tree = VirtualSourceTree::new(); + for subject in iter::once(&primary_subject).chain(dependency_subjects.iter()) { + source_tree.as_mut().add_file( + Path::new(&subject.name), + subject.schema.content.as_bytes().to_vec(), + ); + } + for (path, bytes) in WELL_KNOWN_TYPES { + source_tree + .as_mut() + .add_file(Path::new(path), bytes.to_vec()); + } + + let mut error_collector = SimpleErrorCollector::new(); + // `db` needs to be dropped before we can iterate on `error_collector`. + let fds = { + let mut db = SourceTreeDescriptorDatabase::new(source_tree.as_mut()); + db.as_mut().record_errors_to(error_collector.as_mut()); + db.as_mut() + .build_file_descriptor_set(&[Path::new(&primary_subject.name)]) + } + .map_err(|_| PbCompileError::Build { + errs: error_collector.as_mut().collect(), + })?; + fds.serialize().map_err(|_| PbCompileError::Serialize) +}