Skip to content

Commit e582607

Browse files
committed
Refactor BeginEndStatements into a reusable struct, then use for functions
- this lets us disacard the former (unfortunate, bespoke) multi statement parsing because we can just use `parse_statement_list` - however, `parse_statement_list` also needed a small change to allow subsequent statements to come after the final `END`
1 parent 75bd710 commit e582607

File tree

6 files changed

+231
-141
lines changed

6 files changed

+231
-141
lines changed

src/ast/ddl.rs

+2-4
Original file line numberDiff line numberDiff line change
@@ -2277,11 +2277,9 @@ impl fmt::Display for CreateFunction {
22772277
if let Some(CreateFunctionBody::AsAfterOptions(function_body)) = &self.function_body {
22782278
write!(f, " AS {function_body}")?;
22792279
}
2280-
if let Some(CreateFunctionBody::MultiStatement(statements)) = &self.function_body {
2280+
if let Some(CreateFunctionBody::AsBeginEnd(bes)) = &self.function_body {
22812281
write!(f, " AS")?;
2282-
write!(f, " BEGIN")?;
2283-
write!(f, " {}", display_separated(statements, "; "))?;
2284-
write!(f, " END")?;
2282+
write!(f, " {}", bes)?;
22852283
}
22862284
Ok(())
22872285
}

src/ast/mod.rs

+30-12
Original file line numberDiff line numberDiff line change
@@ -2292,18 +2292,14 @@ pub enum ConditionalStatements {
22922292
/// SELECT 1; SELECT 2; SELECT 3; ...
22932293
Sequence { statements: Vec<Statement> },
22942294
/// BEGIN SELECT 1; SELECT 2; SELECT 3; ... END
2295-
BeginEnd {
2296-
begin_token: AttachedToken,
2297-
statements: Vec<Statement>,
2298-
end_token: AttachedToken,
2299-
},
2295+
BeginEnd(BeginEndStatements),
23002296
}
23012297

23022298
impl ConditionalStatements {
23032299
pub fn statements(&self) -> &Vec<Statement> {
23042300
match self {
23052301
ConditionalStatements::Sequence { statements } => statements,
2306-
ConditionalStatements::BeginEnd { statements, .. } => statements,
2302+
ConditionalStatements::BeginEnd(bes) => &bes.statements,
23072303
}
23082304
}
23092305
}
@@ -2317,12 +2313,34 @@ impl fmt::Display for ConditionalStatements {
23172313
}
23182314
Ok(())
23192315
}
2320-
ConditionalStatements::BeginEnd { statements, .. } => {
2321-
write!(f, "BEGIN ")?;
2322-
format_statement_list(f, statements)?;
2323-
write!(f, " END")
2324-
}
2316+
ConditionalStatements::BeginEnd(bes) => write!(f, "{}", bes),
2317+
}
2318+
}
2319+
}
2320+
2321+
/// A shared representation of `BEGIN`, multiple statements, and `END` tokens.
2322+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
2323+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
2324+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
2325+
pub struct BeginEndStatements {
2326+
pub begin_token: AttachedToken,
2327+
pub statements: Vec<Statement>,
2328+
pub end_token: AttachedToken,
2329+
}
2330+
2331+
impl fmt::Display for BeginEndStatements {
2332+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2333+
let BeginEndStatements {
2334+
begin_token: AttachedToken(begin_token),
2335+
statements,
2336+
end_token: AttachedToken(end_token),
2337+
} = self;
2338+
2339+
write!(f, "{begin_token} ")?;
2340+
if !statements.is_empty() {
2341+
format_statement_list(f, statements)?;
23252342
}
2343+
write!(f, " {end_token}")
23262344
}
23272345
}
23282346

@@ -8399,7 +8417,7 @@ pub enum CreateFunctionBody {
83998417
/// ```
84008418
///
84018419
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
8402-
MultiStatement(Vec<Statement>),
8420+
AsBeginEnd(BeginEndStatements),
84038421
/// Function body expression using the 'RETURN' keyword.
84048422
///
84058423
/// Example:

src/ast/spans.rs

+3-5
Original file line numberDiff line numberDiff line change
@@ -778,11 +778,9 @@ impl Spanned for ConditionalStatements {
778778
ConditionalStatements::Sequence { statements } => {
779779
union_spans(statements.iter().map(|s| s.span()))
780780
}
781-
ConditionalStatements::BeginEnd {
782-
begin_token: AttachedToken(start),
783-
statements: _,
784-
end_token: AttachedToken(end),
785-
} => union_spans([start.span, end.span].into_iter()),
781+
ConditionalStatements::BeginEnd(bes) => {
782+
union_spans([bes.begin_token.0.span, bes.end_token.0.span].into_iter())
783+
}
786784
}
787785
}
788786
}

