Skip to content

Add CREATE FUNCTION support for SQL Server #1808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/ast/ddl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2157,6 +2157,10 @@ impl fmt::Display for ClusteredBy {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct CreateFunction {
/// True if this is a `CREATE OR ALTER FUNCTION` statement
///
/// [MsSql](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql?view=sql-server-ver16#or-alter)
pub or_alter: bool,
pub or_replace: bool,
pub temporary: bool,
pub if_not_exists: bool,
Expand Down Expand Up @@ -2219,9 +2223,10 @@ impl fmt::Display for CreateFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"CREATE {or_replace}{temp}FUNCTION {if_not_exists}{name}",
"CREATE {or_alter}{or_replace}{temp}FUNCTION {if_not_exists}{name}",
name = self.name,
temp = if self.temporary { "TEMPORARY " } else { "" },
or_alter = if self.or_alter { "OR ALTER " } else { "" },
or_replace = if self.or_replace { "OR REPLACE " } else { "" },
if_not_exists = if self.if_not_exists {
"IF NOT EXISTS "
Expand Down Expand Up @@ -2272,6 +2277,9 @@ impl fmt::Display for CreateFunction {
if let Some(CreateFunctionBody::AsAfterOptions(function_body)) = &self.function_body {
write!(f, " AS {function_body}")?;
}
if let Some(CreateFunctionBody::AsBeginEnd(bes)) = &self.function_body {
write!(f, " AS {bes}")?;
}
Ok(())
}
}
Expand Down
100 changes: 89 additions & 11 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2292,18 +2292,14 @@ pub enum ConditionalStatements {
/// SELECT 1; SELECT 2; SELECT 3; ...
Sequence { statements: Vec<Statement> },
/// BEGIN SELECT 1; SELECT 2; SELECT 3; ... END
BeginEnd {
begin_token: AttachedToken,
statements: Vec<Statement>,
end_token: AttachedToken,
},
BeginEnd(BeginEndStatements),
}

impl ConditionalStatements {
pub fn statements(&self) -> &Vec<Statement> {
match self {
ConditionalStatements::Sequence { statements } => statements,
ConditionalStatements::BeginEnd { statements, .. } => statements,
ConditionalStatements::BeginEnd(bes) => &bes.statements,
}
}
}
Expand All @@ -2317,15 +2313,44 @@ impl fmt::Display for ConditionalStatements {
}
Ok(())
}
ConditionalStatements::BeginEnd { statements, .. } => {
write!(f, "BEGIN ")?;
format_statement_list(f, statements)?;
write!(f, " END")
}
ConditionalStatements::BeginEnd(bes) => write!(f, "{}", bes),
}
}
}

/// Represents a list of statements enclosed within `BEGIN` and `END` keywords.
/// Example:
/// ```sql
/// BEGIN
/// SELECT 1;
/// SELECT 2;
/// END
/// ```
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct BeginEndStatements {
pub begin_token: AttachedToken,
pub statements: Vec<Statement>,
pub end_token: AttachedToken,
}

impl fmt::Display for BeginEndStatements {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let BeginEndStatements {
begin_token: AttachedToken(begin_token),
statements,
end_token: AttachedToken(end_token),
} = self;

write!(f, "{begin_token} ")?;
if !statements.is_empty() {
format_statement_list(f, statements)?;
}
write!(f, " {end_token}")
}
}

/// A `RAISE` statement.
///
/// Examples:
Expand Down Expand Up @@ -3614,6 +3639,7 @@ pub enum Statement {
/// 1. [Hive](https://cwiki.apache.org/confluence/display/hive/languagemanual+ddl#LanguageManualDDL-Create/Drop/ReloadFunction)
/// 2. [PostgreSQL](https://www.postgresql.org/docs/15/sql-createfunction.html)
/// 3. [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_function_statement)
/// 4. [MsSql](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql)
CreateFunction(CreateFunction),
/// CREATE TRIGGER
///
Expand Down Expand Up @@ -4060,6 +4086,12 @@ pub enum Statement {
///
/// See: <https://learn.microsoft.com/en-us/sql/t-sql/statements/print-transact-sql>
Print(PrintStatement),
/// ```sql
/// RETURN [ expression ]
/// ```
///
/// See [ReturnStatement]
Return(ReturnStatement),
}

#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
Expand Down Expand Up @@ -5752,6 +5784,7 @@ impl fmt::Display for Statement {
Ok(())
}
Statement::Print(s) => write!(f, "{s}"),
Statement::Return(r) => write!(f, "{r}"),
Statement::List(command) => write!(f, "LIST {command}"),
Statement::Remove(command) => write!(f, "REMOVE {command}"),
}
Expand Down Expand Up @@ -8354,6 +8387,7 @@ impl fmt::Display for FunctionDeterminismSpecifier {
///
/// [BigQuery]: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#syntax_11
/// [PostgreSQL]: https://www.postgresql.org/docs/15/sql-createfunction.html
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
Expand Down Expand Up @@ -8382,6 +8416,22 @@ pub enum CreateFunctionBody {
///
/// [BigQuery]: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#syntax_11
AsAfterOptions(Expr),
/// Function body with statements before the `RETURN` keyword.
///
/// Example:
/// ```sql
/// CREATE FUNCTION my_scalar_udf(a INT, b INT)
/// RETURNS INT
/// AS
/// BEGIN
/// DECLARE c INT;
/// SET c = a + b;
/// RETURN c;
/// END
/// ```
///
/// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql
AsBeginEnd(BeginEndStatements),
/// Function body expression using the 'RETURN' keyword.
///
/// Example:
Expand Down Expand Up @@ -9230,6 +9280,34 @@ impl fmt::Display for PrintStatement {
}
}

/// Represents a `Return` statement.
///
/// [MsSql triggers](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-trigger-transact-sql)
/// [MsSql functions](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql)
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct ReturnStatement {
pub value: Option<ReturnStatementValue>,
}

impl fmt::Display for ReturnStatement {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match &self.value {
Some(ReturnStatementValue::Expr(expr)) => write!(f, "RETURN {}", expr),
None => write!(f, "RETURN"),
}
}
}

