diff --git a/src/ast/data_type.rs b/src/ast/data_type.rs index e6477f56b..ff2a3ad04 100644 --- a/src/ast/data_type.rs +++ b/src/ast/data_type.rs @@ -319,6 +319,10 @@ pub enum DataType { /// [`SQLiteDialect`](crate::dialect::SQLiteDialect), from statements such /// as `CREATE TABLE t1 (a)`. Unspecified, + /// Trigger data type, returned by functions associated with triggers + /// + /// [postgresql]: https://www.postgresql.org/docs/current/plpgsql-trigger.html + Trigger, } impl fmt::Display for DataType { @@ -543,6 +547,7 @@ impl fmt::Display for DataType { write!(f, "Nested({})", display_comma_separated(fields)) } DataType::Unspecified => Ok(()), + DataType::Trigger => write!(f, "TRIGGER"), } } } diff --git a/src/ast/ddl.rs b/src/ast/ddl.rs index d207f5766..bebd98604 100644 --- a/src/ast/ddl.rs +++ b/src/ast/ddl.rs @@ -1175,7 +1175,7 @@ fn display_option_spaced(option: &Option) -> impl fmt::Displ /// ` = [ DEFERRABLE | NOT DEFERRABLE ] [ INITIALLY DEFERRED | INITIALLY IMMEDIATE ] [ ENFORCED | NOT ENFORCED ]` /// /// Used in UNIQUE and foreign key constraints. The individual settings may occur in any order. -#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Default, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] pub struct ConstraintCharacteristics { diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 86e2592a3..ae0522ccc 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -53,6 +53,12 @@ pub use self::query::{ TableAlias, TableFactor, TableFunctionArgs, TableVersion, TableWithJoins, Top, TopQuantity, ValueTableMode, Values, WildcardAdditionalOptions, With, WithFill, }; + +pub use self::trigger::{ + TriggerEvent, TriggerExecBody, TriggerExecBodyType, TriggerObject, TriggerPeriod, + TriggerReferencing, TriggerReferencingType, +}; + pub use self::value::{ escape_double_quote_string, escape_quoted_string, DateTimeField, DollarQuotedString, TrimWhereField, Value, @@ -71,6 +77,7 @@ mod dml; pub mod helpers; mod operator; mod query; +mod trigger; mod value; #[cfg(feature = "visitor")] @@ -2282,7 +2289,7 @@ pub enum Statement { DropFunction { if_exists: bool, /// One or more function to drop - func_desc: Vec, + func_desc: Vec, /// `CASCADE` or `RESTRICT` option: Option, }, @@ -2292,7 +2299,7 @@ pub enum Statement { DropProcedure { if_exists: bool, /// One or more function to drop - proc_desc: Vec, + proc_desc: Vec, /// `CASCADE` or `RESTRICT` option: Option, }, @@ -2618,6 +2625,96 @@ pub enum Statement { /// [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_a_remote_function) remote_connection: Option, }, + /// CREATE TRIGGER + /// + /// Examples: + /// + /// ```sql + /// CREATE TRIGGER trigger_name + /// BEFORE INSERT ON table_name + /// FOR EACH ROW + /// EXECUTE FUNCTION trigger_function(); + /// ``` + /// + /// Postgres: + CreateTrigger { + /// The `OR REPLACE` clause is used to re-create the trigger if it already exists. + /// + /// Example: + /// ```sql + /// CREATE OR REPLACE TRIGGER trigger_name + /// AFTER INSERT ON table_name + /// FOR EACH ROW + /// EXECUTE FUNCTION trigger_function(); + /// ``` + or_replace: bool, + /// The `CONSTRAINT` keyword is used to create a trigger as a constraint. + is_constraint: bool, + /// The name of the trigger to be created. + name: ObjectName, + /// Determines whether the function is called before, after, or instead of the event. + /// + /// Example of BEFORE: + /// + /// ```sql + /// CREATE TRIGGER trigger_name + /// BEFORE INSERT ON table_name + /// FOR EACH ROW + /// EXECUTE FUNCTION trigger_function(); + /// ``` + /// + /// Example of AFTER: + /// + /// ```sql + /// CREATE TRIGGER trigger_name + /// AFTER INSERT ON table_name + /// FOR EACH ROW + /// EXECUTE FUNCTION trigger_function(); + /// ``` + /// + /// Example of INSTEAD OF: + /// + /// ```sql + /// CREATE TRIGGER trigger_name + /// INSTEAD OF INSERT ON table_name + /// FOR EACH ROW + /// EXECUTE FUNCTION trigger_function(); + /// ``` + period: TriggerPeriod, + /// Multiple events can be specified using OR, such as `INSERT`, `UPDATE`, `DELETE`, or `TRUNCATE`. + events: Vec, + /// The table on which the trigger is to be created. + table_name: ObjectName, + /// The optional referenced table name that can be referenced via + /// the `FROM` keyword. + referenced_table_name: Option, + /// This keyword immediately precedes the declaration of one or two relation names that provide access to the transition relations of the triggering statement. + referencing: Vec, + /// This specifies whether the trigger function should be fired once for + /// every row affected by the trigger event, or just once per SQL statement. + trigger_object: TriggerObject, + /// Whether to include the `EACH` term of the `FOR EACH`, as it is optional syntax. + include_each: bool, + /// Triggering conditions + condition: Option, + /// Execute logic block + exec_body: TriggerExecBody, + /// The characteristic of the trigger, which include whether the trigger is `DEFERRABLE`, `INITIALLY DEFERRED`, or `INITIALLY IMMEDIATE`, + characteristics: Option, + }, + /// DROP TRIGGER + /// + /// ```sql + /// DROP TRIGGER [ IF EXISTS ] name ON table_name [ CASCADE | RESTRICT ] + /// ``` + /// + DropTrigger { + if_exists: bool, + trigger_name: ObjectName, + table_name: ObjectName, + /// `CASCADE` or `RESTRICT` + option: Option, + }, /// ```sql /// CREATE PROCEDURE /// ``` @@ -3394,6 +3491,71 @@ impl fmt::Display for Statement { } Ok(()) } + Statement::CreateTrigger { + or_replace, + is_constraint, + name, + period, + events, + table_name, + referenced_table_name, + referencing, + trigger_object, + condition, + include_each, + exec_body, + characteristics, + } => { + write!( + f, + "CREATE {or_replace}{is_constraint}TRIGGER {name} {period}", + or_replace = if *or_replace { "OR REPLACE " } else { "" }, + is_constraint = if *is_constraint { "CONSTRAINT " } else { "" }, + )?; + + if !events.is_empty() { + write!(f, " {}", display_separated(events, " OR "))?; + } + write!(f, " ON {table_name}")?; + + if let Some(referenced_table_name) = referenced_table_name { + write!(f, " FROM {referenced_table_name}")?; + } + + if let Some(characteristics) = characteristics { + write!(f, " {characteristics}")?; + } + + if !referencing.is_empty() { + write!(f, " REFERENCING {}", display_separated(referencing, " "))?; + } + + if *include_each { + write!(f, " FOR EACH {trigger_object}")?; + } else { + write!(f, " FOR {trigger_object}")?; + } + if let Some(condition) = condition { + write!(f, " WHEN {condition}")?; + } + write!(f, " EXECUTE {exec_body}") + } + Statement::DropTrigger { + if_exists, + trigger_name, + table_name, + option, + } => { + write!(f, "DROP TRIGGER")?; + if *if_exists { + write!(f, " IF EXISTS")?; + } + write!(f, " {trigger_name} ON {table_name}")?; + if let Some(option) = option { + write!(f, " {option}")?; + } + Ok(()) + } Statement::CreateProcedure { name, or_alter, @@ -6026,16 +6188,16 @@ impl fmt::Display for DropFunctionOption { } } -/// Function describe in DROP FUNCTION. +/// Generic function description for DROP FUNCTION and CREATE TRIGGER. #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] -pub struct DropFunctionDesc { +pub struct FunctionDesc { pub name: ObjectName, pub args: Option>, } -impl fmt::Display for DropFunctionDesc { +impl fmt::Display for FunctionDesc { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self.name)?; if let Some(args) = &self.args { diff --git a/src/ast/trigger.rs b/src/ast/trigger.rs new file mode 100644 index 000000000..a0913db94 --- /dev/null +++ b/src/ast/trigger.rs @@ -0,0 +1,158 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! SQL Abstract Syntax Tree (AST) for triggers. +use super::*; + +/// This specifies whether the trigger function should be fired once for every row affected by the trigger event, or just once per SQL statement. +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum TriggerObject { + Row, + Statement, +} + +impl fmt::Display for TriggerObject { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + TriggerObject::Row => write!(f, "ROW"), + TriggerObject::Statement => write!(f, "STATEMENT"), + } + } +} + +/// This clause indicates whether the following relation name is for the before-image transition relation or the after-image transition relation +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum TriggerReferencingType { + OldTable, + NewTable, +} + +impl fmt::Display for TriggerReferencingType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + TriggerReferencingType::OldTable => write!(f, "OLD TABLE"), + TriggerReferencingType::NewTable => write!(f, "NEW TABLE"), + } + } +} + +/// This keyword immediately precedes the declaration of one or two relation names that provide access to the transition relations of the triggering statement +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct TriggerReferencing { + pub refer_type: TriggerReferencingType, + pub is_as: bool, + pub transition_relation_name: ObjectName, +} + +impl fmt::Display for TriggerReferencing { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{refer_type}{is_as} {relation_name}", + refer_type = self.refer_type, + is_as = if self.is_as { " AS" } else { "" }, + relation_name = self.transition_relation_name + ) + } +} + +/// Used to describe trigger events +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum TriggerEvent { + Insert, + Update(Vec), + Delete, + Truncate, +} + +impl fmt::Display for TriggerEvent { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + TriggerEvent::Insert => write!(f, "INSERT"), + TriggerEvent::Update(columns) => { + write!(f, "UPDATE")?; + if !columns.is_empty() { + write!(f, " OF")?; + write!(f, " {}", display_comma_separated(columns))?; + } + Ok(()) + } + TriggerEvent::Delete => write!(f, "DELETE"), + TriggerEvent::Truncate => write!(f, "TRUNCATE"), + } + } +} + +/// Trigger period +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum TriggerPeriod { + After, + Before, + InsteadOf, +} + +impl fmt::Display for TriggerPeriod { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + TriggerPeriod::After => write!(f, "AFTER"), + TriggerPeriod::Before => write!(f, "BEFORE"), + TriggerPeriod::InsteadOf => write!(f, "INSTEAD OF"), + } + } +} + +/// Types of trigger body execution body. +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum TriggerExecBodyType { + Function, + Procedure, +} + +impl fmt::Display for TriggerExecBodyType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + TriggerExecBodyType::Function => write!(f, "FUNCTION"), + TriggerExecBodyType::Procedure => write!(f, "PROCEDURE"), + } + } +} +/// This keyword immediately precedes the declaration of one or two relation names that provide access to the transition relations of the triggering statement +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct TriggerExecBody { + pub exec_type: TriggerExecBodyType, + pub func_desc: FunctionDesc, +} + +impl fmt::Display for TriggerExecBody { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{exec_type} {func_desc}", + exec_type = self.exec_type, + func_desc = self.func_desc + ) + } +} diff --git a/src/keywords.rs b/src/keywords.rs index c175da874..0c9d3dd6c 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -20,7 +20,7 @@ //! As a matter of fact, most of these keywords are not used at all //! and could be removed. //! 3) a `RESERVED_FOR_TABLE_ALIAS` array with keywords reserved in a -//! "table alias" context. +//! "table alias" context. #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -108,6 +108,7 @@ define_keywords!( AVRO, BACKWARD, BASE64, + BEFORE, BEGIN, BEGIN_FRAME, BEGIN_PARTITION, @@ -378,6 +379,7 @@ define_keywords!( INSENSITIVE, INSERT, INSTALL, + INSTEAD, INT, INT128, INT16, @@ -683,6 +685,7 @@ define_keywords!( STABLE, STAGE, START, + STATEMENT, STATIC, STATISTICS, STATUS, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 60a7b4d0b..5706df56c 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -3368,6 +3368,25 @@ impl<'a> Parser<'a> { Ok(values) } + /// Parse a keyword-separated list of 1+ items accepted by `F` + pub fn parse_keyword_separated( + &mut self, + keyword: Keyword, + mut f: F, + ) -> Result, ParserError> + where + F: FnMut(&mut Parser<'a>) -> Result, + { + let mut values = vec![]; + loop { + values.push(f(self)?); + if !self.parse_keyword(keyword) { + break; + } + } + Ok(values) + } + pub fn parse_parenthesized(&mut self, mut f: F) -> Result where F: FnMut(&mut Parser<'a>) -> Result, @@ -3471,6 +3490,10 @@ impl<'a> Parser<'a> { self.parse_create_external_table(or_replace) } else if self.parse_keyword(Keyword::FUNCTION) { self.parse_create_function(or_replace, temporary) + } else if self.parse_keyword(Keyword::TRIGGER) { + self.parse_create_trigger(or_replace, false) + } else if self.parse_keywords(&[Keyword::CONSTRAINT, Keyword::TRIGGER]) { + self.parse_create_trigger(or_replace, true) } else if self.parse_keyword(Keyword::MACRO) { self.parse_create_macro(or_replace, temporary) } else if self.parse_keyword(Keyword::SECRET) { @@ -4061,6 +4084,180 @@ impl<'a> Parser<'a> { }) } + /// Parse statements of the DropTrigger type such as: + /// + /// ```sql + /// DROP TRIGGER [ IF EXISTS ] name ON table_name [ CASCADE | RESTRICT ] + /// ``` + pub fn parse_drop_trigger(&mut self) -> Result { + if !dialect_of!(self is PostgreSqlDialect | GenericDialect) { + self.prev_token(); + return self.expected("an object type after DROP", self.peek_token()); + } + let if_exists = self.parse_keywords(&[Keyword::IF, Keyword::EXISTS]); + let trigger_name = self.parse_object_name(false)?; + self.expect_keyword(Keyword::ON)?; + let table_name = self.parse_object_name(false)?; + let option = self + .parse_one_of_keywords(&[Keyword::CASCADE, Keyword::RESTRICT]) + .map(|keyword| match keyword { + Keyword::CASCADE => ReferentialAction::Cascade, + Keyword::RESTRICT => ReferentialAction::Restrict, + _ => unreachable!(), + }); + Ok(Statement::DropTrigger { + if_exists, + trigger_name, + table_name, + option, + }) + } + + pub fn parse_create_trigger( + &mut self, + or_replace: bool, + is_constraint: bool, + ) -> Result { + if !dialect_of!(self is PostgreSqlDialect | GenericDialect) { + self.prev_token(); + return self.expected("an object type after CREATE", self.peek_token()); + } + + let name = self.parse_object_name(false)?; + let period = self.parse_trigger_period()?; + + let events = self.parse_keyword_separated(Keyword::OR, Parser::parse_trigger_event)?; + self.expect_keyword(Keyword::ON)?; + let table_name = self.parse_object_name(false)?; + + let referenced_table_name = if self.parse_keyword(Keyword::FROM) { + self.parse_object_name(true).ok() + } else { + None + }; + + let characteristics = self.parse_constraint_characteristics()?; + + let mut referencing = vec![]; + if self.parse_keyword(Keyword::REFERENCING) { + while let Some(refer) = self.parse_trigger_referencing()? { + referencing.push(refer); + } + } + + self.expect_keyword(Keyword::FOR)?; + let include_each = self.parse_keyword(Keyword::EACH); + let trigger_object = + match self.expect_one_of_keywords(&[Keyword::ROW, Keyword::STATEMENT])? { + Keyword::ROW => TriggerObject::Row, + Keyword::STATEMENT => TriggerObject::Statement, + _ => unreachable!(), + }; + + let condition = self + .parse_keyword(Keyword::WHEN) + .then(|| self.parse_expr()) + .transpose()?; + + self.expect_keyword(Keyword::EXECUTE)?; + + let exec_body = self.parse_trigger_exec_body()?; + + Ok(Statement::CreateTrigger { + or_replace, + is_constraint, + name, + period, + events, + table_name, + referenced_table_name, + referencing, + trigger_object, + include_each, + condition, + exec_body, + characteristics, + }) + } + + pub fn parse_trigger_period(&mut self) -> Result { + Ok( + match self.expect_one_of_keywords(&[ + Keyword::BEFORE, + Keyword::AFTER, + Keyword::INSTEAD, + ])? { + Keyword::BEFORE => TriggerPeriod::Before, + Keyword::AFTER => TriggerPeriod::After, + Keyword::INSTEAD => self + .expect_keyword(Keyword::OF) + .map(|_| TriggerPeriod::InsteadOf)?, + _ => unreachable!(), + }, + ) + } + + pub fn parse_trigger_event(&mut self) -> Result { + Ok( + match self.expect_one_of_keywords(&[ + Keyword::INSERT, + Keyword::UPDATE, + Keyword::DELETE, + Keyword::TRUNCATE, + ])? { + Keyword::INSERT => TriggerEvent::Insert, + Keyword::UPDATE => { + if self.parse_keyword(Keyword::OF) { + let cols = self.parse_comma_separated(|ident| { + Parser::parse_identifier(ident, false) + })?; + TriggerEvent::Update(cols) + } else { + TriggerEvent::Update(vec![]) + } + } + Keyword::DELETE => TriggerEvent::Delete, + Keyword::TRUNCATE => TriggerEvent::Truncate, + _ => unreachable!(), + }, + ) + } + + pub fn parse_trigger_referencing(&mut self) -> Result, ParserError> { + let refer_type = match self.parse_one_of_keywords(&[Keyword::OLD, Keyword::NEW]) { + Some(Keyword::OLD) if self.parse_keyword(Keyword::TABLE) => { + TriggerReferencingType::OldTable + } + Some(Keyword::NEW) if self.parse_keyword(Keyword::TABLE) => { + TriggerReferencingType::NewTable + } + _ => { + return Ok(None); + } + }; + + let is_as = self.parse_keyword(Keyword::AS); + let transition_relation_name = self.parse_object_name(false)?; + Ok(Some(TriggerReferencing { + refer_type, + is_as, + transition_relation_name, + })) + } + + pub fn parse_trigger_exec_body(&mut self) -> Result { + Ok(TriggerExecBody { + exec_type: match self + .expect_one_of_keywords(&[Keyword::FUNCTION, Keyword::PROCEDURE])? + { + Keyword::FUNCTION => TriggerExecBodyType::Function, + Keyword::PROCEDURE => TriggerExecBodyType::Procedure, + _ => unreachable!(), + }, + func_desc: self.parse_function_desc()?, + }) + } + pub fn parse_create_macro( &mut self, or_replace: bool, @@ -4509,9 +4706,11 @@ impl<'a> Parser<'a> { return self.parse_drop_procedure(); } else if self.parse_keyword(Keyword::SECRET) { return self.parse_drop_secret(temporary, persistent); + } else if self.parse_keyword(Keyword::TRIGGER) { + return self.parse_drop_trigger(); } else { return self.expected( - "TABLE, VIEW, INDEX, ROLE, SCHEMA, FUNCTION, PROCEDURE, STAGE or SEQUENCE after DROP", + "TABLE, VIEW, INDEX, ROLE, SCHEMA, FUNCTION, PROCEDURE, STAGE, TRIGGER, SECRET or SEQUENCE after DROP", self.peek_token(), ); }; @@ -4550,7 +4749,7 @@ impl<'a> Parser<'a> { /// ``` fn parse_drop_function(&mut self) -> Result { let if_exists = self.parse_keywords(&[Keyword::IF, Keyword::EXISTS]); - let func_desc = self.parse_comma_separated(Parser::parse_drop_function_desc)?; + let func_desc = self.parse_comma_separated(Parser::parse_function_desc)?; let option = match self.parse_one_of_keywords(&[Keyword::CASCADE, Keyword::RESTRICT]) { Some(Keyword::CASCADE) => Some(ReferentialAction::Cascade), Some(Keyword::RESTRICT) => Some(ReferentialAction::Restrict), @@ -4569,7 +4768,7 @@ impl<'a> Parser<'a> { /// ``` fn parse_drop_procedure(&mut self) -> Result { let if_exists = self.parse_keywords(&[Keyword::IF, Keyword::EXISTS]); - let proc_desc = self.parse_comma_separated(Parser::parse_drop_function_desc)?; + let proc_desc = self.parse_comma_separated(Parser::parse_function_desc)?; let option = match self.parse_one_of_keywords(&[Keyword::CASCADE, Keyword::RESTRICT]) { Some(Keyword::CASCADE) => Some(ReferentialAction::Cascade), Some(Keyword::RESTRICT) => Some(ReferentialAction::Restrict), @@ -4583,7 +4782,7 @@ impl<'a> Parser<'a> { }) } - fn parse_drop_function_desc(&mut self) -> Result { + fn parse_function_desc(&mut self) -> Result { let name = self.parse_object_name(false)?; let args = if self.consume_token(&Token::LParen) { @@ -4598,7 +4797,7 @@ impl<'a> Parser<'a> { None }; - Ok(DropFunctionDesc { name, args }) + Ok(FunctionDesc { name, args }) } /// See [DuckDB Docs](https://duckdb.org/docs/sql/statements/create_secret.html) for more details. @@ -5882,11 +6081,7 @@ impl<'a> Parser<'a> { pub fn parse_constraint_characteristics( &mut self, ) -> Result, ParserError> { - let mut cc = ConstraintCharacteristics { - deferrable: None, - initially: None, - enforced: None, - }; + let mut cc = ConstraintCharacteristics::default(); loop { if cc.deferrable.is_none() && self.parse_keywords(&[Keyword::NOT, Keyword::DEFERRABLE]) @@ -7285,6 +7480,7 @@ impl<'a> Parser<'a> { let field_defs = self.parse_click_house_tuple_def()?; Ok(DataType::Tuple(field_defs)) } + Keyword::TRIGGER => Ok(DataType::Trigger), _ => { self.prev_token(); let type_name = self.parse_object_name(false)?; diff --git a/src/test_utils.rs b/src/test_utils.rs index d9100d351..5c05ec996 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -124,6 +124,7 @@ impl TestedDialects { } let only_statement = statements.pop().unwrap(); + if !canonical.is_empty() { assert_eq!(canonical, only_statement.to_string()) } diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index f370748d2..2f9fe86c9 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -3623,7 +3623,7 @@ fn parse_drop_function() { pg().verified_stmt(sql), Statement::DropFunction { if_exists: true, - func_desc: vec![DropFunctionDesc { + func_desc: vec![FunctionDesc { name: ObjectName(vec![Ident { value: "test_func".to_string(), quote_style: None @@ -3639,7 +3639,7 @@ fn parse_drop_function() { pg().verified_stmt(sql), Statement::DropFunction { if_exists: true, - func_desc: vec![DropFunctionDesc { + func_desc: vec![FunctionDesc { name: ObjectName(vec![Ident { value: "test_func".to_string(), quote_style: None @@ -3664,7 +3664,7 @@ fn parse_drop_function() { Statement::DropFunction { if_exists: true, func_desc: vec![ - DropFunctionDesc { + FunctionDesc { name: ObjectName(vec![Ident { value: "test_func1".to_string(), quote_style: None @@ -3682,7 +3682,7 @@ fn parse_drop_function() { } ]), }, - DropFunctionDesc { + FunctionDesc { name: ObjectName(vec![Ident { value: "test_func2".to_string(), quote_style: None @@ -3713,7 +3713,7 @@ fn parse_drop_procedure() { pg().verified_stmt(sql), Statement::DropProcedure { if_exists: true, - proc_desc: vec![DropFunctionDesc { + proc_desc: vec![FunctionDesc { name: ObjectName(vec![Ident { value: "test_proc".to_string(), quote_style: None @@ -3729,7 +3729,7 @@ fn parse_drop_procedure() { pg().verified_stmt(sql), Statement::DropProcedure { if_exists: true, - proc_desc: vec![DropFunctionDesc { + proc_desc: vec![FunctionDesc { name: ObjectName(vec![Ident { value: "test_proc".to_string(), quote_style: None @@ -3754,7 +3754,7 @@ fn parse_drop_procedure() { Statement::DropProcedure { if_exists: true, proc_desc: vec![ - DropFunctionDesc { + FunctionDesc { name: ObjectName(vec![Ident { value: "test_proc1".to_string(), quote_style: None @@ -3772,7 +3772,7 @@ fn parse_drop_procedure() { } ]), }, - DropFunctionDesc { + FunctionDesc { name: ObjectName(vec![Ident { value: "test_proc2".to_string(), quote_style: None @@ -4455,6 +4455,478 @@ fn test_escaped_string_literal() { } } +#[test] +fn parse_create_simple_before_insert_trigger() { + let sql = "CREATE TRIGGER check_insert BEFORE INSERT ON accounts FOR EACH ROW EXECUTE FUNCTION check_account_insert"; + let expected = Statement::CreateTrigger { + or_replace: false, + is_constraint: false, + name: ObjectName(vec![Ident::new("check_insert")]), + period: TriggerPeriod::Before, + events: vec![TriggerEvent::Insert], + table_name: ObjectName(vec![Ident::new("accounts")]), + referenced_table_name: None, + referencing: vec![], + trigger_object: TriggerObject::Row, + include_each: true, + condition: None, + exec_body: TriggerExecBody { + exec_type: TriggerExecBodyType::Function, + func_desc: FunctionDesc { + name: ObjectName(vec![Ident::new("check_account_insert")]), + args: None, + }, + }, + characteristics: None, + }; + + assert_eq!(pg().verified_stmt(sql), expected); +} + +#[test] +fn parse_create_after_update_trigger_with_condition() { + let sql = "CREATE TRIGGER check_update AFTER UPDATE ON accounts FOR EACH ROW WHEN (NEW.balance > 10000) EXECUTE FUNCTION check_account_update"; + let expected = Statement::CreateTrigger { + or_replace: false, + is_constraint: false, + name: ObjectName(vec![Ident::new("check_update")]), + period: TriggerPeriod::After, + events: vec![TriggerEvent::Update(vec![])], + table_name: ObjectName(vec![Ident::new("accounts")]), + referenced_table_name: None, + referencing: vec![], + trigger_object: TriggerObject::Row, + include_each: true, + condition: Some(Expr::Nested(Box::new(Expr::BinaryOp { + left: Box::new(Expr::CompoundIdentifier(vec![ + Ident::new("NEW"), + Ident::new("balance"), + ])), + op: BinaryOperator::Gt, + right: Box::new(Expr::Value(number("10000"))), + }))), + exec_body: TriggerExecBody { + exec_type: TriggerExecBodyType::Function, + func_desc: FunctionDesc { + name: ObjectName(vec![Ident::new("check_account_update")]), + args: None, + }, + }, + characteristics: None, + }; + + assert_eq!(pg().verified_stmt(sql), expected); +} + +#[test] +fn parse_create_instead_of_delete_trigger() { + let sql = "CREATE TRIGGER check_delete INSTEAD OF DELETE ON accounts FOR EACH ROW EXECUTE FUNCTION check_account_deletes"; + let expected = Statement::CreateTrigger { + or_replace: false, + is_constraint: false, + name: ObjectName(vec![Ident::new("check_delete")]), + period: TriggerPeriod::InsteadOf, + events: vec![TriggerEvent::Delete], + table_name: ObjectName(vec![Ident::new("accounts")]), + referenced_table_name: None, + referencing: vec![], + trigger_object: TriggerObject::Row, + include_each: true, + condition: None, + exec_body: TriggerExecBody { + exec_type: TriggerExecBodyType::Function, + func_desc: FunctionDesc { + name: ObjectName(vec![Ident::new("check_account_deletes")]), + args: None, + }, + }, + characteristics: None, + }; + + assert_eq!(pg().verified_stmt(sql), expected); +} + +#[test] +fn parse_create_trigger_with_multiple_events_and_deferrable() { + let sql = "CREATE CONSTRAINT TRIGGER check_multiple_events BEFORE INSERT OR UPDATE OR DELETE ON accounts DEFERRABLE INITIALLY DEFERRED FOR EACH ROW EXECUTE FUNCTION check_account_changes"; + let expected = Statement::CreateTrigger { + or_replace: false, + is_constraint: true, + name: ObjectName(vec![Ident::new("check_multiple_events")]), + period: TriggerPeriod::Before, + events: vec![ + TriggerEvent::Insert, + TriggerEvent::Update(vec![]), + TriggerEvent::Delete, + ], + table_name: ObjectName(vec![Ident::new("accounts")]), + referenced_table_name: None, + referencing: vec![], + trigger_object: TriggerObject::Row, + include_each: true, + condition: None, + exec_body: TriggerExecBody { + exec_type: TriggerExecBodyType::Function, + func_desc: FunctionDesc { + name: ObjectName(vec![Ident::new("check_account_changes")]), + args: None, + }, + }, + characteristics: Some(ConstraintCharacteristics { + deferrable: Some(true), + initially: Some(DeferrableInitial::Deferred), + enforced: None, + }), + }; + + assert_eq!(pg().verified_stmt(sql), expected); +} + +#[test] +fn parse_create_trigger_with_referencing() { + let sql = "CREATE TRIGGER check_referencing BEFORE INSERT ON accounts REFERENCING NEW TABLE AS new_accounts OLD TABLE AS old_accounts FOR EACH ROW EXECUTE FUNCTION check_account_referencing"; + let expected = Statement::CreateTrigger { + or_replace: false, + is_constraint: false, + name: ObjectName(vec![Ident::new("check_referencing")]), + period: TriggerPeriod::Before, + events: vec![TriggerEvent::Insert], + table_name: ObjectName(vec![Ident::new("accounts")]), + referenced_table_name: None, + referencing: vec![ + TriggerReferencing { + refer_type: TriggerReferencingType::NewTable, + is_as: true, + transition_relation_name: ObjectName(vec![Ident::new("new_accounts")]), + }, + TriggerReferencing { + refer_type: TriggerReferencingType::OldTable, + is_as: true, + transition_relation_name: ObjectName(vec![Ident::new("old_accounts")]), + }, + ], + trigger_object: TriggerObject::Row, + include_each: true, + condition: None, + exec_body: TriggerExecBody { + exec_type: TriggerExecBodyType::Function, + func_desc: FunctionDesc { + name: ObjectName(vec![Ident::new("check_account_referencing")]), + args: None, + }, + }, + characteristics: None, + }; + + assert_eq!(pg().verified_stmt(sql), expected); +} + +#[test] +/// While in the parse_create_trigger test we test the full syntax of the CREATE TRIGGER statement, +/// here we test the invalid cases of the CREATE TRIGGER statement which should cause an appropriate +/// error to be returned. +fn parse_create_trigger_invalid_cases() { + // Test invalid cases for the CREATE TRIGGER statement + let invalid_cases = vec![ + ( + "CREATE TRIGGER check_update BEFORE UPDATE ON accounts FUNCTION check_account_update", + "Expected: FOR, found: FUNCTION" + ), + ( + "CREATE TRIGGER check_update TOMORROW UPDATE ON accounts EXECUTE FUNCTION check_account_update", + "Expected: one of BEFORE or AFTER or INSTEAD, found: TOMORROW" + ), + ( + "CREATE TRIGGER check_update BEFORE SAVE ON accounts EXECUTE FUNCTION check_account_update", + "Expected: one of INSERT or UPDATE or DELETE or TRUNCATE, found: SAVE" + ) + ]; + + for (sql, expected_error) in invalid_cases { + let res = pg().parse_sql_statements(sql); + assert_eq!( + format!("sql parser error: {expected_error}"), + res.unwrap_err().to_string() + ); + } +} + +#[test] +fn parse_drop_trigger() { + for if_exists in [true, false] { + for option in [ + None, + Some(ReferentialAction::Cascade), + Some(ReferentialAction::Restrict), + ] { + let sql = &format!( + "DROP TRIGGER{} check_update ON table_name{}", + if if_exists { " IF EXISTS" } else { "" }, + option + .map(|o| format!(" {}", o)) + .unwrap_or_else(|| "".to_string()) + ); + assert_eq!( + pg().verified_stmt(sql), + Statement::DropTrigger { + if_exists, + trigger_name: ObjectName(vec![Ident::new("check_update")]), + table_name: ObjectName(vec![Ident::new("table_name")]), + option + } + ); + } + } +} + +#[test] +fn parse_drop_trigger_invalid_cases() { + // Test invalid cases for the DROP TRIGGER statement + let invalid_cases = vec![ + ( + "DROP TRIGGER check_update ON table_name CASCADE RESTRICT", + "Expected: end of statement, found: RESTRICT", + ), + ( + "DROP TRIGGER check_update ON table_name CASCADE CASCADE", + "Expected: end of statement, found: CASCADE", + ), + ( + "DROP TRIGGER check_update ON table_name CASCADE CASCADE CASCADE", + "Expected: end of statement, found: CASCADE", + ), + ]; + + for (sql, expected_error) in invalid_cases { + let res = pg().parse_sql_statements(sql); + assert_eq!( + format!("sql parser error: {expected_error}"), + res.unwrap_err().to_string() + ); + } +} + +#[test] +fn parse_trigger_related_functions() { + // First we define all parts of the trigger definition, + // including the table creation, the function creation, the trigger creation and the trigger drop. + // The following example is taken from the PostgreSQL documentation + + let sql_table_creation = r#" + CREATE TABLE emp ( + empname text, + salary integer, + last_date timestamp, + last_user text + ); + "#; + + let sql_create_function = r#" + CREATE FUNCTION emp_stamp() RETURNS trigger AS $emp_stamp$ + BEGIN + -- Check that empname and salary are given + IF NEW.empname IS NULL THEN + RAISE EXCEPTION 'empname cannot be null'; + END IF; + IF NEW.salary IS NULL THEN + RAISE EXCEPTION '% cannot have null salary', NEW.empname; + END IF; + + -- Who works for us when they must pay for it? + IF NEW.salary < 0 THEN + RAISE EXCEPTION '% cannot have a negative salary', NEW.empname; + END IF; + + -- Remember who changed the payroll when + NEW.last_date := current_timestamp; + NEW.last_user := current_user; + RETURN NEW; + END; + $emp_stamp$ LANGUAGE plpgsql; + "#; + + let sql_create_trigger = r#" + CREATE TRIGGER emp_stamp BEFORE INSERT OR UPDATE ON emp + FOR EACH ROW EXECUTE FUNCTION emp_stamp(); + "#; + + let sql_drop_trigger = r#" + DROP TRIGGER emp_stamp ON emp; + "#; + + // Now we parse the statements and check if they are parsed correctly. + let mut statements = pg() + .parse_sql_statements(&format!( + "{}{}{}{}", + sql_table_creation, sql_create_function, sql_create_trigger, sql_drop_trigger + )) + .unwrap(); + + assert_eq!(statements.len(), 4); + let drop_trigger = statements.pop().unwrap(); + let create_trigger = statements.pop().unwrap(); + let create_function = statements.pop().unwrap(); + let create_table = statements.pop().unwrap(); + + // Check the first statement + let create_table = match create_table { + Statement::CreateTable(create_table) => create_table, + _ => panic!("Expected CreateTable statement"), + }; + + assert_eq!( + create_table, + CreateTable { + or_replace: false, + temporary: false, + external: false, + global: None, + if_not_exists: false, + transient: false, + volatile: false, + name: ObjectName(vec![Ident::new("emp")]), + columns: vec![ + ColumnDef { + name: "empname".into(), + data_type: DataType::Text, + collation: None, + options: vec![], + }, + ColumnDef { + name: "salary".into(), + data_type: DataType::Integer(None), + collation: None, + options: vec![], + }, + ColumnDef { + name: "last_date".into(), + data_type: DataType::Timestamp(None, TimezoneInfo::None), + collation: None, + options: vec![], + }, + ColumnDef { + name: "last_user".into(), + data_type: DataType::Text, + collation: None, + options: vec![], + }, + ], + constraints: vec![], + hive_distribution: HiveDistributionStyle::NONE, + hive_formats: Some(HiveFormat { + row_format: None, + serde_properties: None, + storage: None, + location: None + }), + table_properties: vec![], + with_options: vec![], + file_format: None, + location: None, + query: None, + without_rowid: false, + like: None, + clone: None, + engine: None, + comment: None, + auto_increment_offset: None, + default_charset: None, + collation: None, + on_commit: None, + on_cluster: None, + primary_key: None, + order_by: None, + partition_by: None, + cluster_by: None, + options: None, + strict: false, + copy_grants: false, + enable_schema_evolution: None, + change_tracking: None, + data_retention_time_in_days: None, + max_data_extension_time_in_days: None, + default_ddl_collation: None, + with_aggregation_policy: None, + with_row_access_policy: None, + with_tags: None, + } + ); + + // Check the second statement + + assert_eq!( + create_function, + Statement::CreateFunction { + or_replace: false, + temporary: false, + if_not_exists: false, + name: ObjectName(vec![Ident::new("emp_stamp")]), + args: None, + return_type: Some(DataType::Trigger), + function_body: Some( + CreateFunctionBody::AsBeforeOptions( + Expr::Value( + Value::DollarQuotedString( + DollarQuotedString { + value: "\n BEGIN\n -- Check that empname and salary are given\n IF NEW.empname IS NULL THEN\n RAISE EXCEPTION 'empname cannot be null';\n END IF;\n IF NEW.salary IS NULL THEN\n RAISE EXCEPTION '% cannot have null salary', NEW.empname;\n END IF;\n \n -- Who works for us when they must pay for it?\n IF NEW.salary < 0 THEN\n RAISE EXCEPTION '% cannot have a negative salary', NEW.empname;\n END IF;\n \n -- Remember who changed the payroll when\n NEW.last_date := current_timestamp;\n NEW.last_user := current_user;\n RETURN NEW;\n END;\n ".to_owned(), + tag: Some( + "emp_stamp".to_owned(), + ), + }, + ), + ), + ), + ), + behavior: None, + called_on_null: None, + parallel: None, + using: None, + language: Some(Ident::new("plpgsql")), + determinism_specifier: None, + options: None, + remote_connection: None + } + ); + + // Check the third statement + + assert_eq!( + create_trigger, + Statement::CreateTrigger { + or_replace: false, + is_constraint: false, + name: ObjectName(vec![Ident::new("emp_stamp")]), + period: TriggerPeriod::Before, + events: vec![TriggerEvent::Insert, TriggerEvent::Update(vec![])], + table_name: ObjectName(vec![Ident::new("emp")]), + referenced_table_name: None, + referencing: vec![], + trigger_object: TriggerObject::Row, + include_each: true, + condition: None, + exec_body: TriggerExecBody { + exec_type: TriggerExecBodyType::Function, + func_desc: FunctionDesc { + name: ObjectName(vec![Ident::new("emp_stamp")]), + args: None, + } + }, + characteristics: None + } + ); + + // Check the fourth statement + assert_eq!( + drop_trigger, + Statement::DropTrigger { + if_exists: false, + trigger_name: ObjectName(vec![Ident::new("emp_stamp")]), + table_name: ObjectName(vec![Ident::new("emp")]), + option: None + } + ); +} + #[test] fn test_unicode_string_literal() { let pairs = [