Skip to content

Commit a26087a

Browse files
committed
Add basic CREATE FUNCTION support for SQL Server
- in this dialect, functions can have statement(s) bodies like stored procedures (including `BEGIN`..`END`) - functions must end with `RETURN`, so a corresponding statement type is also introduced
1 parent 0d2976d commit a26087a

File tree

6 files changed

+268
-9
lines changed

6 files changed

+268
-9
lines changed

src/ast/ddl.rs

+7
Original file line numberDiff line numberDiff line change
@@ -2272,6 +2272,13 @@ impl fmt::Display for CreateFunction {
22722272
if let Some(CreateFunctionBody::AsAfterOptions(function_body)) = &self.function_body {
22732273
write!(f, " AS {function_body}")?;
22742274
}
2275+
if let Some(CreateFunctionBody::MultiStatement(statements)) = &self.function_body {
2276+
write!(
2277+
f,
2278+
" AS BEGIN {function_body} END",
2279+
function_body = display_separated(statements, "; ")
2280+
)?;
2281+
}
22752282
Ok(())
22762283
}
22772284
}

src/ast/mod.rs

+54-1
Original file line numberDiff line numberDiff line change
@@ -3610,6 +3610,7 @@ pub enum Statement {
36103610
/// 1. [Hive](https://cwiki.apache.org/confluence/display/hive/languagemanual+ddl#LanguageManualDDL-Create/Drop/ReloadFunction)
36113611
/// 2. [PostgreSQL](https://www.postgresql.org/docs/15/sql-createfunction.html)
36123612
/// 3. [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_function_statement)
3613+
/// 4. [MsSql](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql)
36133614
CreateFunction(CreateFunction),
36143615
/// CREATE TRIGGER
36153616
///
@@ -4050,6 +4051,13 @@ pub enum Statement {
40504051
arguments: Vec<Expr>,
40514052
options: Vec<RaisErrorOption>,
40524053
},
4054+
/// Return (Mssql)
4055+
///
4056+
/// for Functions:
4057+
/// RETURN scalar_expression
4058+
///
4059+
/// See: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
4060+
Return(ReturnStatement),
40534061
}
40544062

40554063
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
@@ -5736,7 +5744,11 @@ impl fmt::Display for Statement {
57365744
write!(f, " WITH {}", display_comma_separated(options))?;
57375745
}
57385746
Ok(())
5739-
}
5747+
},
5748+
Statement::Return(r) => {
5749+
write!(f, "{r}")?;
5750+
Ok(())
5751+
},
57405752

