diff --git a/src/common/src/types/jsonb.rs b/src/common/src/types/jsonb.rs index 824020fac3123..642b363a8c67e 100644 --- a/src/common/src/types/jsonb.rs +++ b/src/common/src/types/jsonb.rs @@ -239,10 +239,15 @@ impl<'a> JsonbRef<'a> { } /// Returns a jsonb `null` value. - pub fn null() -> Self { + pub const fn null() -> Self { Self(ValueRef::Null) } + /// Returns a value for empty string. + pub const fn empty_string() -> Self { + Self(ValueRef::String("")) + } + /// Returns true if this is a jsonb `null`. pub fn is_jsonb_null(&self) -> bool { self.0.is_null() diff --git a/src/connector/src/parser/protobuf/parser.rs b/src/connector/src/parser/protobuf/parser.rs index b5df0aeb83909..8be25074f6295 100644 --- a/src/connector/src/parser/protobuf/parser.rs +++ b/src/connector/src/parser/protobuf/parser.rs @@ -21,7 +21,10 @@ use prost_reflect::{ MessageDescriptor, ReflectMessage, Value, }; use risingwave_common::array::{ListValue, StructValue}; -use risingwave_common::types::{DataType, Datum, Decimal, JsonbVal, ScalarImpl, F32, F64}; +use risingwave_common::types::{ + DataType, Datum, DatumCow, Decimal, JsonbRef, JsonbVal, ScalarImpl, ScalarRefImpl, ToDatumRef, + ToOwnedDatum, F32, F64, +}; use risingwave_common::{bail, try_match_expand}; use risingwave_pb::plan_common::{AdditionalColumn, ColumnDesc, ColumnDescVersion}; use thiserror::Error; @@ -344,14 +347,20 @@ fn recursive_parse_json( serde_json::Value::Object(ret) } -pub fn from_protobuf_value( +pub fn from_protobuf_value<'a>( field_desc: &FieldDescriptor, - value: &Value, + value: &'a Value, descriptor_pool: &Arc, -) -> AccessResult { +) -> AccessResult> { let kind = field_desc.kind(); - let v = match value { + macro_rules! borrowed { + ($v:expr) => { + return Ok(DatumCow::Borrowed(Some($v.into()))) + }; + } + + let v: ScalarImpl = match value { Value::Bool(v) => ScalarImpl::Bool(*v), Value::I32(i) => ScalarImpl::Int32(*i), Value::U32(i) => ScalarImpl::Int64(*i as i64), @@ -359,7 +368,7 @@ pub fn from_protobuf_value( Value::U64(i) => ScalarImpl::Decimal(Decimal::from(*i)), Value::F32(f) => ScalarImpl::Float32(F32::from(*f)), Value::F64(f) => ScalarImpl::Float64(F64::from(*f)), - Value::String(s) => ScalarImpl::Utf8(s.as_str().into()), + Value::String(s) => borrowed!(s.as_str()), Value::EnumNumber(idx) => { let enum_desc = kind.as_enum().ok_or_else(|| AccessError::TypeError { expected: "enum".to_owned(), @@ -375,9 +384,7 @@ pub fn from_protobuf_value( if dyn_msg.descriptor().full_name() == "google.protobuf.Any" { // If the fields are not presented, default value is an empty string if !dyn_msg.has_field_by_name("type_url") || !dyn_msg.has_field_by_name("value") { - return Ok(Some(ScalarImpl::Jsonb(JsonbVal::from( - serde_json::json! {""}, - )))); + borrowed!(JsonbRef::empty_string()); } // Sanity check @@ -391,9 +398,8 @@ pub fn from_protobuf_value( let payload_field_desc = dyn_msg.descriptor().get_field_by_name("value").unwrap(); - let Some(ScalarImpl::Bytea(payload)) = - from_protobuf_value(&payload_field_desc, &payload, descriptor_pool)? - else { + let payload = from_protobuf_value(&payload_field_desc, &payload, descriptor_pool)?; + let Some(ScalarRefImpl::Bytea(payload)) = payload.to_datum_ref() else { bail_uncategorized!("expected bytes for dynamic message payload"); }; @@ -413,12 +419,13 @@ pub fn from_protobuf_value( let full_name = msg_desc.clone().full_name().to_string(); // Decode the payload based on the `msg_desc` - let decoded_value = DynamicMessage::decode(msg_desc, payload.as_ref()).unwrap(); + let decoded_value = DynamicMessage::decode(msg_desc, payload).unwrap(); let decoded_value = from_protobuf_value( field_desc, &Value::Message(decoded_value), descriptor_pool, )? + .to_owned_datum() .unwrap(); // Extract the struct value @@ -447,7 +454,9 @@ pub fn from_protobuf_value( } // use default value if dyn_msg doesn't has this field let value = dyn_msg.get_field(&field_desc); - rw_values.push(from_protobuf_value(&field_desc, &value, descriptor_pool)?); + rw_values.push( + from_protobuf_value(&field_desc, &value, descriptor_pool)?.to_owned_datum(), + ); } ScalarImpl::Struct(StructValue::new(rw_values)) } @@ -461,14 +470,14 @@ pub fn from_protobuf_value( } ScalarImpl::List(ListValue::new(builder.finish())) } - Value::Bytes(value) => ScalarImpl::Bytea(value.to_vec().into_boxed_slice()), + Value::Bytes(value) => borrowed!(&**value), _ => { return Err(AccessError::UnsupportedType { ty: format!("{kind:?}"), }); } }; - Ok(Some(v)) + Ok(Some(v).into()) } /// Maps protobuf type to RW type. @@ -965,7 +974,9 @@ mod test { let field = value.fields().next().unwrap().0; if let Some(ret) = - from_protobuf_value(&field, &Value::Message(value), &conf.descriptor_pool).unwrap() + from_protobuf_value(&field, &Value::Message(value), &conf.descriptor_pool) + .unwrap() + .to_owned_datum() { println!("Decoded Value for ANY_GEN_PROTO_DATA: {:#?}", ret); println!("---------------------------"); @@ -1026,7 +1037,9 @@ mod test { let field = value.fields().next().unwrap().0; if let Some(ret) = - from_protobuf_value(&field, &Value::Message(value), &conf.descriptor_pool).unwrap() + from_protobuf_value(&field, &Value::Message(value), &conf.descriptor_pool) + .unwrap() + .to_owned_datum() { println!("Decoded Value for ANY_GEN_PROTO_DATA: {:#?}", ret); println!("---------------------------"); @@ -1098,7 +1111,9 @@ mod test { let field = value.fields().next().unwrap().0; if let Some(ret) = - from_protobuf_value(&field, &Value::Message(value), &conf.descriptor_pool).unwrap() + from_protobuf_value(&field, &Value::Message(value), &conf.descriptor_pool) + .unwrap() + .to_owned_datum() { println!("Decoded Value for ANY_RECURSIVE_GEN_PROTO_DATA: {:#?}", ret); println!("---------------------------"); diff --git a/src/connector/src/parser/unified/protobuf.rs b/src/connector/src/parser/unified/protobuf.rs index 4bea2cbab306b..02febc22db247 100644 --- a/src/connector/src/parser/unified/protobuf.rs +++ b/src/connector/src/parser/unified/protobuf.rs @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::borrow::Cow; use std::sync::{Arc, LazyLock}; use prost_reflect::{DescriptorPool, DynamicMessage, ReflectMessage}; use risingwave_common::log::LogSuppresser; -use risingwave_common::types::{DataType, DatumCow}; +use risingwave_common::types::{DataType, DatumCow, ToOwnedDatum}; use thiserror_ext::AsReport; use super::{Access, AccessResult}; @@ -56,9 +57,14 @@ impl Access for ProtobufAccess { tracing::error!(suppressed_count, "{}", e.as_report()); } })?; - let value = self.message.get_field(&field_desc); - // TODO: may borrow the value directly - from_protobuf_value(&field_desc, &value, &self.descriptor_pool).map(Into::into) + match self.message.get_field(&field_desc) { + Cow::Borrowed(value) => from_protobuf_value(&field_desc, value, &self.descriptor_pool), + + // `Owned` variant occurs only if there's no such field and the default value is returned. + Cow::Owned(value) => from_protobuf_value(&field_desc, &value, &self.descriptor_pool) + // enforce `Owned` variant to avoid returning a reference to a temporary value + .map(|d| d.to_owned_datum().into()), + } } }