/// Variants of a `RETURN` statement
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum ReturnStatementValue {
Expr(Expr),
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
26 changes: 19 additions & 7 deletions src/ast/spans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ use crate::tokenizer::Span;
use super::{
dcl::SecondaryRoles, value::ValueWithSpan, AccessExpr, AlterColumnOperation,
AlterIndexOperation, AlterTableOperation, Array, Assignment, AssignmentTarget, AttachedToken,
CaseStatement, CloseCursor, ClusteredIndex, ColumnDef, ColumnOption, ColumnOptionDef,
ConditionalStatementBlock, ConditionalStatements, ConflictTarget, ConnectBy,
BeginEndStatements, CaseStatement, CloseCursor, ClusteredIndex, ColumnDef, ColumnOption,
ColumnOptionDef, ConditionalStatementBlock, ConditionalStatements, ConflictTarget, ConnectBy,
ConstraintCharacteristics, CopySource, CreateIndex, CreateTable, CreateTableOptions, Cte,
Delete, DoUpdate, ExceptSelectItem, ExcludeSelectItem, Expr, ExprWithAlias, Fetch, FromTable,
Function, FunctionArg, FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList,
Expand Down Expand Up @@ -520,6 +520,7 @@ impl Spanned for Statement {
Statement::RenameTable { .. } => Span::empty(),
Statement::RaisError { .. } => Span::empty(),
Statement::Print { .. } => Span::empty(),
Statement::Return { .. } => Span::empty(),
Statement::List(..) | Statement::Remove(..) => Span::empty(),
}
}
Expand Down Expand Up @@ -778,11 +779,7 @@ impl Spanned for ConditionalStatements {
ConditionalStatements::Sequence { statements } => {
union_spans(statements.iter().map(|s| s.span()))
}
ConditionalStatements::BeginEnd {
begin_token: AttachedToken(start),
statements: _,
end_token: AttachedToken(end),
} => union_spans([start.span, end.span].into_iter()),
ConditionalStatements::BeginEnd(bes) => bes.span(),
}
}
}
Expand Down Expand Up @@ -2281,6 +2278,21 @@ impl Spanned for TableObject {
}
}

impl Spanned for BeginEndStatements {
fn span(&self) -> Span {
let BeginEndStatements {
begin_token,
statements,
end_token,
} = self;
union_spans(
core::iter::once(begin_token.0.span)
.chain(statements.iter().map(|i| i.span()))
.chain(core::iter::once(end_token.0.span)),
)
}
}

#[cfg(test)]
pub mod tests {
use crate::dialect::{Dialect, GenericDialect, SnowflakeDialect};
Expand Down
16 changes: 11 additions & 5 deletions src/dialect/mssql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
// under the License.

use crate::ast::helpers::attached_token::AttachedToken;
use crate::ast::{ConditionalStatementBlock, ConditionalStatements, IfStatement, Statement};
use crate::ast::{
BeginEndStatements, ConditionalStatementBlock, ConditionalStatements, IfStatement, Statement,
};
use crate::dialect::Dialect;
use crate::keywords::{self, Keyword};
use crate::parser::{Parser, ParserError};
Expand Down Expand Up @@ -149,11 +151,11 @@ impl MsSqlDialect {
start_token: AttachedToken(if_token),
condition: Some(condition),
then_token: None,
conditional_statements: ConditionalStatements::BeginEnd {
conditional_statements: ConditionalStatements::BeginEnd(BeginEndStatements {
begin_token: AttachedToken(begin_token),
statements,
end_token: AttachedToken(end_token),
},
}),
}
} else {
let stmt = parser.parse_statement()?;
Expand All @@ -167,8 +169,10 @@ impl MsSqlDialect {
}
};

let mut prior_statement_ended_with_semi_colon = false;
while let Token::SemiColon = parser.peek_token_ref().token {
parser.advance_token();
prior_statement_ended_with_semi_colon = true;
}

let mut else_block = None;
Expand All @@ -182,11 +186,11 @@ impl MsSqlDialect {
start_token: AttachedToken(else_token),
condition: None,
then_token: None,
conditional_statements: ConditionalStatements::BeginEnd {
conditional_statements: ConditionalStatements::BeginEnd(BeginEndStatements {
begin_token: AttachedToken(begin_token),
statements,
end_token: AttachedToken(end_token),
},
}),
});
} else {
let stmt = parser.parse_statement()?;
Expand All @@ -199,6 +203,8 @@ impl MsSqlDialect {
},
});
}
} else if prior_statement_ended_with_semi_colon {
parser.prev_token();
}

Ok(Statement::If(IfStatement {
Expand Down
Loading