diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 8c4079213..3f83ac56f 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -2256,6 +2256,57 @@ impl fmt::Display for ConditionalStatements { } } +/// A `RAISE` statement. +/// +/// Examples: +/// ```sql +/// RAISE USING MESSAGE = 'error'; +/// +/// RAISE myerror; +/// ``` +/// +/// [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#raise) +/// [Snowflake](https://docs.snowflake.com/en/sql-reference/snowflake-scripting/raise) +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct RaiseStatement { + pub value: Option, +} + +impl fmt::Display for RaiseStatement { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let RaiseStatement { value } = self; + + write!(f, "RAISE")?; + if let Some(value) = value { + write!(f, " {value}")?; + } + + Ok(()) + } +} + +/// Represents the error value of a [RaiseStatement]. +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum RaiseStatementValue { + /// `RAISE USING MESSAGE = 'error'` + UsingMessage(Expr), + /// `RAISE myerror` + Expr(Expr), +} + +impl fmt::Display for RaiseStatementValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + RaiseStatementValue::Expr(expr) => write!(f, "{expr}"), + RaiseStatementValue::UsingMessage(expr) => write!(f, "USING MESSAGE = {expr}"), + } + } +} + /// Represents an expression assignment within a variable `DECLARE` statement. /// /// Examples: @@ -2827,6 +2878,8 @@ pub enum Statement { Case(CaseStatement), /// An `IF` statement. If(IfStatement), + /// A `RAISE` statement. + Raise(RaiseStatement), /// ```sql /// CALL /// ``` @@ -4142,6 +4195,9 @@ impl fmt::Display for Statement { Statement::If(stmt) => { write!(f, "{stmt}") } + Statement::Raise(stmt) => { + write!(f, "{stmt}") + } Statement::AttachDatabase { schema_name, database_file_name, diff --git a/src/ast/spans.rs b/src/ast/spans.rs index 0ee11f23f..c0ddcbbe7 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -32,11 +32,11 @@ use super::{ JoinOperator, JsonPath, JsonPathElem, LateralView, LimitClause, MatchRecognizePattern, Measure, NamedWindowDefinition, ObjectName, ObjectNamePart, Offset, OnConflict, OnConflictAction, OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition, PivotValueSource, ProjectionSelect, - Query, ReferentialAction, RenameSelectItem, ReplaceSelectElement, ReplaceSelectItem, Select, - SelectInto, SelectItem, SetExpr, SqlOption, Statement, Subscript, SymbolDefinition, TableAlias, - TableAliasColumnDef, TableConstraint, TableFactor, TableObject, TableOptionsClustered, - TableWithJoins, UpdateTableFromKind, Use, Value, Values, ViewColumnDef, - WildcardAdditionalOptions, With, WithFill, + Query, RaiseStatement, RaiseStatementValue, ReferentialAction, RenameSelectItem, + ReplaceSelectElement, ReplaceSelectItem, Select, SelectInto, SelectItem, SetExpr, SqlOption, + Statement, Subscript, SymbolDefinition, TableAlias, TableAliasColumnDef, TableConstraint, + TableFactor, TableObject, TableOptionsClustered, TableWithJoins, UpdateTableFromKind, Use, + Value, Values, ViewColumnDef, WildcardAdditionalOptions, With, WithFill, }; /// Given an iterator of spans, return the [Span::union] of all spans. @@ -337,6 +337,7 @@ impl Spanned for Statement { } => source.span(), Statement::Case(stmt) => stmt.span(), Statement::If(stmt) => stmt.span(), + Statement::Raise(stmt) => stmt.span(), Statement::Call(function) => function.span(), Statement::Copy { source, @@ -782,6 +783,23 @@ impl Spanned for ConditionalStatements { } } +impl Spanned for RaiseStatement { + fn span(&self) -> Span { + let RaiseStatement { value } = self; + + union_spans(value.iter().map(|value| value.span())) + } +} + +impl Spanned for RaiseStatementValue { + fn span(&self) -> Span { + match self { + RaiseStatementValue::UsingMessage(expr) => expr.span(), + RaiseStatementValue::Expr(expr) => expr.span(), + } + } +} + /// # partial span /// /// Missing spans: diff --git a/src/keywords.rs b/src/keywords.rs index 47da10096..565aa638e 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -533,6 +533,7 @@ define_keywords!( MEDIUMTEXT, MEMBER, MERGE, + MESSAGE, METADATA, METHOD, METRIC, @@ -695,6 +696,7 @@ define_keywords!( QUARTER, QUERY, QUOTE, + RAISE, RAISERROR, RANGE, RANK, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index e4c170edd..97280b982 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -536,6 +536,10 @@ impl<'a> Parser<'a> { self.prev_token(); self.parse_if_stmt() } + Keyword::RAISE => { + self.prev_token(); + self.parse_raise_stmt() + } Keyword::SELECT | Keyword::WITH | Keyword::VALUES | Keyword::FROM => { self.prev_token(); self.parse_query().map(Statement::Query) @@ -719,6 +723,22 @@ impl<'a> Parser<'a> { }) } + /// Parse a `RAISE` statement. + /// + /// See [Statement::Raise] + pub fn parse_raise_stmt(&mut self) -> Result { + self.expect_keyword_is(Keyword::RAISE)?; + + let value = if self.parse_keywords(&[Keyword::USING, Keyword::MESSAGE]) { + self.expect_token(&Token::Eq)?; + Some(RaiseStatementValue::UsingMessage(self.parse_expr()?)) + } else { + self.maybe_parse(|parser| parser.parse_expr().map(RaiseStatementValue::Expr))? + }; + + Ok(Statement::Raise(RaiseStatement { value })) + } + pub fn parse_comment(&mut self) -> Result { let if_exists = self.parse_keywords(&[Keyword::IF, Keyword::EXISTS]); diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index c65cc6b53..c8df7cab3 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -14298,6 +14298,29 @@ fn parse_if_statement() { ); } +#[test] +fn parse_raise_statement() { + let sql = "RAISE USING MESSAGE = 42"; + let Statement::Raise(stmt) = verified_stmt(sql) else { + unreachable!() + }; + assert_eq!( + Some(RaiseStatementValue::UsingMessage(Expr::value(number("42")))), + stmt.value + ); + + verified_stmt("RAISE USING MESSAGE = 'error'"); + verified_stmt("RAISE myerror"); + verified_stmt("RAISE 42"); + verified_stmt("RAISE using"); + verified_stmt("RAISE"); + + assert_eq!( + ParserError::ParserError("Expected: =, found: error".to_string()), + parse_sql_statements("RAISE USING MESSAGE error").unwrap_err() + ); +} + #[test] fn test_lambdas() { let dialects = all_dialects_where(|d| d.supports_lambda_functions());