src/dialect/mssql.rs

+7-5
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
// under the License.
1717

1818
use crate::ast::helpers::attached_token::AttachedToken;
19-
use crate::ast::{ConditionalStatementBlock, ConditionalStatements, IfStatement, Statement};
19+
use crate::ast::{
20+
BeginEndStatements, ConditionalStatementBlock, ConditionalStatements, IfStatement, Statement,
21+
};
2022
use crate::dialect::Dialect;
2123
use crate::keywords::{self, Keyword};
2224
use crate::parser::{Parser, ParserError};
@@ -149,11 +151,11 @@ impl MsSqlDialect {
149151
start_token: AttachedToken(if_token),
150152
condition: Some(condition),
151153
then_token: None,
152-
conditional_statements: ConditionalStatements::BeginEnd {
154+
conditional_statements: ConditionalStatements::BeginEnd(BeginEndStatements {
153155
begin_token: AttachedToken(begin_token),
154156
statements,
155157
end_token: AttachedToken(end_token),
156-
},
158+
}),
157159
}
158160
} else {
159161
let stmt = parser.parse_statement()?;
@@ -182,11 +184,11 @@ impl MsSqlDialect {
182184
start_token: AttachedToken(else_token),
183185
condition: None,
184186
then_token: None,
185-
conditional_statements: ConditionalStatements::BeginEnd {
187+
conditional_statements: ConditionalStatements::BeginEnd(BeginEndStatements {
186188
begin_token: AttachedToken(begin_token),
187189
statements,
188190
end_token: AttachedToken(end_token),
189-
},
191+
}),
190192
});
191193
} else {
192194
let stmt = parser.parse_statement()?;

src/parser/mod.rs

+25-21
Original file line numberDiff line numberDiff line change
@@ -485,11 +485,11 @@ impl<'a> Parser<'a> {
485485

486486
let statement = self.parse_statement()?;
487487
expecting_statement_delimiter = match &statement {
488-
Statement::If(s) => match s.if_block.conditional_statements {
489-
ConditionalStatements::BeginEnd { .. } => false,
490-
_ => true,
491-
},
492-
_ => true
488+
Statement::If(s) => !matches!(
489+
s.if_block.conditional_statements,
490+
ConditionalStatements::BeginEnd { .. }
491+
),
492+
_ => true,
493493
};
494494
stmts.push(statement);
495495
}
@@ -4460,9 +4460,17 @@ impl<'a> Parser<'a> {
44604460
break;
44614461
}
44624462
}
4463-
44644463
values.push(self.parse_statement()?);
4465-
self.expect_token(&Token::SemiColon)?;
4464+
4465+
let semi_colon_expected = match values.last() {
4466+
Some(Statement::If(if_statement)) => if_statement.end_token.is_some(),
4467+
Some(_) => true,
4468+
None => false,
4469+
};
4470+
4471+
if semi_colon_expected {
4472+
self.expect_token(&Token::SemiColon)?;
4473+
}
44664474
}
44674475
Ok(values)
44684476
}
@@ -5175,20 +5183,16 @@ impl<'a> Parser<'a> {
51755183
};
51765184

51775185
self.expect_keyword_is(Keyword::AS)?;
5178-
self.expect_keyword_is(Keyword::BEGIN)?;
5179-
let mut result = self.parse_statements()?;
5180-
// note: `parse_statements` will consume the `END` token & produce a Commit statement...
5181-
if let Some(Statement::Commit {
5182-
chain,
5183-
end,
5184-
modifier,
5185-
}) = result.last()
5186-
{
5187-
if *chain == false && *end == true && *modifier == None {
5188-
result = result[..result.len() - 1].to_vec();
5189-
}
5190-
}
5191-
let function_body = Some(CreateFunctionBody::MultiStatement(result));
5186+
5187+
let begin_token = self.expect_keyword(Keyword::BEGIN)?;
5188+
let statements = self.parse_statement_list(&[Keyword::END])?;
5189+
let end_token = self.expect_keyword(Keyword::END)?;
5190+
5191+
let function_body = Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements {
5192+
begin_token: AttachedToken(begin_token),
5193+
statements,
5194+
end_token: AttachedToken(end_token),
5195+
}));
51925196

51935197
Ok(Statement::CreateFunction(CreateFunction {
51945198
or_alter,

0 commit comments

Comments
 (0)