Skip to content

Commit

Permalink
refactor: change the return type of build() in QueryContextBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
Kelvinyu1117 committed May 2, 2024
1 parent 57142a7 commit 9043dd8
Show file tree
Hide file tree
Showing 30 changed files with 276 additions and 253 deletions.
344 changes: 177 additions & 167 deletions Cargo.lock

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion src/cmd/src/cli/repl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,10 @@ impl Repl {
let start = Instant::now();

let output = if let Some(query_engine) = &self.query_engine {
let query_ctx = QueryContext::with(self.database.catalog(), self.database.schema());
let query_ctx = Arc::new(QueryContext::with(
self.database.catalog(),
self.database.schema(),
));

let stmt = QueryLanguageParser::parse_sql(&sql, &query_ctx)
.with_context(|_| ParseSqlSnafu { sql: sql.clone() })?;
Expand Down
4 changes: 2 additions & 2 deletions src/common/function/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl FunctionContext {
#[cfg(any(test, feature = "testing"))]
pub fn mock() -> Self {
Self {
query_ctx: QueryContextBuilder::default().build(),
query_ctx: QueryContextBuilder::default().build().into(),
state: Arc::new(FunctionState::mock()),
}
}
Expand All @@ -44,7 +44,7 @@ impl FunctionContext {
impl Default for FunctionContext {
fn default() -> Self {
Self {
query_ctx: QueryContextBuilder::default().build(),
query_ctx: QueryContextBuilder::default().build().into(),
state: Arc::new(FunctionState::default()),
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/common/function/src/scalars/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ mod tests {
#[test]
fn test_create_udf() {
let f = Arc::new(TestAndFunction);
let query_ctx = QueryContextBuilder::default().build();
let query_ctx = QueryContextBuilder::default().build().into();

let args: Vec<VectorRef> = vec![
Arc::new(ConstantVector::new(
Expand Down
3 changes: 2 additions & 1 deletion src/common/function/src/system/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ mod tests {

let query_ctx = QueryContextBuilder::default()
.current_schema("test_db".to_string())
.build();
.build()
.into();

let func_ctx = FunctionContext {
query_ctx,
Expand Down
2 changes: 1 addition & 1 deletion src/common/function/src/system/timezone.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ mod tests {
} if valid_types == vec![]
));

let query_ctx = QueryContextBuilder::default().build();
let query_ctx = QueryContextBuilder::default().build().into();

let func_ctx = FunctionContext {
query_ctx,
Expand Down
2 changes: 1 addition & 1 deletion src/datanode/src/region_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ impl RegionServerInner {
let ctx: QueryContextRef = header
.as_ref()
.map(|h| Arc::new(h.into()))
.unwrap_or_else(|| QueryContextBuilder::default().build());
.unwrap_or_else(|| QueryContextBuilder::default().build().into());

// build dummy catalog list
let region_status = self
Expand Down
6 changes: 4 additions & 2 deletions src/operator/src/expr_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,8 @@ mod tests {
// query context with timezone `+08:00`
let ctx = QueryContextBuilder::default()
.timezone(Timezone::from_tz_string("+08:00").unwrap().into())
.build();
.build()
.into();
let expr = create_to_expr(&create_table, ctx).unwrap();
let ts_column = &expr.column_defs[1];
let constraint = assert_ts_column(ts_column);
Expand Down Expand Up @@ -624,7 +625,8 @@ mod tests {
// query context with timezone `+08:00`
let ctx = QueryContextBuilder::default()
.timezone(Timezone::from_tz_string("+08:00").unwrap().into())
.build();
.build()
.into();
let expr = to_alter_expr(alter_table, ctx).unwrap();
let kind = expr.kind.unwrap();

Expand Down
3 changes: 2 additions & 1 deletion src/operator/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,8 @@ mod tests {
fn check_timestamp_range((start, end): (&str, &str)) -> error::Result<Option<TimestampRange>> {
let query_ctx = QueryContextBuilder::default()
.timezone(Arc::new(Timezone::from_tz_string("Asia/Shanghai").unwrap()))
.build();
.build()
.into();
let map = OptionMap::from(
[
(COPY_DATABASE_TIME_START_KEY.to_string(), start.to_string()),
Expand Down
2 changes: 1 addition & 1 deletion src/operator/src/statement/ddl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,7 @@ ENGINE=mito",
r#"[{"column_list":["b","a"],"value_list":["{\"Value\":{\"String\":\"hz\"}}","{\"Value\":{\"Int32\":10}}"]},{"column_list":["b","a"],"value_list":["{\"Value\":{\"String\":\"sh\"}}","{\"Value\":{\"Int32\":20}}"]},{"column_list":["b","a"],"value_list":["\"MaxValue\"","\"MaxValue\""]}]"#,
),
];
let ctx = QueryContextBuilder::default().build();
let ctx = QueryContextBuilder::default().build().into();
for (sql, expected) in cases {
let result = ParserContext::create_with_dialect(
sql,
Expand Down
4 changes: 2 additions & 2 deletions src/query/src/dist_plan/merge_scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ impl MergeScanExec {
let current_catalog = self.query_ctx.current_catalog().to_string();
let current_schema = self.query_ctx.current_schema().to_string();
let timezone = self.query_ctx.timezone().to_string();
let extension = self.query_ctx.to_extension();
let extensions = self.query_ctx.extensions();

let stream = Box::pin(stream!({
MERGE_SCAN_REGIONS.observe(regions.len() as f64);
Expand All @@ -199,7 +199,7 @@ impl MergeScanExec {
current_catalog: current_catalog.clone(),
current_schema: current_schema.clone(),
timezone: timezone.clone(),
extension: extension.clone(),
extensions: extensions.clone(),
}),
}),
region_id: region_id.into(),
Expand Down
4 changes: 3 additions & 1 deletion src/query/src/query_engine/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,15 @@ pub fn validate_catalog_and_schema(
#[cfg(test)]
mod tests {

use std::sync::Arc;

use session::context::QueryContext;

use super::*;

#[test]
fn test_validate_catalog_and_schema() {
let context = QueryContext::with("greptime", "public");
let context = Arc::new(QueryContext::with("greptime", "public"));

validate_catalog_and_schema("greptime", "public", &context).unwrap();
let re = validate_catalog_and_schema("greptime", "private_schema", &context);
Expand Down
8 changes: 5 additions & 3 deletions src/query/src/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -944,9 +944,11 @@ mod test {
let stmt = ShowVariables {
variable: ObjectName(vec![Ident::new(variable)]),
};
let ctx = QueryContextBuilder::default()
.timezone(Arc::new(Timezone::from_tz_string(tz).unwrap()))
.build();
let ctx = Arc::new(
QueryContextBuilder::default()
.timezone(Arc::new(Timezone::from_tz_string(tz).unwrap()))
.build(),
);
match show_variable(stmt, ctx) {
Ok(Output {
data: OutputData::RecordBatches(record),
Expand Down
2 changes: 1 addition & 1 deletion src/script/src/python/ffi_types/copr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ impl PyQueryEngine {
let rt = tokio::runtime::Runtime::new().map_err(|e| e.to_string())?;
let handle = rt.handle().clone();
let res = handle.block_on(async {
let ctx = QueryContextBuilder::default().build();
let ctx = Arc::new(QueryContextBuilder::default().build());
let plan = engine
.planner()
.plan(stmt, ctx.clone())
Expand Down
1 change: 1 addition & 0 deletions src/script/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ fn query_ctx(table_info: &TableInfo) -> QueryContextRef {
.current_catalog(table_info.catalog_name.to_string())
.current_schema(table_info.schema_name.to_string())
.build()
.into()
}

/// Builds scripts schema, returns (time index, primary keys, column defs)
Expand Down
3 changes: 2 additions & 1 deletion src/servers/src/export_metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;

use axum::http::HeaderValue;
Expand Down Expand Up @@ -247,7 +248,7 @@ pub async fn write_system_metric_by_handler(
);
// Pass the first tick. Because the first tick completes immediately.
interval.tick().await;
let ctx = QueryContextBuilder::default().current_schema(db).build();
let ctx = Arc::new(QueryContextBuilder::default().current_schema(db).build());
loop {
interval.tick().await;
let metric_families = prometheus::gather();
Expand Down
1 change: 1 addition & 0 deletions src/servers/src/grpc/greptime_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ pub(crate) fn create_query_context(header: Option<&RequestHeader>) -> QueryConte
.current_schema(schema)
.timezone(Arc::new(timezone))
.build()
.into()
}

/// Histogram timer for handling gRPC request.
Expand Down
2 changes: 1 addition & 1 deletion src/servers/src/http/authorize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ pub async fn inner_auth<B>(
.current_schema(schema.clone())
.timezone(timezone);

let query_ctx = query_ctx_builder.build();
let query_ctx = Arc::new(query_ctx_builder.build());
let need_auth = need_auth(&req);

// 2. check if auth is needed
Expand Down
11 changes: 7 additions & 4 deletions src/servers/src/http/prometheus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

//! prom supply the prometheus HTTP API Server compliance
use std::collections::{HashMap, HashSet};
use std::sync::Arc;

use axum::extract::{Path, Query, State};
use axum::{Extension, Form};
Expand Down Expand Up @@ -572,10 +573,12 @@ pub(crate) fn try_update_catalog_schema(
schema: &str,
) -> QueryContextRef {
if ctx.current_catalog() != catalog || ctx.current_schema() != schema {
QueryContextBuilder::from_existing(&ctx)
.current_catalog(catalog.to_string())
.current_schema(schema.to_string())
.build()
Arc::new(
QueryContextBuilder::from_existing(&ctx)
.current_catalog(catalog.to_string())
.current_schema(schema.to_string())
.build(),
)
} else {
ctx
}
Expand Down
5 changes: 3 additions & 2 deletions src/servers/src/http/script.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;

use axum::extract::{Query, RawBody, State};
Expand Down Expand Up @@ -76,7 +77,7 @@ pub async fn scripts(
unwrap_or_json_err!(String::from_utf8(bytes.to_vec()).context(InvalidUtf8ValueSnafu));

// Safety: schema and name are already checked above.
let query_ctx = QueryContext::with(&catalog, schema.unwrap());
let query_ctx = Arc::new(QueryContext::with(&catalog, schema.unwrap()));
match script_handler
.insert_script(query_ctx, name.unwrap(), &script)
.await
Expand Down Expand Up @@ -128,7 +129,7 @@ pub async fn run_script(
}

// Safety: schema and name are already checked above.
let query_ctx = QueryContext::with(&catalog, schema.unwrap());
let query_ctx = Arc::new(QueryContext::with(&catalog, schema.unwrap()));
let output = script_handler
.execute_script(query_ctx, name.unwrap(), params.params)
.await;
Expand Down
4 changes: 3 additions & 1 deletion src/servers/src/opentsdb/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

//! Modified from Tokio's mini-redis example.
use std::sync::Arc;

use common_error::ext::ErrorExt;
use session::context::QueryContextBuilder;
use tokio::io::{AsyncRead, AsyncWrite};
Expand Down Expand Up @@ -62,7 +64,7 @@ impl<S: AsyncWrite + AsyncRead + Unpin> Handler<S> {

pub(crate) async fn run(&mut self) -> Result<()> {
// TODO(shuiyisong): figure out how to auth in tcp connection.
let ctx = QueryContextBuilder::default().build();
let ctx = Arc::new(QueryContextBuilder::default().build());
while !self.shutdown.is_shutdown() {
// While reading a request, also listen for the shutdown signal.
let maybe_line = tokio::select! {
Expand Down
3 changes: 2 additions & 1 deletion src/servers/src/postgres/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,8 @@ mod test {
];
let query_context = QueryContextBuilder::default()
.configuration_parameter(Default::default())
.build();
.build()
.into();
let mut builder = DataRowEncoder::new(Arc::new(schema));
for i in values.iter() {
encode_value(&query_context, i, &mut builder).unwrap();
Expand Down
10 changes: 6 additions & 4 deletions src/servers/tests/py_script/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ def hello() -> vector[str]:

let table = MemTable::table("scripts", recordbatch);

let query_ctx = QueryContextBuilder::default()
.current_catalog(catalog.to_string())
.current_schema(schema.to_string())
.build();
let query_ctx = Arc::new(
QueryContextBuilder::default()
.current_catalog(catalog.to_string())
.current_schema(schema.to_string())
.build(),
);

let instance = create_testing_instance(table);
instance
Expand Down
Loading

0 comments on commit 9043dd8

Please sign in to comment.