Skip to content

Commit

Permalink
share schema registry loader for avro/protobuf
Browse files Browse the repository at this point in the history
Referenced official SDKs:
* python https://docs.confluent.io/platform/current/clients/confluent-kafka-python/html/_modules/confluent_kafka/schema_registry/protobuf.html#ProtobufSerializer.__call__
* golang https://github.com/confluentinc/confluent-kafka-go/blob/v2.3.0/schemaregistry/serde/protobuf/protobuf.go#L357C18-L357C18
* java https://github.com/confluentinc/schema-registry/blob/v7.7.0-198/protobuf-serializer/src/main/java/io/confluent/kafka/serializers/protobuf/AbstractKafkaProtobufDeserializer.java#L134

Note:
* Official Deserializer does not resolve name strategy. It is only meant for Serializer auto register. Deserializer read by schema_id.
* Deserializer in python ignores schema_id / index array, trusting caller provided descriptor; while golang decodes them. The Java handling logic is more complicated.
  • Loading branch information
xiangjinwu committed Feb 20, 2024
1 parent 6739733 commit 2e8d723
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 95 deletions.
1 change: 0 additions & 1 deletion src/connector/src/parser/protobuf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

mod parser;
pub use parser::*;
mod schema_resolver;

#[rustfmt::skip]
#[cfg(test)]
Expand Down
37 changes: 17 additions & 20 deletions src/connector/src/parser/protobuf/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ use std::sync::Arc;
use anyhow::Context;
use itertools::Itertools;
use prost_reflect::{
Cardinality, DescriptorPool, DynamicMessage, FieldDescriptor, Kind, MessageDescriptor,
ReflectMessage, Value,
Cardinality, DescriptorPool, DynamicMessage, FieldDescriptor, FileDescriptor, Kind,
MessageDescriptor, ReflectMessage, Value,
};
use risingwave_common::array::{ListValue, StructValue};
use risingwave_common::types::{DataType, Datum, Decimal, JsonbVal, ScalarImpl, F32, F64};
Expand All @@ -27,16 +27,14 @@ use risingwave_pb::plan_common::{AdditionalColumn, ColumnDesc, ColumnDescVersion
use thiserror::Error;
use thiserror_ext::{AsReport, Macro};

use super::schema_resolver::*;
use crate::parser::unified::protobuf::ProtobufAccess;
use crate::parser::unified::{
bail_uncategorized, uncategorized, AccessError, AccessImpl, AccessResult,
};
use crate::parser::util::bytes_from_url;
use crate::parser::{AccessBuilder, EncodingProperties};
use crate::schema::schema_registry::{
extract_schema_id, get_subject_by_strategy, handle_sr_list, Client,
};
use crate::schema::schema_registry::{extract_schema_id, handle_sr_list, Client};
use crate::schema::SchemaLoader;

#[derive(Debug)]
pub struct ProtobufAccessBuilder {
Expand Down Expand Up @@ -99,25 +97,24 @@ impl ProtobufParserConfig {
// https://docs.confluent.io/platform/7.5/control-center/topics/schema.html#c3-schemas-best-practices-key-value-pairs
bail!("protobuf key is not supported");
}
let schema_bytes = if protobuf_config.use_schema_registry {
let schema_value = get_subject_by_strategy(
&protobuf_config.name_strategy,
protobuf_config.topic.as_str(),
Some(message_name.as_ref()),
false,
)?;
tracing::debug!("infer value subject {schema_value}");

let pool = if protobuf_config.use_schema_registry {
let client = Client::new(url, &protobuf_config.client_config)?;
compile_file_descriptor_from_schema_registry(schema_value.as_str(), &client).await?
let loader = SchemaLoader {
client,
name_strategy: protobuf_config.name_strategy,
topic: protobuf_config.topic,
key_record_name: None,
val_record_name: Some(message_name.clone()),
};
let (_, x): (_, FileDescriptor) = loader.load_schema::<_, false>().await.unwrap();
x.parent_pool().clone()
} else {
let url = url.first().unwrap();
bytes_from_url(url, protobuf_config.aws_auth_props.as_ref()).await?
let schema_bytes = bytes_from_url(url, protobuf_config.aws_auth_props.as_ref()).await?;
DescriptorPool::decode(schema_bytes.as_slice())
.with_context(|| format!("cannot build descriptor pool from schema `{location}`"))?
};

let pool = DescriptorPool::decode(schema_bytes.as_slice())
.with_context(|| format!("cannot build descriptor pool from schema `{}`", location))?;

let message_descriptor = pool.get_message_by_name(message_name).with_context(|| {
format!(
"cannot find message `{}` in schema `{}`",
Expand Down
94 changes: 23 additions & 71 deletions src/connector/src/schema/avro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,89 +16,41 @@ use std::collections::BTreeMap;
use std::sync::Arc;

use apache_avro::Schema as AvroSchema;
use risingwave_pb::catalog::PbSchemaRegistryNameStrategy;

use super::schema_registry::{
get_subject_by_strategy, handle_sr_list, name_strategy_from_str, Client, ConfluentSchema,
SchemaRegistryAuth,
};
use super::{
invalid_option_error, InvalidOptionError, SchemaFetchError, KEY_MESSAGE_NAME_KEY,
MESSAGE_NAME_KEY, NAME_STRATEGY_KEY, SCHEMA_REGISTRY_KEY,
};
use super::loader::{LoadedSchema, SchemaLoader};
use super::schema_registry::Subject;
use super::SchemaFetchError;

pub struct SchemaWithId {
pub schema: Arc<AvroSchema>,
pub id: i32,
}

impl TryFrom<ConfluentSchema> for SchemaWithId {
type Error = SchemaFetchError;

fn try_from(fetched: ConfluentSchema) -> Result<Self, Self::Error> {
let parsed = AvroSchema::parse_str(&fetched.content)
.map_err(|e| SchemaFetchError::SchemaCompile(e.into()))?;
Ok(Self {
schema: Arc::new(parsed),
id: fetched.id,
})
}
}

/// Schema registry only
pub async fn fetch_schema(
format_options: &BTreeMap<String, String>,
topic: &str,
) -> Result<(SchemaWithId, SchemaWithId), SchemaFetchError> {
let schema_location = format_options
.get(SCHEMA_REGISTRY_KEY)
.ok_or_else(|| invalid_option_error!("{SCHEMA_REGISTRY_KEY} required"))?
.clone();
let client_config = format_options.into();
let name_strategy = format_options
.get(NAME_STRATEGY_KEY)
.map(|s| {
name_strategy_from_str(s)
.ok_or_else(|| invalid_option_error!("unrecognized strategy {s}"))
})
.transpose()?
.unwrap_or_default();
let key_record_name = format_options
.get(KEY_MESSAGE_NAME_KEY)
.map(std::ops::Deref::deref);
let val_record_name = format_options
.get(MESSAGE_NAME_KEY)
.map(std::ops::Deref::deref);

let (key_schema, val_schema) = fetch_schema_inner(
&schema_location,
&client_config,
&name_strategy,
topic,
key_record_name,
val_record_name,
)
.await?;

Ok((key_schema.try_into()?, val_schema.try_into()?))
let loader = SchemaLoader::from_format_options(topic, format_options)?;

let (kid, kav) = loader.load_schema::<_, true>().await?;
let (vid, vav) = loader.load_schema::<_, false>().await?;

Ok((
SchemaWithId {
id: kid,
schema: Arc::new(kav),
},
SchemaWithId {
id: vid,
schema: Arc::new(vav),
},
))
}

async fn fetch_schema_inner(
schema_location: &str,
client_config: &SchemaRegistryAuth,
name_strategy: &PbSchemaRegistryNameStrategy,
topic: &str,
key_record_name: Option<&str>,
val_record_name: Option<&str>,
) -> Result<(ConfluentSchema, ConfluentSchema), SchemaFetchError> {
let urls = handle_sr_list(schema_location)?;
let client = Client::new(urls, client_config)?;

let key_subject = get_subject_by_strategy(name_strategy, topic, key_record_name, true)?;
let key_schema = client.get_schema_by_subject(&key_subject).await?;

let val_subject = get_subject_by_strategy(name_strategy, topic, val_record_name, false)?;
let val_schema = client.get_schema_by_subject(&val_subject).await?;

Ok((key_schema, val_schema))
impl LoadedSchema for AvroSchema {
fn compile(primary: Subject, _: Vec<Subject>) -> Result<Self, SchemaFetchError> {
AvroSchema::parse_str(&primary.schema.content)
.map_err(|e| SchemaFetchError::SchemaCompile(e.into()))
}
}
99 changes: 99 additions & 0 deletions src/connector/src/schema/loader.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::BTreeMap;

use risingwave_pb::catalog::PbSchemaRegistryNameStrategy;

use super::schema_registry::{
get_subject_by_strategy, handle_sr_list, name_strategy_from_str, Client, Subject,
};
use super::{invalid_option_error, InvalidOptionError, SchemaFetchError};

const MESSAGE_NAME_KEY: &str = "message";
const KEY_MESSAGE_NAME_KEY: &str = "key.message";
const SCHEMA_LOCATION_KEY: &str = "schema.location";
const SCHEMA_REGISTRY_KEY: &str = "schema.registry";
const NAME_STRATEGY_KEY: &str = "schema.registry.name.strategy";

pub struct SchemaLoader {
pub client: Client,
pub name_strategy: PbSchemaRegistryNameStrategy,
pub topic: String,
pub key_record_name: Option<String>,
pub val_record_name: Option<String>,
}

impl SchemaLoader {
pub fn from_format_options(
topic: &str,
format_options: &BTreeMap<String, String>,
) -> Result<Self, SchemaFetchError> {
let schema_location = format_options
.get(SCHEMA_REGISTRY_KEY)
.ok_or_else(|| invalid_option_error!("{SCHEMA_REGISTRY_KEY} required"))?;
let client_config = format_options.into();
let urls = handle_sr_list(schema_location)?;
let client = Client::new(urls, &client_config)?;

let name_strategy = format_options
.get(NAME_STRATEGY_KEY)
.map(|s| {
name_strategy_from_str(s)
.ok_or_else(|| invalid_option_error!("unrecognized strategy {s}"))
})
.transpose()?
.unwrap_or_default();
let key_record_name = format_options.get(KEY_MESSAGE_NAME_KEY).cloned();
let val_record_name = format_options.get(MESSAGE_NAME_KEY).cloned();

Ok(Self {
client,
name_strategy,
topic: topic.into(),
key_record_name,
val_record_name,
})
}

pub async fn load_schema<O, const IS_KEY: bool>(&self) -> Result<(i32, O), SchemaFetchError>
where
// O: TryFrom<(Subject, Vec<Subject>), Error = SchemaFetchError>,
O: LoadedSchema,
{
let subject = get_subject_by_strategy(
&self.name_strategy,
&self.topic,
self.key_record_name.as_deref(),
IS_KEY,
)?;
// let loaded = self.client.get_schema_by_subject(&subject).await?;
let (primary_subject, dependency_subjects) =
self.client.get_subject_and_references(&subject).await?;
let schema_id = primary_subject.schema.id;
let o = O::compile(primary_subject, dependency_subjects)?;
Ok((schema_id, o))
}
}

pub trait LoadedSchema: Sized {
fn compile(primary: Subject, references: Vec<Subject>) -> Result<Self, SchemaFetchError>;
}

// load key: returns none for url or pb, some for sr-avro
// load value:

// post-fetch enc-specific parsing
// * url: avro::parse_str or pb::decode
// * sr: avro::parse_str or compile + pb::decode
3 changes: 3 additions & 0 deletions src/connector/src/schema/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
// limitations under the License.

pub mod avro;
mod loader;
pub mod protobuf;
pub mod schema_registry;

pub use loader::SchemaLoader;

const MESSAGE_NAME_KEY: &str = "message";
const KEY_MESSAGE_NAME_KEY: &str = "key.message";
const SCHEMA_LOCATION_KEY: &str = "schema.location";
Expand Down
Loading

0 comments on commit 2e8d723

Please sign in to comment.