Skip to content

Commit

Permalink
refactor: use rwlock for modifiable session data (#4232)
Browse files Browse the repository at this point in the history
* chore: update sqlness results

* refactor: use rwlock for modifiable data in session and querycontext

* chore: format toml

* refactor: use mutable_inner structure for mutable fields

* refactor: remove arc wrapper
  • Loading branch information
sunng87 authored Jul 4, 2024
1 parent 6e2c21d commit 8399dca
Show file tree
Hide file tree
Showing 18 changed files with 226 additions and 535 deletions.
490 changes: 106 additions & 384 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/auth/src/permission.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ pub enum PermissionResp {
pub trait PermissionChecker: Send + Sync {
fn check_permission(
&self,
user_info: Option<UserInfoRef>,
user_info: UserInfoRef,
req: PermissionReq,
) -> Result<PermissionResp>;
}

impl PermissionChecker for Option<&PermissionCheckerRef> {
fn check_permission(
&self,
user_info: Option<UserInfoRef>,
user_info: UserInfoRef,
req: PermissionReq,
) -> Result<PermissionResp> {
match self {
Expand Down
9 changes: 5 additions & 4 deletions src/auth/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ struct DummyPermissionChecker;
impl PermissionChecker for DummyPermissionChecker {
fn check_permission(
&self,
_user_info: Option<UserInfoRef>,
_user_info: UserInfoRef,
req: PermissionReq,
) -> auth::error::Result<PermissionResp> {
match req {
Expand All @@ -45,20 +45,21 @@ fn test_permission_checker() {
let checker: PermissionCheckerRef = Arc::new(DummyPermissionChecker);

let grpc_result = checker.check_permission(
None,
auth::userinfo_by_name(None),
PermissionReq::GrpcRequest(&Request::Query(Default::default())),
);
assert_matches!(grpc_result, Ok(PermissionResp::Allow));

let sql_result = checker.check_permission(
None,
auth::userinfo_by_name(None),
PermissionReq::SqlStatement(&Statement::ShowDatabases(ShowDatabases::new(
ShowKind::All,
false,
))),
);
assert_matches!(sql_result, Ok(PermissionResp::Reject));

let err_result = checker.check_permission(None, PermissionReq::Opentsdb);
let err_result =
checker.check_permission(auth::userinfo_by_name(None), PermissionReq::Opentsdb);
assert_matches!(err_result, Err(InternalState { msg }) if msg == "testing");
}
4 changes: 3 additions & 1 deletion src/common/datasource/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ derive_builder.workspace = true
futures.workspace = true
lazy_static.workspace = true
object-store.workspace = true
orc-rust = { git = "https://github.com/datafusion-contrib/datafusion-orc.git", rev = "502217315726314c4008808fe169764529640599" }
orc-rust = { git = "https://github.com/datafusion-contrib/datafusion-orc.git", rev = "502217315726314c4008808fe169764529640599", default-features = false, features = [
"async",
] }
parquet.workspace = true
paste = "1.0"
rand.workspace = true
Expand Down
4 changes: 2 additions & 2 deletions src/operator/src/expr_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ mod tests {

// query context with timezone `+08:00`
let ctx = QueryContextBuilder::default()
.timezone(Timezone::from_tz_string("+08:00").unwrap().into())
.timezone(Timezone::from_tz_string("+08:00").unwrap())
.build()
.into();
let expr = create_to_expr(&create_table, &ctx).unwrap();
Expand Down Expand Up @@ -735,7 +735,7 @@ mod tests {
//
// query context with timezone `+08:00`
let ctx = QueryContextBuilder::default()
.timezone(Timezone::from_tz_string("+08:00").unwrap().into())
.timezone(Timezone::from_tz_string("+08:00").unwrap())
.build()
.into();
let expr = to_alter_expr(alter_table, &ctx).unwrap();
Expand Down
3 changes: 1 addition & 2 deletions src/operator/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,6 @@ fn idents_to_full_database_name(
mod tests {
use std::assert_matches::assert_matches;
use std::collections::HashMap;
use std::sync::Arc;

use common_time::range::TimestampRange;
use common_time::{Timestamp, Timezone};
Expand All @@ -509,7 +508,7 @@ mod tests {

fn check_timestamp_range((start, end): (&str, &str)) -> error::Result<Option<TimestampRange>> {
let query_ctx = QueryContextBuilder::default()
.timezone(Arc::new(Timezone::from_tz_string("Asia/Shanghai").unwrap()))
.timezone(Timezone::from_tz_string("Asia/Shanghai").unwrap())
.build()
.into();
let map = OptionMap::from(
Expand Down
2 changes: 1 addition & 1 deletion src/query/src/optimizer/type_conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ impl TypeConverter {
) -> Result<ScalarValue> {
match (target_type, value) {
(DataType::Timestamp(_, _), ScalarValue::Utf8(Some(v))) => {
string_to_timestamp_ms(v, Some(self.query_ctx.timezone().as_ref()))
string_to_timestamp_ms(v, Some(&self.query_ctx.timezone()))
}
(DataType::Boolean, ScalarValue::Utf8(Some(v))) => match v.to_lowercase().as_str() {
"true" => Ok(ScalarValue::Boolean(Some(true))),
Expand Down
7 changes: 2 additions & 5 deletions src/query/src/range_select/plan_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,8 @@ impl<'a> TreeNodeRewriter for RangeExprRewriter<'a> {
.map_err(|e| DataFusionError::Plan(e.to_string()))?;
let by = parse_expr_list(&func.args, 4, byc)?;
let align = parse_duration_expr(&func.args, byc + 4)?;
let align_to = parse_align_to(
&func.args,
byc + 5,
Some(self.query_ctx.timezone().as_ref()),
)?;
let align_to =
parse_align_to(&func.args, byc + 5, Some(&self.query_ctx.timezone()))?;
let mut data_type = range_expr.get_type(self.input_plan.schema())?;
let mut need_cast = false;
let fill = Fill::try_from_str(parse_str_expr(&func.args, 2)?, &data_type)?;
Expand Down
2 changes: 1 addition & 1 deletion src/query/src/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,7 @@ mod test {
};
let ctx = Arc::new(
QueryContextBuilder::default()
.timezone(Arc::new(Timezone::from_tz_string(tz).unwrap()))
.timezone(Timezone::from_tz_string(tz).unwrap())
.build(),
);
match show_variable(stmt, ctx) {
Expand Down
6 changes: 3 additions & 3 deletions src/servers/src/grpc/authorize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ async fn do_auth<T>(
let query_ctx = Arc::new(QueryContext::with(&catalog, &schema));

let Some(user_provider) = user_provider else {
query_ctx.set_current_user(Some(auth::userinfo_by_name(None)));
query_ctx.set_current_user(auth::userinfo_by_name(None));
let _ = req.extensions_mut().insert(query_ctx);
return Ok(());
};
Expand All @@ -124,7 +124,7 @@ async fn do_auth<T>(
.await
.map_err(|e| tonic::Status::unauthenticated(e.to_string()))?;

query_ctx.set_current_user(Some(user_info));
query_ctx.set_current_user(user_info);
let _ = req.extensions_mut().insert(query_ctx);

Ok(())
Expand Down Expand Up @@ -201,7 +201,7 @@ mod tests {
assert_eq!(expected_catalog, ctx.current_catalog());
assert_eq!(expected_schema, ctx.current_schema());

let user_info = ctx.current_user().unwrap();
let user_info = ctx.current_user();
assert_eq!(expected_user_name, user_info.username());
}
}
7 changes: 3 additions & 4 deletions src/servers/src/grpc/greptime_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ pub(crate) async fn auth(
user_provider: Option<UserProviderRef>,
header: Option<&RequestHeader>,
query_ctx: &QueryContextRef,
) -> Result<Option<UserInfoRef>> {
) -> Result<UserInfoRef> {
let Some(user_provider) = user_provider else {
return Ok(None);
return Ok(auth::userinfo_by_name(None));
};

let auth_scheme = header
Expand All @@ -156,7 +156,6 @@ pub(crate) async fn auth(
name: "Token AuthScheme".to_string(),
}),
}
.map(Some)
.map_err(|e| {
METRIC_AUTH_FAILURE
.with_label_values(&[e.status_code().as_ref()])
Expand Down Expand Up @@ -197,7 +196,7 @@ pub(crate) fn create_query_context(header: Option<&RequestHeader>) -> QueryConte
QueryContextBuilder::default()
.current_catalog(catalog)
.current_schema(schema)
.timezone(Arc::new(timezone))
.timezone(timezone)
.build()
.into()
}
Expand Down
6 changes: 3 additions & 3 deletions src/servers/src/http/authorize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ pub async fn inner_auth<B>(
// 1. prepare
let (catalog, schema) = extract_catalog_and_schema(&req);
// TODO(ruihang): move this out of auth module
let timezone = Arc::new(extract_timezone(&req));
let timezone = extract_timezone(&req);
let query_ctx_builder = QueryContextBuilder::default()
.current_catalog(catalog.clone())
.current_schema(schema.clone())
Expand All @@ -75,7 +75,7 @@ pub async fn inner_auth<B>(
let user_provider = if let Some(user_provider) = user_provider.filter(|_| need_auth) {
user_provider
} else {
query_ctx.set_current_user(Some(auth::userinfo_by_name(None)));
query_ctx.set_current_user(auth::userinfo_by_name(None));
let _ = req.extensions_mut().insert(query_ctx);
return Ok(req);
};
Expand Down Expand Up @@ -103,7 +103,7 @@ pub async fn inner_auth<B>(
.await
{
Ok(userinfo) => {
query_ctx.set_current_user(Some(userinfo));
query_ctx.set_current_user(userinfo);
let _ = req.extensions_mut().insert(query_ctx);
Ok(req)
}
Expand Down
4 changes: 1 addition & 3 deletions src/servers/src/mysql/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,7 @@ impl MysqlInstanceShim {
{
vec![Ok(output)]
} else {
let output = self.query_handler.do_query(query, query_ctx.clone()).await;
query_ctx.update_session(&self.session);
output
self.query_handler.do_query(query, query_ctx.clone()).await
}
}

Expand Down
1 change: 0 additions & 1 deletion src/servers/src/postgres/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ impl SimpleQueryHandler for PostgresServerHandler {
.with_label_values(&[crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY, db.as_str()])
.start_timer();
let outputs = self.query_handler.do_query(query, query_ctx.clone()).await;
query_ctx.update_session(&self.session);

let mut results = Vec::with_capacity(outputs.len());

Expand Down
6 changes: 3 additions & 3 deletions src/servers/tests/http/authorize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ async fn test_http_auth() {
let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
let req = inner_auth(None, req).await.unwrap();
let ctx: &QueryContextRef = req.extensions().get().unwrap();
let user_info = ctx.current_user().unwrap();
let user_info = ctx.current_user();
let default = auth::userinfo_by_name(None);
assert_eq!(default.username(), user_info.username());

Expand All @@ -39,7 +39,7 @@ async fn test_http_auth() {
let req = mock_http_request(Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="), None).unwrap();
let req = inner_auth(mock_user_provider.clone(), req).await.unwrap();
let ctx: &QueryContextRef = req.extensions().get().unwrap();
let user_info = ctx.current_user().unwrap();
let user_info = ctx.current_user();
let default = auth::userinfo_by_name(None);
assert_eq!(default.username(), user_info.username());

Expand Down Expand Up @@ -80,7 +80,7 @@ async fn test_schema_validating() {
.unwrap();
let req = inner_auth(mock_user_provider.clone(), req).await.unwrap();
let ctx: &QueryContextRef = req.extensions().get().unwrap();
let user_info = ctx.current_user().unwrap();
let user_info = ctx.current_user();
let default = auth::userinfo_by_name(None);
assert_eq!(default.username(), user_info.username());

Expand Down
8 changes: 4 additions & 4 deletions src/servers/tests/http/http_handler_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use crate::{
async fn test_sql_not_provided() {
let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table());
let ctx = QueryContext::arc();
ctx.set_current_user(Some(auth::userinfo_by_name(None)));
ctx.set_current_user(auth::userinfo_by_name(None));
let api_state = ApiState {
sql_handler,
script_handler: None,
Expand Down Expand Up @@ -75,7 +75,7 @@ async fn test_sql_output_rows() {
let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table());

let ctx = QueryContext::arc();
ctx.set_current_user(Some(auth::userinfo_by_name(None)));
ctx.set_current_user(auth::userinfo_by_name(None));
let api_state = ApiState {
sql_handler,
script_handler: None,
Expand Down Expand Up @@ -181,7 +181,7 @@ async fn test_sql_output_rows() {
async fn test_dashboard_sql_limit() {
let sql_handler = create_testing_sql_query_handler(MemTable::specified_numbers_table(2000));
let ctx = QueryContext::arc();
ctx.set_current_user(Some(auth::userinfo_by_name(None)));
ctx.set_current_user(auth::userinfo_by_name(None));
let api_state = ApiState {
sql_handler,
script_handler: None,
Expand Down Expand Up @@ -227,7 +227,7 @@ async fn test_sql_form() {
let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table());

let ctx = QueryContext::arc();
ctx.set_current_user(Some(auth::userinfo_by_name(None)));
ctx.set_current_user(auth::userinfo_by_name(None));
let api_state = ApiState {
sql_handler,
script_handler: None,
Expand Down
Loading

0 comments on commit 8399dca

Please sign in to comment.