Skip to content

Commit 73947a5

Browse files
wugeeriffyio
andauthored
Add support for PostgreSQL UNLISTEN syntax and Add support for Postgres LOAD extension expr (#1531)
Co-authored-by: Ifeanyi Ubah <[email protected]>
1 parent 92be237 commit 73947a5

File tree

7 files changed

+113
-33
lines changed

7 files changed

+113
-33
lines changed

src/ast/mod.rs

+11
Original file line numberDiff line numberDiff line change
@@ -3340,6 +3340,13 @@ pub enum Statement {
33403340
/// See Postgres <https://www.postgresql.org/docs/current/sql-listen.html>
33413341
LISTEN { channel: Ident },
33423342
/// ```sql
3343+
/// UNLISTEN
3344+
/// ```
3345+
/// stop listening for a notification
3346+
///
3347+
/// See Postgres <https://www.postgresql.org/docs/current/sql-unlisten.html>
3348+
UNLISTEN { channel: Ident },
3349+
/// ```sql
33433350
/// NOTIFY channel [ , payload ]
33443351
/// ```
33453352
/// send a notification event together with an optional “payload” string to channel
@@ -4948,6 +4955,10 @@ impl fmt::Display for Statement {
49484955
write!(f, "LISTEN {channel}")?;
49494956
Ok(())
49504957
}
4958+
Statement::UNLISTEN { channel } => {
4959+
write!(f, "UNLISTEN {channel}")?;
4960+
Ok(())
4961+
}
49514962
Statement::NOTIFY { channel, payload } => {
49524963
write!(f, "NOTIFY {channel}")?;
49534964
if let Some(payload) = payload {

src/dialect/mod.rs

+2-7
Original file line numberDiff line numberDiff line change
@@ -633,13 +633,8 @@ pub trait Dialect: Debug + Any {
633633
false
634634
}
635635

636-
/// Returns true if the dialect supports the `LISTEN` statement
637-
fn supports_listen(&self) -> bool {
638-
false
639-
}
640-
641-
/// Returns true if the dialect supports the `NOTIFY` statement
642-
fn supports_notify(&self) -> bool {
636+
/// Returns true if the dialect supports the `LISTEN`, `UNLISTEN` and `NOTIFY` statements
637+
fn supports_listen_notify(&self) -> bool {
643638
false
644639
}
645640

src/dialect/postgresql.rs

+7-5
Original file line numberDiff line numberDiff line change
@@ -191,12 +191,9 @@ impl Dialect for PostgreSqlDialect {
191191
}
192192

193193
/// see <https://www.postgresql.org/docs/current/sql-listen.html>
194-
fn supports_listen(&self) -> bool {
195-
true
196-
}
197-
194+
/// see <https://www.postgresql.org/docs/current/sql-unlisten.html>
198195
/// see <https://www.postgresql.org/docs/current/sql-notify.html>
199-
fn supports_notify(&self) -> bool {
196+
fn supports_listen_notify(&self) -> bool {
200197
true
201198
}
202199

@@ -209,6 +206,11 @@ impl Dialect for PostgreSqlDialect {
209206
fn supports_comment_on(&self) -> bool {
210207
true
211208
}
209+
210+
/// See <https://www.postgresql.org/docs/current/sql-load.html>
211+
fn supports_load_extension(&self) -> bool {
212+
true
213+
}
212214
}
213215

214216
pub fn parse_create(parser: &mut Parser) -> Option<Result<Statement, ParserError>> {

src/keywords.rs

+1
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,7 @@ define_keywords!(
799799
UNION,
800800
UNIQUE,
801801
UNKNOWN,
802+
UNLISTEN,
802803
UNLOAD,
803804
UNLOCK,
804805
UNLOGGED,

src/parser/mod.rs

+19-3
Original file line numberDiff line numberDiff line change
@@ -532,10 +532,11 @@ impl<'a> Parser<'a> {
532532
Keyword::EXECUTE | Keyword::EXEC => self.parse_execute(),
533533
Keyword::PREPARE => self.parse_prepare(),
534534
Keyword::MERGE => self.parse_merge(),
535-
// `LISTEN` and `NOTIFY` are Postgres-specific
535+
// `LISTEN`, `UNLISTEN` and `NOTIFY` are Postgres-specific
536536
// syntaxes. They are used for Postgres statement.
537-
Keyword::LISTEN if self.dialect.supports_listen() => self.parse_listen(),
538-
Keyword::NOTIFY if self.dialect.supports_notify() => self.parse_notify(),
537+
Keyword::LISTEN if self.dialect.supports_listen_notify() => self.parse_listen(),
538+
Keyword::UNLISTEN if self.dialect.supports_listen_notify() => self.parse_unlisten(),
539+
Keyword::NOTIFY if self.dialect.supports_listen_notify() => self.parse_notify(),
539540
// `PRAGMA` is sqlite specific https://www.sqlite.org/pragma.html
540541
Keyword::PRAGMA => self.parse_pragma(),
541542
Keyword::UNLOAD => self.parse_unload(),
@@ -999,6 +1000,21 @@ impl<'a> Parser<'a> {
9991000
Ok(Statement::LISTEN { channel })
10001001
}
10011002

1003+
pub fn parse_unlisten(&mut self) -> Result<Statement, ParserError> {
1004+
let channel = if self.consume_token(&Token::Mul) {
1005+
Ident::new(Expr::Wildcard.to_string())
1006+
} else {
1007+
match self.parse_identifier(false) {
1008+
Ok(expr) => expr,
1009+
_ => {
1010+
self.prev_token();
1011+
return self.expected("wildcard or identifier", self.peek_token());
1012+
}
1013+
}
1014+
};
1015+
Ok(Statement::UNLISTEN { channel })
1016+
}
1017+
10021018
pub fn parse_notify(&mut self) -> Result<Statement, ParserError> {
10031019
let channel = self.parse_identifier(false)?;
10041020
let payload = if self.consume_token(&Token::Comma) {

tests/sqlparser_common.rs

+73-4
Original file line numberDiff line numberDiff line change
@@ -11595,7 +11595,7 @@ fn test_show_dbs_schemas_tables_views() {
1159511595

1159611596
#[test]
1159711597
fn parse_listen_channel() {
11598-
let dialects = all_dialects_where(|d| d.supports_listen());
11598+
let dialects = all_dialects_where(|d| d.supports_listen_notify());
1159911599

1160011600
match dialects.verified_stmt("LISTEN test1") {
1160111601
Statement::LISTEN { channel } => {
@@ -11609,17 +11609,48 @@ fn parse_listen_channel() {
1160911609
ParserError::ParserError("Expected: identifier, found: *".to_string())
1161011610
);
1161111611

11612-
let dialects = all_dialects_where(|d| !d.supports_listen());
11612+
let dialects = all_dialects_where(|d| !d.supports_listen_notify());
1161311613

1161411614
assert_eq!(
1161511615
dialects.parse_sql_statements("LISTEN test1").unwrap_err(),
1161611616
ParserError::ParserError("Expected: an SQL statement, found: LISTEN".to_string())
1161711617
);
1161811618
}
1161911619

11620+
#[test]
11621+
fn parse_unlisten_channel() {
11622+
let dialects = all_dialects_where(|d| d.supports_listen_notify());
11623+
11624+
match dialects.verified_stmt("UNLISTEN test1") {
11625+
Statement::UNLISTEN { channel } => {
11626+
assert_eq!(Ident::new("test1"), channel);
11627+
}
11628+
_ => unreachable!(),
11629+
};
11630+
11631+
match dialects.verified_stmt("UNLISTEN *") {
11632+
Statement::UNLISTEN { channel } => {
11633+
assert_eq!(Ident::new("*"), channel);
11634+
}
11635+
_ => unreachable!(),
11636+
};
11637+
11638+
assert_eq!(
11639+
dialects.parse_sql_statements("UNLISTEN +").unwrap_err(),
11640+
ParserError::ParserError("Expected: wildcard or identifier, found: +".to_string())
11641+
);
11642+
11643+
let dialects = all_dialects_where(|d| !d.supports_listen_notify());
11644+
11645+
assert_eq!(
11646+
dialects.parse_sql_statements("UNLISTEN test1").unwrap_err(),
11647+
ParserError::ParserError("Expected: an SQL statement, found: UNLISTEN".to_string())
11648+
);
11649+
}
11650+
1162011651
#[test]
1162111652
fn parse_notify_channel() {
11622-
let dialects = all_dialects_where(|d| d.supports_notify());
11653+
let dialects = all_dialects_where(|d| d.supports_listen_notify());
1162311654

1162411655
match dialects.verified_stmt("NOTIFY test1") {
1162511656
Statement::NOTIFY { channel, payload } => {
@@ -11655,7 +11686,7 @@ fn parse_notify_channel() {
1165511686
"NOTIFY test1",
1165611687
"NOTIFY test1, 'this is a test notification'",
1165711688
];
11658-
let dialects = all_dialects_where(|d| !d.supports_notify());
11689+
let dialects = all_dialects_where(|d| !d.supports_listen_notify());
1165911690

1166011691
for &sql in &sql_statements {
1166111692
assert_eq!(
@@ -11864,6 +11895,44 @@ fn parse_load_data() {
1186411895
);
1186511896
}
1186611897

11898+
#[test]
11899+
fn test_load_extension() {
11900+
let dialects = all_dialects_where(|d| d.supports_load_extension());
11901+
let not_supports_load_extension_dialects = all_dialects_where(|d| !d.supports_load_extension());
11902+
let sql = "LOAD my_extension";
11903+
11904+
match dialects.verified_stmt(sql) {
11905+
Statement::Load { extension_name } => {
11906+
assert_eq!(Ident::new("my_extension"), extension_name);
11907+
}
11908+
_ => unreachable!(),
11909+
};
11910+
11911+
assert_eq!(
11912+
not_supports_load_extension_dialects
11913+
.parse_sql_statements(sql)
11914+
.unwrap_err(),
11915+
ParserError::ParserError(
11916+
"Expected: `DATA` or an extension name after `LOAD`, found: my_extension".to_string()
11917+
)
11918+
);
11919+
11920+
let sql = "LOAD 'filename'";
11921+
11922+
match dialects.verified_stmt(sql) {
11923+
Statement::Load { extension_name } => {
11924+
assert_eq!(
11925+
Ident {
11926+
value: "filename".to_string(),
11927+
quote_style: Some('\'')
11928+
},
11929+
extension_name
11930+
);
11931+
}
11932+
_ => unreachable!(),
11933+
};
11934+
}
11935+
1186711936
#[test]
1186811937
fn test_select_top() {
1186911938
let dialects = all_dialects_where(|d| d.supports_top_before_distinct());

tests/sqlparser_duckdb.rs

-14
Original file line numberDiff line numberDiff line change
@@ -359,20 +359,6 @@ fn test_duckdb_install() {
359359
);
360360
}
361361

362-
#[test]
363-
fn test_duckdb_load_extension() {
364-
let stmt = duckdb().verified_stmt("LOAD my_extension");
365-
assert_eq!(
366-
Statement::Load {
367-
extension_name: Ident {
368-
value: "my_extension".to_string(),
369-
quote_style: None
370-
}
371-
},
372-
stmt
373-
);
374-
}
375-
376362
#[test]
377363
fn test_duckdb_struct_literal() {
378364
//struct literal syntax https://duckdb.org/docs/sql/data_types/struct#creating-structs

0 commit comments

Comments
 (0)