Skip to content

Commit 97e261c

Browse files
committed
Add support for MS-SQL BEGIN/END TRY/CATCH
1 parent 8cfc462 commit 97e261c

File tree

9 files changed

+112
-25
lines changed

9 files changed

+112
-25
lines changed

src/ast/helpers/stmt_create_table.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,11 @@ mod tests {
548548

549549
#[test]
550550
pub fn test_from_invalid_statement() {
551-
let stmt = Statement::Commit { chain: false };
551+
let stmt = Statement::Commit {
552+
chain: false,
553+
end: false,
554+
modifier: None,
555+
};
552556

553557
assert_eq!(
554558
CreateTableBuilder::try_from(stmt).unwrap_err(),

src/ast/mod.rs

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2958,7 +2958,6 @@ pub enum Statement {
29582958
modes: Vec<TransactionMode>,
29592959
begin: bool,
29602960
transaction: Option<BeginTransactionKind>,
2961-
/// Only for SQLite
29622961
modifier: Option<TransactionModifier>,
29632962
},
29642963
/// ```sql
@@ -2985,7 +2984,17 @@ pub enum Statement {
29852984
/// ```sql
29862985
/// COMMIT [ TRANSACTION | WORK ] [ AND [ NO ] CHAIN ]
29872986
/// ```
2988-
Commit { chain: bool },
2987+
/// If `end` is false
2988+
///
2989+
/// ```sql
2990+
/// END [ TRY | CATCH ]
2991+
/// ```
2992+
/// If `end` is true
2993+
Commit {
2994+
chain: bool,
2995+
end: bool,
2996+
modifier: Option<TransactionModifier>,
2997+
},
29892998
/// ```sql
29902999
/// ROLLBACK [ TRANSACTION | WORK ] [ AND [ NO ] CHAIN ] [ TO [ SAVEPOINT ] savepoint_name ]
29913000
/// ```
@@ -4614,8 +4623,23 @@ impl fmt::Display for Statement {
46144623
}
46154624
Ok(())
46164625
}
4617-
Statement::Commit { chain } => {
4618-
write!(f, "COMMIT{}", if *chain { " AND CHAIN" } else { "" },)
4626+
Statement::Commit {
4627+
chain,
4628+
end: end_syntax,
4629+
modifier,
4630+
} => {
4631+
if *end_syntax {
4632+
write!(f, "END")?;
4633+
if let Some(modifier) = *modifier {
4634+
write!(f, " {}", modifier)?;
4635+
}
4636+
if *chain {
4637+
write!(f, " AND CHAIN")?;
4638+
}
4639+
} else {
4640+
write!(f, "COMMIT{}", if *chain { " AND CHAIN" } else { "" })?;
4641+
}
4642+
Ok(())
46194643
}
46204644
Statement::Rollback { chain, savepoint } => {
46214645
write!(f, "ROLLBACK")?;
@@ -6388,16 +6412,19 @@ impl fmt::Display for TransactionIsolationLevel {
63886412
}
63896413
}
63906414

6391-
/// SQLite specific syntax
6415+
/// Modifier for the transaction in the `BEGIN` syntax
63926416
///
6393-
/// <https://sqlite.org/lang_transaction.html>
6417+
/// SQLite: <https://sqlite.org/lang_transaction.html>
6418+
/// MS-SQL: <https://learn.microsoft.com/en-us/sql/t-sql/language-elements/try-catch-transact-sql>
63946419
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
63956420
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
63966421
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
63976422
pub enum TransactionModifier {
63986423
Deferred,
63996424
Immediate,
64006425
Exclusive,
6426+
Try,
6427+
Catch,
64016428
}
64026429

64036430
impl fmt::Display for TransactionModifier {
@@ -6407,6 +6434,8 @@ impl fmt::Display for TransactionModifier {
64076434
Deferred => "DEFERRED",
64086435
Immediate => "IMMEDIATE",
64096436
Exclusive => "EXCLUSIVE",
6437+
Try => "TRY",
6438+
Catch => "CATCH",
64106439
})
64116440
}
64126441
}

src/dialect/mod.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,11 +260,16 @@ pub trait Dialect: Debug + Any {
260260
false
261261
}
262262

263-
/// Returns true if the dialect supports `BEGIN {DEFERRED | IMMEDIATE | EXCLUSIVE} [TRANSACTION]` statements
263+
/// Returns true if the dialect supports `BEGIN {DEFERRED | IMMEDIATE | EXCLUSIVE | TRY | CATCH} [TRANSACTION]` statements
264264
fn supports_start_transaction_modifier(&self) -> bool {
265265
false
266266
}
267267

268+
/// Returns true if the dialect supports `END {TRY | CATCH}` statements
269+
fn supports_end_transaction_modifier(&self) -> bool {
270+
false
271+
}
272+
268273
/// Returns true if the dialect supports named arguments of the form `FUN(a = '1', b = '2')`.
269274
fn supports_named_fn_args_with_eq_operator(&self) -> bool {
270275
false

src/dialect/mssql.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,11 @@ impl Dialect for MsSqlDialect {
7878
fn supports_named_fn_args_with_rarrow_operator(&self) -> bool {
7979
false
8080
}
81+
82+
fn supports_start_transaction_modifier(&self) -> bool {
83+
true
84+
}
85+
fn supports_end_transaction_modifier(&self) -> bool {
86+
true
87+
}
8188
}

src/keywords.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ define_keywords!(
151151
CASE,
152152
CAST,
153153
CATALOG,
154+
CATCH,
154155
CEIL,
155156
CEILING,
156157
CENTURY,
@@ -808,6 +809,7 @@ define_keywords!(
808809
TRIM_ARRAY,
809810
TRUE,
810811
TRUNCATE,
812+
TRY,
811813
TRY_CAST,
812814
TRY_CONVERT,
813815
TUPLE,

src/parser/mod.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12786,6 +12786,10 @@ impl<'a> Parser<'a> {
1278612786
Some(TransactionModifier::Immediate)
1278712787
} else if self.parse_keyword(Keyword::EXCLUSIVE) {
1278812788
Some(TransactionModifier::Exclusive)
12789+
} else if self.parse_keyword(Keyword::TRY) {
12790+
Some(TransactionModifier::Try)
12791+
} else if self.parse_keyword(Keyword::CATCH) {
12792+
Some(TransactionModifier::Catch)
1278912793
} else {
1279012794
None
1279112795
};
@@ -12803,8 +12807,19 @@ impl<'a> Parser<'a> {
1280312807
}
1280412808

1280512809
pub fn parse_end(&mut self) -> Result<Statement, ParserError> {
12810+
let modifier = if !self.dialect.supports_end_transaction_modifier() {
12811+
None
12812+
} else if self.parse_keyword(Keyword::TRY) {
12813+
Some(TransactionModifier::Try)
12814+
} else if self.parse_keyword(Keyword::CATCH) {
12815+
Some(TransactionModifier::Catch)
12816+
} else {
12817+
None
12818+
};
1280612819
Ok(Statement::Commit {
1280712820
chain: self.parse_commit_rollback_chain()?,
12821+
end: true,
12822+
modifier,
1280812823
})
1280912824
}
1281012825

@@ -12847,6 +12862,8 @@ impl<'a> Parser<'a> {
1284712862
pub fn parse_commit(&mut self) -> Result<Statement, ParserError> {
1284812863
Ok(Statement::Commit {
1284912864
chain: self.parse_commit_rollback_chain()?,
12865+
end: false,
12866+
modifier: None,
1285012867
})
1285112868
}
1285212869

tests/sqlparser_common.rs

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7854,6 +7854,27 @@ fn parse_start_transaction() {
78547854
ParserError::ParserError("Expected: transaction mode, found: EOF".to_string()),
78557855
res.unwrap_err()
78567856
);
7857+
7858+
// MS-SQL syntax
7859+
let dialects = all_dialects_where(|d| d.supports_start_transaction_modifier());
7860+
dialects.verified_stmt("BEGIN TRY");
7861+
dialects.verified_stmt("BEGIN CATCH");
7862+
7863+
let dialects = all_dialects_where(|d| {
7864+
d.supports_start_transaction_modifier() && d.supports_end_transaction_modifier()
7865+
});
7866+
dialects
7867+
.parse_sql_statements(
7868+
r#"
7869+
BEGIN TRY;
7870+
SELECT 1/0;
7871+
END TRY;
7872+
BEGIN CATCH;
7873+
EXECUTE foo;
7874+
END CATCH;
7875+
"#,
7876+
)
7877+
.unwrap();
78577878
}
78587879

78597880
#[test]
@@ -8069,12 +8090,12 @@ fn parse_set_time_zone_alias() {
80698090
#[test]
80708091
fn parse_commit() {
80718092
match verified_stmt("COMMIT") {
8072-
Statement::Commit { chain: false } => (),
8093+
Statement::Commit { chain: false, .. } => (),
80738094
_ => unreachable!(),
80748095
}
80758096

80768097
match verified_stmt("COMMIT AND CHAIN") {
8077-
Statement::Commit { chain: true } => (),
8098+
Statement::Commit { chain: true, .. } => (),
80788099
_ => unreachable!(),
80798100
}
80808101

@@ -8089,13 +8110,17 @@ fn parse_commit() {
80898110

80908111
#[test]
80918112
fn parse_end() {
8092-
one_statement_parses_to("END AND NO CHAIN", "COMMIT");
8093-
one_statement_parses_to("END WORK AND NO CHAIN", "COMMIT");
8094-
one_statement_parses_to("END TRANSACTION AND NO CHAIN", "COMMIT");
8095-
one_statement_parses_to("END WORK AND CHAIN", "COMMIT AND CHAIN");
8096-
one_statement_parses_to("END TRANSACTION AND CHAIN", "COMMIT AND CHAIN");
8097-
one_statement_parses_to("END WORK", "COMMIT");
8098-
one_statement_parses_to("END TRANSACTION", "COMMIT");
8113+
one_statement_parses_to("END AND NO CHAIN", "END");
8114+
one_statement_parses_to("END WORK AND NO CHAIN", "END");
8115+
one_statement_parses_to("END TRANSACTION AND NO CHAIN", "END");
8116+
one_statement_parses_to("END WORK AND CHAIN", "END AND CHAIN");
8117+
one_statement_parses_to("END TRANSACTION AND CHAIN", "END AND CHAIN");
8118+
one_statement_parses_to("END WORK", "END");
8119+
one_statement_parses_to("END TRANSACTION", "END");
8120+
// MS-SQL syntax
8121+
let dialects = all_dialects_where(|d| d.supports_end_transaction_modifier());
8122+
dialects.verified_stmt("END TRY");
8123+
dialects.verified_stmt("END CATCH");
80998124
}
81008125

81018126
#[test]

tests/sqlparser_custom_dialect.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,11 @@ fn custom_statement_parser() -> Result<(), ParserError> {
115115
for _ in 0..3 {
116116
let _ = parser.next_token();
117117
}
118-
Some(Ok(Statement::Commit { chain: false }))
118+
Some(Ok(Statement::Commit {
119+
chain: false,
120+
end: false,
121+
modifier: None,
122+
}))
119123
} else {
120124
None
121125
}

tests/sqlparser_sqlite.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -523,13 +523,7 @@ fn parse_start_transaction_with_modifier() {
523523
sqlite_and_generic().verified_stmt("BEGIN IMMEDIATE");
524524
sqlite_and_generic().verified_stmt("BEGIN EXCLUSIVE");
525525

526-
let unsupported_dialects = TestedDialects::new(
527-
all_dialects()
528-
.dialects
529-
.into_iter()
530-
.filter(|x| !(x.is::<SQLiteDialect>() || x.is::<GenericDialect>()))
531-
.collect(),
532-
);
526+
let unsupported_dialects = all_dialects_except(|d| d.supports_start_transaction_modifier());
533527
let res = unsupported_dialects.parse_sql_statements("BEGIN DEFERRED");
534528
assert_eq!(
535529
ParserError::ParserError("Expected: end of statement, found: DEFERRED".to_string()),

0 commit comments

Comments
 (0)