Skip to content

Commit c3a71da

Browse files
committed
Change USE Statement to use Enum
1 parent 984d805 commit c3a71da

11 files changed

+316
-329
lines changed

src/ast/dcl.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,30 @@ impl fmt::Display for AlterRoleOperation {
193193
}
194194
}
195195
}
196+
197+
/// A `USE` (`Statement::Use`) operation
198+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
199+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
200+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
201+
pub enum Use {
202+
Catalog(ObjectName), // e.g. `USE CATALOG foo.bar`
203+
Schema(ObjectName), // e.g. `USE SCHEMA foo.bar`
204+
Database(ObjectName), // e.g. `USE DATABASE foo.bar`
205+
Warehouse(ObjectName), // e.g. `USE WAREHOUSE foo.bar`
206+
Object(ObjectName), // e.g. `USE foo.bar`
207+
Default, // e.g. `USE DEFAULT`
208+
}
209+
210+
impl fmt::Display for Use {
211+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
212+
f.write_str("USE ")?;
213+
match self {
214+
Use::Catalog(name) => write!(f, "CATALOG {}", name),
215+
Use::Schema(name) => write!(f, "SCHEMA {}", name),
216+
Use::Database(name) => write!(f, "DATABASE {}", name),
217+
Use::Warehouse(name) => write!(f, "WAREHOUSE {}", name),
218+
Use::Object(name) => write!(f, "{}", name),
219+
Use::Default => write!(f, "DEFAULT"),
220+
}
221+
}
222+
}

