Skip to content

Commit

Permalink
fix default and add validation
Browse files Browse the repository at this point in the history
  • Loading branch information
yezizp2012 committed Feb 28, 2024
1 parent 0b85778 commit dc97f26
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 40 deletions.
5 changes: 4 additions & 1 deletion src/common/src/system_param/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ where
}

fn oauth_jwks_url(&self) -> &str {
self.inner().oauth_jwks_url.as_ref().unwrap()
self.inner()
.oauth_jwks_url
.as_ref()
.unwrap_or(&default::OAUTH_JWKS_URL)
}
}
34 changes: 18 additions & 16 deletions src/frontend/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ use risingwave_common::session_config::{ConfigMap, ConfigReporter, VisibilityMod
use risingwave_common::system_param::local_manager::{
LocalSystemParamsManager, LocalSystemParamsManagerRef,
};
use risingwave_common::system_param::reader::{SystemParamsRead, SystemParamsReader};
use risingwave_common::system_param::reader::SystemParamsRead;
use risingwave_common::telemetry::manager::TelemetryManager;
use risingwave_common::telemetry::telemetry_env_enabled;
use risingwave_common::types::DataType;
Expand Down Expand Up @@ -927,16 +927,16 @@ pub struct SessionManagerImpl {
impl SessionManager for SessionManagerImpl {
type Session = SessionImpl;

async fn connect(
fn connect(
&self,
database: String,
user_name: String,
database: &str,
user_name: &str,
peer_addr: AddressRef,
) -> std::result::Result<Arc<Self::Session>, BoxedError> {
let database_id = {
let catalog_reader = self.env.catalog_reader().read_guard();
catalog_reader
.get_database_by_name(&database)
.get_database_by_name(database)
.map_err(|_| {
Box::new(Error::new(
ErrorKind::InvalidInput,
Expand All @@ -947,7 +947,7 @@ impl SessionManager for SessionManagerImpl {
};
let user = {
let user_reader = self.env.user_info_reader().read_guard();
user_reader.get_user_by_name(&user_name).cloned()
user_reader.get_user_by_name(user_name).cloned()
};
if let Some(user) = user {
if !user.can_login {
Expand Down Expand Up @@ -981,13 +981,19 @@ impl SessionManager for SessionManagerImpl {
salt,
}
} else if auth_info.encryption_type == EncryptionType::Oauth as i32 {
let reader = self
let oauth_jwks_url = self
.env
.meta_client()
.get_system_params()
.await
.map_err(|e| PsqlError::StartupError(e.into()))?;
let oauth_jwks_url = reader.oauth_jwks_url().to_string();
.system_params_manager
.get_params()
.load()
.oauth_jwks_url()
.to_string();
if oauth_jwks_url.is_empty() {
return Err(Box::new(Error::new(
ErrorKind::PermissionDenied,
"OAuth JWKS URL is not set",
)));
}
UserAuthenticator::OAuth(oauth_jwks_url)
} else {
return Err(Box::new(Error::new(
Expand Down Expand Up @@ -1102,10 +1108,6 @@ impl Session for SessionImpl {
&self.user_authenticator
}

async fn get_system_params(&self) -> std::result::Result<SystemParamsReader, BoxedError> {
Ok(self.env.meta_client.get_system_params().await?)
}

fn id(&self) -> SessionId {
self.id
}
Expand Down
6 changes: 3 additions & 3 deletions src/frontend/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ pub struct LocalFrontend {
impl SessionManager for LocalFrontend {
type Session = SessionImpl;

async fn connect(
fn connect(
&self,
_database: String,
_user_name: String,
_database: &str,
_user_name: &str,
_peer_addr: AddressRef,
) -> std::result::Result<Arc<Self::Session>, BoxedError> {
Ok(self.session_ref())
Expand Down
7 changes: 3 additions & 4 deletions src/utils/pgwire/src/pg_protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ where

match msg {
FeMessage::Ssl => self.process_ssl_msg().await?,
FeMessage::Startup(msg) => self.process_startup_msg(msg).await?,
FeMessage::Startup(msg) => self.process_startup_msg(msg)?,
FeMessage::Password(msg) => self.process_password_msg(msg).await?,
FeMessage::Query(query_msg) => self.process_query_msg(query_msg.get_sql()).await?,
FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?,
Expand Down Expand Up @@ -469,7 +469,7 @@ where
Ok(())
}

async fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> {
fn process_startup_msg(&mut self, msg: FeStartupMessage) -> PsqlResult<()> {
let db_name = msg
.config
.get("database")
Expand All @@ -483,8 +483,7 @@ where

let session = self
.session_mgr
.connect(db_name, user_name, self.peer_addr.clone())
.await
.connect(&db_name, &user_name, self.peer_addr.clone())
.map_err(PsqlError::StartupError)?;

let application_name = msg.config.get("application_name");
Expand Down
22 changes: 6 additions & 16 deletions src/utils/pgwire/src/pg_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ use std::time::Instant;
use bytes::Bytes;
use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
use parking_lot::Mutex;
use risingwave_common::system_param::reader::SystemParamsReader;
use risingwave_common::types::DataType;
use risingwave_sqlparser::ast::Statement;
use serde::Deserialize;
Expand All @@ -50,10 +49,10 @@ pub trait SessionManager: Send + Sync + 'static {

fn connect(
&self,
database: String,
user_name: String,
database: &str,
user_name: &str,
peer_addr: AddressRef,
) -> impl Future<Output = Result<Arc<Self::Session>, BoxedError>> + Send;
) -> Result<Arc<Self::Session>, BoxedError>;

fn cancel_queries_in_session(&self, session_id: SessionId);

Expand Down Expand Up @@ -112,10 +111,6 @@ pub trait Session: Send + Sync {

fn user_authenticator(&self) -> &UserAuthenticator;

fn get_system_params(
&self,
) -> impl Future<Output = Result<SystemParamsReader, BoxedError>> + Send;

fn id(&self) -> SessionId;

fn set_config(&self, key: &str, value: String) -> Result<(), BoxedError>;
Expand Down Expand Up @@ -288,7 +283,6 @@ mod tests {
use bytes::Bytes;
use futures::stream::BoxStream;
use futures::StreamExt;
use risingwave_common::system_param::reader::SystemParamsReader;
use risingwave_common::types::DataType;
use risingwave_sqlparser::ast::Statement;
use tokio_postgres::NoTls;
Expand All @@ -310,10 +304,10 @@ mod tests {
impl SessionManager for MockSessionManager {
type Session = MockSession;

async fn connect(
fn connect(
&self,
_database: String,
_user_name: String,
_database: &str,
_user_name: &str,
_peer_addr: crate::net::AddressRef,
) -> Result<Arc<Self::Session>, Box<dyn Error + Send + Sync>> {
Ok(Arc::new(MockSession {}))
Expand Down Expand Up @@ -411,10 +405,6 @@ mod tests {
&UserAuthenticator::None
}

async fn get_system_params(&self) -> Result<SystemParamsReader, BoxedError> {
Ok(SystemParamsReader::new(Default::default()))
}

fn id(&self) -> SessionId {
(0, 0)
}
Expand Down

0 comments on commit dc97f26

Please sign in to comment.