57415753
Statement::List(command) => write!(f, "LIST {command}"),
57425754
Statement::Remove(command) => write!(f, "REMOVE {command}"),
@@ -8340,6 +8352,7 @@ impl fmt::Display for FunctionDeterminismSpecifier {
83408352
///
83418353
/// [BigQuery]: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#syntax_11
83428354
/// [PostgreSQL]: https://www.postgresql.org/docs/15/sql-createfunction.html
8355+
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
83438356
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
83448357
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
83458358
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
@@ -8368,6 +8381,22 @@ pub enum CreateFunctionBody {
83688381
///
83698382
/// [BigQuery]: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#syntax_11
83708383
AsAfterOptions(Expr),
8384+
/// Function body with statements before the `RETURN` keyword.
8385+
///
8386+
/// Example:
8387+
/// ```sql
8388+
/// CREATE FUNCTION my_scalar_udf(a INT, b INT)
8389+
/// RETURNS INT
8390+
/// AS
8391+
/// BEGIN
8392+
/// DECLARE c INT;
8393+
/// SET c = a + b;
8394+
/// RETURN c;
8395+
/// END;
8396+
/// ```
8397+
///
8398+
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
8399+
MultiStatement(Vec<Statement>),
83718400
/// Function body expression using the 'RETURN' keyword.
83728401
///
83738402
/// Example:
@@ -9203,6 +9232,30 @@ pub enum CopyIntoSnowflakeKind {
92039232
Location,
92049233
}
92059234

9235+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
9236+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9237+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
9238+
pub struct ReturnStatement {
9239+
pub value: Option<ReturnStatementValue>,
9240+
}
9241+
9242+
impl fmt::Display for ReturnStatement {
9243+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
9244+
match &self.value {
9245+
Some(ReturnStatementValue::Expr(expr)) => write!(f, "RETURN {}", expr),
9246+
None => write!(f, "RETURN")
9247+
}
9248+
}
9249+
}
9250+
9251+
/// Variants of the Mssql `RETURN` statement
9252+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
9253+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9254+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
9255+
pub enum ReturnStatementValue {
9256+
Expr(Expr),
9257+
}
9258+
92069259
#[cfg(test)]
92079260
mod tests {
92089261
use super::*;

src/ast/spans.rs

+1
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,7 @@ impl Spanned for Statement {
518518
Statement::UNLISTEN { .. } => Span::empty(),
519519
Statement::RenameTable { .. } => Span::empty(),
520520
Statement::RaisError { .. } => Span::empty(),
521+
Statement::Return { .. } => Span::empty(),
521522
Statement::List(..) | Statement::Remove(..) => Span::empty(),
522523
}
523524
}

src/parser/mod.rs

+65-6
Original file line numberDiff line numberDiff line change
@@ -577,13 +577,7 @@ impl<'a> Parser<'a> {
577577
Keyword::GRANT => self.parse_grant(),
578578
Keyword::REVOKE => self.parse_revoke(),
579579
Keyword::START => self.parse_start_transaction(),
580-
// `BEGIN` is a nonstandard but common alias for the
581-
// standard `START TRANSACTION` statement. It is supported
582-
// by at least PostgreSQL and MySQL.
583580
Keyword::BEGIN => self.parse_begin(),
584-
// `END` is a nonstandard but common alias for the
585-
// standard `COMMIT TRANSACTION` statement. It is supported
586-
// by PostgreSQL.
587581
Keyword::END => self.parse_end(),
588582
Keyword::SAVEPOINT => self.parse_savepoint(),
589583
Keyword::RELEASE => self.parse_release(),
@@ -617,6 +611,7 @@ impl<'a> Parser<'a> {
617611
}
618612
// `COMMENT` is snowflake specific https://docs.snowflake.com/en/sql-reference/sql/comment
619613
Keyword::COMMENT if self.dialect.supports_comment_on() => self.parse_comment(),
614+
Keyword::RETURN => self.parse_return(),
620615
_ => self.expected("an SQL statement", next_token),
621616
},
622617
Token::LParen => {
@@ -4881,6 +4876,8 @@ impl<'a> Parser<'a> {
48814876
self.parse_create_macro(or_replace, temporary)
48824877
} else if dialect_of!(self is BigQueryDialect) {
48834878
self.parse_bigquery_create_function(or_replace, temporary)
4879+
} else if dialect_of!(self is MsSqlDialect) {
4880+
self.parse_mssql_create_function(or_replace, temporary)
48844881
} else {
48854882
self.prev_token();
48864883
self.expected("an object type after CREATE", self.peek_token())
@@ -5135,6 +5132,61 @@ impl<'a> Parser<'a> {
51355132
}))
51365133
}
51375134

