Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Rossil2012 committed Feb 29, 2024
1 parent e490f80 commit 9ce4fee
Show file tree
Hide file tree
Showing 13 changed files with 78 additions and 47 deletions.
1 change: 0 additions & 1 deletion e2e_test/batch/catalog/pg_settings.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ postmaster barrier_interval_ms
postmaster checkpoint_frequency
postmaster enable_tracing
postmaster max_concurrent_creating_streaming_jobs
postmaster oauth_jwks_url
postmaster pause_on_next_bootstrap
user application_name
user background_ddl
Expand Down
1 change: 0 additions & 1 deletion proto/meta.proto
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,6 @@ message SystemParams {
optional bool pause_on_next_bootstrap = 13;
optional string wasm_storage_url = 14 [deprecated = true];
optional bool enable_tracing = 15;
optional string oauth_jwks_url = 16;
}

message GetSystemParamsRequest {}
Expand Down
1 change: 1 addition & 0 deletions proto/user.proto
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ message AuthInfo {
}
EncryptionType encryption_type = 1;
bytes encrypted_value = 2;
map<string, string> meta_data = 3;
}

// User defines a user in the system.
Expand Down
3 changes: 0 additions & 3 deletions src/common/src/system_param/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ macro_rules! for_all_params {
{ max_concurrent_creating_streaming_jobs, u32, Some(1_u32), true, "Max number of concurrent creating streaming jobs.", },
{ pause_on_next_bootstrap, bool, Some(false), true, "Whether to pause all data sources on next bootstrap.", },
{ enable_tracing, bool, Some(false), true, "Whether to enable distributed tracing.", },
{ oauth_jwks_url, String, Some("".to_string()), true, "Url to get JSON Web Key Set(JWKS) for oauth authentication.", },
}
};
}
Expand Down Expand Up @@ -376,7 +375,6 @@ macro_rules! impl_system_params_for_test {
ret.state_store = Some("hummock+memory".to_string());
ret.backup_storage_url = Some("memory".into());
ret.backup_storage_directory = Some("backup".into());
ret.oauth_jwks_url = Some("https://auth-static.confluent.io/jwks".into());
ret
}
};
Expand Down Expand Up @@ -442,7 +440,6 @@ mod tests {
(MAX_CONCURRENT_CREATING_STREAMING_JOBS_KEY, "1"),
(PAUSE_ON_NEXT_BOOTSTRAP_KEY, "false"),
(ENABLE_TRACING_KEY, "true"),
(OAUTH_JWKS_URL_KEY, "a"),
("a_deprecated_param", "foo"),
];

Expand Down
7 changes: 0 additions & 7 deletions src/common/src/system_param/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,4 @@ where
.enable_tracing
.unwrap_or_else(default::enable_tracing)
}

