Skip to content

Commit ba4c0fc

Browse files
committed
Add basic CREATE FUNCTION support for SQL Server
- in this dialect, functions can have statement(s) bodies like stored procedures (including `BEGIN`..`END`) - functions must end with `RETURN`, so a corresponding statement type is also introduced
1 parent 514d2ec commit ba4c0fc

File tree

6 files changed

+285
-9
lines changed

6 files changed

+285
-9
lines changed

src/ast/ddl.rs

+6
Original file line numberDiff line numberDiff line change
@@ -2272,6 +2272,12 @@ impl fmt::Display for CreateFunction {
22722272
if let Some(CreateFunctionBody::AsAfterOptions(function_body)) = &self.function_body {
22732273
write!(f, " AS {function_body}")?;
22742274
}
2275+
if let Some(CreateFunctionBody::MultiStatement(statements)) = &self.function_body {
2276+
write!(f, " AS")?;
2277+
write!(f, " BEGIN")?;
2278+
write!(f, " {}", display_separated(statements, "; "))?;
2279+
write!(f, " END")?;
2280+
}
22752281
Ok(())
22762282
}
22772283
}

src/ast/mod.rs

+60-1
Original file line numberDiff line numberDiff line change
@@ -3614,6 +3614,7 @@ pub enum Statement {
36143614
/// 1. [Hive](https://cwiki.apache.org/confluence/display/hive/languagemanual+ddl#LanguageManualDDL-Create/Drop/ReloadFunction)
36153615
/// 2. [PostgreSQL](https://www.postgresql.org/docs/15/sql-createfunction.html)
36163616
/// 3. [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_function_statement)
3617+
/// 4. [MsSql](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql)
36173618
CreateFunction(CreateFunction),
36183619
/// CREATE TRIGGER
36193620
///
@@ -4054,6 +4055,12 @@ pub enum Statement {
40544055
arguments: Vec<Expr>,
40554056
options: Vec<RaisErrorOption>,
40564057
},
4058+
/// ```sql
4059+
/// RETURN [ expression ]
4060+
/// ```
4061+
///
4062+
/// See [ReturnStatement]
4063+
Return(ReturnStatement),
40574064
}
40584065

