Skip to content

Commit

Permalink
feat(source): add NATS source consumer parameters (#17615)
Browse files Browse the repository at this point in the history
Co-authored-by: benjamin-awd <[email protected]>
  • Loading branch information
2 people authored and kwannoel committed Aug 27, 2024
1 parent 2c5d15e commit 80ff556
Show file tree
Hide file tree
Showing 6 changed files with 393 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/connector/src/connector_common/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,7 @@ impl NatsCommon {
stream: String,
split_id: String,
start_sequence: NatsOffset,
mut config: jetstream::consumer::pull::Config,
) -> ConnectorResult<
async_nats::jetstream::consumer::Consumer<async_nats::jetstream::consumer::pull::Config>,
> {
Expand All @@ -649,10 +650,6 @@ impl NatsCommon {
.replace(',', "-")
.replace(['.', '>', '*', ' ', '\t'], "_");
let name = format!("risingwave-consumer-{}-{}", subject_name, split_id);
let mut config = jetstream::consumer::pull::Config {
ack_policy: jetstream::consumer::AckPolicy::None,
..Default::default()
};

let deliver_policy = match start_sequence {
NatsOffset::Earliest => DeliverPolicy::All,
Expand All @@ -671,6 +668,7 @@ impl NatsCommon {
},
NatsOffset::None => DeliverPolicy::All,
};

let consumer = stream
.get_or_create_consumer(&name, {
config.deliver_policy = deliver_policy;
Expand Down
19 changes: 19 additions & 0 deletions src/connector/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,25 @@ where
}
}

pub(crate) fn deserialize_optional_u64_seq_from_string<'de, D>(
deserializer: D,
) -> std::result::Result<Option<Vec<u64>>, D::Error>
where
D: de::Deserializer<'de>,
{
let s: Option<String> = de::Deserialize::deserialize(deserializer)?;
if let Some(s) = s {
let numbers = s
.split(',')
.map(|s| s.trim().parse())
.collect::<Result<Vec<u64>, _>>()
.map_err(|_| de::Error::invalid_value(de::Unexpected::Str(&s), &"invalid number"));
Ok(Some(numbers?))
} else {
Ok(None)
}
}

pub(crate) fn deserialize_bool_from_string<'de, D>(deserializer: D) -> Result<bool, D::Error>
where
D: de::Deserializer<'de>,
Expand Down
292 changes: 292 additions & 0 deletions src/connector/src/source/nats/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,57 @@ pub mod source;
pub mod split;

use std::collections::HashMap;
use std::time::Duration;

use async_nats::jetstream::consumer::pull::Config;
use async_nats::jetstream::consumer::{AckPolicy, ReplayPolicy};
use serde::Deserialize;
use serde_with::{serde_as, DisplayFromStr};
use with_options::WithOptions;

use crate::connector_common::NatsCommon;
use crate::source::nats::enumerator::NatsSplitEnumerator;
use crate::source::nats::source::{NatsSplit, NatsSplitReader};
use crate::source::SourceProperties;
use crate::{
deserialize_optional_string_seq_from_string, deserialize_optional_u64_seq_from_string,
};

pub const NATS_CONNECTOR: &str = "nats";

pub struct AckPolicyWrapper;

impl AckPolicyWrapper {
pub fn parse_str(s: &str) -> Result<AckPolicy, String> {
match s {
"none" => Ok(AckPolicy::None),
"all" => Ok(AckPolicy::All),
"explicit" => Ok(AckPolicy::Explicit),
_ => Err(format!("Invalid AckPolicy '{}'", s)),
}
}
}

pub struct ReplayPolicyWrapper;

impl ReplayPolicyWrapper {
pub fn parse_str(s: &str) -> Result<ReplayPolicy, String> {
match s {
"instant" => Ok(ReplayPolicy::Instant),
"original" => Ok(ReplayPolicy::Original),
_ => Err(format!("Invalid ReplayPolicy '{}'", s)),
}
}
}

#[derive(Clone, Debug, Deserialize, WithOptions)]
pub struct NatsProperties {
#[serde(flatten)]
pub common: NatsCommon,

#[serde(flatten)]
pub nats_properties_consumer: NatsPropertiesConsumer,

#[serde(rename = "scan.startup.mode")]
pub scan_startup_mode: Option<String>,

Expand All @@ -49,6 +84,173 @@ pub struct NatsProperties {
pub unknown_fields: HashMap<String, String>,
}

impl NatsProperties {
pub fn set_config(&self, c: &mut Config) {
self.nats_properties_consumer.set_config(c);
}
}

/// Properties for the async-nats library.
/// See <https://docs.rs/async-nats/latest/async_nats/jetstream/consumer/struct.Config.html>
#[serde_as]
#[derive(Clone, Debug, Deserialize, WithOptions)]
pub struct NatsPropertiesConsumer {
#[serde(rename = "consumer.deliver_subject")]
pub deliver_subject: Option<String>,

#[serde(rename = "consumer.durable_name")]
pub durable_name: Option<String>,

#[serde(rename = "consumer.name")]
pub name: Option<String>,

#[serde(rename = "consumer.description")]
pub description: Option<String>,

#[serde(rename = "consumer.deliver_policy")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub deliver_policy: Option<String>,

#[serde(rename = "consumer.ack_policy")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub ack_policy: Option<String>,

#[serde(rename = "consumer.ack_wait.sec")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub ack_wait: Option<u64>,

#[serde(rename = "consumer.max_deliver")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub max_deliver: Option<i64>,

#[serde(rename = "consumer.filter_subject")]
pub filter_subject: Option<String>,

#[serde(rename = "consumer.filter_subjects")]
#[serde(deserialize_with = "deserialize_optional_string_seq_from_string")]
pub filter_subjects: Option<Vec<String>>,

#[serde(rename = "consumer.replay_policy")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub replay_policy: Option<String>,

#[serde(rename = "consumer.rate_limit")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub rate_limit: Option<u64>,

#[serde(rename = "consumer.sample_frequency")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub sample_frequency: Option<u8>,

#[serde(rename = "consumer.max_waiting")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub max_waiting: Option<i64>,

#[serde(rename = "consumer.max_ack_pending")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub max_ack_pending: Option<i64>,

#[serde(rename = "consumer.headers_only")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub headers_only: Option<bool>,

#[serde(rename = "consumer.max_batch")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub max_batch: Option<i64>,

#[serde(rename = "consumer.max_bytes")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub max_bytes: Option<i64>,

#[serde(rename = "consumer.max_expires.sec")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub max_expires: Option<u64>,

#[serde(rename = "consumer.inactive_threshold.sec")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub inactive_threshold: Option<u64>,

#[serde(rename = "consumer.num.replicas", alias = "consumer.num_replicas")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub num_replicas: Option<usize>,

#[serde(rename = "consumer.memory_storage")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub memory_storage: Option<bool>,

#[serde(rename = "consumer.backoff.sec")]
#[serde(deserialize_with = "deserialize_optional_u64_seq_from_string")]
pub backoff: Option<Vec<u64>>,
}

impl NatsPropertiesConsumer {
pub fn set_config(&self, c: &mut Config) {
if let Some(v) = &self.name {
c.name = Some(v.clone())
}
if let Some(v) = &self.durable_name {
c.durable_name = Some(v.clone())
}
if let Some(v) = &self.description {
c.description = Some(v.clone())
}
if let Some(v) = &self.ack_policy {
c.ack_policy = AckPolicyWrapper::parse_str(v).unwrap()
}
if let Some(v) = &self.ack_wait {
c.ack_wait = Duration::from_secs(*v)
}
if let Some(v) = &self.max_deliver {
c.max_deliver = *v
}
if let Some(v) = &self.filter_subject {
c.filter_subject = v.clone()
}
if let Some(v) = &self.filter_subjects {
c.filter_subjects = v.clone()
}
if let Some(v) = &self.replay_policy {
c.replay_policy = ReplayPolicyWrapper::parse_str(v).unwrap()
}
if let Some(v) = &self.rate_limit {
c.rate_limit = *v
}
if let Some(v) = &self.sample_frequency {
c.sample_frequency = *v
}
if let Some(v) = &self.max_waiting {
c.max_waiting = *v
}
if let Some(v) = &self.max_ack_pending {
c.max_ack_pending = *v
}
if let Some(v) = &self.headers_only {
c.headers_only = *v
}
if let Some(v) = &self.max_batch {
c.max_batch = *v
}
if let Some(v) = &self.max_bytes {
c.max_bytes = *v
}
if let Some(v) = &self.max_expires {
c.max_expires = Duration::from_secs(*v)
}
if let Some(v) = &self.inactive_threshold {
c.inactive_threshold = Duration::from_secs(*v)
}
if let Some(v) = &self.num_replicas {
c.num_replicas = *v
}
if let Some(v) = &self.memory_storage {
c.memory_storage = *v
}
if let Some(v) = &self.backoff {
c.backoff = v.iter().map(|&x| Duration::from_secs(x)).collect()
}
}
}

impl SourceProperties for NatsProperties {
type Split = NatsSplit;
type SplitEnumerator = NatsSplitEnumerator;
Expand All @@ -62,3 +264,93 @@ impl crate::source::UnknownFields for NatsProperties {
self.unknown_fields.clone()
}
}

#[cfg(test)]
mod test {
use std::collections::BTreeMap;

use maplit::btreemap;

use super::*;

#[test]
fn test_parse_config_consumer() {
let config: BTreeMap<String, String> = btreemap! {
"stream".to_string() => "risingwave".to_string(),

// NATS common
"subject".to_string() => "subject1".to_string(),
"server_url".to_string() => "nats-server:4222".to_string(),
"connect_mode".to_string() => "plain".to_string(),
"type".to_string() => "append-only".to_string(),

// NATS properties consumer
"consumer.name".to_string() => "foobar".to_string(),
"consumer.durable_name".to_string() => "durable_foobar".to_string(),
"consumer.description".to_string() => "A description".to_string(),
"consumer.ack_policy".to_string() => "all".to_string(),
"consumer.ack_wait.sec".to_string() => "10".to_string(),
"consumer.max_deliver".to_string() => "10".to_string(),
"consumer.filter_subject".to_string() => "subject".to_string(),
"consumer.filter_subjects".to_string() => "subject1,subject2".to_string(),
"consumer.replay_policy".to_string() => "instant".to_string(),
"consumer.rate_limit".to_string() => "100".to_string(),
"consumer.sample_frequency".to_string() => "1".to_string(),
"consumer.max_waiting".to_string() => "5".to_string(),
"consumer.max_ack_pending".to_string() => "100".to_string(),
"consumer.headers_only".to_string() => "true".to_string(),
"consumer.max_batch".to_string() => "10".to_string(),
"consumer.max_bytes".to_string() => "1024".to_string(),
"consumer.max_expires.sec".to_string() => "24".to_string(),
"consumer.inactive_threshold.sec".to_string() => "10".to_string(),
"consumer.num_replicas".to_string() => "3".to_string(),
"consumer.memory_storage".to_string() => "true".to_string(),
"consumer.backoff.sec".to_string() => "2,10,15".to_string(),

};

let props: NatsProperties =
serde_json::from_value(serde_json::to_value(config).unwrap()).unwrap();

assert_eq!(
props.nats_properties_consumer.name,
Some("foobar".to_string())
);
assert_eq!(
props.nats_properties_consumer.durable_name,
Some("durable_foobar".to_string())
);
assert_eq!(
props.nats_properties_consumer.description,
Some("A description".to_string())
);
assert_eq!(
props.nats_properties_consumer.ack_policy,
Some("all".to_string())
);
assert_eq!(props.nats_properties_consumer.ack_wait, Some(10));
assert_eq!(
props.nats_properties_consumer.filter_subjects,
Some(vec!["subject1".to_string(), "subject2".to_string()])
);
assert_eq!(
props.nats_properties_consumer.replay_policy,
Some("instant".to_string())
);
assert_eq!(props.nats_properties_consumer.rate_limit, Some(100));
assert_eq!(props.nats_properties_consumer.sample_frequency, Some(1));
assert_eq!(props.nats_properties_consumer.max_waiting, Some(5));
assert_eq!(props.nats_properties_consumer.max_ack_pending, Some(100));
assert_eq!(props.nats_properties_consumer.headers_only, Some(true));
assert_eq!(props.nats_properties_consumer.max_batch, Some(10));
assert_eq!(props.nats_properties_consumer.max_bytes, Some(1024));
assert_eq!(props.nats_properties_consumer.max_expires, Some(24));
assert_eq!(props.nats_properties_consumer.inactive_threshold, Some(10));
assert_eq!(props.nats_properties_consumer.num_replicas, Some(3));
assert_eq!(props.nats_properties_consumer.memory_storage, Some(true));
assert_eq!(
props.nats_properties_consumer.backoff,
Some(vec![2, 10, 15])
);
}
}
Loading

0 comments on commit 80ff556

Please sign in to comment.