src/ast/mod.rs

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ pub use self::data_type::{
3131
ArrayElemTypeDef, CharLengthUnits, CharacterLength, DataType, ExactNumberInfo,
3232
StructBracketKind, TimezoneInfo,
3333
};
34-
pub use self::dcl::{AlterRoleOperation, ResetConfig, RoleOption, SetConfigValue};
34+
pub use self::dcl::{AlterRoleOperation, ResetConfig, RoleOption, SetConfigValue, Use};
3535
pub use self::ddl::{
3636
AlterColumnOperation, AlterIndexOperation, AlterTableOperation, ColumnDef, ColumnOption,
3737
ColumnOptionDef, ConstraintCharacteristics, Deduplicate, DeferrableInitial, GeneratedAs,
@@ -2515,13 +2515,9 @@ pub enum Statement {
25152515
/// Note: this is a MySQL-specific statement.
25162516
ShowCollation { filter: Option<ShowStatementFilter> },
25172517
/// ```sql
2518-
/// USE [DATABASE|SCHEMA|CATALOG|...] [<db_name>.<schema_name>|<db_name>|<schema_name>]
2518+
/// `USE ...`
25192519
/// ```
2520-
Use {
2521-
db_name: Option<Ident>,
2522-
schema_name: Option<Ident>,
2523-
keyword: Option<String>,
2524-
},
2520+
Use(Use),
25252521
/// ```sql
25262522
/// START [ TRANSACTION | WORK ] | START TRANSACTION } ...
25272523
/// ```
@@ -4127,28 +4123,7 @@ impl fmt::Display for Statement {
41274123
}
41284124
Ok(())
41294125
}
4130-
Statement::Use {
4131-
db_name,
4132-
schema_name,
4133-
keyword,
4134-
} => {
4135-
write!(f, "USE")?;
4136-
4137-
if let Some(kw) = keyword.as_ref() {
4138-
write!(f, " {}", kw)?;
4139-
}
4140-
4141-
if let Some(db_name) = db_name {
4142-
write!(f, " {}", db_name)?;
4143-
if let Some(schema_name) = schema_name {
4144-
write!(f, ".{}", schema_name)?;
4145-
}
4146-
} else if let Some(schema_name) = schema_name {
4147-
write!(f, " {}", schema_name)?;
4148-
}
4149-
4150-
Ok(())
4151-
}
4126+
Statement::Use(use_expr) => use_expr.fmt(f),
41524127
Statement::ShowCollation { filter } => {
41534128
write!(f, "SHOW COLLATION")?;
41544129
if let Some(filter) = filter {

src/keywords.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,7 @@ define_keywords!(
804804
VIEW,
805805
VIRTUAL,
806806
VOLATILE,
807+
WAREHOUSE,
807808
WEEK,
808809
WHEN,
809810
WHENEVER,

src/parser/mod.rs

Lines changed: 18 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -9225,63 +9225,31 @@ impl<'a> Parser<'a> {
92259225
}
92269226

92279227
pub fn parse_use(&mut self) -> Result<Statement, ParserError> {
9228-
// What should be treated as keyword in given dialect
9229-
let allowed_keywords = if dialect_of!(self is HiveDialect) {
9230-
vec![Keyword::DEFAULT]
9228+
// Determine which keywords are recognized by the current dialect
9229+
let parsed_keyword = if dialect_of!(self is HiveDialect) {
9230+
// HiveDialect accepts USE DEFAULT; statement without any db specified
9231+
if self.parse_keyword(Keyword::DEFAULT) {
9232+
return Ok(Statement::Use(Use::Default));
9233+
}
9234+
None // HiveDialect doesn't expect any other specific keyword after `USE`
92319235
} else if dialect_of!(self is DatabricksDialect) {
9232-
vec![Keyword::CATALOG, Keyword::DATABASE, Keyword::SCHEMA]
9236+
self.parse_one_of_keywords(&[Keyword::CATALOG, Keyword::DATABASE, Keyword::SCHEMA])
92339237
} else if dialect_of!(self is SnowflakeDialect) {
9234-
vec![Keyword::DATABASE, Keyword::SCHEMA]
9238+
self.parse_one_of_keywords(&[Keyword::DATABASE, Keyword::SCHEMA, Keyword::WAREHOUSE])
92359239
} else {
9236-
vec![]
9240+
None // No specific keywords for other dialects, including GenericDialect
92379241
};
9238-
let parsed_keyword = self.parse_one_of_keywords(&allowed_keywords);
92399242

9240-
// Hive dialect accepts USE DEFAULT; statement without any db specified
9241-
if dialect_of!(self is HiveDialect) && parsed_keyword == Some(Keyword::DEFAULT) {
9242-
return Ok(Statement::Use {
9243-
db_name: None,
9244-
schema_name: None,
9245-
keyword: Some("DEFAULT".to_string()),
9246-
});
9247-
}
9248-
9249-
// Parse the object name, which might be a single identifier or fully qualified name (e.g., x.y)
9250-
let parts = self.parse_object_name(false)?.0;
9251-
let (db_name, schema_name) = match parts.len() {
9252-
1 => {
9253-
// Single identifier found
9254-
if dialect_of!(self is DatabricksDialect) {
9255-
if parsed_keyword == Some(Keyword::CATALOG) {
9256-
// Databricks: CATALOG keyword provided, treat as database name
9257-
(Some(parts[0].clone()), None)
9258-
} else {
9259-
// Databricks: DATABASE, SCHEMA or no keyword provided, treat as schema name
9260-
(None, Some(parts[0].clone()))
9261-
}
9262-
} else if dialect_of!(self is SnowflakeDialect)
9263-
&& parsed_keyword == Some(Keyword::SCHEMA)
9264-
{
9265-
// Snowflake: SCHEMA keyword provided, treat as schema name
9266-
(None, Some(parts[0].clone()))
9267-
} else {
9268-
// Other dialects: treat as database name by default
9269-
(Some(parts[0].clone()), None)
9270-
}
9271-
}
9272-
2 => (Some(parts[0].clone()), Some(parts[1].clone())),
9273-
_ => {
9274-
return Err(ParserError::ParserError(
9275-
"Invalid format in the USE statement".to_string(),
9276-
))
9277-
}
9243+
let obj_name = self.parse_object_name(false)?;
9244+
let result = match parsed_keyword {
9245+
Some(Keyword::CATALOG) => Use::Catalog(obj_name),
9246+
Some(Keyword::DATABASE) => Use::Database(obj_name),
9247+
Some(Keyword::SCHEMA) => Use::Schema(obj_name),
9248+
Some(Keyword::WAREHOUSE) => Use::Warehouse(obj_name),
9249+
_ => Use::Object(obj_name),
92789250
};
92799251

9280-
Ok(Statement::Use {
9281-
db_name,
9282-
schema_name,
9283-
keyword: parsed_keyword.map(|kw| format!("{:?}", kw)),
9284-
})
9252+
Ok(Statement::Use(result))
92859253
}
92869254

92879255
pub fn parse_table_and_joins(&mut self) -> Result<TableWithJoins, ParserError> {

tests/sqlparser_clickhouse.rs

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,38 +1162,35 @@ fn test_prewhere() {
11621162

11631163
#[test]
11641164
fn parse_use() {
1165-
assert_eq!(
1166-
clickhouse().verified_stmt("USE mydb"),
1167-
Statement::Use {
1168-
db_name: Some(Ident::new("mydb")),
1169-
schema_name: None,
1170-
keyword: None
1171-
}
1172-
);
1173-
assert_eq!(
1174-
clickhouse().verified_stmt("USE DATABASE"),
1175-
Statement::Use {
1176-
db_name: Some(Ident::new("DATABASE")),
1177-
schema_name: None,
1178-
keyword: None
1179-
}
1180-
);
1181-
assert_eq!(
1182-
clickhouse().verified_stmt("USE SCHEMA"),
1183-
Statement::Use {
1184-
db_name: Some(Ident::new("SCHEMA")),
1185-
schema_name: None,
1186-
keyword: None
1187-
}
1188-
);
1189-
assert_eq!(
1190-
clickhouse().verified_stmt("USE CATALOG"),
1191-
Statement::Use {
1192-
db_name: Some(Ident::new("CATALOG")),
1193-
schema_name: None,
1194-
keyword: None
1165+
let valid_object_names = [
1166+
"mydb",
1167+
"SCHEMA",
1168+
"DATABASE",
1169+
"CATALOG",
1170+
"WAREHOUSE",
1171+
"DEFAULT",
1172+
];
1173+
let quote_styles = ['"', '`'];
1174+
1175+
for object_name in &valid_object_names {
1176+
// Test single identifier without quotes
1177+
assert_eq!(
1178+
clickhouse().verified_stmt(&format!("USE {}", object_name)),
1179+
Statement::Use(Use::Object(ObjectName(vec![Ident::new(
1180+
object_name.to_string()
1181+
)])))
1182+
);
1183+
for &quote in &quote_styles {
1184+
// Test single identifier with different type of quotes
1185+
assert_eq!(
1186+
clickhouse().verified_stmt(&format!("USE {0}{1}{0}", quote, object_name)),
1187+
Statement::Use(Use::Object(ObjectName(vec![Ident::with_quote(
1188+
quote,
1189+
object_name.to_string(),
1190+
)])))
1191+
);
11951192
}
1196-
);
1193+
}
11971194
}
11981195

11991196
#[test]

tests/sqlparser_databricks.rs

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -192,47 +192,69 @@ fn test_values_clause() {
192192

193193
#[test]
194194
fn parse_use() {
195-
assert_eq!(
196-
databricks().verified_stmt("USE my_schema"),
197-
Statement::Use {
198-
db_name: None,
199-
schema_name: Some(Ident::new("my_schema")),
200-
keyword: None
195+
let valid_object_names = ["mydb", "WAREHOUSE", "DEFAULT"];
196+
let quote_styles = ['"', '`'];
197+
198+
for object_name in &valid_object_names {
199+
// Test single identifier without quotes
200+
assert_eq!(
201+
databricks().verified_stmt(&format!("USE {}", object_name)),
202+
Statement::Use(Use::Object(ObjectName(vec![Ident::new(
203+
object_name.to_string()
204+
)])))
205+
);
206+
for &quote in &quote_styles {
207+
// Test single identifier with different type of quotes
208+
assert_eq!(
209+
databricks().verified_stmt(&format!("USE {0}{1}{0}", quote, object_name)),
210+
Statement::Use(Use::Object(ObjectName(vec![Ident::with_quote(
211+
quote,
212+
object_name.to_string(),
213+
)])))
214+
);
201215
}
202-
);
216+
}
217+
218+
for &quote in &quote_styles {
219+
// Test single identifier with keyword and different type of quotes
220+
assert_eq!(
221+
databricks().verified_stmt(&format!("USE CATALOG {0}my_catalog{0}", quote)),
222+
Statement::Use(Use::Catalog(ObjectName(vec![Ident::with_quote(
223+
quote,
224+
"my_catalog".to_string(),
225+
)])))
226+
);
227+
assert_eq!(
228+
databricks().verified_stmt(&format!("USE DATABASE {0}my_database{0}", quote)),
229+
Statement::Use(Use::Database(ObjectName(vec![Ident::with_quote(
230+
quote,
231+
"my_database".to_string(),
232+
)])))
233+
);
234+
assert_eq!(
235+
databricks().verified_stmt(&format!("USE SCHEMA {0}my_schema{0}", quote)),
236+
Statement::Use(Use::Schema(ObjectName(vec![Ident::with_quote(
237+
quote,
238+
"my_schema".to_string(),
239+
)])))
240+
);
241+
}
242+
243+
// Test single identifier with keyword and no quotes
203244
assert_eq!(
204245
databricks().verified_stmt("USE CATALOG my_catalog"),
205-
Statement::Use {
206-
db_name: Some(Ident::new("my_catalog")),
207-
schema_name: None,
208-
keyword: Some("CATALOG".to_string())
209-
}
210-
);
211-
assert_eq!(
212-
databricks().verified_stmt("USE CATALOG 'my_catalog'"),
213-
Statement::Use {
214-
db_name: Some(Ident::with_quote('\'', "my_catalog")),
215-
schema_name: None,
216-
keyword: Some("CATALOG".to_string())
217-
}
246+
Statement::Use(Use::Catalog(ObjectName(vec![Ident::new("my_catalog")])))
218247
);
219248
assert_eq!(
220249
databricks().verified_stmt("USE DATABASE my_schema"),
221-
Statement::Use {
222-
db_name: None,
223-
schema_name: Some(Ident::new("my_schema")),
224-
keyword: Some("DATABASE".to_string())
225-
}
250+
Statement::Use(Use::Database(ObjectName(vec![Ident::new("my_schema")])))
226251
);
227252
assert_eq!(
228253
databricks().verified_stmt("USE SCHEMA my_schema"),
229-
Statement::Use {
230-
db_name: None,
231-
schema_name: Some(Ident::new("my_schema")),
232-
keyword: Some("SCHEMA".to_string())
233-
}
254+
Statement::Use(Use::Schema(ObjectName(vec![Ident::new("my_schema")])))
234255
);
235256

257+
// Test invalid syntax - missing identifier
236258
let invalid_cases = ["USE SCHEMA", "USE DATABASE", "USE CATALOG"];
237259
for sql in &invalid_cases {
238260
assert_eq!(

0 commit comments

Comments
 (0)