Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: canonicalize catalog and schema names #3600

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions src/common/catalog/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,23 +55,23 @@ pub fn build_db_string(catalog: &str, schema: &str) -> String {
/// schema name
/// - if `[<catalog>-]` is provided, we split database name with `-` and use
/// `<catalog>` and `<schema>`.
pub fn parse_catalog_and_schema_from_db_string(db: &str) -> (&str, &str) {
pub fn parse_catalog_and_schema_from_db_string(db: &str) -> (String, String) {
match parse_optional_catalog_and_schema_from_db_string(db) {
(Some(catalog), schema) => (catalog, schema),
(None, schema) => (DEFAULT_CATALOG_NAME, schema),
(None, schema) => (DEFAULT_CATALOG_NAME.to_string(), schema),
}
}

/// Attempt to parse catalog and schema from given database name
///
/// Similar to [`parse_catalog_and_schema_from_db_string`] but returns an optional
/// catalog if it's not provided in the database name.
pub fn parse_optional_catalog_and_schema_from_db_string(db: &str) -> (Option<&str>, &str) {
pub fn parse_optional_catalog_and_schema_from_db_string(db: &str) -> (Option<String>, String) {
let parts = db.splitn(2, '-').collect::<Vec<&str>>();
if parts.len() == 2 {
(Some(parts[0]), parts[1])
(Some(parts[0].to_lowercase()), parts[1].to_lowercase())
} else {
(None, db)
(None, db.to_lowercase())
}
}

Expand All @@ -88,32 +88,37 @@ mod tests {
#[test]
fn test_parse_catalog_and_schema() {
assert_eq!(
(DEFAULT_CATALOG_NAME, "fullschema"),
(DEFAULT_CATALOG_NAME.to_string(), "fullschema".to_string()),
parse_catalog_and_schema_from_db_string("fullschema")
);

assert_eq!(
("catalog", "schema"),
("catalog".to_string(), "schema".to_string()),
parse_catalog_and_schema_from_db_string("catalog-schema")
);

assert_eq!(
("catalog", "schema1-schema2"),
("catalog".to_string(), "schema1-schema2".to_string()),
parse_catalog_and_schema_from_db_string("catalog-schema1-schema2")
);

assert_eq!(
(None, "fullschema"),
(None, "fullschema".to_string()),
parse_optional_catalog_and_schema_from_db_string("fullschema")
);

assert_eq!(
(Some("catalog"), "schema"),
(Some("catalog".to_string()), "schema".to_string()),
parse_optional_catalog_and_schema_from_db_string("catalog-schema")
);

assert_eq!(
(Some("catalog"), "schema1-schema2"),
(Some("catalog".to_string()), "schema".to_string()),
parse_optional_catalog_and_schema_from_db_string("CATALOG-SCHEMA")
);

assert_eq!(
(Some("catalog".to_string()), "schema1-schema2".to_string()),
parse_optional_catalog_and_schema_from_db_string("catalog-schema1-schema2")
);
}
Expand Down
4 changes: 2 additions & 2 deletions src/servers/src/grpc/authorize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ async fn do_auth<T>(
) -> Result<(), tonic::Status> {
let (catalog, schema) = extract_catalog_and_schema(req);

let query_ctx = QueryContext::with(catalog, schema);
let query_ctx = QueryContext::with(&catalog, &schema);

let Some(user_provider) = user_provider else {
query_ctx.set_current_user(Some(auth::userinfo_by_name(None)));
Expand All @@ -119,7 +119,7 @@ async fn do_auth<T>(
let pwd = auth::Password::PlainText(password);

let user_info = user_provider
.auth(id, pwd, catalog, schema)
.auth(id, pwd, &catalog, &schema)
.await
.map_err(|e| tonic::Status::unauthenticated(e.to_string()))?;

Expand Down
19 changes: 12 additions & 7 deletions src/servers/src/grpc/greptime_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,23 +166,28 @@ pub(crate) fn create_query_context(header: Option<&RequestHeader>) -> QueryConte
} else {
(
if !header.catalog.is_empty() {
&header.catalog
header.catalog.to_lowercase()
} else {
DEFAULT_CATALOG_NAME
DEFAULT_CATALOG_NAME.to_string()
},
if !header.schema.is_empty() {
&header.schema
header.schema.to_lowercase()
} else {
DEFAULT_SCHEMA_NAME
DEFAULT_SCHEMA_NAME.to_string()
},
)
}
})
.unwrap_or((DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME));
.unwrap_or_else(|| {
(
DEFAULT_CATALOG_NAME.to_string(),
DEFAULT_SCHEMA_NAME.to_string(),
)
});
let timezone = parse_timezone(header.map(|h| h.timezone.as_str()));
QueryContextBuilder::default()
.current_catalog(catalog.to_string())
.current_schema(schema.to_string())
.current_catalog(catalog)
.current_schema(schema)
.timezone(Arc::new(timezone))
.build()
}
Expand Down
12 changes: 6 additions & 6 deletions src/servers/src/http/authorize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ pub async fn inner_auth<B>(
// TODO(ruihang): move this out of auth module
let timezone = Arc::new(extract_timezone(&req));
let query_ctx_builder = QueryContextBuilder::default()
.current_catalog(catalog.to_string())
.current_schema(schema.to_string())
.current_catalog(catalog.clone())
.current_schema(schema.clone())
.timezone(timezone);

let query_ctx = query_ctx_builder.build();
Expand Down Expand Up @@ -97,8 +97,8 @@ pub async fn inner_auth<B>(
.auth(
auth::Identity::UserId(&username, None),
auth::Password::PlainText(password),
catalog,
schema,
&catalog,
&schema,
)
.await
{
Expand Down Expand Up @@ -132,7 +132,7 @@ fn err_response(err: impl ErrorExt) -> Response {
(StatusCode::UNAUTHORIZED, ErrorResponse::from_error(err)).into_response()
}

pub fn extract_catalog_and_schema<B>(request: &Request<B>) -> (&str, &str) {
pub fn extract_catalog_and_schema<B>(request: &Request<B>) -> (String, String) {
// parse database from header
let dbname = request
.headers()
Expand Down Expand Up @@ -414,7 +414,7 @@ mod tests {
.unwrap();

let db = extract_catalog_and_schema(&req);
assert_eq!(db, ("greptime", "tomcat"));
assert_eq!(db, ("greptime".to_string(), "tomcat".to_string()));
}

#[test]
Expand Down
8 changes: 6 additions & 2 deletions src/servers/src/http/prometheus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ pub async fn labels_query(
queries = form_params.matches.0;
}
if queries.is_empty() {
match get_all_column_names(catalog, schema, &handler.catalog_manager()).await {
match get_all_column_names(&catalog, &schema, &handler.catalog_manager()).await {
Ok(labels) => {
return PrometheusJsonResponse::success(PrometheusResponse::Labels(labels))
}
Expand Down Expand Up @@ -530,7 +530,11 @@ pub async fn label_values_query(
let (catalog, schema) = parse_catalog_and_schema_from_db_string(db);

if label_name == METRIC_NAME_LABEL {
let mut table_names = match handler.catalog_manager().table_names(catalog, schema).await {
let mut table_names = match handler
.catalog_manager()
.table_names(&catalog, &schema)
.await
{
Ok(table_names) => table_names,
Err(e) => {
return PrometheusJsonResponse::error(e.status_code().to_string(), e.output_msg());
Expand Down
14 changes: 9 additions & 5 deletions src/servers/src/mysql/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -371,13 +371,17 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi

async fn on_init<'a>(&'a mut self, database: &'a str, w: InitWriter<'a, W>) -> Result<()> {
let (catalog_from_db, schema) = parse_optional_catalog_and_schema_from_db_string(database);
let catalog = if let Some(catalog) = catalog_from_db {
catalog.to_owned()
let catalog = if let Some(catalog) = &catalog_from_db {
catalog.to_string()
} else {
self.session.get_catalog()
};

if !self.query_handler.is_valid_schema(&catalog, schema).await? {
if !self
.query_handler
.is_valid_schema(&catalog, &schema)
.await?
{
return w
.error(
ErrorKind::ER_WRONG_DB_NAME,
Expand All @@ -391,7 +395,7 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi

if let Some(schema_validator) = &self.user_provider {
if let Err(e) = schema_validator
.authorize(&catalog, schema, user_info)
.authorize(&catalog, &schema, user_info)
.await
{
METRIC_AUTH_FAILURE
Expand All @@ -410,7 +414,7 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
if catalog_from_db.is_some() {
self.session.set_catalog(catalog)
}
self.session.set_schema(schema.into());
self.session.set_schema(schema);

w.ok().await.map_err(|e| e.into())
}
Expand Down
7 changes: 2 additions & 5 deletions src/servers/src/postgres/auth_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,14 +237,11 @@ where
if let Some(db) = db_ref {
let (catalog, schema) = parse_catalog_and_schema_from_db_string(db);
if query_handler
.is_valid_schema(catalog, schema)
.is_valid_schema(&catalog, &schema)
.await
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
{
Ok(DbResolution::Resolved(
catalog.to_owned(),
schema.to_owned(),
))
Ok(DbResolution::Resolved(catalog, schema))
} else {
Ok(DbResolution::NotFound(format!("Database not found: {db}")))
}
Expand Down
2 changes: 1 addition & 1 deletion src/session/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ impl QueryContext {
let (catalog, schema) = db_name
.map(|db| {
let (catalog, schema) = parse_catalog_and_schema_from_db_string(db);
(catalog.to_string(), schema.to_string())
(catalog, schema)
})
.unwrap_or_else(|| {
(
Expand Down
21 changes: 21 additions & 0 deletions tests/cases/standalone/common/system/information_schema.result
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,27 @@ DESC TABLE GREPTIME_REGION_PEERS;
| down_seconds | Int64 | | YES | | FIELD |
+--------------+--------+-----+------+---------+---------------+

USE INFORMATION_SCHEMA;

Affected Rows: 0

DESC COLUMNS;

+----------------+--------+-----+------+---------+---------------+
| Column | Type | Key | Null | Default | Semantic Type |
+----------------+--------+-----+------+---------+---------------+
| table_catalog | String | | NO | | FIELD |
| table_schema | String | | NO | | FIELD |
| table_name | String | | NO | | FIELD |
| column_name | String | | NO | | FIELD |
| data_type | String | | NO | | FIELD |
| semantic_type | String | | NO | | FIELD |
| column_default | String | | YES | | FIELD |
| is_nullable | String | | NO | | FIELD |
| column_type | String | | NO | | FIELD |
| column_comment | String | | YES | | FIELD |
+----------------+--------+-----+------+---------+---------------+

drop table my_db.foo;

Error: 4001(TableNotFound), Table not found: greptime.my_db.foo
Expand Down
4 changes: 4 additions & 0 deletions tests/cases/standalone/common/system/information_schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ DESC TABLE RUNTIME_METRICS;

DESC TABLE GREPTIME_REGION_PEERS;

USE INFORMATION_SCHEMA;

DESC COLUMNS;

drop table my_db.foo;

use public;
Loading