5135+
/// Parse `CREATE FUNCTION` for [MsSql]
5136+
///
5137+
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
5138+
fn parse_mssql_create_function(
5139+
&mut self,
5140+
or_replace: bool,
5141+
temporary: bool,
5142+
) -> Result<Statement, ParserError> {
5143+
let name = self.parse_object_name(false)?;
5144+
5145+
let parse_function_param =
5146+
|parser: &mut Parser| -> Result<OperateFunctionArg, ParserError> {
5147+
let name = parser.parse_identifier()?;
5148+
let data_type = parser.parse_data_type()?;
5149+
Ok(OperateFunctionArg {
5150+
mode: None,
5151+
name: Some(name),
5152+
data_type,
5153+
default_expr: None,
5154+
})
5155+
};
5156+
self.expect_token(&Token::LParen)?;
5157+
let args = self.parse_comma_separated0(parse_function_param, Token::RParen)?;
5158+
self.expect_token(&Token::RParen)?;
5159+
5160+
let return_type = if self.parse_keyword(Keyword::RETURNS) {
5161+
Some(self.parse_data_type()?)
5162+
} else {
5163+
None
5164+
};
5165+
5166+
self.expect_keyword_is(Keyword::AS)?;
5167+
self.expect_keyword_is(Keyword::BEGIN)?;
5168+
let function_body = Some(CreateFunctionBody::MultiStatement(self.parse_statements()?));
5169+
self.expect_keyword_is(Keyword::END)?;
5170+
5171+
Ok(Statement::CreateFunction(CreateFunction {
5172+
or_replace,
5173+
temporary,
5174+
if_not_exists: false,
5175+
name,
5176+
args: Some(args),
5177+
return_type,
5178+
function_body,
5179+
language: None,
5180+
determinism_specifier: None,
5181+
options: None,
5182+
remote_connection: None,
5183+
using: None,
5184+
behavior: None,
5185+
called_on_null: None,
5186+
parallel: None,
5187+
}))
5188+
}
5189+
51385190
fn parse_function_arg(&mut self) -> Result<OperateFunctionArg, ParserError> {
51395191
let mode = if self.parse_keyword(Keyword::IN) {
51405192
Some(ArgMode::In)
@@ -15017,6 +15069,13 @@ impl<'a> Parser<'a> {
1501715069
}
1501815070
}
1501915071

15072+
fn parse_return(&mut self) -> Result<Statement, ParserError> {
15073+
let expr = self.parse_expr()?;
15074+
Ok(Statement::Return(ReturnStatement {
15075+
value: Some(ReturnStatementValue::Expr(expr)),
15076+
}))
15077+
}
15078+
1502015079
/// Consume the parser and return its underlying token buffer
1502115080
pub fn into_tokens(self) -> Vec<TokenWithSpan> {
1502215081
self.tokens

tests/sqlparser_hive.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use sqlparser::ast::{
2525
Expr, Function, FunctionArgumentList, FunctionArguments, Ident, ObjectName, OrderByExpr,
2626
OrderByOptions, SelectItem, Set, Statement, TableFactor, UnaryOperator, Use, Value,
2727
};
28-
use sqlparser::dialect::{GenericDialect, HiveDialect, MsSqlDialect};
28+
use sqlparser::dialect::{AnsiDialect, GenericDialect, HiveDialect};
2929
use sqlparser::parser::ParserError;
3030
use sqlparser::test_utils::*;
3131

@@ -423,7 +423,7 @@ fn parse_create_function() {
423423
}
424424

425425
// Test error in dialect that doesn't support parsing CREATE FUNCTION
426-
let unsupported_dialects = TestedDialects::new(vec![Box::new(MsSqlDialect {})]);
426+
let unsupported_dialects = TestedDialects::new(vec![Box::new(AnsiDialect {})]);
427427

428428
assert_eq!(
429429
unsupported_dialects.parse_sql_statements(sql).unwrap_err(),

tests/sqlparser_mssql.rs

+139
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,145 @@ fn parse_mssql_create_procedure() {
187187
let _ = ms().verified_stmt("CREATE PROCEDURE [foo] AS BEGIN UPDATE bar SET col = 'test'; SELECT [foo] FROM BAR WHERE [FOO] > 10 END");
188188
}
189189

190+
#[test]
191+
fn parse_create_function() {
192+
let return_expression_function = "CREATE FUNCTION some_scalar_udf(@foo INT, @bar VARCHAR(256)) RETURNS INT AS BEGIN RETURN 1 END";
193+
assert_eq!(
194+
ms().verified_stmt(return_expression_function),
195+
sqlparser::ast::Statement::CreateFunction(CreateFunction {
196+
or_replace: false,
197+
temporary: false,
198+
if_not_exists: false,
199+
name: ObjectName::from(vec![Ident {
200+
value: "some_scalar_udf".into(),
201+
quote_style: None,
202+
span: Span::empty(),
203+
}]),
204+
args: Some(vec![
205+
OperateFunctionArg {
206+
mode: None,
207+
name: Some(Ident {
208+
value: "@foo".into(),
209+
quote_style: None,
210+
span: Span::empty(),
211+
}),
212+
data_type: DataType::Int(None),
213+
default_expr: None,
214+
},
215+
OperateFunctionArg {
216+
mode: None,
217+
name: Some(Ident {
218+
value: "@bar".into(),
219+
quote_style: None,
220+
span: Span::empty(),
221+
}),
222+
data_type: DataType::Varchar(Some(CharacterLength::IntegerLength {
223+
length: 256,
224+
unit: None
225+
})),
226+
default_expr: None,
227+
},
228+
]),
229+
return_type: Some(DataType::Int(None)),
230+
function_body: Some(CreateFunctionBody::MultiStatement(vec![
231+
Statement::Return(ReturnStatement {
232+
value: Some(ReturnStatementValue::Expr(Expr::Value((number("1")).with_empty_span()))),
233+
}),
234+
])),
235+
behavior: None,
236+
called_on_null: None,
237+
parallel: None,
238+
using: None,
239+
language: None,
240+
determinism_specifier: None,
241+
options: None,
242+
remote_connection: None,
243+
}),
244+
);
245+
246+
let multi_statement_function = "\
247+
CREATE FUNCTION some_scalar_udf(@foo INT, @bar VARCHAR(256)) \
248+
RETURNS INT \
249+
AS \
250+
BEGIN \
251+
SET @foo = @foo + 1; \
252+
RETURN @foo \
253+
END\
254+
";
255+
assert_eq!(
256+
ms().verified_stmt(multi_statement_function),
257+
sqlparser::ast::Statement::CreateFunction(CreateFunction {
258+
or_replace: false,
259+
temporary: false,
260+
if_not_exists: false,
261+
name: ObjectName::from(vec![Ident {
262+
value: "some_scalar_udf".into(),
263+
quote_style: None,
264+
span: Span::empty(),
265+
}]),
266+
args: Some(vec![
267+
OperateFunctionArg {
268+
mode: None,
269+
name: Some(Ident {
270+
value: "@foo".into(),
271+
quote_style: None,
272+
span: Span::empty(),
273+
}),
274+
data_type: DataType::Int(None),
275+
default_expr: None,
276+
},
277+
OperateFunctionArg {
278+
mode: None,
279+
name: Some(Ident {
280+
value: "@bar".into(),
281+
quote_style: None,
282+
span: Span::empty(),
283+
}),
284+
data_type: DataType::Varchar(Some(CharacterLength::IntegerLength {
285+
length: 256,
286+
unit: None
287+
})),
288+
default_expr: None,
289+
},
290+
]),
291+
return_type: Some(DataType::Int(None)),
292+
function_body: Some(CreateFunctionBody::MultiStatement(vec![
293+
Statement::Set(Set::SingleAssignment {
294+
scope: None,
295+
hivevar: false,
296+
variable: ObjectName::from(vec!["@foo".into()]),
297+
values: vec![sqlparser::ast::Expr::BinaryOp {
298+
left: Box::new(sqlparser::ast::Expr::Identifier(Ident {
299+
value: "@foo".to_string(),
300+
quote_style: None,
301+
span: Span::empty(),
302+
})),
303+
op: sqlparser::ast::BinaryOperator::Plus,
304+
right: Box::new(Expr::Value(
305+
(Value::Number("1".into(), false)).with_empty_span()
306+
)),
307+
}],
308+
}),
309+
Statement::Return(ReturnStatement{
310+
value: Some(ReturnStatementValue::Expr(Expr::Identifier(Ident {
311+
value: "@foo".into(),
312+
quote_style: None,
313+
span: Span::empty(),
314+
}))),
315+
}),
316+
])),
317+
behavior: None,
318+
called_on_null: None,
319+
parallel: None,
320+
using: None,
321+
language: None,
322+
determinism_specifier: None,
323+
options: None,
324+
remote_connection: None,
325+
}),
326+
);
327+
}
328+
190329
#[test]
191330
fn parse_mssql_apply_join() {
192331
let _ = ms_and_generic().verified_only_select(

0 commit comments

Comments
 (0)