Skip to content

Commit

Permalink
impl protobuf borrow
Browse files Browse the repository at this point in the history
Signed-off-by: Bugen Zhao <[email protected]>
  • Loading branch information
BugenZhao committed Jun 10, 2024
1 parent 7f6402a commit f888615
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 24 deletions.
7 changes: 6 additions & 1 deletion src/common/src/types/jsonb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,15 @@ impl<'a> JsonbRef<'a> {
}

/// Returns a jsonb `null` value.
pub fn null() -> Self {
pub const fn null() -> Self {
Self(ValueRef::Null)
}

/// Returns a value for empty string.
pub const fn empty_string() -> Self {
Self(ValueRef::String(""))
}

/// Returns true if this is a jsonb `null`.
pub fn is_jsonb_null(&self) -> bool {
self.0.is_null()
Expand Down
53 changes: 34 additions & 19 deletions src/connector/src/parser/protobuf/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ use prost_reflect::{
MessageDescriptor, ReflectMessage, Value,
};
use risingwave_common::array::{ListValue, StructValue};
use risingwave_common::types::{DataType, Datum, Decimal, JsonbVal, ScalarImpl, F32, F64};
use risingwave_common::types::{
DataType, Datum, DatumCow, Decimal, JsonbRef, JsonbVal, ScalarImpl, ScalarRefImpl, ToDatumRef,
ToOwnedDatum, F32, F64,
};
use risingwave_common::{bail, try_match_expand};
use risingwave_pb::plan_common::{AdditionalColumn, ColumnDesc, ColumnDescVersion};
use thiserror::Error;
Expand Down Expand Up @@ -344,22 +347,28 @@ fn recursive_parse_json(
serde_json::Value::Object(ret)
}

pub fn from_protobuf_value(
pub fn from_protobuf_value<'a>(
field_desc: &FieldDescriptor,
value: &Value,
value: &'a Value,
descriptor_pool: &Arc<DescriptorPool>,
) -> AccessResult {
) -> AccessResult<DatumCow<'a>> {
let kind = field_desc.kind();

let v = match value {
macro_rules! borrowed {
($v:expr) => {
return Ok(DatumCow::Borrowed(Some($v.into())))
};
}

let v: ScalarImpl = match value {
Value::Bool(v) => ScalarImpl::Bool(*v),
Value::I32(i) => ScalarImpl::Int32(*i),
Value::U32(i) => ScalarImpl::Int64(*i as i64),
Value::I64(i) => ScalarImpl::Int64(*i),
Value::U64(i) => ScalarImpl::Decimal(Decimal::from(*i)),
Value::F32(f) => ScalarImpl::Float32(F32::from(*f)),
Value::F64(f) => ScalarImpl::Float64(F64::from(*f)),
Value::String(s) => ScalarImpl::Utf8(s.as_str().into()),
Value::String(s) => borrowed!(s.as_str()),
Value::EnumNumber(idx) => {
let enum_desc = kind.as_enum().ok_or_else(|| AccessError::TypeError {
expected: "enum".to_owned(),
Expand All @@ -375,9 +384,7 @@ pub fn from_protobuf_value(
if dyn_msg.descriptor().full_name() == "google.protobuf.Any" {
// If the fields are not presented, default value is an empty string
if !dyn_msg.has_field_by_name("type_url") || !dyn_msg.has_field_by_name("value") {
return Ok(Some(ScalarImpl::Jsonb(JsonbVal::from(
serde_json::json! {""},
))));
borrowed!(JsonbRef::empty_string());
}

// Sanity check
Expand All @@ -391,9 +398,8 @@ pub fn from_protobuf_value(

let payload_field_desc = dyn_msg.descriptor().get_field_by_name("value").unwrap();

let Some(ScalarImpl::Bytea(payload)) =
from_protobuf_value(&payload_field_desc, &payload, descriptor_pool)?
else {
let payload = from_protobuf_value(&payload_field_desc, &payload, descriptor_pool)?;
let Some(ScalarRefImpl::Bytea(payload)) = payload.to_datum_ref() else {
bail_uncategorized!("expected bytes for dynamic message payload");
};

Expand All @@ -413,12 +419,13 @@ pub fn from_protobuf_value(
let full_name = msg_desc.clone().full_name().to_string();

// Decode the payload based on the `msg_desc`
let decoded_value = DynamicMessage::decode(msg_desc, payload.as_ref()).unwrap();
let decoded_value = DynamicMessage::decode(msg_desc, payload).unwrap();
let decoded_value = from_protobuf_value(
field_desc,
&Value::Message(decoded_value),
descriptor_pool,
)?
.to_owned_datum()
.unwrap();

// Extract the struct value
Expand Down Expand Up @@ -447,7 +454,9 @@ pub fn from_protobuf_value(
}
// use default value if dyn_msg doesn't has this field
let value = dyn_msg.get_field(&field_desc);
rw_values.push(from_protobuf_value(&field_desc, &value, descriptor_pool)?);
rw_values.push(
from_protobuf_value(&field_desc, &value, descriptor_pool)?.to_owned_datum(),
);
}
ScalarImpl::Struct(StructValue::new(rw_values))
}
Expand All @@ -461,14 +470,14 @@ pub fn from_protobuf_value(
}
ScalarImpl::List(ListValue::new(builder.finish()))
}
Value::Bytes(value) => ScalarImpl::Bytea(value.to_vec().into_boxed_slice()),
Value::Bytes(value) => borrowed!(&**value),
_ => {
return Err(AccessError::UnsupportedType {
ty: format!("{kind:?}"),
});
}
};
Ok(Some(v))
Ok(Some(v).into())
}

/// Maps protobuf type to RW type.
Expand Down Expand Up @@ -965,7 +974,9 @@ mod test {
let field = value.fields().next().unwrap().0;

if let Some(ret) =
from_protobuf_value(&field, &Value::Message(value), &conf.descriptor_pool).unwrap()
from_protobuf_value(&field, &Value::Message(value), &conf.descriptor_pool)
.unwrap()
.to_owned_datum()
{
println!("Decoded Value for ANY_GEN_PROTO_DATA: {:#?}", ret);
println!("---------------------------");
Expand Down Expand Up @@ -1026,7 +1037,9 @@ mod test {
let field = value.fields().next().unwrap().0;

if let Some(ret) =
from_protobuf_value(&field, &Value::Message(value), &conf.descriptor_pool).unwrap()
from_protobuf_value(&field, &Value::Message(value), &conf.descriptor_pool)
.unwrap()
.to_owned_datum()
{
println!("Decoded Value for ANY_GEN_PROTO_DATA: {:#?}", ret);
println!("---------------------------");
Expand Down Expand Up @@ -1098,7 +1111,9 @@ mod test {
let field = value.fields().next().unwrap().0;

if let Some(ret) =
from_protobuf_value(&field, &Value::Message(value), &conf.descriptor_pool).unwrap()
from_protobuf_value(&field, &Value::Message(value), &conf.descriptor_pool)
.unwrap()
.to_owned_datum()
{
println!("Decoded Value for ANY_RECURSIVE_GEN_PROTO_DATA: {:#?}", ret);
println!("---------------------------");
Expand Down
14 changes: 10 additions & 4 deletions src/connector/src/parser/unified/protobuf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::borrow::Cow;
use std::sync::{Arc, LazyLock};

use prost_reflect::{DescriptorPool, DynamicMessage, ReflectMessage};
use risingwave_common::log::LogSuppresser;
use risingwave_common::types::{DataType, DatumCow};
use risingwave_common::types::{DataType, DatumCow, ToOwnedDatum};
use thiserror_ext::AsReport;

use super::{Access, AccessResult};
Expand Down Expand Up @@ -56,9 +57,14 @@ impl Access for ProtobufAccess {
tracing::error!(suppressed_count, "{}", e.as_report());
}
})?;
let value = self.message.get_field(&field_desc);

// TODO: may borrow the value directly
from_protobuf_value(&field_desc, &value, &self.descriptor_pool).map(Into::into)
match self.message.get_field(&field_desc) {
Cow::Borrowed(value) => from_protobuf_value(&field_desc, value, &self.descriptor_pool),

// `Owned` variant occurs only if there's no such field and the default value is returned.
Cow::Owned(value) => from_protobuf_value(&field_desc, &value, &self.descriptor_pool)
// enforce `Owned` variant to avoid returning a reference to a temporary value
.map(|d| d.to_owned_datum().into()),
}
}
}

0 comments on commit f888615

Please sign in to comment.