40594066
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
@@ -5745,7 +5752,7 @@ impl fmt::Display for Statement {
57455752
}
57465753
Ok(())
57475754
}
5748-
5755+
Statement::Return(r) => write!(f, "{r}"),
57495756
Statement::List(command) => write!(f, "LIST {command}"),
57505757
Statement::Remove(command) => write!(f, "REMOVE {command}"),
57515758
}
@@ -8348,6 +8355,7 @@ impl fmt::Display for FunctionDeterminismSpecifier {
83488355
///
83498356
/// [BigQuery]: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#syntax_11
83508357
/// [PostgreSQL]: https://www.postgresql.org/docs/15/sql-createfunction.html
8358+
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
83518359
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
83528360
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
83538361
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
@@ -8376,6 +8384,22 @@ pub enum CreateFunctionBody {
83768384
///
83778385
/// [BigQuery]: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#syntax_11
83788386
AsAfterOptions(Expr),
8387+
/// Function body with statements before the `RETURN` keyword.
8388+
///
8389+
/// Example:
8390+
/// ```sql
8391+
/// CREATE FUNCTION my_scalar_udf(a INT, b INT)
8392+
/// RETURNS INT
8393+
/// AS
8394+
/// BEGIN
8395+
/// DECLARE c INT;
8396+
/// SET c = a + b;
8397+
/// RETURN c;
8398+
/// END
8399+
/// ```
8400+
///
8401+
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
8402+
MultiStatement(Vec<Statement>),
83798403
/// Function body expression using the 'RETURN' keyword.
83808404
///
83818405
/// Example:
@@ -9211,6 +9235,41 @@ pub enum CopyIntoSnowflakeKind {
92119235
Location,
92129236
}
92139237

9238+
/// Return (MsSql)
9239+
///
9240+
/// for Functions:
9241+
/// RETURN scalar_expression
9242+
///
9243+
/// See <https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql>
9244+
///
9245+
/// for Triggers:
9246+
/// RETURN
9247+
///
9248+
/// See <https://learn.microsoft.com/en-us/sql/t-sql/statements/create-trigger-transact-sql>
9249+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
9250+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9251+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
9252+
pub struct ReturnStatement {
9253+
pub value: Option<ReturnStatementValue>,
9254+
}
9255+
9256+
impl fmt::Display for ReturnStatement {
9257+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
9258+
match &self.value {
9259+
Some(ReturnStatementValue::Expr(expr)) => write!(f, "RETURN {}", expr),
9260+
None => write!(f, "RETURN"),
9261+
}
9262+
}
9263+
}
9264+
9265+
/// Variants of a `RETURN` statement
9266+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
9267+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9268+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
9269+
pub enum ReturnStatementValue {
9270+
Expr(Expr),
9271+
}
9272+
92149273
#[cfg(test)]
92159274
mod tests {
92169275
use super::*;

src/ast/spans.rs

+1
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,7 @@ impl Spanned for Statement {
519519
Statement::UNLISTEN { .. } => Span::empty(),
520520
Statement::RenameTable { .. } => Span::empty(),
521521
Statement::RaisError { .. } => Span::empty(),
522+
Statement::Return { .. } => Span::empty(),
522523
Statement::List(..) | Statement::Remove(..) => Span::empty(),
523524
}
524525
}

src/parser/mod.rs

+77-6
Original file line numberDiff line numberDiff line change
@@ -577,13 +577,7 @@ impl<'a> Parser<'a> {
577577
Keyword::GRANT => self.parse_grant(),
578578
Keyword::REVOKE => self.parse_revoke(),
579579
Keyword::START => self.parse_start_transaction(),
580-
// `BEGIN` is a nonstandard but common alias for the
581-
// standard `START TRANSACTION` statement. It is supported
582-
// by at least PostgreSQL and MySQL.
583580
Keyword::BEGIN => self.parse_begin(),
584-
// `END` is a nonstandard but common alias for the
585-
// standard `COMMIT TRANSACTION` statement. It is supported
586-
// by PostgreSQL.
587581
Keyword::END => self.parse_end(),
588582
Keyword::SAVEPOINT => self.parse_savepoint(),
589583
Keyword::RELEASE => self.parse_release(),
@@ -617,6 +611,7 @@ impl<'a> Parser<'a> {
617611
}
618612
// `COMMENT` is snowflake specific https://docs.snowflake.com/en/sql-reference/sql/comment
619613
Keyword::COMMENT if self.dialect.supports_comment_on() => self.parse_comment(),
614+
Keyword::RETURN => self.parse_return(),
620615
_ => self.expected("an SQL statement", next_token),
621616
},
622617
Token::LParen => {
@@ -4881,6 +4876,8 @@ impl<'a> Parser<'a> {
48814876
self.parse_create_macro(or_replace, temporary)
48824877
} else if dialect_of!(self is BigQueryDialect) {
48834878
self.parse_bigquery_create_function(or_replace, temporary)
4879+
} else if dialect_of!(self is MsSqlDialect) {
4880+
self.parse_mssql_create_function(or_replace, temporary)
48844881
} else {
48854882
self.prev_token();
48864883
self.expected("an object type after CREATE", self.peek_token())
@@ -5135,6 +5132,72 @@ impl<'a> Parser<'a> {
51355132
}))
51365133
}
51375134

