Skip to content

Commit

Permalink
refactor(connector): use AwsAuthProps instead of HashMap in config (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
xxchan authored Nov 22, 2023
1 parent f43d2f1 commit e612ef6
Show file tree
Hide file tree
Showing 15 changed files with 202 additions and 216 deletions.
139 changes: 0 additions & 139 deletions src/connector/src/aws_auth.rs

This file was deleted.

18 changes: 2 additions & 16 deletions src/connector/src/aws_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,9 @@ use risingwave_common::error::ErrorCode::InternalError;
use risingwave_common::error::{Result, RwError};
use url::Url;

use crate::aws_auth::AwsAuthProps;
use crate::common::AwsAuthProps;

pub const REGION: &str = "region";
pub const ACCESS_KEY: &str = "access_key";
pub const SECRET_ACCESS: &str = "secret_access";

pub const AWS_DEFAULT_CONFIG: [&str; 7] = [
REGION,
"arn",
"profile",
ACCESS_KEY,
SECRET_ACCESS,
"session_token",
"endpoint_url",
];
pub const AWS_CUSTOM_CONFIG_KEY: [&str; 3] = ["retry_times", "conn_timeout", "read_timeout"];
const AWS_CUSTOM_CONFIG_KEY: [&str; 3] = ["retry_times", "conn_timeout", "read_timeout"];

pub fn default_conn_config() -> HashMap<String, u64> {
let mut default_conn_config = HashMap::new();
Expand Down Expand Up @@ -118,7 +105,6 @@ pub fn s3_client(
}

// TODO(Tao): Probably we should never allow to use S3 URI.
/// properties require keys: refer to [`AWS_DEFAULT_CONFIG`]
pub async fn load_file_descriptor_from_s3(
location: &Url,
config: &AwsAuthProps,
Expand Down
108 changes: 95 additions & 13 deletions src/connector/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ use time::OffsetDateTime;
use url::Url;
use with_options::WithOptions;

use crate::aws_auth::AwsAuthProps;
use crate::aws_utils::load_file_descriptor_from_s3;
use crate::deserialize_duration_from_string;
use crate::sink::SinkError;
Expand All @@ -51,6 +50,98 @@ pub struct AwsPrivateLinkItem {
pub port: u16,
}

use aws_config::default_provider::region::DefaultRegionChain;
use aws_config::sts::AssumeRoleProvider;
use aws_credential_types::provider::SharedCredentialsProvider;
use aws_types::region::Region;
use aws_types::SdkConfig;

/// A flatten config map for aws auth.
#[derive(Deserialize, Serialize, Debug, Clone, WithOptions)]
pub struct AwsAuthProps {
pub region: Option<String>,
#[serde(alias = "endpoint_url")]
pub endpoint: Option<String>,
pub access_key: Option<String>,
pub secret_key: Option<String>,
pub session_token: Option<String>,
pub arn: Option<String>,
/// This field was added for kinesis. Not sure if it's useful for other connectors.
/// Please ignore it in the documentation for now.
pub external_id: Option<String>,
pub profile: Option<String>,
}

impl AwsAuthProps {
async fn build_region(&self) -> anyhow::Result<Region> {
if let Some(region_name) = &self.region {
Ok(Region::new(region_name.clone()))
} else {
let mut region_chain = DefaultRegionChain::builder();
if let Some(profile_name) = &self.profile {
region_chain = region_chain.profile_name(profile_name);
}

Ok(region_chain
.build()
.region()
.await
.ok_or_else(|| anyhow::format_err!("region should be provided"))?)
}
}

fn build_credential_provider(&self) -> anyhow::Result<SharedCredentialsProvider> {
if self.access_key.is_some() && self.secret_key.is_some() {
Ok(SharedCredentialsProvider::new(
aws_credential_types::Credentials::from_keys(
self.access_key.as_ref().unwrap(),
self.secret_key.as_ref().unwrap(),
self.session_token.clone(),
),
))
} else {
Err(anyhow!(
"Both \"access_key\" and \"secret_access\" are required."
))
}
}

async fn with_role_provider(
&self,
credential: SharedCredentialsProvider,
) -> anyhow::Result<SharedCredentialsProvider> {
if let Some(role_name) = &self.arn {
let region = self.build_region().await?;
let mut role = AssumeRoleProvider::builder(role_name)
.session_name("RisingWave")
.region(region);
if let Some(id) = &self.external_id {
role = role.external_id(id);
}
let provider = role.build_from_provider(credential).await;
Ok(SharedCredentialsProvider::new(provider))
} else {
Ok(credential)
}
}

pub async fn build_config(&self) -> anyhow::Result<SdkConfig> {
let region = self.build_region().await?;
let credentials_provider = self
.with_role_provider(self.build_credential_provider()?)
.await?;
let mut config_loader = aws_config::from_env()
.region(region)
.credentials_provider(credentials_provider);

if let Some(endpoint) = self.endpoint.as_ref() {
config_loader = config_loader.endpoint_url(endpoint);
}

Ok(config_loader.load().await)
}
}

#[serde_as]
#[derive(Debug, Clone, Serialize, Deserialize, WithOptions)]
pub struct KafkaCommon {
Expand Down Expand Up @@ -282,8 +373,7 @@ pub struct PulsarOauthCommon {
pub scope: Option<String>,

#[serde(flatten)]
/// required keys refer to [`crate::aws_utils::AWS_DEFAULT_CONFIG`]
pub s3_credentials: HashMap<String, String>,
pub aws_auth_props: AwsAuthProps,
}

impl PulsarCommon {
Expand All @@ -294,16 +384,8 @@ impl PulsarCommon {
let url = Url::parse(&oauth.credentials_url)?;
match url.scheme() {
"s3" => {
let credentials = load_file_descriptor_from_s3(
&url,
&AwsAuthProps::from_pairs(
oauth
.s3_credentials
.iter()
.map(|(k, v)| (k.as_str(), v.as_str())),
),
)
.await?;
let credentials =
load_file_descriptor_from_s3(&url, &oauth.aws_auth_props).await?;
let mut f = NamedTempFile::new()?;
f.write_all(&credentials)?;
f.as_file().sync_all()?;
Expand Down
1 change: 0 additions & 1 deletion src/connector/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ use risingwave_pb::connector_service::SinkPayloadFormat;
use risingwave_rpc_client::ConnectorClient;
use serde::de;

pub mod aws_auth;
pub mod aws_utils;
pub mod error;
mod macros;
Expand Down
11 changes: 3 additions & 8 deletions src/connector/src/parser/avro/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ mod test {
read_schema_from_http, read_schema_from_local, read_schema_from_s3, AvroAccessBuilder,
AvroParserConfig,
};
use crate::aws_auth::AwsAuthProps;
use crate::common::AwsAuthProps;
use crate::parser::bytes_parser::BytesAccessBuilder;
use crate::parser::plain_parser::PlainParser;
use crate::parser::unified::avro::unix_epoch_days;
Expand Down Expand Up @@ -257,14 +257,9 @@ mod test {
#[ignore]
async fn test_load_schema_from_s3() {
let schema_location = "s3://mingchao-schemas/complex-schema.avsc".to_string();
let mut s3_config_props = HashMap::new();
s3_config_props.insert("region".to_string(), "ap-southeast-1".to_string());
let url = Url::parse(&schema_location).unwrap();
let aws_auth_config = AwsAuthProps::from_pairs(
s3_config_props
.iter()
.map(|(k, v)| (k.as_str(), v.as_str())),
);
let aws_auth_config: AwsAuthProps =
serde_json::from_str(r#"region":"ap-southeast-1"#).unwrap();
let schema_content = read_schema_from_s3(&url, &aws_auth_config).await;
assert!(schema_content.is_ok());
let schema = Schema::parse_str(&schema_content.unwrap());
Expand Down
20 changes: 13 additions & 7 deletions src/connector/src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use self::simd_json_parser::DebeziumJsonAccessBuilder;
use self::unified::{AccessImpl, AccessResult};
use self::upsert_parser::UpsertParser;
use self::util::get_kafka_topic;
use crate::aws_auth::AwsAuthProps;
use crate::common::AwsAuthProps;
use crate::parser::maxwell::MaxwellParser;
use crate::schema::schema_registry::SchemaRegistryAuth;
use crate::source::{
Expand Down Expand Up @@ -916,9 +916,12 @@ impl SpecificParserConfig {
config.topic = get_kafka_topic(props)?.clone();
config.client_config = SchemaRegistryAuth::from(props);
} else {
config.aws_auth_props = Some(AwsAuthProps::from_pairs(
props.iter().map(|(k, v)| (k.as_str(), v.as_str())),
));
config.aws_auth_props = Some(
serde_json::from_value::<AwsAuthProps>(
serde_json::to_value(props).unwrap(),
)
.map_err(|e| anyhow::anyhow!(e))?,
);
}
EncodingProperties::Avro(config)
}
Expand All @@ -945,9 +948,12 @@ impl SpecificParserConfig {
config.topic = get_kafka_topic(props)?.clone();
config.client_config = SchemaRegistryAuth::from(props);
} else {
config.aws_auth_props = Some(AwsAuthProps::from_pairs(
props.iter().map(|(k, v)| (k.as_str(), v.as_str())),
));
config.aws_auth_props = Some(
serde_json::from_value::<AwsAuthProps>(
serde_json::to_value(props).unwrap(),
)
.map_err(|e| anyhow::anyhow!(e))?,
);
}
EncodingProperties::Protobuf(config)
}
Expand Down
2 changes: 1 addition & 1 deletion src/connector/src/parser/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ use risingwave_common::error::ErrorCode::{
};
use risingwave_common::error::{Result, RwError};

use crate::aws_auth::AwsAuthProps;
use crate::aws_utils::{default_conn_config, s3_client};
use crate::common::AwsAuthProps;

const AVRO_SCHEMA_LOCATION_S3_REGION: &str = "region";

Expand Down
Loading

0 comments on commit e612ef6

Please sign in to comment.