Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update our cross schema check to cross catalog #3123

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 21 additions & 25 deletions src/catalog/src/table_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
use std::collections::HashMap;
use std::sync::Arc;

use common_catalog::consts::INFORMATION_SCHEMA_NAME;
use common_catalog::format_full_table_name;
use datafusion::common::{ResolvedTableReference, TableReference};
use datafusion::datasource::provider_as_source;
Expand All @@ -30,20 +29,20 @@ use crate::CatalogManagerRef;
pub struct DfTableSourceProvider {
catalog_manager: CatalogManagerRef,
resolved_tables: HashMap<String, Arc<dyn TableSource>>,
disallow_cross_schema_query: bool,
disallow_cross_catalog_query: bool,
default_catalog: String,
default_schema: String,
}

impl DfTableSourceProvider {
pub fn new(
catalog_manager: CatalogManagerRef,
disallow_cross_schema_query: bool,
disallow_cross_catalog_query: bool,
query_ctx: &QueryContext,
) -> Self {
Self {
catalog_manager,
disallow_cross_schema_query,
disallow_cross_catalog_query,
resolved_tables: HashMap::new(),
default_catalog: query_ctx.current_catalog().to_owned(),
default_schema: query_ctx.current_schema().to_owned(),
Expand All @@ -54,29 +53,18 @@ impl DfTableSourceProvider {
&'a self,
table_ref: TableReference<'a>,
) -> Result<ResolvedTableReference<'a>> {
if self.disallow_cross_schema_query {
if self.disallow_cross_catalog_query {
match &table_ref {
TableReference::Bare { .. } => (),
TableReference::Partial { schema, .. } => {
ensure!(
schema.as_ref() == self.default_schema
|| schema.as_ref() == INFORMATION_SCHEMA_NAME,
QueryAccessDeniedSnafu {
catalog: &self.default_catalog,
schema: schema.as_ref(),
}
);
}
TableReference::Partial { .. } => {}
TableReference::Full {
catalog, schema, ..
} => {
ensure!(
catalog.as_ref() == self.default_catalog
&& (schema.as_ref() == self.default_schema
|| schema.as_ref() == INFORMATION_SCHEMA_NAME),
catalog.as_ref() == self.default_catalog,
QueryAccessDeniedSnafu {
catalog: catalog.as_ref(),
schema: schema.as_ref()
schema: schema.as_ref(),
}
);
}
Expand Down Expand Up @@ -136,29 +124,29 @@ mod tests {
table: Cow::Borrowed("table_name"),
};
let result = table_provider.resolve_table_ref(table_ref);
let _ = result.unwrap();
assert!(result.is_ok());

let table_ref = TableReference::Partial {
schema: Cow::Borrowed("public"),
table: Cow::Borrowed("table_name"),
};
let result = table_provider.resolve_table_ref(table_ref);
let _ = result.unwrap();
assert!(result.is_ok());

let table_ref = TableReference::Partial {
schema: Cow::Borrowed("wrong_schema"),
table: Cow::Borrowed("table_name"),
};
let result = table_provider.resolve_table_ref(table_ref);
assert!(result.is_err());
assert!(result.is_ok());

let table_ref = TableReference::Full {
catalog: Cow::Borrowed("greptime"),
schema: Cow::Borrowed("public"),
table: Cow::Borrowed("table_name"),
};
let result = table_provider.resolve_table_ref(table_ref);
let _ = result.unwrap();
assert!(result.is_ok());

let table_ref = TableReference::Full {
catalog: Cow::Borrowed("wrong_catalog"),
Expand All @@ -172,20 +160,28 @@ mod tests {
schema: Cow::Borrowed("information_schema"),
table: Cow::Borrowed("columns"),
};
let _ = table_provider.resolve_table_ref(table_ref).unwrap();
let result = table_provider.resolve_table_ref(table_ref);
assert!(result.is_ok());

let table_ref = TableReference::Full {
catalog: Cow::Borrowed("greptime"),
schema: Cow::Borrowed("information_schema"),
table: Cow::Borrowed("columns"),
};
let _ = table_provider.resolve_table_ref(table_ref).unwrap();
assert!(table_provider.resolve_table_ref(table_ref).is_ok());

let table_ref = TableReference::Full {
catalog: Cow::Borrowed("dummy"),
schema: Cow::Borrowed("information_schema"),
table: Cow::Borrowed("columns"),
};
assert!(table_provider.resolve_table_ref(table_ref).is_err());

let table_ref = TableReference::Full {
catalog: Cow::Borrowed("greptime"),
schema: Cow::Borrowed("greptime_private"),
table: Cow::Borrowed("columns"),
};
assert!(table_provider.resolve_table_ref(table_ref).is_ok());
}
}
10 changes: 4 additions & 6 deletions src/frontend/src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ pub fn check_permission(
) -> Result<()> {
let need_validate = plugins
.get::<QueryOptions>()
.map(|opts| opts.disallow_cross_schema_query)
.map(|opts| opts.disallow_cross_catalog_query)
.unwrap_or_default();

if !need_validate {
Expand Down Expand Up @@ -520,7 +520,7 @@ mod tests {
let query_ctx = QueryContext::arc();
let plugins: Plugins = Plugins::new();
plugins.insert(QueryOptions {
disallow_cross_schema_query: true,
disallow_cross_catalog_query: true,
});

let sql = r#"
Expand Down Expand Up @@ -556,8 +556,6 @@ mod tests {
}

let wrong = vec![
("", "wrongschema."),
("greptime.", "wrongschema."),
("wrongcatalog.", "public."),
("wrongcatalog.", "wrongschema."),
];
Expand Down Expand Up @@ -607,10 +605,10 @@ mod tests {
let stmt = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
check_permission(plugins.clone(), &stmt[0], &query_ctx).unwrap();

let sql = "SHOW TABLES FROM wrongschema";
let sql = "SHOW TABLES FROM private";
let stmt = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
let re = check_permission(plugins.clone(), &stmt[0], &query_ctx);
assert!(re.is_err());
assert!(re.is_ok());

// test describe table
let sql = "DESC TABLE {catalog}{schema}demo;";
Expand Down
2 changes: 1 addition & 1 deletion src/query/src/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ impl DfContextProviderAdapter {

let mut table_provider = DfTableSourceProvider::new(
engine_state.catalog_manager().clone(),
engine_state.disallow_cross_schema_query(),
engine_state.disallow_cross_catalog_query(),
query_ctx.as_ref(),
);

Expand Down
4 changes: 2 additions & 2 deletions src/query/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl DfLogicalPlanner {

let table_provider = DfTableSourceProvider::new(
self.engine_state.catalog_manager().clone(),
self.engine_state.disallow_cross_schema_query(),
self.engine_state.disallow_cross_catalog_query(),
query_ctx.as_ref(),
);

Expand Down Expand Up @@ -91,7 +91,7 @@ impl DfLogicalPlanner {
async fn plan_pql(&self, stmt: EvalStmt, query_ctx: QueryContextRef) -> Result<LogicalPlan> {
let table_provider = DfTableSourceProvider::new(
self.engine_state.catalog_manager().clone(),
self.engine_state.disallow_cross_schema_query(),
self.engine_state.disallow_cross_catalog_query(),
query_ctx.as_ref(),
);
PromPlanner::stmt_to_plan(table_provider, stmt)
Expand Down
14 changes: 4 additions & 10 deletions src/query/src/query_engine/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use common_catalog::consts::INFORMATION_SCHEMA_NAME;
use session::context::QueryContextRef;
use snafu::ensure;

use crate::error::{QueryAccessDeniedSnafu, Result};

#[derive(Default, Clone)]
pub struct QueryOptions {
pub disallow_cross_schema_query: bool,
pub disallow_cross_catalog_query: bool,
}

// TODO(shuiyisong): remove one method after #559 is done
Expand All @@ -29,13 +28,8 @@ pub fn validate_catalog_and_schema(
schema: &str,
query_ctx: &QueryContextRef,
) -> Result<()> {
// information_schema is an exception
if schema.eq_ignore_ascii_case(INFORMATION_SCHEMA_NAME) {
return Ok(());
}

ensure!(
catalog == query_ctx.current_catalog() && schema == query_ctx.current_schema(),
catalog == query_ctx.current_catalog(),
QueryAccessDeniedSnafu {
catalog: catalog.to_string(),
schema: schema.to_string(),
Expand All @@ -57,8 +51,8 @@ mod tests {
let context = QueryContext::with("greptime", "public");

validate_catalog_and_schema("greptime", "public", &context).unwrap();
let re = validate_catalog_and_schema("greptime", "wrong_schema", &context);
assert!(re.is_err());
let re = validate_catalog_and_schema("greptime", "private_schema", &context);
assert!(re.is_ok());
let re = validate_catalog_and_schema("wrong_catalog", "public", &context);
assert!(re.is_err());
let re = validate_catalog_and_schema("wrong_catalog", "wrong_schema", &context);
Expand Down
4 changes: 2 additions & 2 deletions src/query/src/query_engine/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ impl QueryEngineState {
self.table_mutation_handler.as_ref()
}

pub(crate) fn disallow_cross_schema_query(&self) -> bool {
pub(crate) fn disallow_cross_catalog_query(&self) -> bool {
self.plugins
.map::<QueryOptions, _, _>(|x| x.disallow_cross_schema_query)
.map::<QueryOptions, _, _>(|x| x.disallow_cross_catalog_query)
.unwrap_or(false)
}

Expand Down
2 changes: 1 addition & 1 deletion src/query/src/tests/query_engine_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ async fn test_query_validate() -> Result<()> {
// set plugins
let plugins = Plugins::new();
plugins.insert(QueryOptions {
disallow_cross_schema_query: true,
disallow_cross_catalog_query: true,
});

let factory = QueryEngineFactory::new_with_plugins(catalog_list, None, None, false, plugins);
Expand Down
Loading