Skip to content

Suppor postgres TRUNCATE syntax #1406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 58 additions & 2 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2013,9 +2013,19 @@ pub enum Statement {
Truncate {
#[cfg_attr(feature = "visitor", visit(with = "visit_relation"))]
table_name: ObjectName,
table_names: Vec<ObjectName>,
partitions: Option<Vec<Expr>>,
/// TABLE - optional keyword;
table: bool,
/// Postgres-specific option
/// [ TRUNCATE TABLE ONLY ]
only: bool,
/// Postgres-specific option
/// [ RESTART IDENTITY | CONTINUE IDENTITY ]
identity: Option<TruncateIdentityOption>,
/// Postgres-specific option
/// [ CASCADE | RESTRICT ]
cascade: Option<TruncateCascadeOption>,
},
/// ```sql
/// MSCK
Expand Down Expand Up @@ -3131,12 +3141,38 @@ impl fmt::Display for Statement {
Ok(())
}
Statement::Truncate {
table_name,
table_name: _,
table_names,
partitions,
table,
only,
identity,
cascade,
} => {
let table = if *table { "TABLE " } else { "" };
write!(f, "TRUNCATE {table}{table_name}")?;
let only = if *only { "ONLY " } else { "" };

let table_names = table_names
.iter()
.map(|table_name| table_name.to_string()) // replace `to_string()` with the appropriate method if necessary
.collect::<Vec<String>>()
.join(", ");

write!(f, "TRUNCATE {table}{only}{table_names}")?;

if let Some(identity) = identity {
match identity {
TruncateIdentityOption::Restart => write!(f, " RESTART IDENTITY")?,
TruncateIdentityOption::Continue => write!(f, " CONTINUE IDENTITY")?,
}
}
if let Some(cascade) = cascade {
match cascade {
TruncateCascadeOption::Cascade => write!(f, " CASCADE")?,
TruncateCascadeOption::Restrict => write!(f, " RESTRICT")?,
}
}

if let Some(ref parts) = partitions {
if !parts.is_empty() {
write!(f, " PARTITION ({})", display_comma_separated(parts))?;
Expand Down Expand Up @@ -4587,6 +4623,26 @@ impl fmt::Display for SequenceOptions {
}
}

/// PostgreSQL identity option for TRUNCATE table
/// [ RESTART IDENTITY | CONTINUE IDENTITY ]
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum TruncateIdentityOption {
Restart,
Continue,
}

/// PostgreSQL cascade option for TRUNCATE table
/// [ CASCADE | RESTRICT ]
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum TruncateCascadeOption {
Cascade,
Restrict,
}

/// Can use to describe options in create sequence or table column type identity
/// [ MINVALUE minvalue | NO MINVALUE ] [ MAXVALUE maxvalue | NO MAXVALUE ]
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
Expand Down
1 change: 1 addition & 0 deletions src/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ define_keywords!(
CONNECTION,
CONSTRAINT,
CONTAINS,
CONTINUE,
CONVERT,
COPY,
COPY_OPTIONS,
Expand Down
34 changes: 33 additions & 1 deletion src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -681,17 +681,49 @@ impl<'a> Parser<'a> {

pub fn parse_truncate(&mut self) -> Result<Statement, ParserError> {
let table = self.parse_keyword(Keyword::TABLE);
let table_name = self.parse_object_name(false)?;
let only = self.parse_keyword(Keyword::ONLY);

let table_names = self.parse_comma_separated(|p| p.parse_object_name(false))?;

// Unwrap is safe - the preceding parse fails if there is not at least one table name
let table_name = table_names.first().unwrap().clone();

let mut partitions = None;
if self.parse_keyword(Keyword::PARTITION) {
self.expect_token(&Token::LParen)?;
partitions = Some(self.parse_comma_separated(Parser::parse_expr)?);
self.expect_token(&Token::RParen)?;
}

let mut identity = None;
let mut cascade = None;

if dialect_of!(self is PostgreSqlDialect | GenericDialect) {
identity = if self.parse_keywords(&[Keyword::RESTART, Keyword::IDENTITY]) {
Some(TruncateIdentityOption::Restart)
} else if self.parse_keywords(&[Keyword::CONTINUE, Keyword::IDENTITY]) {
Some(TruncateIdentityOption::Continue)
} else {
None
};

cascade = if self.parse_keyword(Keyword::CASCADE) {
Some(TruncateCascadeOption::Cascade)
} else if self.parse_keyword(Keyword::RESTRICT) {
Some(TruncateCascadeOption::Restrict)
} else {
None
};
};

Ok(Statement::Truncate {
table_name,
table_names,
partitions,
table,
only,
identity,
cascade,
})
}

Expand Down
76 changes: 71 additions & 5 deletions tests/sqlparser_postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,10 @@ fn parse_alter_table_constraints_rename() {
fn parse_alter_table_disable() {
pg_and_generic().verified_stmt("ALTER TABLE tab DISABLE ROW LEVEL SECURITY");
pg_and_generic().verified_stmt("ALTER TABLE tab DISABLE RULE rule_name");
}

#[test]
fn parse_alter_table_disable_trigger() {
pg_and_generic().verified_stmt("ALTER TABLE tab DISABLE TRIGGER ALL");
pg_and_generic().verified_stmt("ALTER TABLE tab DISABLE TRIGGER USER");
pg_and_generic().verified_stmt("ALTER TABLE tab DISABLE TRIGGER trigger_name");
Expand All @@ -589,6 +593,13 @@ fn parse_alter_table_enable() {
pg_and_generic().verified_stmt("ALTER TABLE tab ENABLE TRIGGER trigger_name");
}

#[test]
fn parse_truncate_table() {
pg_and_generic()
.verified_stmt("TRUNCATE TABLE \"users\", \"orders\" RESTART IDENTITY RESTRICT");
pg_and_generic().verified_stmt("TRUNCATE users, orders RESTART IDENTITY");
}

#[test]
fn parse_create_extension() {
pg_and_generic().verified_stmt("CREATE EXTENSION extension_name");
Expand Down Expand Up @@ -3953,11 +3964,66 @@ fn parse_select_group_by_cube() {
#[test]
fn parse_truncate() {
let truncate = pg_and_generic().verified_stmt("TRUNCATE db.table_name");
let table_name = ObjectName(vec![Ident::new("db"), Ident::new("table_name")]);
let table_names = vec![table_name.clone()];
assert_eq!(
Statement::Truncate {
table_name,
table_names,
partitions: None,
table: false,
only: false,
identity: None,
cascade: None,
},
truncate
);
}

#[test]
fn parse_truncate_with_options() {
let truncate = pg_and_generic()
.verified_stmt("TRUNCATE TABLE ONLY db.table_name RESTART IDENTITY CASCADE");

let table_name = ObjectName(vec![Ident::new("db"), Ident::new("table_name")]);
let table_names = vec![table_name.clone()];

assert_eq!(
Statement::Truncate {
table_name: ObjectName(vec![Ident::new("db"), Ident::new("table_name")]),
table_name,
table_names,
partitions: None,
table: false
table: true,
only: true,
identity: Some(TruncateIdentityOption::Restart),
cascade: Some(TruncateCascadeOption::Cascade)
},
truncate
);
}

#[test]
fn parse_truncate_with_table_list() {
let truncate = pg().verified_stmt(
"TRUNCATE TABLE db.table_name, db.other_table_name RESTART IDENTITY CASCADE",
);

let table_name = ObjectName(vec![Ident::new("db"), Ident::new("table_name")]);

let table_names = vec![
table_name.clone(),
ObjectName(vec![Ident::new("db"), Ident::new("other_table_name")]),
];

assert_eq!(
Statement::Truncate {
table_name,
table_names,
partitions: None,
table: true,
only: false,
identity: Some(TruncateIdentityOption::Restart),
cascade: Some(TruncateCascadeOption::Cascade)
},
truncate
);
Expand Down Expand Up @@ -4731,12 +4797,12 @@ fn parse_trigger_related_functions() {
IF NEW.salary IS NULL THEN
RAISE EXCEPTION '% cannot have null salary', NEW.empname;
END IF;

-- Who works for us when they must pay for it?
IF NEW.salary < 0 THEN
RAISE EXCEPTION '% cannot have a negative salary', NEW.empname;
END IF;

-- Remember who changed the payroll when
NEW.last_date := current_timestamp;
NEW.last_user := current_user;
Expand Down Expand Up @@ -4868,7 +4934,7 @@ fn parse_trigger_related_functions() {
Expr::Value(
Value::DollarQuotedString(
DollarQuotedString {
value: "\n BEGIN\n -- Check that empname and salary are given\n IF NEW.empname IS NULL THEN\n RAISE EXCEPTION 'empname cannot be null';\n END IF;\n IF NEW.salary IS NULL THEN\n RAISE EXCEPTION '% cannot have null salary', NEW.empname;\n END IF;\n \n -- Who works for us when they must pay for it?\n IF NEW.salary < 0 THEN\n RAISE EXCEPTION '% cannot have a negative salary', NEW.empname;\n END IF;\n \n -- Remember who changed the payroll when\n NEW.last_date := current_timestamp;\n NEW.last_user := current_user;\n RETURN NEW;\n END;\n ".to_owned(),
value: "\n BEGIN\n -- Check that empname and salary are given\n IF NEW.empname IS NULL THEN\n RAISE EXCEPTION 'empname cannot be null';\n END IF;\n IF NEW.salary IS NULL THEN\n RAISE EXCEPTION '% cannot have null salary', NEW.empname;\n END IF;\n\n -- Who works for us when they must pay for it?\n IF NEW.salary < 0 THEN\n RAISE EXCEPTION '% cannot have a negative salary', NEW.empname;\n END IF;\n\n -- Remember who changed the payroll when\n NEW.last_date := current_timestamp;\n NEW.last_user := current_user;\n RETURN NEW;\n END;\n ".to_owned(),
tag: Some(
"emp_stamp".to_owned(),
),
Expand Down
Loading