5135+
/// Parse `CREATE FUNCTION` for [MsSql]
5136+
///
5137+
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
5138+
fn parse_mssql_create_function(
5139+
&mut self,
5140+
or_replace: bool,
5141+
temporary: bool,
5142+
) -> Result<Statement, ParserError> {
5143+
let name = self.parse_object_name(false)?;
5144+
5145+
let parse_function_param =
5146+
|parser: &mut Parser| -> Result<OperateFunctionArg, ParserError> {
5147+
let name = parser.parse_identifier()?;
5148+
let data_type = parser.parse_data_type()?;
5149+
Ok(OperateFunctionArg {
5150+
mode: None,
5151+
name: Some(name),
5152+
data_type,
5153+
default_expr: None,
5154+
})
5155+
};
5156+
self.expect_token(&Token::LParen)?;
5157+
let args = self.parse_comma_separated0(parse_function_param, Token::RParen)?;
5158+
self.expect_token(&Token::RParen)?;
5159+
5160+
let return_type = if self.parse_keyword(Keyword::RETURNS) {
5161+
Some(self.parse_data_type()?)
5162+
} else {
5163+
return parser_err!("Expected RETURNS keyword", self.peek_token().span.start);
5164+
};
5165+
5166+
self.expect_keyword_is(Keyword::AS)?;
5167+
self.expect_keyword_is(Keyword::BEGIN)?;
5168+
let mut result = self.parse_statements()?;
5169+
// note: `parse_statements` will consume the `END` token & produce a Commit statement...
5170+
if let Some(Statement::Commit {
5171+
chain,
5172+
end,
5173+
modifier,
5174+
}) = result.last()
5175+
{
5176+
if *chain == false && *end == true && *modifier == None {
5177+
result = result[..result.len() - 1].to_vec();
5178+
}
5179+
}
5180+
let function_body = Some(CreateFunctionBody::MultiStatement(result));
5181+
5182+
Ok(Statement::CreateFunction(CreateFunction {
5183+
or_replace,
5184+
temporary,
5185+
if_not_exists: false,
5186+
name,
5187+
args: Some(args),
5188+
return_type,
5189+
function_body,
5190+
language: None,
5191+
determinism_specifier: None,
5192+
options: None,
5193+
remote_connection: None,
5194+
using: None,
5195+
behavior: None,
5196+
called_on_null: None,
5197+
parallel: None,
5198+
}))
5199+
}
5200+
51385201
fn parse_function_arg(&mut self) -> Result<OperateFunctionArg, ParserError> {
51395202
let mode = if self.parse_keyword(Keyword::IN) {
51405203
Some(ArgMode::In)
@@ -15058,6 +15121,14 @@ impl<'a> Parser<'a> {
1505815121
}
1505915122
}
1506015123

15124+
/// Parse [Statement::Return]
15125+
fn parse_return(&mut self) -> Result<Statement, ParserError> {
15126+
let expr = self.parse_expr()?;
15127+
Ok(Statement::Return(ReturnStatement {
15128+
value: Some(ReturnStatementValue::Expr(expr)),
15129+
}))
15130+
}
15131+
1506115132
/// Consume the parser and return its underlying token buffer
1506215133
pub fn into_tokens(self) -> Vec<TokenWithSpan> {
1506315134
self.tokens

tests/sqlparser_hive.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use sqlparser::ast::{
2525
Expr, Function, FunctionArgumentList, FunctionArguments, Ident, ObjectName, OrderByExpr,
2626
OrderByOptions, SelectItem, Set, Statement, TableFactor, UnaryOperator, Use, Value,
2727
};
28-
use sqlparser::dialect::{GenericDialect, HiveDialect, MsSqlDialect};
28+
use sqlparser::dialect::{AnsiDialect, GenericDialect, HiveDialect};
2929
use sqlparser::parser::ParserError;
3030
use sqlparser::test_utils::*;
3131

@@ -423,7 +423,7 @@ fn parse_create_function() {
423423
}
424424

425425
// Test error in dialect that doesn't support parsing CREATE FUNCTION
426-
let unsupported_dialects = TestedDialects::new(vec![Box::new(MsSqlDialect {})]);
426+
let unsupported_dialects = TestedDialects::new(vec![Box::new(AnsiDialect {})]);
427427

428428
assert_eq!(
429429
unsupported_dialects.parse_sql_statements(sql).unwrap_err(),

0 commit comments

Comments
 (0)