Skip to content

Commit

Permalink
feat(stream): add nkey and jwt auth methods for nats connector (#12227)
Browse files Browse the repository at this point in the history
Co-authored-by: yufansong <[email protected]>
  • Loading branch information
yufansong and yufansong authored Sep 19, 2023
1 parent 4dadb7c commit b17569d
Show file tree
Hide file tree
Showing 9 changed files with 273 additions and 227 deletions.
256 changes: 87 additions & 169 deletions Cargo.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions src/connector/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ mysql_common = { version = "0.30", default-features = false, features = [
"chrono",
] }
nexmark = { version = "0.2", features = ["serde"] }
nkeys = "0.3.2"
num-bigint = "0.4"
opendal = "0.39"
parking_lot = "0.12"
Expand Down Expand Up @@ -101,6 +102,7 @@ serde_with = { version = "3", features = ["json"] }
simd-json = "0.10.6"
tempfile = "3"
thiserror = "1"
time = "0.3.28"
tokio = { version = "0.2", package = "madsim-tokio", features = [
"rt",
"rt-multi-thread",
Expand All @@ -117,6 +119,7 @@ tonic = { workspace = true }
tracing = "0.1"
url = "2"
urlencoding = "2"

[target.'cfg(not(madsim))'.dependencies]
workspace-hack = { path = "../workspace-hack" }

Expand Down
118 changes: 87 additions & 31 deletions src/connector/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ use risingwave_common::error::anyhow_error;
use serde_derive::{Deserialize, Serialize};
use serde_with::json::JsonString;
use serde_with::{serde_as, DisplayFromStr};
use time::OffsetDateTime;

use crate::aws_auth::AwsAuthProps;
use crate::deserialize_duration_from_string;
use crate::sink::SinkError;

use crate::source::nats::source::NatsOffset;
// The file describes the common abstractions for each connector and can be used in both source and
// sink.

Expand Down Expand Up @@ -342,37 +343,73 @@ pub struct UpsertMessage<'a> {
#[serde_as]
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct NatsCommon {
#[serde(rename = "nats.server_url")]
#[serde(rename = "server_url")]
pub server_url: String,
#[serde(rename = "nats.subject")]
#[serde(rename = "subject")]
pub subject: String,
#[serde(rename = "nats.user")]
#[serde(rename = "connect_mode")]
pub connect_mode: Option<String>,
#[serde(rename = "username")]
pub user: Option<String>,
#[serde(rename = "nats.password")]
#[serde(rename = "password")]
pub password: Option<String>,
#[serde(rename = "nats.max_bytes")]
#[serde(rename = "jwt")]
pub jwt: Option<String>,
#[serde(rename = "nkey")]
pub nkey: Option<String>,
#[serde(rename = "max_bytes")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub max_bytes: Option<i64>,
#[serde(rename = "nats.max_messages")]
#[serde(rename = "max_messages")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub max_messages: Option<i64>,
#[serde(rename = "nats.max_messages_per_subject")]
#[serde(rename = "max_messages_per_subject")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub max_messages_per_subject: Option<i64>,
#[serde(rename = "nats.max_consumers")]
#[serde(rename = "max_consumers")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub max_consumers: Option<i32>,
#[serde(rename = "nats.max_message_size")]
#[serde(rename = "max_message_size")]
#[serde_as(as = "Option<DisplayFromStr>")]
pub max_message_size: Option<i32>,
}

impl NatsCommon {
pub(crate) async fn build_client(&self) -> anyhow::Result<async_nats::Client> {
let mut connect_options = async_nats::ConnectOptions::new();
if let (Some(v_user), Some(v_password)) = (self.user.as_ref(), self.password.as_ref()) {
connect_options = connect_options.user_and_password(v_user.into(), v_password.into());
}
match self.connect_mode.as_deref() {
Some("user_and_password") => {
if let (Some(v_user), Some(v_password)) =
(self.user.as_ref(), self.password.as_ref())
{
connect_options =
connect_options.user_and_password(v_user.into(), v_password.into())
} else {
return Err(anyhow_error!(
"nats connect mode is user_and_password, but user or password is empty"
));
}
}

Some("credential") => {
if let (Some(v_nkey), Some(v_jwt)) = (self.nkey.as_ref(), self.jwt.as_ref()) {
connect_options = connect_options
.credentials(&self.create_credential(v_nkey, v_jwt)?)
.expect("failed to parse static creds")
} else {
return Err(anyhow_error!(
"nats connect mode is credential, but nkey or jwt is empty"
));
}
}
Some("plain") => {}
_ => {
return Err(anyhow_error!(
"nats connect mode only accept user_and_password/credential/plain"
));
}
};

let servers = self.server_url.split(',').collect::<Vec<&str>>();
let client = connect_options
.connect(
Expand All @@ -394,8 +431,8 @@ impl NatsCommon {

pub(crate) async fn build_consumer(
&self,
split_id: i32,
start_sequence: Option<u64>,
split_id: String,
start_sequence: NatsOffset,
) -> anyhow::Result<
async_nats::jetstream::consumer::Consumer<async_nats::jetstream::consumer::pull::Config>,
> {
Expand All @@ -406,23 +443,28 @@ impl NatsCommon {
ack_policy: jetstream::consumer::AckPolicy::None,
..Default::default()
};
match start_sequence {
Some(v) => {
let consumer = stream
.get_or_create_consumer(&name, {
config.deliver_policy = DeliverPolicy::ByStartSequence {
start_sequence: v + 1,
};
config
})
.await?;
Ok(consumer)
}
None => {
let consumer = stream.get_or_create_consumer(&name, config).await?;
Ok(consumer)

let deliver_policy = match start_sequence {
NatsOffset::Earliest => DeliverPolicy::All,
NatsOffset::Latest => DeliverPolicy::Last,
NatsOffset::SequenceNumber(v) => {
let parsed = v.parse::<u64>()?;
DeliverPolicy::ByStartSequence {
start_sequence: 1 + parsed,
}
}
}
NatsOffset::Timestamp(v) => DeliverPolicy::ByStartTime {
start_time: OffsetDateTime::from_unix_timestamp_nanos(v * 1_000_000)?,
},
NatsOffset::None => DeliverPolicy::All,
};
let consumer = stream
.get_or_create_consumer(&name, {
config.deliver_policy = deliver_policy;
config
})
.await?;
Ok(consumer)
}

pub(crate) async fn build_or_get_stream(
Expand All @@ -432,6 +474,7 @@ impl NatsCommon {
let mut config = jetstream::stream::Config {
// the subject default use name value
name: self.subject.clone(),
max_bytes: 1000000,
..Default::default()
};
if let Some(v) = self.max_bytes {
Expand All @@ -452,4 +495,17 @@ impl NatsCommon {
let stream = jetstream.get_or_create_stream(config).await?;
Ok(stream)
}

pub(crate) fn create_credential(&self, seed: &str, jwt: &str) -> anyhow::Result<String> {
let creds = format!(
"-----BEGIN NATS USER JWT-----\n{}\n------END NATS USER JWT------\n\n\
************************* IMPORTANT *************************\n\
NKEY Seed printed below can be used to sign and prove identity.\n\
NKEYs are sensitive and should be treated as secrets.\n\n\
-----BEGIN USER NKEY SEED-----\n{}\n------END USER NKEY SEED------\n\n\
*************************************************************",
jwt, seed
);
Ok(creds)
}
}
14 changes: 8 additions & 6 deletions src/connector/src/source/nats/enumerator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::Arc;

use anyhow;
use async_trait::async_trait;

use super::source::NatsSplit;
use super::source::{NatsOffset, NatsSplit};
use super::NatsProperties;
use crate::source::{SourceEnumeratorContextRef, SplitEnumerator};
use crate::source::{SourceEnumeratorContextRef, SplitEnumerator, SplitId};

#[derive(Debug, Clone, Eq, PartialEq)]
pub struct NatsSplitEnumerator {
subject: String,
split_num: i32,
split_id: SplitId,
}

#[async_trait]
Expand All @@ -36,16 +38,16 @@ impl SplitEnumerator for NatsSplitEnumerator {
) -> anyhow::Result<NatsSplitEnumerator> {
Ok(Self {
subject: properties.common.subject,
split_num: 0,
split_id: Arc::from("0"),
})
}

async fn list_splits(&mut self) -> anyhow::Result<Vec<NatsSplit>> {
// TODO: to simplify the logic, return 1 split for first version
let nats_split = NatsSplit {
subject: self.subject.clone(),
split_num: 0, // be the same as `from_nats_jetstream_message`
start_sequence: None,
split_id: Arc::from("0"), // be the same as `from_nats_jetstream_message`
start_sequence: NatsOffset::None,
};

Ok(vec![nats_split])
Expand Down
6 changes: 6 additions & 0 deletions src/connector/src/source/nats/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ pub const NATS_CONNECTOR: &str = "nats";
pub struct NatsProperties {
#[serde(flatten)]
pub common: NatsCommon,

#[serde(rename = "scan.startup.mode")]
pub scan_startup_mode: Option<String>,

#[serde(rename = "scan.startup.timestamp_millis")]
pub start_time: Option<String>,
}

impl NatsProperties {}
Expand Down
30 changes: 24 additions & 6 deletions src/connector/src/source/nats/source/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,37 @@
// limitations under the License.

use async_nats;
use async_nats::jetstream::Message;

use crate::source::base::SourceMessage;
use crate::source::SourceMeta;
use crate::source::{SourceMeta, SplitId};

impl SourceMessage {
pub fn from_nats_jetstream_message(message: async_nats::jetstream::message::Message) -> Self {
#[derive(Clone, Debug)]
pub struct NatsMessage {
pub split_id: SplitId,
pub sequence_number: String,
pub payload: Vec<u8>,
}

impl From<NatsMessage> for SourceMessage {
fn from(message: NatsMessage) -> Self {
SourceMessage {
key: None,
payload: Some(message.message.payload.to_vec()),
payload: Some(message.payload),
// For nats jetstream, use sequence id as offset
offset: message.info().unwrap().stream_sequence.to_string(),
split_id: "0".into(),
offset: message.sequence_number,
split_id: message.split_id,
meta: SourceMeta::Empty,
}
}
}

impl NatsMessage {
pub fn new(split_id: SplitId, message: Message) -> Self {
NatsMessage {
split_id,
sequence_number: message.info().unwrap().stream_sequence.to_string(),
payload: message.message.payload.to_vec(),
}
}
}
43 changes: 38 additions & 5 deletions src/connector/src/source/nats/source/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,28 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use anyhow::Result;
use anyhow::{anyhow, Result};
use async_nats::jetstream::consumer;
use async_trait::async_trait;
use futures::StreamExt;
use futures_async_stream::try_stream;

use super::message::NatsMessage;
use super::{NatsOffset, NatsSplit};
use crate::parser::ParserConfig;
use crate::source::common::{into_chunk_stream, CommonSplitReader};
use crate::source::nats::split::NatsSplit;
use crate::source::nats::NatsProperties;
use crate::source::{
BoxSourceWithStateStream, Column, SourceContextRef, SourceMessage, SplitReader,
BoxSourceWithStateStream, Column, SourceContextRef, SourceMessage, SplitId, SplitReader,
};

pub struct NatsSplitReader {
consumer: consumer::Consumer<consumer::pull::Config>,
properties: NatsProperties,
parser_config: ParserConfig,
source_ctx: SourceContextRef,
start_position: NatsOffset,
split_id: SplitId,
}

#[async_trait]
Expand All @@ -47,15 +50,42 @@ impl SplitReader for NatsSplitReader {
) -> Result<Self> {
// TODO: to simplify the logic, return 1 split for first version
assert!(splits.len() == 1);
let split = splits.into_iter().next().unwrap();
let split_id = split.split_id;
let start_position = match &split.start_sequence {
NatsOffset::None => match &properties.scan_startup_mode {
None => NatsOffset::Earliest,
Some(mode) => match mode.as_str() {
"latest" => NatsOffset::Latest,
"earliest" => NatsOffset::Earliest,
"timestamp_millis" => {
if let Some(time) = &properties.start_time {
NatsOffset::Timestamp(time.parse()?)
} else {
return Err(anyhow!("scan_startup_timestamp_millis is required"));
}
}
_ => {
return Err(anyhow!(
"invalid scan_startup_mode, accept earliest/latest/timestamp_millis"
))
}
},
},
start_position => start_position.to_owned(),
};

let consumer = properties
.common
.build_consumer(0, splits[0].start_sequence)
.build_consumer(split_id.to_string(), start_position.clone())
.await?;
Ok(Self {
consumer,
properties,
parser_config,
source_ctx,
start_position,
split_id,
})
}

Expand All @@ -75,7 +105,10 @@ impl CommonSplitReader for NatsSplitReader {
for msgs in messages.ready_chunks(capacity) {
let mut msg_vec = Vec::with_capacity(capacity);
for msg in msgs {
msg_vec.push(SourceMessage::from_nats_jetstream_message(msg?));
msg_vec.push(SourceMessage::from(NatsMessage::new(
self.split_id.clone(),
msg?,
)));
}
yield msg_vec;
}
Expand Down
Loading

0 comments on commit b17569d

Please sign in to comment.