From 06c68ff3260507045bee8a57c92198bb705ed4a4 Mon Sep 17 00:00:00 2001 From: Michael Victor Zink Date: Tue, 4 Feb 2025 11:45:51 -0800 Subject: [PATCH] Parse SET NAMES syntax --- src/ast/mod.rs | 7 ++----- src/dialect/generic.rs | 4 ++++ src/dialect/mod.rs | 10 ++++++++++ src/dialect/mysql.rs | 4 ++++ src/dialect/postgresql.rs | 4 ++++ src/parser/mod.rs | 8 ++++---- tests/sqlparser_common.rs | 8 ++++++++ tests/sqlparser_mysql.rs | 6 +++--- 8 files changed, 39 insertions(+), 12 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 72be3ff6c..554ec19b7 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -2956,10 +2956,8 @@ pub enum Statement { /// ```sql /// SET NAMES 'charset_name' [COLLATE 'collation_name'] /// ``` - /// - /// Note: this is a MySQL-specific statement. SetNames { - charset_name: String, + charset_name: Ident, collation_name: Option, }, /// ```sql @@ -4684,8 +4682,7 @@ impl fmt::Display for Statement { charset_name, collation_name, } => { - f.write_str("SET NAMES ")?; - f.write_str(charset_name)?; + write!(f, "SET NAMES {}", charset_name)?; if let Some(collation) = collation_name { f.write_str(" COLLATE ")?; diff --git a/src/dialect/generic.rs b/src/dialect/generic.rs index 041d44bb2..c13d5aa69 100644 --- a/src/dialect/generic.rs +++ b/src/dialect/generic.rs @@ -155,4 +155,8 @@ impl Dialect for GenericDialect { fn supports_match_against(&self) -> bool { true } + + fn supports_set_names(&self) -> bool { + true + } } diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 1c32bc513..a8c981d9b 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -953,6 +953,16 @@ pub trait Dialect: Debug + Any { fn supports_order_by_all(&self) -> bool { false } + + /// Returns true if the dialect supports `SET NAMES [COLLATE ]`. + /// + /// - [MySQL](https://dev.mysql.com/doc/refman/8.4/en/set-names.html) + /// - [Postgres](https://www.postgresql.org/docs/17/sql-set.html) + /// + /// Note: Postgres doesn't support the `COLLATE` clause, but we permissively parse it anyway. + fn supports_set_names(&self) -> bool { + false + } } /// This represents the operators for which precedence must be defined diff --git a/src/dialect/mysql.rs b/src/dialect/mysql.rs index 8a0da87e4..4465dcc11 100644 --- a/src/dialect/mysql.rs +++ b/src/dialect/mysql.rs @@ -133,6 +133,10 @@ impl Dialect for MySqlDialect { fn supports_match_against(&self) -> bool { true } + + fn supports_set_names(&self) -> bool { + true + } } /// `LOCK TABLES` diff --git a/src/dialect/postgresql.rs b/src/dialect/postgresql.rs index 57ed0b684..9b08b8f32 100644 --- a/src/dialect/postgresql.rs +++ b/src/dialect/postgresql.rs @@ -254,4 +254,8 @@ impl Dialect for PostgreSqlDialect { fn supports_geometric_types(&self) -> bool { true } + + fn supports_set_names(&self) -> bool { + true + } } diff --git a/src/parser/mod.rs b/src/parser/mod.rs index f234fcc07..b11e57791 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -10962,14 +10962,14 @@ impl<'a> Parser<'a> { OneOrManyWithParens::One(self.parse_object_name(false)?) }; - if matches!(&variables, OneOrManyWithParens::One(variable) if variable.to_string().eq_ignore_ascii_case("NAMES") - && dialect_of!(self is MySqlDialect | GenericDialect)) - { + let names = matches!(&variables, OneOrManyWithParens::One(variable) if variable.to_string().eq_ignore_ascii_case("NAMES")); + + if names && self.dialect.supports_set_names() { if self.parse_keyword(Keyword::DEFAULT) { return Ok(Statement::SetNamesDefault {}); } - let charset_name = self.parse_literal_string()?; + let charset_name = self.parse_identifier()?; let collation_name = if self.parse_one_of_keywords(&[Keyword::COLLATE]).is_some() { Some(self.parse_literal_string()?) } else { diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 0a68d31e8..e6451df64 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -14617,3 +14617,11 @@ fn parse_array_type_def_with_brackets() { dialects.verified_stmt("SELECT x::INT[]"); dialects.verified_stmt("SELECT STRING_TO_ARRAY('1,2,3', ',')::INT[3]"); } + +#[test] +fn parse_set_names() { + let dialects = all_dialects_where(|d| d.supports_set_names()); + dialects.verified_stmt("SET NAMES 'UTF8'"); + dialects.verified_stmt("SET NAMES 'utf8'"); + dialects.verified_stmt("SET NAMES UTF8 COLLATE bogus"); +} diff --git a/tests/sqlparser_mysql.rs b/tests/sqlparser_mysql.rs index 15f79b4c2..0007c8a4c 100644 --- a/tests/sqlparser_mysql.rs +++ b/tests/sqlparser_mysql.rs @@ -2685,7 +2685,7 @@ fn parse_set_names() { assert_eq!( stmt, Statement::SetNames { - charset_name: "utf8mb4".to_string(), + charset_name: "utf8mb4".into(), collation_name: None, } ); @@ -2694,7 +2694,7 @@ fn parse_set_names() { assert_eq!( stmt, Statement::SetNames { - charset_name: "utf8mb4".to_string(), + charset_name: "utf8mb4".into(), collation_name: Some("bogus".to_string()), } ); @@ -2705,7 +2705,7 @@ fn parse_set_names() { assert_eq!( stmt, vec![Statement::SetNames { - charset_name: "utf8mb4".to_string(), + charset_name: "utf8mb4".into(), collation_name: Some("bogus".to_string()), }] );