diff --git a/src/ast/helpers/mod.rs b/src/ast/helpers/mod.rs new file mode 100644 index 000000000..41098b014 --- /dev/null +++ b/src/ast/helpers/mod.rs @@ -0,0 +1 @@ +pub mod stmt_create_table; diff --git a/src/ast/helpers/stmt_create_table.rs b/src/ast/helpers/stmt_create_table.rs new file mode 100644 index 000000000..97c567b83 --- /dev/null +++ b/src/ast/helpers/stmt_create_table.rs @@ -0,0 +1,323 @@ +#[cfg(not(feature = "std"))] +use alloc::{boxed::Box, format, string::String, vec, vec::Vec}; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use crate::ast::{ + ColumnDef, FileFormat, HiveDistributionStyle, HiveFormat, ObjectName, OnCommit, Query, + SqlOption, Statement, TableConstraint, +}; +use crate::parser::ParserError; + +/// Builder for create table statement variant ([1]). +/// +/// This structure helps building and accessing a create table with more ease, without needing to: +/// - Match the enum itself a lot of times; or +/// - Moving a lot of variables around the code. +/// +/// # Example +/// ```rust +/// use sqlparser::ast::helpers::stmt_create_table::CreateTableBuilder; +/// use sqlparser::ast::{ColumnDef, DataType, Ident, ObjectName}; +/// let builder = CreateTableBuilder::new(ObjectName(vec![Ident::new("table_name")])) +/// .if_not_exists(true) +/// .columns(vec![ColumnDef { +/// name: Ident::new("c1"), +/// data_type: DataType::Int(None), +/// collation: None, +/// options: vec![], +/// }]); +/// // You can access internal elements with ease +/// assert!(builder.if_not_exists); +/// // Convert to a statement +/// assert_eq!( +/// builder.build().to_string(), +/// "CREATE TABLE IF NOT EXISTS table_name (c1 INT)" +/// ) +/// ``` +/// +/// [1]: crate::ast::Statement::CreateTable +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct CreateTableBuilder { + pub or_replace: bool, + pub temporary: bool, + pub external: bool, + pub global: Option, + pub if_not_exists: bool, + pub name: ObjectName, + pub columns: Vec, + pub constraints: Vec, + pub hive_distribution: HiveDistributionStyle, + pub hive_formats: Option, + pub table_properties: Vec, + pub with_options: Vec, + pub file_format: Option, + pub location: Option, + pub query: Option>, + pub without_rowid: bool, + pub like: Option, + pub clone: Option, + pub engine: Option, + pub default_charset: Option, + pub collation: Option, + pub on_commit: Option, + pub on_cluster: Option, +} + +impl CreateTableBuilder { + pub fn new(name: ObjectName) -> Self { + Self { + or_replace: false, + temporary: false, + external: false, + global: None, + if_not_exists: false, + name, + columns: vec![], + constraints: vec![], + hive_distribution: HiveDistributionStyle::NONE, + hive_formats: None, + table_properties: vec![], + with_options: vec![], + file_format: None, + location: None, + query: None, + without_rowid: false, + like: None, + clone: None, + engine: None, + default_charset: None, + collation: None, + on_commit: None, + on_cluster: None, + } + } + pub fn or_replace(mut self, or_replace: bool) -> Self { + self.or_replace = or_replace; + self + } + + pub fn temporary(mut self, temporary: bool) -> Self { + self.temporary = temporary; + self + } + + pub fn external(mut self, external: bool) -> Self { + self.external = external; + self + } + + pub fn global(mut self, global: Option) -> Self { + self.global = global; + self + } + + pub fn if_not_exists(mut self, if_not_exists: bool) -> Self { + self.if_not_exists = if_not_exists; + self + } + + pub fn columns(mut self, columns: Vec) -> Self { + self.columns = columns; + self + } + + pub fn constraints(mut self, constraints: Vec) -> Self { + self.constraints = constraints; + self + } + + pub fn hive_distribution(mut self, hive_distribution: HiveDistributionStyle) -> Self { + self.hive_distribution = hive_distribution; + self + } + + pub fn hive_formats(mut self, hive_formats: Option) -> Self { + self.hive_formats = hive_formats; + self + } + + pub fn table_properties(mut self, table_properties: Vec) -> Self { + self.table_properties = table_properties; + self + } + + pub fn with_options(mut self, with_options: Vec) -> Self { + self.with_options = with_options; + self + } + pub fn file_format(mut self, file_format: Option) -> Self { + self.file_format = file_format; + self + } + pub fn location(mut self, location: Option) -> Self { + self.location = location; + self + } + + pub fn query(mut self, query: Option>) -> Self { + self.query = query; + self + } + pub fn without_rowid(mut self, without_rowid: bool) -> Self { + self.without_rowid = without_rowid; + self + } + + pub fn like(mut self, like: Option) -> Self { + self.like = like; + self + } + + // Different name to allow the object to be cloned + pub fn clone_clause(mut self, clone: Option) -> Self { + self.clone = clone; + self + } + + pub fn engine(mut self, engine: Option) -> Self { + self.engine = engine; + self + } + + pub fn default_charset(mut self, default_charset: Option) -> Self { + self.default_charset = default_charset; + self + } + + pub fn collation(mut self, collation: Option) -> Self { + self.collation = collation; + self + } + + pub fn on_commit(mut self, on_commit: Option) -> Self { + self.on_commit = on_commit; + self + } + + pub fn on_cluster(mut self, on_cluster: Option) -> Self { + self.on_cluster = on_cluster; + self + } + + pub fn build(self) -> Statement { + Statement::CreateTable { + or_replace: self.or_replace, + temporary: self.temporary, + external: self.external, + global: self.global, + if_not_exists: self.if_not_exists, + name: self.name, + columns: self.columns, + constraints: self.constraints, + hive_distribution: self.hive_distribution, + hive_formats: self.hive_formats, + table_properties: self.table_properties, + with_options: self.with_options, + file_format: self.file_format, + location: self.location, + query: self.query, + without_rowid: self.without_rowid, + like: self.like, + clone: self.clone, + engine: self.engine, + default_charset: self.default_charset, + collation: self.collation, + on_commit: self.on_commit, + on_cluster: self.on_cluster, + } + } +} + +impl TryFrom for CreateTableBuilder { + type Error = ParserError; + + // As the builder can be transformed back to a statement, it shouldn't be a problem to take the + // ownership. + fn try_from(stmt: Statement) -> Result { + match stmt { + Statement::CreateTable { + or_replace, + temporary, + external, + global, + if_not_exists, + name, + columns, + constraints, + hive_distribution, + hive_formats, + table_properties, + with_options, + file_format, + location, + query, + without_rowid, + like, + clone, + engine, + default_charset, + collation, + on_commit, + on_cluster, + } => Ok(Self { + or_replace, + temporary, + external, + global, + if_not_exists, + name, + columns, + constraints, + hive_distribution, + hive_formats, + table_properties, + with_options, + file_format, + location, + query, + without_rowid, + like, + clone, + engine, + default_charset, + collation, + on_commit, + on_cluster, + }), + _ => Err(ParserError::ParserError(format!( + "Expected create table statement, but received: {stmt}" + ))), + } + } +} + +#[cfg(test)] +mod tests { + use crate::ast::helpers::stmt_create_table::CreateTableBuilder; + use crate::ast::{Ident, ObjectName, Statement}; + use crate::parser::ParserError; + + #[test] + pub fn test_from_valid_statement() { + let builder = CreateTableBuilder::new(ObjectName(vec![Ident::new("table_name")])); + + let stmt = builder.clone().build(); + + assert_eq!(builder, CreateTableBuilder::try_from(stmt).unwrap()); + } + + #[test] + pub fn test_from_invalid_statement() { + let stmt = Statement::Commit { chain: false }; + + assert_eq!( + CreateTableBuilder::try_from(stmt).unwrap_err(), + ParserError::ParserError( + "Expected create table statement, but received: COMMIT".to_owned() + ) + ); + } +} diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 4bc6ad1e9..342bd28cf 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -13,6 +13,7 @@ //! SQL Abstract Syntax Tree (AST) types mod data_type; mod ddl; +pub mod helpers; mod operator; mod query; mod value; diff --git a/src/ast/value.rs b/src/ast/value.rs index 2d273d031..3861ab008 100644 --- a/src/ast/value.rs +++ b/src/ast/value.rs @@ -10,8 +10,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#[cfg(not(feature = "std"))] -use alloc::boxed::Box; #[cfg(not(feature = "std"))] use alloc::string::String; use core::fmt; diff --git a/src/parser.rs b/src/parser.rs index 36c577b05..fb473b74f 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -27,6 +27,7 @@ use log::debug; use IsLateral::*; use IsOptional::*; +use crate::ast::helpers::stmt_create_table::CreateTableBuilder; use crate::ast::*; use crate::dialect::*; use crate::keywords::{self, Keyword}; @@ -2032,31 +2033,18 @@ impl<'a> Parser<'a> { }; let location = hive_formats.location.clone(); let table_properties = self.parse_options(Keyword::TBLPROPERTIES)?; - Ok(Statement::CreateTable { - name: table_name, - columns, - constraints, - hive_distribution, - hive_formats: Some(hive_formats), - with_options: vec![], - table_properties, - or_replace, - if_not_exists, - external: true, - global: None, - temporary: false, - file_format, - location, - query: None, - without_rowid: false, - like: None, - clone: None, - default_charset: None, - engine: None, - collation: None, - on_commit: None, - on_cluster: None, - }) + Ok(CreateTableBuilder::new(table_name) + .columns(columns) + .constraints(constraints) + .hive_distribution(hive_distribution) + .hive_formats(Some(hive_formats)) + .table_properties(table_properties) + .or_replace(or_replace) + .if_not_exists(if_not_exists) + .external(true) + .file_format(file_format) + .location(location) + .build()) } pub fn parse_file_format(&mut self) -> Result { @@ -2667,31 +2655,27 @@ impl<'a> Parser<'a> { None }; - Ok(Statement::CreateTable { - name: table_name, - temporary, - columns, - constraints, - with_options, - table_properties, - or_replace, - if_not_exists, - hive_distribution, - hive_formats: Some(hive_formats), - external: false, - global, - file_format: None, - location: None, - query, - without_rowid, - like, - clone, - engine, - default_charset, - collation, - on_commit, - on_cluster, - }) + Ok(CreateTableBuilder::new(table_name) + .temporary(temporary) + .columns(columns) + .constraints(constraints) + .with_options(with_options) + .table_properties(table_properties) + .or_replace(or_replace) + .if_not_exists(if_not_exists) + .hive_distribution(hive_distribution) + .hive_formats(Some(hive_formats)) + .global(global) + .query(query) + .without_rowid(without_rowid) + .like(like) + .clone_clause(clone) + .engine(engine) + .default_charset(default_charset) + .collation(collation) + .on_commit(on_commit) + .on_cluster(on_cluster) + .build()) } pub fn parse_columns(&mut self) -> Result<(Vec, Vec), ParserError> { diff --git a/src/test_utils.rs b/src/test_utils.rs index c3d60ee62..d51aec1ae 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -15,7 +15,7 @@ /// on this module, as it will change without notice. // // Integration tests (i.e. everything under `tests/`) import this -// via `tests/test_utils/mod.rs`. +// via `tests/test_utils/helpers`. #[cfg(not(feature = "std"))] use alloc::{