Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: passing QueryContext to RegionServer #3829

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
632 changes: 313 additions & 319 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ etcd-client = { git = "https://github.com/MichaelScofield/etcd-client.git", rev
fst = "0.4.7"
futures = "0.3"
futures-util = "0.3"
greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "f699e240f7a6c83f139dabac8669714f08513120" }
greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "a191edaea1089362a86ebc7d8e98ee9a1bd522d1" }
humantime = "2.1"
humantime-serde = "1.1"
itertools = "0.10"
Expand Down
5 changes: 4 additions & 1 deletion src/cmd/src/cli/repl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,10 @@ impl Repl {
let start = Instant::now();

let output = if let Some(query_engine) = &self.query_engine {
let query_ctx = QueryContext::with(self.database.catalog(), self.database.schema());
let query_ctx = Arc::new(QueryContext::with(
self.database.catalog(),
self.database.schema(),
));

let stmt = QueryLanguageParser::parse_sql(&sql, &query_ctx)
.with_context(|_| ParseSqlSnafu { sql: sql.clone() })?;
Expand Down
4 changes: 2 additions & 2 deletions src/common/function/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl FunctionContext {
#[cfg(any(test, feature = "testing"))]
pub fn mock() -> Self {
Self {
query_ctx: QueryContextBuilder::default().build(),
query_ctx: QueryContextBuilder::default().build().into(),
state: Arc::new(FunctionState::mock()),
}
}
Expand All @@ -44,7 +44,7 @@ impl FunctionContext {
impl Default for FunctionContext {
fn default() -> Self {
Self {
query_ctx: QueryContextBuilder::default().build(),
query_ctx: QueryContextBuilder::default().build().into(),
state: Arc::new(FunctionState::default()),
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/common/function/src/scalars/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ mod tests {
#[test]
fn test_create_udf() {
let f = Arc::new(TestAndFunction);
let query_ctx = QueryContextBuilder::default().build();
let query_ctx = QueryContextBuilder::default().build().into();

let args: Vec<VectorRef> = vec![
Arc::new(ConstantVector::new(
Expand Down
3 changes: 2 additions & 1 deletion src/common/function/src/system/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ mod tests {

let query_ctx = QueryContextBuilder::default()
.current_schema("test_db".to_string())
.build();
.build()
.into();

let func_ctx = FunctionContext {
query_ctx,
Expand Down
2 changes: 1 addition & 1 deletion src/common/function/src/system/timezone.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ mod tests {
} if valid_types == vec![]
));

let query_ctx = QueryContextBuilder::default().build();
let query_ctx = QueryContextBuilder::default().build().into();

let func_ctx = FunctionContext {
query_ctx,
Expand Down
2 changes: 1 addition & 1 deletion src/datanode/src/region_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ impl RegionServerInner {
let ctx: QueryContextRef = header
.as_ref()
.map(|h| Arc::new(h.into()))
.unwrap_or_else(|| QueryContextBuilder::default().build());
.unwrap_or_else(|| QueryContextBuilder::default().build().into());

// build dummy catalog list
let region_status = self
Expand Down
6 changes: 4 additions & 2 deletions src/operator/src/expr_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,8 @@ mod tests {
// query context with timezone `+08:00`
let ctx = QueryContextBuilder::default()
.timezone(Timezone::from_tz_string("+08:00").unwrap().into())
.build();
.build()
.into();
let expr = create_to_expr(&create_table, ctx).unwrap();
let ts_column = &expr.column_defs[1];
let constraint = assert_ts_column(ts_column);
Expand Down Expand Up @@ -712,7 +713,8 @@ mod tests {
// query context with timezone `+08:00`
let ctx = QueryContextBuilder::default()
.timezone(Timezone::from_tz_string("+08:00").unwrap().into())
.build();
.build()
.into();
let expr = to_alter_expr(alter_table, ctx).unwrap();
let kind = expr.kind.unwrap();

Expand Down
3 changes: 2 additions & 1 deletion src/operator/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,8 @@ mod tests {
fn check_timestamp_range((start, end): (&str, &str)) -> error::Result<Option<TimestampRange>> {
let query_ctx = QueryContextBuilder::default()
.timezone(Arc::new(Timezone::from_tz_string("Asia/Shanghai").unwrap()))
.build();
.build()
.into();
let map = OptionMap::from(
[
(COPY_DATABASE_TIME_START_KEY.to_string(), start.to_string()),
Expand Down
2 changes: 1 addition & 1 deletion src/operator/src/statement/ddl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1089,7 +1089,7 @@ ENGINE=mito",
r#"[{"column_list":["b","a"],"value_list":["{\"Value\":{\"String\":\"hz\"}}","{\"Value\":{\"Int32\":10}}"]},{"column_list":["b","a"],"value_list":["{\"Value\":{\"String\":\"sh\"}}","{\"Value\":{\"Int32\":20}}"]},{"column_list":["b","a"],"value_list":["\"MaxValue\"","\"MaxValue\""]}]"#,
),
];
let ctx = QueryContextBuilder::default().build();
let ctx = QueryContextBuilder::default().build().into();
for (sql, expected) in cases {
let result = ParserContext::create_with_dialect(
sql,
Expand Down
14 changes: 11 additions & 3 deletions src/query/src/dist_plan/merge_scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use datafusion_expr::{Extension, LogicalPlan, UserDefinedLogicalNodeCore};
use datafusion_physical_expr::EquivalenceProperties;
use datatypes::schema::{Schema, SchemaRef};
use futures_util::StreamExt;
use greptime_proto::v1::region::{QueryRequest, RegionRequestHeader};
use greptime_proto::v1::region::{QueryContext, QueryRequest, RegionRequestHeader};
use meter_core::data::ReadItem;
use meter_macros::read_meter;
use session::context::QueryContextRef;
Expand Down Expand Up @@ -179,7 +179,10 @@ impl MergeScanExec {

let dbname = context.task_id().unwrap_or_default();
let tracing_context = TracingContext::from_json(context.session_id().as_str());
let tz = self.query_ctx.timezone().to_string();
let current_catalog = self.query_ctx.current_catalog().to_string();
let current_schema = self.query_ctx.current_schema().to_string();
let timezone = self.query_ctx.timezone().to_string();
let extensions = self.query_ctx.extensions();

let stream = Box::pin(stream!({
MERGE_SCAN_REGIONS.observe(regions.len() as f64);
Expand All @@ -192,7 +195,12 @@ impl MergeScanExec {
header: Some(RegionRequestHeader {
tracing_context: tracing_context.to_w3c(),
dbname: dbname.clone(),
timezone: tz.clone(),
query_context: Some(QueryContext {
current_catalog,
current_schema,
timezone,
extensions,
killme2008 marked this conversation as resolved.
Show resolved Hide resolved
}),
}),
region_id: region_id.into(),
plan: substrait_plan.clone(),
Expand Down
4 changes: 3 additions & 1 deletion src/query/src/query_engine/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,15 @@ pub fn validate_catalog_and_schema(
#[cfg(test)]
mod tests {

use std::sync::Arc;

use session::context::QueryContext;

use super::*;

#[test]
fn test_validate_catalog_and_schema() {
let context = QueryContext::with("greptime", "public");
let context = Arc::new(QueryContext::with("greptime", "public"));

validate_catalog_and_schema("greptime", "public", &context).unwrap();
let re = validate_catalog_and_schema("greptime", "private_schema", &context);
Expand Down
8 changes: 5 additions & 3 deletions src/query/src/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -944,9 +944,11 @@ mod test {
let stmt = ShowVariables {
variable: ObjectName(vec![Ident::new(variable)]),
};
let ctx = QueryContextBuilder::default()
.timezone(Arc::new(Timezone::from_tz_string(tz).unwrap()))
.build();
let ctx = Arc::new(
QueryContextBuilder::default()
.timezone(Arc::new(Timezone::from_tz_string(tz).unwrap()))
.build(),
);
match show_variable(stmt, ctx) {
Ok(Output {
data: OutputData::RecordBatches(record),
Expand Down
2 changes: 1 addition & 1 deletion src/script/src/python/ffi_types/copr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ impl PyQueryEngine {
let rt = tokio::runtime::Runtime::new().map_err(|e| e.to_string())?;
let handle = rt.handle().clone();
let res = handle.block_on(async {
let ctx = QueryContextBuilder::default().build();
let ctx = Arc::new(QueryContextBuilder::default().build());
let plan = engine
.planner()
.plan(stmt, ctx.clone())
Expand Down
1 change: 1 addition & 0 deletions src/script/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ fn query_ctx(table_info: &TableInfo) -> QueryContextRef {
.current_catalog(table_info.catalog_name.to_string())
.current_schema(table_info.schema_name.to_string())
.build()
.into()
}

/// Builds scripts schema, returns (time index, primary keys, column defs)
Expand Down
3 changes: 2 additions & 1 deletion src/servers/src/export_metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;

use axum::http::HeaderValue;
Expand Down Expand Up @@ -247,7 +248,7 @@ pub async fn write_system_metric_by_handler(
);
// Pass the first tick. Because the first tick completes immediately.
interval.tick().await;
let ctx = QueryContextBuilder::default().current_schema(db).build();
let ctx = Arc::new(QueryContextBuilder::default().current_schema(db).build());
loop {
interval.tick().await;
let metric_families = prometheus::gather();
Expand Down
3 changes: 2 additions & 1 deletion src/servers/src/grpc/authorize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

use std::pin::Pin;
use std::result::Result as StdResult;
use std::sync::Arc;
use std::task::{Context, Poll};

use auth::UserProviderRef;
Expand Down Expand Up @@ -104,7 +105,7 @@ async fn do_auth<T>(
) -> Result<(), tonic::Status> {
let (catalog, schema) = extract_catalog_and_schema(req);

let query_ctx = QueryContext::with(&catalog, &schema);
let query_ctx = Arc::new(QueryContext::with(&catalog, &schema));

let Some(user_provider) = user_provider else {
query_ctx.set_current_user(Some(auth::userinfo_by_name(None)));
Expand Down
1 change: 1 addition & 0 deletions src/servers/src/grpc/greptime_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ pub(crate) fn create_query_context(header: Option<&RequestHeader>) -> QueryConte
.current_schema(schema)
.timezone(Arc::new(timezone))
.build()
.into()
}

/// Histogram timer for handling gRPC request.
Expand Down
2 changes: 1 addition & 1 deletion src/servers/src/http/authorize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ pub async fn inner_auth<B>(
.current_schema(schema.clone())
.timezone(timezone);

let query_ctx = query_ctx_builder.build();
let query_ctx = Arc::new(query_ctx_builder.build());
let need_auth = need_auth(&req);

// 2. check if auth is needed
Expand Down
11 changes: 7 additions & 4 deletions src/servers/src/http/prometheus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

//! prom supply the prometheus HTTP API Server compliance
use std::collections::{HashMap, HashSet};
use std::sync::Arc;

use axum::extract::{Path, Query, State};
use axum::{Extension, Form};
Expand Down Expand Up @@ -572,10 +573,12 @@ pub(crate) fn try_update_catalog_schema(
schema: &str,
) -> QueryContextRef {
if ctx.current_catalog() != catalog || ctx.current_schema() != schema {
QueryContextBuilder::from_existing(&ctx)
.current_catalog(catalog.to_string())
.current_schema(schema.to_string())
.build()
Arc::new(
QueryContextBuilder::from_existing(&ctx)
.current_catalog(catalog.to_string())
.current_schema(schema.to_string())
.build(),
)
} else {
ctx
}
Expand Down
5 changes: 3 additions & 2 deletions src/servers/src/http/script.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;

use axum::extract::{Query, RawBody, State};
Expand Down Expand Up @@ -76,7 +77,7 @@ pub async fn scripts(
unwrap_or_json_err!(String::from_utf8(bytes.to_vec()).context(InvalidUtf8ValueSnafu));

// Safety: schema and name are already checked above.
let query_ctx = QueryContext::with(&catalog, schema.unwrap());
let query_ctx = Arc::new(QueryContext::with(&catalog, schema.unwrap()));
match script_handler
.insert_script(query_ctx, name.unwrap(), &script)
.await
Expand Down Expand Up @@ -128,7 +129,7 @@ pub async fn run_script(
}

// Safety: schema and name are already checked above.
let query_ctx = QueryContext::with(&catalog, schema.unwrap());
let query_ctx = Arc::new(QueryContext::with(&catalog, schema.unwrap()));
let output = script_handler
.execute_script(query_ctx, name.unwrap(), params.params)
.await;
Expand Down
3 changes: 2 additions & 1 deletion src/servers/src/postgres/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,8 @@ mod test {
];
let query_context = QueryContextBuilder::default()
.configuration_parameter(Default::default())
.build();
.build()
.into();
let mut builder = DataRowEncoder::new(Arc::new(schema));
for i in values.iter() {
encode_value(&query_context, i, &mut builder).unwrap();
Expand Down
10 changes: 6 additions & 4 deletions src/servers/tests/py_script/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ def hello() -> vector[str]:

let table = MemTable::table("scripts", recordbatch);

let query_ctx = QueryContextBuilder::default()
.current_catalog(catalog.to_string())
.current_schema(schema.to_string())
.build();
let query_ctx = Arc::new(
QueryContextBuilder::default()
.current_catalog(catalog.to_string())
.current_schema(schema.to_string())
.build(),
);

let instance = create_testing_instance(table);
instance
Expand Down
Loading
Loading