fn oauth_jwks_url(&self) -> &str {
self.inner()
.oauth_jwks_url
.as_ref()
.unwrap_or(&default::OAUTH_JWKS_URL)
}
}
1 change: 0 additions & 1 deletion src/config/docs.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ This page is automatically generated by `./risedev generate-example-config`
| data_directory | Remote directory for storing data and metadata objects. | |
| enable_tracing | Whether to enable distributed tracing. | false |
| max_concurrent_creating_streaming_jobs | Max number of concurrent creating streaming jobs. | 1 |
| oauth_jwks_url | Url to get JSON Web Key Set(JWKS) for oauth authentication. | "" |
| parallel_compact_size_mb | | 512 |
| pause_on_next_bootstrap | Whether to pause all data sources on next bootstrap. | false |
| sstable_size_mb | Target size of the Sstable. | 256 |
Expand Down
1 change: 0 additions & 1 deletion src/config/example.toml
Original file line number Diff line number Diff line change
Expand Up @@ -196,4 +196,3 @@ bloom_false_positive = 0.001
max_concurrent_creating_streaming_jobs = 1
pause_on_next_bootstrap = false
enable_tracing = false
oauth_jwks_url = ""
21 changes: 16 additions & 5 deletions src/frontend/src/handler/alter_user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ use risingwave_sqlparser::ast::{AlterUserStatement, ObjectName, UserOption, User
use super::RwPgResponse;
use crate::binder::Binder;
use crate::catalog::CatalogError;
use crate::error::ErrorCode::{InternalError, PermissionDenied};
use crate::error::ErrorCode::{self, InternalError, PermissionDenied};
use crate::error::Result;
use crate::handler::HandlerArgs;
use crate::user::user_authentication::{build_oauth_info, encrypted_password};
use crate::user::user_authentication::{
build_oauth_info, encrypted_password, OAUTH_ISSUER_KEY, OAUTH_JWKS_URL_KEY,
};
use crate::user::user_catalog::UserCatalog;

fn alter_prost_user_info(
Expand Down Expand Up @@ -111,8 +113,14 @@ fn alter_prost_user_info(
}
update_fields.push(UpdateField::AuthInfo);
}
UserOption::OAuth => {
user_info.auth_info = build_oauth_info();
UserOption::OAuth(options) => {
let auth_info = build_oauth_info(options).ok_or_else(|| {
ErrorCode::InvalidParameterValue(format!(
"{} and {} must be provided",
OAUTH_JWKS_URL_KEY, OAUTH_ISSUER_KEY
))
})?;
user_info.auth_info = Some(auth_info);
update_fields.push(UpdateField::AuthInfo)
}
}
Expand Down Expand Up @@ -185,6 +193,8 @@ pub async fn handle_alter_user(

#[cfg(test)]
mod tests {
use std::collections::HashMap;

use risingwave_pb::user::auth_info::EncryptionType;
use risingwave_pb::user::AuthInfo;

Expand Down Expand Up @@ -223,7 +233,8 @@ mod tests {
user_info.auth_info,
Some(AuthInfo {
encryption_type: EncryptionType::Md5 as i32,
encrypted_value: b"9f2fa6a30871a92249bdd2f1eeee4ef6".to_vec()
encrypted_value: b"9f2fa6a30871a92249bdd2f1eeee4ef6".to_vec(),
meta_data: HashMap::new(),
})
);
}
Expand Down
21 changes: 17 additions & 4 deletions src/frontend/src/handler/create_user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ use risingwave_sqlparser::ast::{CreateUserStatement, UserOption, UserOptions};
use super::RwPgResponse;
use crate::binder::Binder;
use crate::catalog::{CatalogError, DatabaseId};
use crate::error::ErrorCode::PermissionDenied;
use crate::error::ErrorCode::{self, PermissionDenied};
use crate::error::Result;
use crate::handler::HandlerArgs;
use crate::user::user_authentication::{build_oauth_info, encrypted_password};
use crate::user::user_authentication::{
build_oauth_info, encrypted_password, OAUTH_ISSUER_KEY, OAUTH_JWKS_URL_KEY,
};
use crate::user::user_catalog::UserCatalog;

fn make_prost_user_info(
Expand Down Expand Up @@ -91,7 +93,15 @@ fn make_prost_user_info(
user_info.auth_info = encrypted_password(&user_info.name, &password.0);
}
}
UserOption::OAuth => user_info.auth_info = build_oauth_info(),
UserOption::OAuth(options) => {
let auth_info = build_oauth_info(options).ok_or_else(|| {
ErrorCode::InvalidParameterValue(format!(
"{} and {} must be provided",
OAUTH_JWKS_URL_KEY, OAUTH_ISSUER_KEY
))
})?;
user_info.auth_info = Some(auth_info);
}
}
}

