From 0998398d9969329e3acc448848b113b55d31014a Mon Sep 17 00:00:00 2001 From: wugeer <1284057728@qq.com> Date: Mon, 18 Nov 2024 22:36:16 +0800 Subject: [PATCH 1/5] Add support for PostgreSQL `UNLISTEN` syntax --- src/ast/mod.rs | 11 +++++++++++ src/dialect/mod.rs | 5 +++++ src/dialect/postgresql.rs | 5 +++++ src/keywords.rs | 1 + src/parser/mod.rs | 18 +++++++++++++++++- tests/sqlparser_common.rs | 31 +++++++++++++++++++++++++++++++ 6 files changed, 70 insertions(+), 1 deletion(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 89e70bdd4..9185c9df4 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -3340,6 +3340,13 @@ pub enum Statement { /// See Postgres LISTEN { channel: Ident }, /// ```sql + /// UNLISTEN + /// ``` + /// stop listening for a notification + /// + /// See Postgres + UNLISTEN { channel: Ident }, + /// ```sql /// NOTIFY channel [ , payload ] /// ``` /// send a notification event together with an optional “payload” string to channel @@ -4948,6 +4955,10 @@ impl fmt::Display for Statement { write!(f, "LISTEN {channel}")?; Ok(()) } + Statement::UNLISTEN { channel } => { + write!(f, "UNLISTEN {channel}")?; + Ok(()) + } Statement::NOTIFY { channel, payload } => { write!(f, "NOTIFY {channel}")?; if let Some(payload) = payload { diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 39ea98c69..5758a9fbc 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -638,6 +638,11 @@ pub trait Dialect: Debug + Any { false } + /// Returns true if the dialect supports the `UNLISTEN` statement + fn supports_unlisten(&self) -> bool { + false + } + /// Returns true if the dialect supports the `NOTIFY` statement fn supports_notify(&self) -> bool { false diff --git a/src/dialect/postgresql.rs b/src/dialect/postgresql.rs index 5af1ab853..e95510d03 100644 --- a/src/dialect/postgresql.rs +++ b/src/dialect/postgresql.rs @@ -195,6 +195,11 @@ impl Dialect for PostgreSqlDialect { true } + /// see + fn supports_unlisten(&self) -> bool { + true + } + /// see fn supports_notify(&self) -> bool { true diff --git a/src/keywords.rs b/src/keywords.rs index 29115a0d2..fc2a2927c 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -799,6 +799,7 @@ define_keywords!( UNION, UNIQUE, UNKNOWN, + UNLISTEN, UNLOAD, UNLOCK, UNLOGGED, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 35ad95803..abd6a796b 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -532,9 +532,10 @@ impl<'a> Parser<'a> { Keyword::EXECUTE | Keyword::EXEC => self.parse_execute(), Keyword::PREPARE => self.parse_prepare(), Keyword::MERGE => self.parse_merge(), - // `LISTEN` and `NOTIFY` are Postgres-specific + // `LISTEN`, `UNLISTEN` and `NOTIFY` are Postgres-specific // syntaxes. They are used for Postgres statement. Keyword::LISTEN if self.dialect.supports_listen() => self.parse_listen(), + Keyword::UNLISTEN if self.dialect.supports_unlisten() => self.parse_unlisten(), Keyword::NOTIFY if self.dialect.supports_notify() => self.parse_notify(), // `PRAGMA` is sqlite specific https://www.sqlite.org/pragma.html Keyword::PRAGMA => self.parse_pragma(), @@ -999,6 +1000,21 @@ impl<'a> Parser<'a> { Ok(Statement::LISTEN { channel }) } + pub fn parse_unlisten(&mut self) -> Result { + let channel = if self.consume_token(&Token::Mul) { + Ident::new(Expr::Wildcard.to_string()) + } else { + match self.parse_identifier(false) { + Ok(expr) => expr, + _ => { + self.prev_token(); + return self.expected("wildcard or identent", self.peek_token()); + } + } + }; + Ok(Statement::UNLISTEN { channel }) + } + pub fn parse_notify(&mut self) -> Result { let channel = self.parse_identifier(false)?; let payload = if self.consume_token(&Token::Comma) { diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index ecdca6b1b..06c0435b8 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -11617,6 +11617,37 @@ fn parse_listen_channel() { ); } +#[test] +fn parse_unlisten_channel() { + let dialects = all_dialects_where(|d| d.supports_unlisten()); + + match dialects.verified_stmt("UNLISTEN test1") { + Statement::UNLISTEN { channel } => { + assert_eq!(Ident::new("test1"), channel); + } + _ => unreachable!(), + }; + + match dialects.verified_stmt("UNLISTEN *") { + Statement::UNLISTEN { channel } => { + assert_eq!(Ident::new("*"), channel); + } + _ => unreachable!(), + }; + + assert_eq!( + dialects.parse_sql_statements("UNLISTEN +").unwrap_err(), + ParserError::ParserError("Expected: wildcard or identent, found: +".to_string()) + ); + + let dialects = all_dialects_where(|d| !d.supports_listen()); + + assert_eq!( + dialects.parse_sql_statements("UNLISTEN test1").unwrap_err(), + ParserError::ParserError("Expected: an SQL statement, found: UNLISTEN".to_string()) + ); +} + #[test] fn parse_notify_channel() { let dialects = all_dialects_where(|d| d.supports_notify()); From a83c9dbd09a0867fbdffe5a0c5509e020aa5085c Mon Sep 17 00:00:00 2001 From: wugeer <1284057728@qq.com> Date: Mon, 18 Nov 2024 23:03:29 +0800 Subject: [PATCH 2/5] Add support for Postgres `LOAD extension` expr --- src/dialect/postgresql.rs | 5 ++++ tests/sqlparser_common.rs | 50 +++++++++++++++++++++++++++++++++++++++ tests/sqlparser_duckdb.rs | 14 ----------- 3 files changed, 55 insertions(+), 14 deletions(-) diff --git a/src/dialect/postgresql.rs b/src/dialect/postgresql.rs index e95510d03..009af32ca 100644 --- a/src/dialect/postgresql.rs +++ b/src/dialect/postgresql.rs @@ -214,6 +214,11 @@ impl Dialect for PostgreSqlDialect { fn supports_comment_on(&self) -> bool { true } + + /// See + fn supports_load_extension(&self) -> bool { + true + } } pub fn parse_create(parser: &mut Parser) -> Option> { diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 06c0435b8..2c8685ca2 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -11895,6 +11895,56 @@ fn parse_load_data() { ); } +#[test] +fn test_load_extension() { + let dialects = all_dialects_where(|d| d.supports_load_extension()); + let only_supports_load_data_dialects = + all_dialects_where(|d| !d.supports_load_extension() && d.supports_load_data()); + let not_supports_load_dialects = + all_dialects_where(|d| !d.supports_load_data() && !d.supports_load_extension()); + let sql = "LOAD my_extension"; + + match dialects.verified_stmt(sql) { + Statement::Load { extension_name } => { + assert_eq!(Ident::new("my_extension"), extension_name); + } + _ => unreachable!(), + }; + + assert_eq!( + only_supports_load_data_dialects + .parse_sql_statements(sql) + .unwrap_err(), + ParserError::ParserError( + "Expected: `DATA` or an extension name after `LOAD`, found: my_extension".to_string() + ) + ); + + assert_eq!( + not_supports_load_dialects + .parse_sql_statements(sql) + .unwrap_err(), + ParserError::ParserError( + "Expected: `DATA` or an extension name after `LOAD`, found: my_extension".to_string() + ) + ); + + let sql = "LOAD 'filename'"; + + match dialects.verified_stmt(sql) { + Statement::Load { extension_name } => { + assert_eq!( + Ident { + value: "filename".to_string(), + quote_style: Some('\'') + }, + extension_name + ); + } + _ => unreachable!(), + }; +} + #[test] fn test_select_top() { let dialects = all_dialects_where(|d| d.supports_top_before_distinct()); diff --git a/tests/sqlparser_duckdb.rs b/tests/sqlparser_duckdb.rs index d68f37713..a2db5c282 100644 --- a/tests/sqlparser_duckdb.rs +++ b/tests/sqlparser_duckdb.rs @@ -359,20 +359,6 @@ fn test_duckdb_install() { ); } -#[test] -fn test_duckdb_load_extension() { - let stmt = duckdb().verified_stmt("LOAD my_extension"); - assert_eq!( - Statement::Load { - extension_name: Ident { - value: "my_extension".to_string(), - quote_style: None - } - }, - stmt - ); -} - #[test] fn test_duckdb_struct_literal() { //struct literal syntax https://duckdb.org/docs/sql/data_types/struct#creating-structs From 137d61b9f4f827b299908fccf12b6655f2ef4434 Mon Sep 17 00:00:00 2001 From: wugeer <1284057728@qq.com> Date: Tue, 19 Nov 2024 12:45:20 +0800 Subject: [PATCH 3/5] Update src/parser/mod.rs good Co-authored-by: Ifeanyi Ubah --- src/parser/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/parser/mod.rs b/src/parser/mod.rs index abd6a796b..6cd4d391a 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -1008,7 +1008,7 @@ impl<'a> Parser<'a> { Ok(expr) => expr, _ => { self.prev_token(); - return self.expected("wildcard or identent", self.peek_token()); + return self.expected("wildcard or identifier", self.peek_token()); } } }; From fac5f6a0a12db56d2481e94c2d675711c84cba80 Mon Sep 17 00:00:00 2001 From: wugeer <1284057728@qq.com> Date: Tue, 19 Nov 2024 12:59:51 +0800 Subject: [PATCH 4/5] Remove redundant structures and use more descriptive error messages. --- tests/sqlparser_common.rs | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 2c8685ca2..3bdc7895e 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -11637,7 +11637,7 @@ fn parse_unlisten_channel() { assert_eq!( dialects.parse_sql_statements("UNLISTEN +").unwrap_err(), - ParserError::ParserError("Expected: wildcard or identent, found: +".to_string()) + ParserError::ParserError("Expected: wildcard or identifier, found: +".to_string()) ); let dialects = all_dialects_where(|d| !d.supports_listen()); @@ -11898,10 +11898,7 @@ fn parse_load_data() { #[test] fn test_load_extension() { let dialects = all_dialects_where(|d| d.supports_load_extension()); - let only_supports_load_data_dialects = - all_dialects_where(|d| !d.supports_load_extension() && d.supports_load_data()); - let not_supports_load_dialects = - all_dialects_where(|d| !d.supports_load_data() && !d.supports_load_extension()); + let not_supports_load_extension_dialects = all_dialects_where(|d| !d.supports_load_extension()); let sql = "LOAD my_extension"; match dialects.verified_stmt(sql) { @@ -11912,16 +11909,7 @@ fn test_load_extension() { }; assert_eq!( - only_supports_load_data_dialects - .parse_sql_statements(sql) - .unwrap_err(), - ParserError::ParserError( - "Expected: `DATA` or an extension name after `LOAD`, found: my_extension".to_string() - ) - ); - - assert_eq!( - not_supports_load_dialects + not_supports_load_extension_dialects .parse_sql_statements(sql) .unwrap_err(), ParserError::ParserError( From d97426e85b6b6d59bf2102c81957adb1851764a3 Mon Sep 17 00:00:00 2001 From: wugeer <1284057728@qq.com> Date: Tue, 19 Nov 2024 21:52:58 +0800 Subject: [PATCH 5/5] use `supports_listen_notify` to both support `listen`, `unlisten` and `notify` statements --- src/dialect/mod.rs | 14 ++------------ src/dialect/postgresql.rs | 10 +--------- src/parser/mod.rs | 6 +++--- tests/sqlparser_common.rs | 12 ++++++------ 4 files changed, 12 insertions(+), 30 deletions(-) diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 5758a9fbc..985cad749 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -633,18 +633,8 @@ pub trait Dialect: Debug + Any { false } - /// Returns true if the dialect supports the `LISTEN` statement - fn supports_listen(&self) -> bool { - false - } - - /// Returns true if the dialect supports the `UNLISTEN` statement - fn supports_unlisten(&self) -> bool { - false - } - - /// Returns true if the dialect supports the `NOTIFY` statement - fn supports_notify(&self) -> bool { + /// Returns true if the dialect supports the `LISTEN`, `UNLISTEN` and `NOTIFY` statements + fn supports_listen_notify(&self) -> bool { false } diff --git a/src/dialect/postgresql.rs b/src/dialect/postgresql.rs index 009af32ca..559586e3f 100644 --- a/src/dialect/postgresql.rs +++ b/src/dialect/postgresql.rs @@ -191,17 +191,9 @@ impl Dialect for PostgreSqlDialect { } /// see - fn supports_listen(&self) -> bool { - true - } - /// see - fn supports_unlisten(&self) -> bool { - true - } - /// see - fn supports_notify(&self) -> bool { + fn supports_listen_notify(&self) -> bool { true } diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 6cd4d391a..35c763e93 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -534,9 +534,9 @@ impl<'a> Parser<'a> { Keyword::MERGE => self.parse_merge(), // `LISTEN`, `UNLISTEN` and `NOTIFY` are Postgres-specific // syntaxes. They are used for Postgres statement. - Keyword::LISTEN if self.dialect.supports_listen() => self.parse_listen(), - Keyword::UNLISTEN if self.dialect.supports_unlisten() => self.parse_unlisten(), - Keyword::NOTIFY if self.dialect.supports_notify() => self.parse_notify(), + Keyword::LISTEN if self.dialect.supports_listen_notify() => self.parse_listen(), + Keyword::UNLISTEN if self.dialect.supports_listen_notify() => self.parse_unlisten(), + Keyword::NOTIFY if self.dialect.supports_listen_notify() => self.parse_notify(), // `PRAGMA` is sqlite specific https://www.sqlite.org/pragma.html Keyword::PRAGMA => self.parse_pragma(), Keyword::UNLOAD => self.parse_unload(), diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 3bdc7895e..3d9ba5da2 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -11595,7 +11595,7 @@ fn test_show_dbs_schemas_tables_views() { #[test] fn parse_listen_channel() { - let dialects = all_dialects_where(|d| d.supports_listen()); + let dialects = all_dialects_where(|d| d.supports_listen_notify()); match dialects.verified_stmt("LISTEN test1") { Statement::LISTEN { channel } => { @@ -11609,7 +11609,7 @@ fn parse_listen_channel() { ParserError::ParserError("Expected: identifier, found: *".to_string()) ); - let dialects = all_dialects_where(|d| !d.supports_listen()); + let dialects = all_dialects_where(|d| !d.supports_listen_notify()); assert_eq!( dialects.parse_sql_statements("LISTEN test1").unwrap_err(), @@ -11619,7 +11619,7 @@ fn parse_listen_channel() { #[test] fn parse_unlisten_channel() { - let dialects = all_dialects_where(|d| d.supports_unlisten()); + let dialects = all_dialects_where(|d| d.supports_listen_notify()); match dialects.verified_stmt("UNLISTEN test1") { Statement::UNLISTEN { channel } => { @@ -11640,7 +11640,7 @@ fn parse_unlisten_channel() { ParserError::ParserError("Expected: wildcard or identifier, found: +".to_string()) ); - let dialects = all_dialects_where(|d| !d.supports_listen()); + let dialects = all_dialects_where(|d| !d.supports_listen_notify()); assert_eq!( dialects.parse_sql_statements("UNLISTEN test1").unwrap_err(), @@ -11650,7 +11650,7 @@ fn parse_unlisten_channel() { #[test] fn parse_notify_channel() { - let dialects = all_dialects_where(|d| d.supports_notify()); + let dialects = all_dialects_where(|d| d.supports_listen_notify()); match dialects.verified_stmt("NOTIFY test1") { Statement::NOTIFY { channel, payload } => { @@ -11686,7 +11686,7 @@ fn parse_notify_channel() { "NOTIFY test1", "NOTIFY test1, 'this is a test notification'", ]; - let dialects = all_dialects_where(|d| !d.supports_notify()); + let dialects = all_dialects_where(|d| !d.supports_listen_notify()); for &sql in &sql_statements { assert_eq!(