From a703f9d94f637f65e1bfd57c3887b756ee1a63eb Mon Sep 17 00:00:00 2001 From: Dennis Zhuang Date: Wed, 27 Mar 2024 14:16:40 -0700 Subject: [PATCH] fix: canonicalize catalog and schema names --- src/common/catalog/src/lib.rs | 27 +++++++++++-------- src/servers/src/grpc/authorize.rs | 4 +-- src/servers/src/grpc/greptime_handler.rs | 19 ++++++++----- src/servers/src/http/authorize.rs | 12 ++++----- src/servers/src/http/prometheus.rs | 8 ++++-- src/servers/src/mysql/handler.rs | 14 ++++++---- src/servers/src/postgres/auth_handler.rs | 7 ++--- src/session/src/context.rs | 2 +- .../common/system/information_schema.result | 21 +++++++++++++++ .../common/system/information_schema.sql | 4 +++ 10 files changed, 79 insertions(+), 39 deletions(-) diff --git a/src/common/catalog/src/lib.rs b/src/common/catalog/src/lib.rs index 1a2596371709..e1cf4c201d48 100644 --- a/src/common/catalog/src/lib.rs +++ b/src/common/catalog/src/lib.rs @@ -55,10 +55,10 @@ pub fn build_db_string(catalog: &str, schema: &str) -> String { /// schema name /// - if `[-]` is provided, we split database name with `-` and use /// `` and ``. -pub fn parse_catalog_and_schema_from_db_string(db: &str) -> (&str, &str) { +pub fn parse_catalog_and_schema_from_db_string(db: &str) -> (String, String) { match parse_optional_catalog_and_schema_from_db_string(db) { (Some(catalog), schema) => (catalog, schema), - (None, schema) => (DEFAULT_CATALOG_NAME, schema), + (None, schema) => (DEFAULT_CATALOG_NAME.to_string(), schema), } } @@ -66,12 +66,12 @@ pub fn parse_catalog_and_schema_from_db_string(db: &str) -> (&str, &str) { /// /// Similar to [`parse_catalog_and_schema_from_db_string`] but returns an optional /// catalog if it's not provided in the database name. -pub fn parse_optional_catalog_and_schema_from_db_string(db: &str) -> (Option<&str>, &str) { +pub fn parse_optional_catalog_and_schema_from_db_string(db: &str) -> (Option, String) { let parts = db.splitn(2, '-').collect::>(); if parts.len() == 2 { - (Some(parts[0]), parts[1]) + (Some(parts[0].to_lowercase()), parts[1].to_lowercase()) } else { - (None, db) + (None, db.to_lowercase()) } } @@ -88,32 +88,37 @@ mod tests { #[test] fn test_parse_catalog_and_schema() { assert_eq!( - (DEFAULT_CATALOG_NAME, "fullschema"), + (DEFAULT_CATALOG_NAME.to_string(), "fullschema".to_string()), parse_catalog_and_schema_from_db_string("fullschema") ); assert_eq!( - ("catalog", "schema"), + ("catalog".to_string(), "schema".to_string()), parse_catalog_and_schema_from_db_string("catalog-schema") ); assert_eq!( - ("catalog", "schema1-schema2"), + ("catalog".to_string(), "schema1-schema2".to_string()), parse_catalog_and_schema_from_db_string("catalog-schema1-schema2") ); assert_eq!( - (None, "fullschema"), + (None, "fullschema".to_string()), parse_optional_catalog_and_schema_from_db_string("fullschema") ); assert_eq!( - (Some("catalog"), "schema"), + (Some("catalog".to_string()), "schema".to_string()), parse_optional_catalog_and_schema_from_db_string("catalog-schema") ); assert_eq!( - (Some("catalog"), "schema1-schema2"), + (Some("catalog".to_string()), "schema".to_string()), + parse_optional_catalog_and_schema_from_db_string("CATALOG-SCHEMA") + ); + + assert_eq!( + (Some("catalog".to_string()), "schema1-schema2".to_string()), parse_optional_catalog_and_schema_from_db_string("catalog-schema1-schema2") ); } diff --git a/src/servers/src/grpc/authorize.rs b/src/servers/src/grpc/authorize.rs index 84e203d3730e..ae003640ea4b 100644 --- a/src/servers/src/grpc/authorize.rs +++ b/src/servers/src/grpc/authorize.rs @@ -104,7 +104,7 @@ async fn do_auth( ) -> Result<(), tonic::Status> { let (catalog, schema) = extract_catalog_and_schema(req); - let query_ctx = QueryContext::with(catalog, schema); + let query_ctx = QueryContext::with(&catalog, &schema); let Some(user_provider) = user_provider else { query_ctx.set_current_user(Some(auth::userinfo_by_name(None))); @@ -119,7 +119,7 @@ async fn do_auth( let pwd = auth::Password::PlainText(password); let user_info = user_provider - .auth(id, pwd, catalog, schema) + .auth(id, pwd, &catalog, &schema) .await .map_err(|e| tonic::Status::unauthenticated(e.to_string()))?; diff --git a/src/servers/src/grpc/greptime_handler.rs b/src/servers/src/grpc/greptime_handler.rs index 19a4e1d373e0..a79217e6ee09 100644 --- a/src/servers/src/grpc/greptime_handler.rs +++ b/src/servers/src/grpc/greptime_handler.rs @@ -166,23 +166,28 @@ pub(crate) fn create_query_context(header: Option<&RequestHeader>) -> QueryConte } else { ( if !header.catalog.is_empty() { - &header.catalog + header.catalog.to_lowercase() } else { - DEFAULT_CATALOG_NAME + DEFAULT_CATALOG_NAME.to_string() }, if !header.schema.is_empty() { - &header.schema + header.schema.to_lowercase() } else { - DEFAULT_SCHEMA_NAME + DEFAULT_SCHEMA_NAME.to_string() }, ) } }) - .unwrap_or((DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME)); + .unwrap_or_else(|| { + ( + DEFAULT_CATALOG_NAME.to_string(), + DEFAULT_SCHEMA_NAME.to_string(), + ) + }); let timezone = parse_timezone(header.map(|h| h.timezone.as_str())); QueryContextBuilder::default() - .current_catalog(catalog.to_string()) - .current_schema(schema.to_string()) + .current_catalog(catalog) + .current_schema(schema) .timezone(Arc::new(timezone)) .build() } diff --git a/src/servers/src/http/authorize.rs b/src/servers/src/http/authorize.rs index de99828fb33e..12c270c43cda 100644 --- a/src/servers/src/http/authorize.rs +++ b/src/servers/src/http/authorize.rs @@ -64,8 +64,8 @@ pub async fn inner_auth( // TODO(ruihang): move this out of auth module let timezone = Arc::new(extract_timezone(&req)); let query_ctx_builder = QueryContextBuilder::default() - .current_catalog(catalog.to_string()) - .current_schema(schema.to_string()) + .current_catalog(catalog.clone()) + .current_schema(schema.clone()) .timezone(timezone); let query_ctx = query_ctx_builder.build(); @@ -97,8 +97,8 @@ pub async fn inner_auth( .auth( auth::Identity::UserId(&username, None), auth::Password::PlainText(password), - catalog, - schema, + &catalog, + &schema, ) .await { @@ -132,7 +132,7 @@ fn err_response(err: impl ErrorExt) -> Response { (StatusCode::UNAUTHORIZED, ErrorResponse::from_error(err)).into_response() } -pub fn extract_catalog_and_schema(request: &Request) -> (&str, &str) { +pub fn extract_catalog_and_schema(request: &Request) -> (String, String) { // parse database from header let dbname = request .headers() @@ -414,7 +414,7 @@ mod tests { .unwrap(); let db = extract_catalog_and_schema(&req); - assert_eq!(db, ("greptime", "tomcat")); + assert_eq!(db, ("greptime".to_string(), "tomcat".to_string())); } #[test] diff --git a/src/servers/src/http/prometheus.rs b/src/servers/src/http/prometheus.rs index af5567993fac..21e5b4c2ccd0 100644 --- a/src/servers/src/http/prometheus.rs +++ b/src/servers/src/http/prometheus.rs @@ -255,7 +255,7 @@ pub async fn labels_query( queries = form_params.matches.0; } if queries.is_empty() { - match get_all_column_names(catalog, schema, &handler.catalog_manager()).await { + match get_all_column_names(&catalog, &schema, &handler.catalog_manager()).await { Ok(labels) => { return PrometheusJsonResponse::success(PrometheusResponse::Labels(labels)) } @@ -530,7 +530,11 @@ pub async fn label_values_query( let (catalog, schema) = parse_catalog_and_schema_from_db_string(db); if label_name == METRIC_NAME_LABEL { - let mut table_names = match handler.catalog_manager().table_names(catalog, schema).await { + let mut table_names = match handler + .catalog_manager() + .table_names(&catalog, &schema) + .await + { Ok(table_names) => table_names, Err(e) => { return PrometheusJsonResponse::error(e.status_code().to_string(), e.output_msg()); diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index 9fe088cb6604..9e43aea7b42b 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -371,13 +371,17 @@ impl AsyncMysqlShim for MysqlInstanceShi async fn on_init<'a>(&'a mut self, database: &'a str, w: InitWriter<'a, W>) -> Result<()> { let (catalog_from_db, schema) = parse_optional_catalog_and_schema_from_db_string(database); - let catalog = if let Some(catalog) = catalog_from_db { - catalog.to_owned() + let catalog = if let Some(catalog) = &catalog_from_db { + catalog.to_string() } else { self.session.get_catalog() }; - if !self.query_handler.is_valid_schema(&catalog, schema).await? { + if !self + .query_handler + .is_valid_schema(&catalog, &schema) + .await? + { return w .error( ErrorKind::ER_WRONG_DB_NAME, @@ -391,7 +395,7 @@ impl AsyncMysqlShim for MysqlInstanceShi if let Some(schema_validator) = &self.user_provider { if let Err(e) = schema_validator - .authorize(&catalog, schema, user_info) + .authorize(&catalog, &schema, user_info) .await { METRIC_AUTH_FAILURE @@ -410,7 +414,7 @@ impl AsyncMysqlShim for MysqlInstanceShi if catalog_from_db.is_some() { self.session.set_catalog(catalog) } - self.session.set_schema(schema.into()); + self.session.set_schema(schema); w.ok().await.map_err(|e| e.into()) } diff --git a/src/servers/src/postgres/auth_handler.rs b/src/servers/src/postgres/auth_handler.rs index 3708f6f57a53..da316d04cf42 100644 --- a/src/servers/src/postgres/auth_handler.rs +++ b/src/servers/src/postgres/auth_handler.rs @@ -237,14 +237,11 @@ where if let Some(db) = db_ref { let (catalog, schema) = parse_catalog_and_schema_from_db_string(db); if query_handler - .is_valid_schema(catalog, schema) + .is_valid_schema(&catalog, &schema) .await .map_err(|e| PgWireError::ApiError(Box::new(e)))? { - Ok(DbResolution::Resolved( - catalog.to_owned(), - schema.to_owned(), - )) + Ok(DbResolution::Resolved(catalog, schema)) } else { Ok(DbResolution::NotFound(format!("Database not found: {db}"))) } diff --git a/src/session/src/context.rs b/src/session/src/context.rs index d401b0331637..ab1e468dc6e3 100644 --- a/src/session/src/context.rs +++ b/src/session/src/context.rs @@ -114,7 +114,7 @@ impl QueryContext { let (catalog, schema) = db_name .map(|db| { let (catalog, schema) = parse_catalog_and_schema_from_db_string(db); - (catalog.to_string(), schema.to_string()) + (catalog, schema) }) .unwrap_or_else(|| { ( diff --git a/tests/cases/standalone/common/system/information_schema.result b/tests/cases/standalone/common/system/information_schema.result index a6dec1d3bafc..1409b2f85a5c 100644 --- a/tests/cases/standalone/common/system/information_schema.result +++ b/tests/cases/standalone/common/system/information_schema.result @@ -673,6 +673,27 @@ DESC TABLE GREPTIME_REGION_PEERS; | down_seconds | Int64 | | YES | | FIELD | +--------------+--------+-----+------+---------+---------------+ +USE INFORMATION_SCHEMA; + +Affected Rows: 0 + +DESC COLUMNS; + ++----------------+--------+-----+------+---------+---------------+ +| Column | Type | Key | Null | Default | Semantic Type | ++----------------+--------+-----+------+---------+---------------+ +| table_catalog | String | | NO | | FIELD | +| table_schema | String | | NO | | FIELD | +| table_name | String | | NO | | FIELD | +| column_name | String | | NO | | FIELD | +| data_type | String | | NO | | FIELD | +| semantic_type | String | | NO | | FIELD | +| column_default | String | | YES | | FIELD | +| is_nullable | String | | NO | | FIELD | +| column_type | String | | NO | | FIELD | +| column_comment | String | | YES | | FIELD | ++----------------+--------+-----+------+---------+---------------+ + drop table my_db.foo; Error: 4001(TableNotFound), Table not found: greptime.my_db.foo diff --git a/tests/cases/standalone/common/system/information_schema.sql b/tests/cases/standalone/common/system/information_schema.sql index 76261d1c665b..d54c2c0ebd51 100644 --- a/tests/cases/standalone/common/system/information_schema.sql +++ b/tests/cases/standalone/common/system/information_schema.sql @@ -119,6 +119,10 @@ DESC TABLE RUNTIME_METRICS; DESC TABLE GREPTIME_REGION_PEERS; +USE INFORMATION_SCHEMA; + +DESC COLUMNS; + drop table my_db.foo; use public;