Skip to content

Commit 3357edc

Browse files
yoavcloudayman-sigma
authored andcommitted
Add support for MS-SQL BEGIN/END TRY/CATCH (apache#1649)
1 parent 20e60cc commit 3357edc

File tree

9 files changed

+112
-25
lines changed

9 files changed

+112
-25
lines changed

src/ast/helpers/stmt_create_table.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,11 @@ mod tests {
548548

549549
#[test]
550550
pub fn test_from_invalid_statement() {
551-
let stmt = Statement::Commit { chain: false };
551+
let stmt = Statement::Commit {
552+
chain: false,
553+
end: false,
554+
modifier: None,
555+
};
552556

553557
assert_eq!(
554558
CreateTableBuilder::try_from(stmt).unwrap_err(),

src/ast/mod.rs

+35-6
Original file line numberDiff line numberDiff line change
@@ -2976,7 +2976,6 @@ pub enum Statement {
29762976
modes: Vec<TransactionMode>,
29772977
begin: bool,
29782978
transaction: Option<BeginTransactionKind>,
2979-
/// Only for SQLite
29802979
modifier: Option<TransactionModifier>,
29812980
},
29822981
/// ```sql
@@ -3003,7 +3002,17 @@ pub enum Statement {
30033002
/// ```sql
30043003
/// COMMIT [ TRANSACTION | WORK ] [ AND [ NO ] CHAIN ]
30053004
/// ```
3006-
Commit { chain: bool },
3005+
/// If `end` is false
3006+
///
3007+
/// ```sql
3008+
/// END [ TRY | CATCH ]
3009+
/// ```
3010+
/// If `end` is true
3011+
Commit {
3012+
chain: bool,
3013+
end: bool,
3014+
modifier: Option<TransactionModifier>,
3015+
},
30073016
/// ```sql
30083017
/// ROLLBACK [ TRANSACTION | WORK ] [ AND [ NO ] CHAIN ] [ TO [ SAVEPOINT ] savepoint_name ]
30093018
/// ```
@@ -4632,8 +4641,23 @@ impl fmt::Display for Statement {
46324641
}
46334642
Ok(())
46344643
}
4635-
Statement::Commit { chain } => {
4636-
write!(f, "COMMIT{}", if *chain { " AND CHAIN" } else { "" },)
4644+
Statement::Commit {
4645+
chain,
4646+
end: end_syntax,
4647+
modifier,
4648+
} => {
4649+
if *end_syntax {
4650+
write!(f, "END")?;
4651+
if let Some(modifier) = *modifier {
4652+
write!(f, " {}", modifier)?;
4653+
}
4654+
if *chain {
4655+
write!(f, " AND CHAIN")?;
4656+
}
4657+
} else {
4658+
write!(f, "COMMIT{}", if *chain { " AND CHAIN" } else { "" })?;
4659+
}
4660+
Ok(())
46374661
}
46384662
Statement::Rollback { chain, savepoint } => {
46394663
write!(f, "ROLLBACK")?;
@@ -6406,16 +6430,19 @@ impl fmt::Display for TransactionIsolationLevel {
64066430
}
64076431
}
64086432

6409-
/// SQLite specific syntax
6433+
/// Modifier for the transaction in the `BEGIN` syntax
64106434
///
6411-
/// <https://sqlite.org/lang_transaction.html>
6435+
/// SQLite: <https://sqlite.org/lang_transaction.html>
6436+
/// MS-SQL: <https://learn.microsoft.com/en-us/sql/t-sql/language-elements/try-catch-transact-sql>
64126437
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
64136438
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
64146439
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
64156440
pub enum TransactionModifier {
64166441
Deferred,
64176442
Immediate,
64186443
Exclusive,
6444+
Try,
6445+
Catch,
64196446
}
64206447

64216448
impl fmt::Display for TransactionModifier {
@@ -6425,6 +6452,8 @@ impl fmt::Display for TransactionModifier {
64256452
Deferred => "DEFERRED",
64266453
Immediate => "IMMEDIATE",
64276454
Exclusive => "EXCLUSIVE",
6455+
Try => "TRY",
6456+
Catch => "CATCH",
64286457
})
64296458
}
64306459
}

src/dialect/mod.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -260,11 +260,16 @@ pub trait Dialect: Debug + Any {
260260
false
261261
}
262262

263-
/// Returns true if the dialect supports `BEGIN {DEFERRED | IMMEDIATE | EXCLUSIVE} [TRANSACTION]` statements
263+
/// Returns true if the dialect supports `BEGIN {DEFERRED | IMMEDIATE | EXCLUSIVE | TRY | CATCH} [TRANSACTION]` statements
264264
fn supports_start_transaction_modifier(&self) -> bool {
265265
false
266266
}
267267

268+
/// Returns true if the dialect supports `END {TRY | CATCH}` statements
269+
fn supports_end_transaction_modifier(&self) -> bool {
270+
false
271+
}
272+
268273
/// Returns true if the dialect supports named arguments of the form `FUN(a = '1', b = '2')`.
269274
fn supports_named_fn_args_with_eq_operator(&self) -> bool {
270275
false

src/dialect/mssql.rs

+7
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,11 @@ impl Dialect for MsSqlDialect {
7878
fn supports_named_fn_args_with_rarrow_operator(&self) -> bool {
7979
false
8080
}
81+
82+
fn supports_start_transaction_modifier(&self) -> bool {
83+
true
84+
}
85+
fn supports_end_transaction_modifier(&self) -> bool {
86+
true
87+
}
8188
}

src/keywords.rs

+2
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ define_keywords!(
151151
CASE,
152152
CAST,
153153
CATALOG,
154+
CATCH,
154155
CEIL,
155156
CEILING,
156157
CENTURY,
@@ -812,6 +813,7 @@ define_keywords!(
812813
TRIM_ARRAY,
813814
TRUE,
814815
TRUNCATE,
816+
TRY,
815817
TRY_CAST,
816818
TRY_CONVERT,
817819
TUPLE,

src/parser/mod.rs

+17
Original file line numberDiff line numberDiff line change
@@ -12810,6 +12810,10 @@ impl<'a> Parser<'a> {
1281012810
Some(TransactionModifier::Immediate)
1281112811
} else if self.parse_keyword(Keyword::EXCLUSIVE) {
1281212812
Some(TransactionModifier::Exclusive)
12813+
} else if self.parse_keyword(Keyword::TRY) {
12814+
Some(TransactionModifier::Try)
12815+
} else if self.parse_keyword(Keyword::CATCH) {
12816+
Some(TransactionModifier::Catch)
1281312817
} else {
1281412818
None
1281512819
};
@@ -12827,8 +12831,19 @@ impl<'a> Parser<'a> {
1282712831
}
1282812832

1282912833
pub fn parse_end(&mut self) -> Result<Statement, ParserError> {
12834+
let modifier = if !self.dialect.supports_end_transaction_modifier() {
12835+
None
12836+
} else if self.parse_keyword(Keyword::TRY) {
12837+
Some(TransactionModifier::Try)
12838+
} else if self.parse_keyword(Keyword::CATCH) {
12839+
Some(TransactionModifier::Catch)
12840+
} else {
12841+
None
12842+
};
1283012843
Ok(Statement::Commit {
1283112844
chain: self.parse_commit_rollback_chain()?,
12845+
end: true,
12846+
modifier,
1283212847
})
1283312848
}
1283412849

@@ -12871,6 +12886,8 @@ impl<'a> Parser<'a> {
1287112886
pub fn parse_commit(&mut self) -> Result<Statement, ParserError> {
1287212887
Ok(Statement::Commit {
1287312888
chain: self.parse_commit_rollback_chain()?,
12889+
end: false,
12890+
modifier: None,
1287412891
})
1287512892
}
1287612893

tests/sqlparser_common.rs

+34-9
Original file line numberDiff line numberDiff line change
@@ -7888,6 +7888,27 @@ fn parse_start_transaction() {
78887888
ParserError::ParserError("Expected: transaction mode, found: EOF".to_string()),
78897889
res.unwrap_err()
78907890
);
7891+
7892+
// MS-SQL syntax
7893+
let dialects = all_dialects_where(|d| d.supports_start_transaction_modifier());
7894+
dialects.verified_stmt("BEGIN TRY");
7895+
dialects.verified_stmt("BEGIN CATCH");
7896+
7897+
let dialects = all_dialects_where(|d| {
7898+
d.supports_start_transaction_modifier() && d.supports_end_transaction_modifier()
7899+
});
7900+
dialects
7901+
.parse_sql_statements(
7902+
r#"
7903+
BEGIN TRY;
7904+
SELECT 1/0;
7905+
END TRY;
7906+
BEGIN CATCH;
7907+
EXECUTE foo;
7908+
END CATCH;
7909+
"#,
7910+
)
7911+
.unwrap();
78917912
}
78927913

78937914
#[test]
@@ -8103,12 +8124,12 @@ fn parse_set_time_zone_alias() {
81038124
#[test]
81048125
fn parse_commit() {
81058126
match verified_stmt("COMMIT") {
8106-
Statement::Commit { chain: false } => (),
8127+
Statement::Commit { chain: false, .. } => (),
81078128
_ => unreachable!(),
81088129
}
81098130

81108131
match verified_stmt("COMMIT AND CHAIN") {
8111-
Statement::Commit { chain: true } => (),
8132+
Statement::Commit { chain: true, .. } => (),
81128133
_ => unreachable!(),
81138134
}
81148135

@@ -8123,13 +8144,17 @@ fn parse_commit() {
81238144

81248145
#[test]
81258146
fn parse_end() {
8126-
one_statement_parses_to("END AND NO CHAIN", "COMMIT");
8127-
one_statement_parses_to("END WORK AND NO CHAIN", "COMMIT");
8128-
one_statement_parses_to("END TRANSACTION AND NO CHAIN", "COMMIT");
8129-
one_statement_parses_to("END WORK AND CHAIN", "COMMIT AND CHAIN");
8130-
one_statement_parses_to("END TRANSACTION AND CHAIN", "COMMIT AND CHAIN");
8131-
one_statement_parses_to("END WORK", "COMMIT");
8132-
one_statement_parses_to("END TRANSACTION", "COMMIT");
8147+
one_statement_parses_to("END AND NO CHAIN", "END");
8148+
one_statement_parses_to("END WORK AND NO CHAIN", "END");
8149+
one_statement_parses_to("END TRANSACTION AND NO CHAIN", "END");
8150+
one_statement_parses_to("END WORK AND CHAIN", "END AND CHAIN");
8151+
one_statement_parses_to("END TRANSACTION AND CHAIN", "END AND CHAIN");
8152+
one_statement_parses_to("END WORK", "END");
8153+
one_statement_parses_to("END TRANSACTION", "END");
8154+
// MS-SQL syntax
8155+
let dialects = all_dialects_where(|d| d.supports_end_transaction_modifier());
8156+
dialects.verified_stmt("END TRY");
8157+
dialects.verified_stmt("END CATCH");
81338158
}
81348159

81358160
#[test]

tests/sqlparser_custom_dialect.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,11 @@ fn custom_statement_parser() -> Result<(), ParserError> {
115115
for _ in 0..3 {
116116
let _ = parser.next_token();
117117
}
118-
Some(Ok(Statement::Commit { chain: false }))
118+
Some(Ok(Statement::Commit {
119+
chain: false,
120+
end: false,
121+
modifier: None,
122+
}))
119123
} else {
120124
None
121125
}

tests/sqlparser_sqlite.rs

+1-7
Original file line numberDiff line numberDiff line change
@@ -523,13 +523,7 @@ fn parse_start_transaction_with_modifier() {
523523
sqlite_and_generic().verified_stmt("BEGIN IMMEDIATE");
524524
sqlite_and_generic().verified_stmt("BEGIN EXCLUSIVE");
525525

526-
let unsupported_dialects = TestedDialects::new(
527-
all_dialects()
528-
.dialects
529-
.into_iter()
530-
.filter(|x| !(x.is::<SQLiteDialect>() || x.is::<GenericDialect>()))
531-
.collect(),
532-
);
526+
let unsupported_dialects = all_dialects_except(|d| d.supports_start_transaction_modifier());
533527
let res = unsupported_dialects.parse_sql_statements("BEGIN DEFERRED");
534528
assert_eq!(
535529
ParserError::ParserError("Expected: end of statement, found: DEFERRED".to_string()),

0 commit comments

Comments
 (0)