From 2e8d7232463843f920ad81c7e513cc6a03fc7339 Mon Sep 17 00:00:00 2001 From: Xiangjin Date: Fri, 29 Dec 2023 16:22:23 +0800 Subject: [PATCH] share schema registry loader for avro/protobuf Referenced official SDKs: * python https://docs.confluent.io/platform/current/clients/confluent-kafka-python/html/_modules/confluent_kafka/schema_registry/protobuf.html#ProtobufSerializer.__call__ * golang https://github.com/confluentinc/confluent-kafka-go/blob/v2.3.0/schemaregistry/serde/protobuf/protobuf.go#L357C18-L357C18 * java https://github.com/confluentinc/schema-registry/blob/v7.7.0-198/protobuf-serializer/src/main/java/io/confluent/kafka/serializers/protobuf/AbstractKafkaProtobufDeserializer.java#L134 Note: * Official Deserializer does not resolve name strategy. It is only meant for Serializer auto register. Deserializer read by schema_id. * Deserializer in python ignores schema_id / index array, trusting caller provided descriptor; while golang decodes them. The Java handling logic is more complicated. --- src/connector/src/parser/protobuf/mod.rs | 1 - src/connector/src/parser/protobuf/parser.rs | 37 ++++--- src/connector/src/schema/avro.rs | 94 +++++------------- src/connector/src/schema/loader.rs | 99 +++++++++++++++++++ src/connector/src/schema/mod.rs | 3 + src/connector/src/schema/protobuf.rs | 87 +++++++++++++++- .../src/schema/schema_registry/util.rs | 2 - 7 files changed, 228 insertions(+), 95 deletions(-) create mode 100644 src/connector/src/schema/loader.rs diff --git a/src/connector/src/parser/protobuf/mod.rs b/src/connector/src/parser/protobuf/mod.rs index c6a7e23357b0d..bfcb0adfe1a18 100644 --- a/src/connector/src/parser/protobuf/mod.rs +++ b/src/connector/src/parser/protobuf/mod.rs @@ -14,7 +14,6 @@ mod parser; pub use parser::*; -mod schema_resolver; #[rustfmt::skip] #[cfg(test)] diff --git a/src/connector/src/parser/protobuf/parser.rs b/src/connector/src/parser/protobuf/parser.rs index 922705e3d3f8f..558918e11129a 100644 --- a/src/connector/src/parser/protobuf/parser.rs +++ b/src/connector/src/parser/protobuf/parser.rs @@ -17,8 +17,8 @@ use std::sync::Arc; use anyhow::Context; use itertools::Itertools; use prost_reflect::{ - Cardinality, DescriptorPool, DynamicMessage, FieldDescriptor, Kind, MessageDescriptor, - ReflectMessage, Value, + Cardinality, DescriptorPool, DynamicMessage, FieldDescriptor, FileDescriptor, Kind, + MessageDescriptor, ReflectMessage, Value, }; use risingwave_common::array::{ListValue, StructValue}; use risingwave_common::types::{DataType, Datum, Decimal, JsonbVal, ScalarImpl, F32, F64}; @@ -27,16 +27,14 @@ use risingwave_pb::plan_common::{AdditionalColumn, ColumnDesc, ColumnDescVersion use thiserror::Error; use thiserror_ext::{AsReport, Macro}; -use super::schema_resolver::*; use crate::parser::unified::protobuf::ProtobufAccess; use crate::parser::unified::{ bail_uncategorized, uncategorized, AccessError, AccessImpl, AccessResult, }; use crate::parser::util::bytes_from_url; use crate::parser::{AccessBuilder, EncodingProperties}; -use crate::schema::schema_registry::{ - extract_schema_id, get_subject_by_strategy, handle_sr_list, Client, -}; +use crate::schema::schema_registry::{extract_schema_id, handle_sr_list, Client}; +use crate::schema::SchemaLoader; #[derive(Debug)] pub struct ProtobufAccessBuilder { @@ -99,25 +97,24 @@ impl ProtobufParserConfig { // https://docs.confluent.io/platform/7.5/control-center/topics/schema.html#c3-schemas-best-practices-key-value-pairs bail!("protobuf key is not supported"); } - let schema_bytes = if protobuf_config.use_schema_registry { - let schema_value = get_subject_by_strategy( - &protobuf_config.name_strategy, - protobuf_config.topic.as_str(), - Some(message_name.as_ref()), - false, - )?; - tracing::debug!("infer value subject {schema_value}"); - + let pool = if protobuf_config.use_schema_registry { let client = Client::new(url, &protobuf_config.client_config)?; - compile_file_descriptor_from_schema_registry(schema_value.as_str(), &client).await? + let loader = SchemaLoader { + client, + name_strategy: protobuf_config.name_strategy, + topic: protobuf_config.topic, + key_record_name: None, + val_record_name: Some(message_name.clone()), + }; + let (_, x): (_, FileDescriptor) = loader.load_schema::<_, false>().await.unwrap(); + x.parent_pool().clone() } else { let url = url.first().unwrap(); - bytes_from_url(url, protobuf_config.aws_auth_props.as_ref()).await? + let schema_bytes = bytes_from_url(url, protobuf_config.aws_auth_props.as_ref()).await?; + DescriptorPool::decode(schema_bytes.as_slice()) + .with_context(|| format!("cannot build descriptor pool from schema `{location}`"))? }; - let pool = DescriptorPool::decode(schema_bytes.as_slice()) - .with_context(|| format!("cannot build descriptor pool from schema `{}`", location))?; - let message_descriptor = pool.get_message_by_name(message_name).with_context(|| { format!( "cannot find message `{}` in schema `{}`", diff --git a/src/connector/src/schema/avro.rs b/src/connector/src/schema/avro.rs index 553f6e71efca9..01855d7b41c2e 100644 --- a/src/connector/src/schema/avro.rs +++ b/src/connector/src/schema/avro.rs @@ -16,89 +16,41 @@ use std::collections::BTreeMap; use std::sync::Arc; use apache_avro::Schema as AvroSchema; -use risingwave_pb::catalog::PbSchemaRegistryNameStrategy; -use super::schema_registry::{ - get_subject_by_strategy, handle_sr_list, name_strategy_from_str, Client, ConfluentSchema, - SchemaRegistryAuth, -}; -use super::{ - invalid_option_error, InvalidOptionError, SchemaFetchError, KEY_MESSAGE_NAME_KEY, - MESSAGE_NAME_KEY, NAME_STRATEGY_KEY, SCHEMA_REGISTRY_KEY, -}; +use super::loader::{LoadedSchema, SchemaLoader}; +use super::schema_registry::Subject; +use super::SchemaFetchError; pub struct SchemaWithId { pub schema: Arc, pub id: i32, } -impl TryFrom for SchemaWithId { - type Error = SchemaFetchError; - - fn try_from(fetched: ConfluentSchema) -> Result { - let parsed = AvroSchema::parse_str(&fetched.content) - .map_err(|e| SchemaFetchError::SchemaCompile(e.into()))?; - Ok(Self { - schema: Arc::new(parsed), - id: fetched.id, - }) - } -} - /// Schema registry only pub async fn fetch_schema( format_options: &BTreeMap, topic: &str, ) -> Result<(SchemaWithId, SchemaWithId), SchemaFetchError> { - let schema_location = format_options - .get(SCHEMA_REGISTRY_KEY) - .ok_or_else(|| invalid_option_error!("{SCHEMA_REGISTRY_KEY} required"))? - .clone(); - let client_config = format_options.into(); - let name_strategy = format_options - .get(NAME_STRATEGY_KEY) - .map(|s| { - name_strategy_from_str(s) - .ok_or_else(|| invalid_option_error!("unrecognized strategy {s}")) - }) - .transpose()? - .unwrap_or_default(); - let key_record_name = format_options - .get(KEY_MESSAGE_NAME_KEY) - .map(std::ops::Deref::deref); - let val_record_name = format_options - .get(MESSAGE_NAME_KEY) - .map(std::ops::Deref::deref); - - let (key_schema, val_schema) = fetch_schema_inner( - &schema_location, - &client_config, - &name_strategy, - topic, - key_record_name, - val_record_name, - ) - .await?; - - Ok((key_schema.try_into()?, val_schema.try_into()?)) + let loader = SchemaLoader::from_format_options(topic, format_options)?; + + let (kid, kav) = loader.load_schema::<_, true>().await?; + let (vid, vav) = loader.load_schema::<_, false>().await?; + + Ok(( + SchemaWithId { + id: kid, + schema: Arc::new(kav), + }, + SchemaWithId { + id: vid, + schema: Arc::new(vav), + }, + )) } -async fn fetch_schema_inner( - schema_location: &str, - client_config: &SchemaRegistryAuth, - name_strategy: &PbSchemaRegistryNameStrategy, - topic: &str, - key_record_name: Option<&str>, - val_record_name: Option<&str>, -) -> Result<(ConfluentSchema, ConfluentSchema), SchemaFetchError> { - let urls = handle_sr_list(schema_location)?; - let client = Client::new(urls, client_config)?; - - let key_subject = get_subject_by_strategy(name_strategy, topic, key_record_name, true)?; - let key_schema = client.get_schema_by_subject(&key_subject).await?; - - let val_subject = get_subject_by_strategy(name_strategy, topic, val_record_name, false)?; - let val_schema = client.get_schema_by_subject(&val_subject).await?; - - Ok((key_schema, val_schema)) +impl LoadedSchema for AvroSchema { + fn compile(primary: Subject, _: Vec) -> Result { + AvroSchema::parse_str(&primary.schema.content) + .map_err(|e| SchemaFetchError::SchemaCompile(e.into())) + } } diff --git a/src/connector/src/schema/loader.rs b/src/connector/src/schema/loader.rs new file mode 100644 index 0000000000000..96e11973d121b --- /dev/null +++ b/src/connector/src/schema/loader.rs @@ -0,0 +1,99 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::BTreeMap; + +use risingwave_pb::catalog::PbSchemaRegistryNameStrategy; + +use super::schema_registry::{ + get_subject_by_strategy, handle_sr_list, name_strategy_from_str, Client, Subject, +}; +use super::{invalid_option_error, InvalidOptionError, SchemaFetchError}; + +const MESSAGE_NAME_KEY: &str = "message"; +const KEY_MESSAGE_NAME_KEY: &str = "key.message"; +const SCHEMA_LOCATION_KEY: &str = "schema.location"; +const SCHEMA_REGISTRY_KEY: &str = "schema.registry"; +const NAME_STRATEGY_KEY: &str = "schema.registry.name.strategy"; + +pub struct SchemaLoader { + pub client: Client, + pub name_strategy: PbSchemaRegistryNameStrategy, + pub topic: String, + pub key_record_name: Option, + pub val_record_name: Option, +} + +impl SchemaLoader { + pub fn from_format_options( + topic: &str, + format_options: &BTreeMap, + ) -> Result { + let schema_location = format_options + .get(SCHEMA_REGISTRY_KEY) + .ok_or_else(|| invalid_option_error!("{SCHEMA_REGISTRY_KEY} required"))?; + let client_config = format_options.into(); + let urls = handle_sr_list(schema_location)?; + let client = Client::new(urls, &client_config)?; + + let name_strategy = format_options + .get(NAME_STRATEGY_KEY) + .map(|s| { + name_strategy_from_str(s) + .ok_or_else(|| invalid_option_error!("unrecognized strategy {s}")) + }) + .transpose()? + .unwrap_or_default(); + let key_record_name = format_options.get(KEY_MESSAGE_NAME_KEY).cloned(); + let val_record_name = format_options.get(MESSAGE_NAME_KEY).cloned(); + + Ok(Self { + client, + name_strategy, + topic: topic.into(), + key_record_name, + val_record_name, + }) + } + + pub async fn load_schema(&self) -> Result<(i32, O), SchemaFetchError> + where + // O: TryFrom<(Subject, Vec), Error = SchemaFetchError>, + O: LoadedSchema, + { + let subject = get_subject_by_strategy( + &self.name_strategy, + &self.topic, + self.key_record_name.as_deref(), + IS_KEY, + )?; + // let loaded = self.client.get_schema_by_subject(&subject).await?; + let (primary_subject, dependency_subjects) = + self.client.get_subject_and_references(&subject).await?; + let schema_id = primary_subject.schema.id; + let o = O::compile(primary_subject, dependency_subjects)?; + Ok((schema_id, o)) + } +} + +pub trait LoadedSchema: Sized { + fn compile(primary: Subject, references: Vec) -> Result; +} + +// load key: returns none for url or pb, some for sr-avro +// load value: + +// post-fetch enc-specific parsing +// * url: avro::parse_str or pb::decode +// * sr: avro::parse_str or compile + pb::decode diff --git a/src/connector/src/schema/mod.rs b/src/connector/src/schema/mod.rs index 0f64ee85e57dd..37c8ae550a002 100644 --- a/src/connector/src/schema/mod.rs +++ b/src/connector/src/schema/mod.rs @@ -13,9 +13,12 @@ // limitations under the License. pub mod avro; +mod loader; pub mod protobuf; pub mod schema_registry; +pub use loader::SchemaLoader; + const MESSAGE_NAME_KEY: &str = "message"; const KEY_MESSAGE_NAME_KEY: &str = "key.message"; const SCHEMA_LOCATION_KEY: &str = "schema.location"; diff --git a/src/connector/src/schema/protobuf.rs b/src/connector/src/schema/protobuf.rs index 0d6116121977d..41eed545145c0 100644 --- a/src/connector/src/schema/protobuf.rs +++ b/src/connector/src/schema/protobuf.rs @@ -14,8 +14,10 @@ use std::collections::BTreeMap; -use prost_reflect::MessageDescriptor; +use prost_reflect::{DescriptorPool, FileDescriptor, MessageDescriptor}; +use super::loader::LoadedSchema; +use super::schema_registry::Subject; use super::{ invalid_option_error, InvalidOptionError, SchemaFetchError, MESSAGE_NAME_KEY, SCHEMA_LOCATION_KEY, @@ -58,3 +60,86 @@ pub async fn fetch_descriptor( .map_err(SchemaFetchError::YetToMigrate)?; Ok(conf.message_descriptor) } + +impl LoadedSchema for FileDescriptor { + fn compile(primary: Subject, references: Vec) -> Result { + let primary_name = primary.name.clone(); + match compile_pb(primary, references) { + Err(e) => Err(SchemaFetchError::SchemaCompile(e.into())), + Ok(b) => { + let pool = DescriptorPool::decode(b.as_slice()).unwrap(); + Ok(pool.get_file_by_name(&primary_name).unwrap()) + } + } + } +} + +macro_rules! embed_wkts { + [$( $path:literal ),+ $(,)?] => { + &[$( + ( + concat!("google/protobuf/", $path), + include_bytes!(concat!(env!("PROTO_INCLUDE"), "/google/protobuf/", $path)).as_slice(), + ) + ),+] + }; +} +const WELL_KNOWN_TYPES: &[(&str, &[u8])] = embed_wkts![ + "any.proto", + "api.proto", + "compiler/plugin.proto", + "descriptor.proto", + "duration.proto", + "empty.proto", + "field_mask.proto", + "source_context.proto", + "struct.proto", + "timestamp.proto", + "type.proto", + "wrappers.proto", +]; + +#[derive(Debug, thiserror::Error)] +pub enum PbCompileError { + #[error("build_file_descriptor_set failed")] + Build(Vec), + #[error("serialize descriptor set failed")] + Serialize, +} + +pub fn compile_pb( + primary_subject: Subject, + dependency_subjects: Vec, +) -> Result, PbCompileError> { + use std::iter; + use std::path::Path; + + use protobuf_native::compiler::{ + SimpleErrorCollector, SourceTreeDescriptorDatabase, VirtualSourceTree, + }; + use protobuf_native::MessageLite; + + let mut source_tree = VirtualSourceTree::new(); + for subject in iter::once(&primary_subject).chain(dependency_subjects.iter()) { + source_tree.as_mut().add_file( + Path::new(&subject.name), + subject.schema.content.as_bytes().to_vec(), + ); + } + for (path, bytes) in WELL_KNOWN_TYPES { + source_tree + .as_mut() + .add_file(Path::new(path), bytes.to_vec()); + } + + let mut error_collector = SimpleErrorCollector::new(); + // `db` needs to be dropped before we can iterate on `error_collector`. + let fds = { + let mut db = SourceTreeDescriptorDatabase::new(source_tree.as_mut()); + db.as_mut().record_errors_to(error_collector.as_mut()); + db.as_mut() + .build_file_descriptor_set(&[Path::new(&primary_subject.name)]) + } + .map_err(|_| PbCompileError::Build(error_collector.as_mut().collect()))?; + fds.serialize().map_err(|_| PbCompileError::Serialize) +} diff --git a/src/connector/src/schema/schema_registry/util.rs b/src/connector/src/schema/schema_registry/util.rs index 407534b1a5671..baed438da1269 100644 --- a/src/connector/src/schema/schema_registry/util.rs +++ b/src/connector/src/schema/schema_registry/util.rs @@ -147,8 +147,6 @@ pub struct Subject { /// (e.g., import "other.proto" in protobuf) #[derive(Debug, Deserialize)] pub struct SchemaReference { - /// The name of the reference. - pub name: String, /// The subject that the referenced schema belongs to pub subject: String, /// The version of the referenced schema