Skip to content

Commit 984d805

Browse files
committed
feat: support different USE statement syntaxes
1 parent 11a6e6f commit 984d805

10 files changed

+391
-9
lines changed

src/ast/mod.rs

+26-6
Original file line numberDiff line numberDiff line change
@@ -2515,11 +2515,13 @@ pub enum Statement {
25152515
/// Note: this is a MySQL-specific statement.
25162516
ShowCollation { filter: Option<ShowStatementFilter> },
25172517
/// ```sql
2518-
/// USE
2518+
/// USE [DATABASE|SCHEMA|CATALOG|...] [<db_name>.<schema_name>|<db_name>|<schema_name>]
25192519
/// ```
2520-
///
2521-
/// Note: This is a MySQL-specific statement.
2522-
Use { db_name: Ident },
2520+
Use {
2521+
db_name: Option<Ident>,
2522+
schema_name: Option<Ident>,
2523+
keyword: Option<String>,
2524+
},
25232525
/// ```sql
25242526
/// START [ TRANSACTION | WORK ] | START TRANSACTION } ...
25252527
/// ```
@@ -4125,8 +4127,26 @@ impl fmt::Display for Statement {
41254127
}
41264128
Ok(())
41274129
}
4128-
Statement::Use { db_name } => {
4129-
write!(f, "USE {db_name}")?;
4130+
Statement::Use {
4131+
db_name,
4132+
schema_name,
4133+
keyword,
4134+
} => {
4135+
write!(f, "USE")?;
4136+
4137+
if let Some(kw) = keyword.as_ref() {
4138+
write!(f, " {}", kw)?;
4139+
}
4140+
4141+
if let Some(db_name) = db_name {
4142+
write!(f, " {}", db_name)?;
4143+
if let Some(schema_name) = schema_name {
4144+
write!(f, ".{}", schema_name)?;
4145+
}
4146+
} else if let Some(schema_name) = schema_name {
4147+
write!(f, " {}", schema_name)?;
4148+
}
4149+
41304150
Ok(())
41314151
}
41324152
Statement::ShowCollation { filter } => {

src/keywords.rs

+1
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ define_keywords!(
137137
CASCADED,
138138
CASE,
139139
CAST,
140+
CATALOG,
140141
CEIL,
141142
CEILING,
142143
CENTURY,

src/parser/mod.rs

+57-2
Original file line numberDiff line numberDiff line change
@@ -9225,8 +9225,63 @@ impl<'a> Parser<'a> {
92259225
}
92269226

92279227
pub fn parse_use(&mut self) -> Result<Statement, ParserError> {
9228-
let db_name = self.parse_identifier(false)?;
9229-
Ok(Statement::Use { db_name })
9228+
// What should be treated as keyword in given dialect
9229+
let allowed_keywords = if dialect_of!(self is HiveDialect) {
9230+
vec![Keyword::DEFAULT]
9231+
} else if dialect_of!(self is DatabricksDialect) {
9232+
vec![Keyword::CATALOG, Keyword::DATABASE, Keyword::SCHEMA]
9233+
} else if dialect_of!(self is SnowflakeDialect) {
9234+
vec![Keyword::DATABASE, Keyword::SCHEMA]
9235+
} else {
9236+
vec![]
9237+
};
9238+
let parsed_keyword = self.parse_one_of_keywords(&allowed_keywords);
9239+
9240+
// Hive dialect accepts USE DEFAULT; statement without any db specified
9241+
if dialect_of!(self is HiveDialect) && parsed_keyword == Some(Keyword::DEFAULT) {
9242+
return Ok(Statement::Use {
9243+
db_name: None,
9244+
schema_name: None,
9245+
keyword: Some("DEFAULT".to_string()),
9246+
});
9247+
}
9248+
9249+
// Parse the object name, which might be a single identifier or fully qualified name (e.g., x.y)
9250+
let parts = self.parse_object_name(false)?.0;
9251+
let (db_name, schema_name) = match parts.len() {
9252+
1 => {
9253+
// Single identifier found
9254+
if dialect_of!(self is DatabricksDialect) {
9255+
if parsed_keyword == Some(Keyword::CATALOG) {
9256+
// Databricks: CATALOG keyword provided, treat as database name
9257+
(Some(parts[0].clone()), None)
9258+
} else {
9259+
// Databricks: DATABASE, SCHEMA or no keyword provided, treat as schema name
9260+
(None, Some(parts[0].clone()))
9261+
}
9262+
} else if dialect_of!(self is SnowflakeDialect)
9263+
&& parsed_keyword == Some(Keyword::SCHEMA)
9264+
{
9265+
// Snowflake: SCHEMA keyword provided, treat as schema name
9266+
(None, Some(parts[0].clone()))
9267+
} else {
9268+
// Other dialects: treat as database name by default
9269+
(Some(parts[0].clone()), None)
9270+
}
9271+
}
9272+
2 => (Some(parts[0].clone()), Some(parts[1].clone())),
9273+
_ => {
9274+
return Err(ParserError::ParserError(
9275+
"Invalid format in the USE statement".to_string(),
9276+
))
9277+
}
9278+
};
9279+
9280+
Ok(Statement::Use {
9281+
db_name,
9282+
schema_name,
9283+
keyword: parsed_keyword.map(|kw| format!("{:?}", kw)),
9284+
})
92309285
}
92319286

92329287
pub fn parse_table_and_joins(&mut self) -> Result<TableWithJoins, ParserError> {

tests/sqlparser_clickhouse.rs

+36
Original file line numberDiff line numberDiff line change
@@ -1160,6 +1160,42 @@ fn test_prewhere() {
11601160
}
11611161
}
11621162

1163+
#[test]
1164+
fn parse_use() {
1165+
assert_eq!(
1166+
clickhouse().verified_stmt("USE mydb"),
1167+
Statement::Use {
1168+
db_name: Some(Ident::new("mydb")),
1169+
schema_name: None,
1170+
keyword: None
1171+
}
1172+
);
1173+
assert_eq!(
1174+
clickhouse().verified_stmt("USE DATABASE"),
1175+
Statement::Use {
1176+
db_name: Some(Ident::new("DATABASE")),
1177+
schema_name: None,
1178+
keyword: None
1179+
}
1180+
);
1181+
assert_eq!(
1182+
clickhouse().verified_stmt("USE SCHEMA"),
1183+
Statement::Use {
1184+
db_name: Some(Ident::new("SCHEMA")),
1185+
schema_name: None,
1186+
keyword: None
1187+
}
1188+
);
1189+
assert_eq!(
1190+
clickhouse().verified_stmt("USE CATALOG"),
1191+
Statement::Use {
1192+
db_name: Some(Ident::new("CATALOG")),
1193+
schema_name: None,
1194+
keyword: None
1195+
}
1196+
);
1197+
}
1198+
11631199
#[test]
11641200
fn test_query_with_format_clause() {
11651201
let format_options = vec!["TabSeparated", "JSONCompact", "NULL"];

tests/sqlparser_databricks.rs

+52
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,55 @@ fn test_values_clause() {
189189
// TODO: support this example from https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-qry-select-values.html#examples
190190
// databricks().verified_query("VALUES 1, 2, 3");
191191
}
192+
193+
#[test]
194+
fn parse_use() {
195+
assert_eq!(
196+
databricks().verified_stmt("USE my_schema"),
197+
Statement::Use {
198+
db_name: None,
199+
schema_name: Some(Ident::new("my_schema")),
200+
keyword: None
201+
}
202+
);
203+
assert_eq!(
204+
databricks().verified_stmt("USE CATALOG my_catalog"),
205+
Statement::Use {
206+
db_name: Some(Ident::new("my_catalog")),
207+
schema_name: None,
208+
keyword: Some("CATALOG".to_string())
209+
}
210+
);
211+
assert_eq!(
212+
databricks().verified_stmt("USE CATALOG 'my_catalog'"),
213+
Statement::Use {
214+
db_name: Some(Ident::with_quote('\'', "my_catalog")),
215+
schema_name: None,
216+
keyword: Some("CATALOG".to_string())
217+
}
218+
);
219+
assert_eq!(
220+
databricks().verified_stmt("USE DATABASE my_schema"),
221+
Statement::Use {
222+
db_name: None,
223+
schema_name: Some(Ident::new("my_schema")),
224+
keyword: Some("DATABASE".to_string())
225+
}
226+
);
227+
assert_eq!(
228+
databricks().verified_stmt("USE SCHEMA my_schema"),
229+
Statement::Use {
230+
db_name: None,
231+
schema_name: Some(Ident::new("my_schema")),
232+
keyword: Some("SCHEMA".to_string())
233+
}
234+
);
235+
236+
let invalid_cases = ["USE SCHEMA", "USE DATABASE", "USE CATALOG"];
237+
for sql in &invalid_cases {
238+
assert_eq!(
239+
databricks().parse_sql_statements(sql).unwrap_err(),
240+
ParserError::ParserError("Expected: identifier, found: EOF".to_string()),
241+
);
242+
}
243+
}

tests/sqlparser_duckdb.rs

+52
Original file line numberDiff line numberDiff line change
@@ -756,3 +756,55 @@ fn test_duckdb_union_datatype() {
756756
stmt
757757
);
758758
}
759+
760+
#[test]
761+
fn parse_use() {
762+
std::assert_eq!(
763+
duckdb().verified_stmt("USE mydb"),
764+
Statement::Use {
765+
db_name: Some(Ident::new("mydb")),
766+
schema_name: None,
767+
keyword: None
768+
}
769+
);
770+
std::assert_eq!(
771+
duckdb().verified_stmt("USE mydb.my_schema"),
772+
Statement::Use {
773+
db_name: Some(Ident::new("mydb")),
774+
schema_name: Some(Ident::new("my_schema")),
775+
keyword: None
776+
}
777+
);
778+
assert_eq!(
779+
duckdb().verified_stmt("USE DATABASE"),
780+
Statement::Use {
781+
db_name: Some(Ident::new("DATABASE")),
782+
schema_name: None,
783+
keyword: None
784+
}
785+
);
786+
assert_eq!(
787+
duckdb().verified_stmt("USE SCHEMA"),
788+
Statement::Use {
789+
db_name: Some(Ident::new("SCHEMA")),
790+
schema_name: None,
791+
keyword: None
792+
}
793+
);
794+
assert_eq!(
795+
duckdb().verified_stmt("USE CATALOG"),
796+
Statement::Use {
797+
db_name: Some(Ident::new("CATALOG")),
798+
schema_name: None,
799+
keyword: None
800+
}
801+
);
802+
assert_eq!(
803+
duckdb().verified_stmt("USE CATALOG.SCHEMA"),
804+
Statement::Use {
805+
db_name: Some(Ident::new("CATALOG")),
806+
schema_name: Some(Ident::new("SCHEMA")),
807+
keyword: None
808+
}
809+
);
810+
}

tests/sqlparser_hive.rs

+44
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,50 @@ fn parse_delimited_identifiers() {
401401
//TODO verified_stmt(r#"UPDATE foo SET "bar" = 5"#);
402402
}
403403

404+
#[test]
405+
fn parse_use() {
406+
assert_eq!(
407+
hive().verified_stmt("USE mydb"),
408+
Statement::Use {
409+
db_name: Some(Ident::new("mydb")),
410+
schema_name: None,
411+
keyword: None
412+
}
413+
);
414+
assert_eq!(
415+
hive().verified_stmt("USE DEFAULT"),
416+
Statement::Use {
417+
db_name: None,
418+
schema_name: None,
419+
keyword: Some("DEFAULT".to_string()) // Yes, as keyword not db_name
420+
}
421+
);
422+
assert_eq!(
423+
hive().verified_stmt("USE DATABASE"),
424+
Statement::Use {
425+
db_name: Some(Ident::new("DATABASE")),
426+
schema_name: None,
427+
keyword: None
428+
}
429+
);
430+
assert_eq!(
431+
hive().verified_stmt("USE SCHEMA"),
432+
Statement::Use {
433+
db_name: Some(Ident::new("SCHEMA")),
434+
schema_name: None,
435+
keyword: None
436+
}
437+
);
438+
assert_eq!(
439+
hive().verified_stmt("USE CATALOG"),
440+
Statement::Use {
441+
db_name: Some(Ident::new("CATALOG")),
442+
schema_name: None,
443+
keyword: None
444+
}
445+
);
446+
}
447+
404448
fn hive() -> TestedDialects {
405449
TestedDialects {
406450
dialects: vec![Box::new(HiveDialect {})],

tests/sqlparser_mssql.rs

+36
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,42 @@ fn parse_mssql_declare() {
621621
);
622622
}
623623

624+
#[test]
625+
fn parse_use() {
626+
assert_eq!(
627+
ms().verified_stmt("USE mydb"),
628+
Statement::Use {
629+
db_name: Some(Ident::new("mydb")),
630+
schema_name: None,
631+
keyword: None
632+
}
633+
);
634+
assert_eq!(
635+
ms().verified_stmt("USE DATABASE"),
636+
Statement::Use {
637+
db_name: Some(Ident::new("DATABASE")),
638+
schema_name: None,
639+
keyword: None
640+
}
641+
);
642+
assert_eq!(
643+
ms().verified_stmt("USE SCHEMA"),
644+
Statement::Use {
645+
db_name: Some(Ident::new("SCHEMA")),
646+
schema_name: None,
647+
keyword: None
648+
}
649+
);
650+
assert_eq!(
651+
ms().verified_stmt("USE CATALOG"),
652+
Statement::Use {
653+
db_name: Some(Ident::new("CATALOG")),
654+
schema_name: None,
655+
keyword: None
656+
}
657+
);
658+
}
659+
624660
fn ms() -> TestedDialects {
625661
TestedDialects {
626662
dialects: vec![Box::new(MsSqlDialect {})],

0 commit comments

Comments
 (0)