Skip to content

Commit 1b50dd9

Browse files
committed
Fix conditionals in multi statement functions
- when a multi statement block concludes (eg, `BEGIN`..`END`), that last `END` means we should *not* be expecting a statement delimiter
1 parent 618eb4d commit 1b50dd9

File tree

2 files changed

+94
-4
lines changed

2 files changed

+94
-4
lines changed

src/parser/mod.rs

+15-3
Original file line numberDiff line numberDiff line change
@@ -484,8 +484,14 @@ impl<'a> Parser<'a> {
484484
}
485485

486486
let statement = self.parse_statement()?;
487+
expecting_statement_delimiter = match &statement {
488+
Statement::If(s) => match s.if_block.conditional_statements {
489+
ConditionalStatements::BeginEnd { .. } => false,
490+
_ => true,
491+
},
492+
_ => true
493+
};
487494
stmts.push(statement);
488-
expecting_statement_delimiter = true;
489495
}
490496
Ok(stmts)
491497
}
@@ -5170,8 +5176,14 @@ impl<'a> Parser<'a> {
51705176

51715177
self.expect_keyword_is(Keyword::AS)?;
51725178
self.expect_keyword_is(Keyword::BEGIN)?;
5173-
let function_body = Some(CreateFunctionBody::MultiStatement(self.parse_statements()?));
5174-
self.expect_keyword_is(Keyword::END)?;
5179+
let mut result = self.parse_statements()?;
5180+
// note: `parse_statements` will consume the `END` token & produce a Commit statement...
5181+
if let Some(Statement::Commit{ chain, end, modifier }) = result.last() {
5182+
if *chain == false && *end == true && *modifier == None {
5183+
result = result[..result.len() - 1].to_vec();
5184+
}
5185+
}
5186+
let function_body = Some(CreateFunctionBody::MultiStatement(result));
51755187

51765188
Ok(Statement::CreateFunction(CreateFunction {
51775189
or_alter,

tests/sqlparser_mssql.rs

+79-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
mod test_utils;
2424

2525
use helpers::attached_token::AttachedToken;
26-
use sqlparser::tokenizer::{Location, Span};
26+
use sqlparser::keywords::Keyword;
27+
use sqlparser::tokenizer::{Location, Span, TokenWithSpan};
2728
use test_utils::*;
2829

2930
use sqlparser::ast::DataType::{Int, Text, Varbinary};
@@ -326,6 +327,83 @@ fn parse_create_function() {
326327
remote_connection: None,
327328
}),
328329
);
330+
331+
let create_function_with_conditional = r#"
332+
CREATE FUNCTION some_scalar_udf()
333+
RETURNS INT
334+
AS
335+
BEGIN
336+
IF 1=2
337+
BEGIN
338+
RETURN 1;
339+
END
340+
341+
RETURN 0;
342+
END
343+
"#;
344+
let create_stmt = ms().one_statement_parses_to(create_function_with_conditional, "");
345+
assert_eq!(
346+
create_stmt,
347+
Statement::CreateFunction(CreateFunction {
348+
or_alter: false,
349+
or_replace: false,
350+
temporary: false,
351+
if_not_exists: false,
352+
name: ObjectName::from(vec![Ident {
353+
value: "some_scalar_udf".into(),
354+
quote_style: None,
355+
span: Span::empty(),
356+
}]),
357+
args: Some(vec![]),
358+
return_type: Some(DataType::Int(None)),
359+
function_body: Some(CreateFunctionBody::MultiStatement(vec![
360+
Statement::If(IfStatement {
361+
if_block: ConditionalStatementBlock {
362+
start_token: AttachedToken(TokenWithSpan::wrap(sqlparser::tokenizer::Token::Word(sqlparser::tokenizer::Word {
363+
value: "IF".to_string(),
364+
quote_style: None,
365+
keyword: Keyword::IF
366+
}))),
367+
condition: Some(Expr::BinaryOp {
368+
left: Box::new(Expr::Value(Value::Number("1".to_string(), false).with_empty_span())),
369+
op: sqlparser::ast::BinaryOperator::Eq,
370+
right: Box::new(Expr::Value(Value::Number("2".to_string(), false).with_empty_span())),
371+
}),
372+
then_token: None,
373+
conditional_statements: ConditionalStatements::BeginEnd {
374+
begin_token: AttachedToken(TokenWithSpan::wrap(sqlparser::tokenizer::Token::Word(sqlparser::tokenizer::Word {
375+
value: "BEGIN".to_string(),
376+
quote_style: None,
377+
keyword: Keyword::BEGIN
378+
}))),
379+
statements: vec![Statement::Return(ReturnStatement {
380+
value: Some(ReturnStatementValue::Expr(Expr::Value((number("1")).with_empty_span()))),
381+
})],
382+
end_token: AttachedToken(TokenWithSpan::wrap(sqlparser::tokenizer::Token::Word(sqlparser::tokenizer::Word {
383+
value: "END".to_string(),
384+
quote_style: None,
385+
keyword: Keyword::END
386+
}))),
387+
},
388+
},
389+
elseif_blocks: vec![],
390+
else_block: None,
391+
end_token: None,
392+
}),
393+
Statement::Return(ReturnStatement {
394+
value: Some(ReturnStatementValue::Expr(Expr::Value((number("0")).with_empty_span()))),
395+
}),
396+
])),
397+
behavior: None,
398+
called_on_null: None,
399+
parallel: None,
400+
using: None,
401+
language: None,
402+
determinism_specifier: None,
403+
options: None,
404+
remote_connection: None,
405+
})
406+
);
329407
}
330408

331409
#[test]

0 commit comments

Comments
 (0)