Expand Down Expand Up @@ -131,6 +141,8 @@ pub async fn handle_create_user(

#[cfg(test)]
mod tests {
use std::collections::HashMap;

use risingwave_common::catalog::DEFAULT_DATABASE_NAME;
use risingwave_pb::user::auth_info::EncryptionType;
use risingwave_pb::user::AuthInfo;
Expand Down Expand Up @@ -158,7 +170,8 @@ mod tests {
user_info.auth_info,
Some(AuthInfo {
encryption_type: EncryptionType::Md5 as i32,
encrypted_value: b"827ccb0eea8a706c4c34a16891f84e7b".to_vec()
encrypted_value: b"827ccb0eea8a706c4c34a16891f84e7b".to_vec(),
meta_data: HashMap::new(),
})
);
frontend
Expand Down
16 changes: 1 addition & 15 deletions src/frontend/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ use risingwave_common::session_config::{ConfigMap, ConfigReporter, VisibilityMod
use risingwave_common::system_param::local_manager::{
LocalSystemParamsManager, LocalSystemParamsManagerRef,
};
use risingwave_common::system_param::reader::SystemParamsRead;
use risingwave_common::telemetry::manager::TelemetryManager;
use risingwave_common::telemetry::telemetry_env_enabled;
use risingwave_common::types::DataType;
Expand Down Expand Up @@ -978,20 +977,7 @@ impl SessionManager for SessionManagerImpl {
salt,
}
} else if auth_info.encryption_type == EncryptionType::Oauth as i32 {
let oauth_jwks_url = self
.env
.system_params_manager
.get_params()
.load()
.oauth_jwks_url()
.to_string();
if oauth_jwks_url.is_empty() {
return Err(Box::new(Error::new(
ErrorKind::PermissionDenied,
"OAuth JWKS URL is not set",
)));
}
UserAuthenticator::OAuth(oauth_jwks_url)
UserAuthenticator::OAuth(auth_info.meta_data.clone())
} else {
return Err(Box::new(Error::new(
ErrorKind::Unsupported,
Expand Down
22 changes: 21 additions & 1 deletion src/frontend/src/user/user_authentication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashMap;

use risingwave_pb::user::auth_info::EncryptionType;
use risingwave_pb::user::AuthInfo;
use risingwave_sqlparser::ast::SqlOption;
use sha2::{Digest, Sha256};

// SHA-256 is not supported in PostgreSQL protocol. We need to implement SCRAM-SHA-256 instead
Expand All @@ -24,12 +27,23 @@ const MD5_ENCRYPTED_PREFIX: &str = "md5";
const VALID_SHA256_ENCRYPTED_LEN: usize = SHA256_ENCRYPTED_PREFIX.len() + 64;
const VALID_MD5_ENCRYPTED_LEN: usize = MD5_ENCRYPTED_PREFIX.len() + 32;

pub const OAUTH_JWKS_URL_KEY: &str = "jwks_url";
pub const OAUTH_ISSUER_KEY: &str = "issuer";

/// Build `AuthInfo` for `OAuth`.
#[inline(always)]
pub fn build_oauth_info() -> Option<AuthInfo> {
pub fn build_oauth_info(options: &Vec<SqlOption>) -> Option<AuthInfo> {
let meta_data: HashMap<String, String> = options
.iter()
.map(|opt| (opt.name.real_value(), opt.value.to_string()))
.collect();
if !meta_data.contains_key(OAUTH_JWKS_URL_KEY) || !meta_data.contains_key(OAUTH_ISSUER_KEY) {
return None;
}
Some(AuthInfo {
encryption_type: EncryptionType::Oauth as i32,
encrypted_value: Vec::new(),
meta_data,
})
}

Expand Down Expand Up @@ -62,11 +76,13 @@ pub fn encrypted_password(name: &str, password: &str) -> Option<AuthInfo> {
Some(AuthInfo {
encryption_type: EncryptionType::Sha256 as i32,
encrypted_value: password.trim_start_matches(SHA256_ENCRYPTED_PREFIX).into(),
meta_data: HashMap::new(),
})
} else if valid_md5_password(password) {
Some(AuthInfo {
encryption_type: EncryptionType::Md5 as i32,
encrypted_value: password.trim_start_matches(MD5_ENCRYPTED_PREFIX).into(),
meta_data: HashMap::new(),
})
} else {
Some(encrypt_default(name, password))
Expand All @@ -79,6 +95,7 @@ fn encrypt_default(name: &str, password: &str) -> AuthInfo {
AuthInfo {
encryption_type: EncryptionType::Md5 as i32,
encrypted_value: md5_hash(name, password),
meta_data: HashMap::new(),
}
}

Expand Down Expand Up @@ -166,15 +183,18 @@ mod tests {
Some(AuthInfo {
encryption_type: EncryptionType::Md5 as i32,
encrypted_value: md5_hash(user_name, password),
meta_data: HashMap::new(),
}),
None,
Some(AuthInfo {
encryption_type: EncryptionType::Md5 as i32,
encrypted_value: md5_hash(user_name, password),
meta_data: HashMap::new(),
}),
Some(AuthInfo {
encryption_type: EncryptionType::Sha256 as i32,
encrypted_value: sha256_hash(user_name, password),
meta_data: HashMap::new(),
}),
];
let output_passwords = input_passwords
Expand Down
11 changes: 8 additions & 3 deletions src/sqlparser/src/ast/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ pub enum UserOption {
NoLogin,
EncryptedPassword(AstString),
Password(Option<AstString>),
OAuth,
OAuth(Vec<SqlOption>),
}

impl fmt::Display for UserOption {
Expand All @@ -758,7 +758,9 @@ impl fmt::Display for UserOption {
UserOption::EncryptedPassword(p) => write!(f, "ENCRYPTED PASSWORD {}", p),
UserOption::Password(None) => write!(f, "PASSWORD NULL"),
UserOption::Password(Some(p)) => write!(f, "PASSWORD {}", p),
UserOption::OAuth => write!(f, "OAUTH"),
UserOption::OAuth(options) => {
write!(f, "({})", display_comma_separated(options.as_slice()))
}
}
}
}
Expand Down Expand Up @@ -846,7 +848,10 @@ impl ParseTo for UserOptions {
UserOption::EncryptedPassword(AstString::parse_to(parser)?),
)
}
Keyword::OAUTH => (&mut builder.password, UserOption::OAuth),
Keyword::OAUTH => {
let options = parser.parse_options()?;
(&mut builder.password, UserOption::OAuth(options))
}
_ => {
parser.expected(
"SUPERUSER | NOSUPERUSER | CREATEDB | NOCREATEDB | LOGIN \
Expand Down
19 changes: 14 additions & 5 deletions src/utils/pgwire/src/pg_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ pub enum UserAuthenticator {
encrypted_password: Vec<u8>,
salt: [u8; 4],
},
OAuth(String),
OAuth(HashMap<String, String>),
}

#[derive(Debug, Deserialize)]
Expand All @@ -181,7 +181,11 @@ async fn fetch_jwks(url: &str) -> Result<Jwks, reqwest::Error> {
Ok(resp)
}

async fn validate_jwt(jwt: &str, jwks_url: &str) -> Result<bool, BoxedError> {
async fn validate_jwt(
jwt: &str,
jwks_url: &str,
meta_data: &HashMap<String, String>,
) -> Result<bool, BoxedError> {
let header = decode_header(jwt)?;
let jwks = fetch_jwks(jwks_url).await?;

Expand All @@ -194,8 +198,11 @@ async fn validate_jwt(jwt: &str, jwks_url: &str) -> Result<bool, BoxedError> {

let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)?;
let validation = Validation::new(Algorithm::from_str(&jwk.alg)?);
let token_data = decode::<HashMap<String, String>>(jwt, &decoding_key, &validation)?;

Ok(decode::<HashMap<String, String>>(jwt, &decoding_key, &validation).is_ok())
Ok(meta_data
.iter()
.all(|(k, v)| token_data.claims.get(k) == Some(v)))
}

impl UserAuthenticator {
Expand All @@ -206,8 +213,10 @@ impl UserAuthenticator {
UserAuthenticator::Md5WithSalt {
encrypted_password, ..
} => encrypted_password == password,
UserAuthenticator::OAuth(oauth_jwks_url) => {
validate_jwt(&String::from_utf8_lossy(password), oauth_jwks_url)
UserAuthenticator::OAuth(meta_data) => {
let mut meta_data = meta_data.clone();
let jwks_url = meta_data.remove("jwks_url").unwrap();
validate_jwt(&String::from_utf8_lossy(password), &jwks_url, &meta_data)
.await
.map_err(PsqlError::StartupError)?
}
Expand Down

0 comments on commit 9ce4fee

Please sign in to comment.