diff --git a/Cargo.toml b/Cargo.toml index c7f0760cb..9760b81ed 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "sqlparser" description = "Extensible SQL Lexer and Parser with support for ANSI SQL:2011" -version = "0.2.5-alpha.0" +version = "0.3.0" authors = ["Andy Grove "] homepage = "https://github.com/andygrove/sqlparser-rs" documentation = "https://docs.rs/sqlparser/" @@ -21,3 +21,6 @@ path = "src/lib.rs" log = "0.4.5" chrono = "0.4.6" uuid = "0.7.1" + +[dev-dependencies] +simple_logger = "1.0.1" diff --git a/README.md b/README.md index e62473667..dc6b17899 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ println!("AST: {:?}", ast); This outputs ```rust -AST: SQLSelect { projection: [SQLIdentifier("a"), SQLIdentifier("b"), SQLLiteralLong(123), SQLFunction { id: "myfunc", args: [SQLIdentifier("b")] }], relation: Some(SQLIdentifier("table_1")), selection: Some(SQLBinaryExpr { left: SQLBinaryExpr { left: SQLIdentifier("a"), op: Gt, right: SQLIdentifier("b") }, op: And, right: SQLBinaryExpr { left: SQLIdentifier("b"), op: Lt, right: SQLLiteralLong(100) } }), order_by: Some([SQLOrderBy { expr: SQLIdentifier("a"), asc: false }, SQLOrderBy { expr: SQLIdentifier("b"), asc: true }]), group_by: None, having: None, limit: None } +AST: [SQLSelect(SQLSelect { projection: [SQLIdentifier("a"), SQLIdentifier("b"), SQLValue(Long(123)), SQLFunction { id: "myfunc", args: [SQLIdentifier("b")] }], relation: Some(Table { name: SQLObjectName(["table_1"]), alias: None }), joins: [], selection: Some(SQLBinaryExpr { left: SQLBinaryExpr { left: SQLIdentifier("a"), op: Gt, right: SQLIdentifier("b") }, op: And, right: SQLBinaryExpr { left: SQLIdentifier("b"), op: Lt, right: SQLValue(Long(100)) } }), order_by: Some([SQLOrderByExpr { expr: SQLIdentifier("a"), asc: Some(false) }, SQLOrderByExpr { expr: SQLIdentifier("b"), asc: None }]), group_by: None, having: None, limit: None })] ``` ## Design diff --git a/examples/cli.rs b/examples/cli.rs new file mode 100644 index 000000000..545e8218f --- /dev/null +++ b/examples/cli.rs @@ -0,0 +1,46 @@ +extern crate simple_logger; +extern crate sqlparser; +///! A small command-line app to run the parser. +/// Run with `cargo run --example cli` +use std::fs; + +use sqlparser::dialect::GenericSqlDialect; +use sqlparser::sqlparser::Parser; + +fn main() { + simple_logger::init().unwrap(); + + let filename = std::env::args() + .nth(1) + .expect("No arguments provided!\n\nUsage: cargo run --example cli FILENAME.sql"); + + let contents = + fs::read_to_string(&filename).expect(&format!("Unable to read the file {}", &filename)); + let without_bom = if contents.chars().nth(0).unwrap() as u64 != 0xfeff { + contents.as_str() + } else { + let mut chars = contents.chars(); + chars.next(); + chars.as_str() + }; + println!("Input:\n'{}'", &without_bom); + let parse_result = Parser::parse_sql(&GenericSqlDialect {}, without_bom.to_owned()); + match parse_result { + Ok(statements) => { + println!( + "Round-trip:\n'{}'", + statements + .iter() + .map(|s| s.to_string()) + .collect::>() + .join("\n") + ); + println!("Parse results:\n{:#?}", statements); + std::process::exit(0); + } + Err(e) => { + println!("Error during parsing: {:?}", e); + std::process::exit(1); + } + } +} diff --git a/src/dialect/ansi_sql.rs b/src/dialect/ansi_sql.rs index b91fdc6e9..4026cf61c 100644 --- a/src/dialect/ansi_sql.rs +++ b/src/dialect/ansi_sql.rs @@ -1,339 +1,8 @@ use dialect::Dialect; -use dialect::keywords::*; - pub struct AnsiSqlDialect {} impl Dialect for AnsiSqlDialect { - fn keywords(&self) -> Vec<&'static str> { - return vec![ - ABS, - ALL, - ALLOCATE, - ALTER, - AND, - ANY, - ARE, - ARRAY, - ARRAY_AGG, - ARRAY_MAX_CARDINALITY, - AS, - ASENSITIVE, - ASYMMETRIC, - AT, - ATOMIC, - AUTHORIZATION, - AVG, - BEGIN, - BEGIN_FRAME, - BEGIN_PARTITION, - BETWEEN, - BIGINT, - BINARY, - BLOB, - BOOLEAN, - BOTH, - BY, - CALL, - CALLED, - CARDINALITY, - CASCADED, - CASE, - CAST, - CEIL, - CEILING, - CHAR, - CHAR_LENGTH, - CHARACTER, - CHARACTER_LENGTH, - CHECK, - CLOB, - CLOSE, - COALESCE, - COLLATE, - COLLECT, - COLUMN, - COMMIT, - CONDITION, - CONNECT, - CONSTRAINT, - CONTAINS, - CONVERT, - CORR, - CORRESPONDING, - COUNT, - COVAR_POP, - COVAR_SAMP, - CREATE, - CROSS, - CUBE, - CUME_DIST, - CURRENT, - CURRENT_CATALOG, - CURRENT_DATE, - CURRENT_DEFAULT_TRANSFORM_GROUP, - CURRENT_PATH, - CURRENT_ROLE, - CURRENT_ROW, - CURRENT_SCHEMA, - CURRENT_TIME, - CURRENT_TIMESTAMP, - CURRENT_TRANSFORM_GROUP_FOR_TYPE, - CURRENT_USER, - CURSOR, - CYCLE, - DATE, - DAY, - DEALLOCATE, - DEC, - DECIMAL, - DECLARE, - DEFAULT, - DELETE, - DENSE_RANK, - DEREF, - DESCRIBE, - DETERMINISTIC, - DISCONNECT, - DISTINCT, - DOUBLE, - DROP, - DYNAMIC, - EACH, - ELEMENT, - ELSE, - END, - END_FRAME, - END_PARTITION, - END_EXEC, - EQUALS, - ESCAPE, - EVERY, - EXCEPT, - EXEC, - EXECUTE, - EXISTS, - EXP, - EXTERNAL, - EXTRACT, - FALSE, - FETCH, - FILTER, - FIRST_VALUE, - FLOAT, - FLOOR, - FOR, - FOREIGN, - FRAME_ROW, - FREE, - FROM, - FULL, - FUNCTION, - FUSION, - GET, - GLOBAL, - GRANT, - GROUP, - GROUPING, - GROUPS, - HAVING, - HOLD, - HOUR, - IDENTITY, - IN, - INDICATOR, - INNER, - INOUT, - INSENSITIVE, - INSERT, - INT, - INTEGER, - INTERSECT, - INTERSECTION, - INTERVAL, - INTO, - IS, - JOIN, - LAG, - LANGUAGE, - LARGE, - LAST_VALUE, - LATERAL, - LEAD, - LEADING, - LEFT, - LIKE, - LIKE_REGEX, - LN, - LOCAL, - LOCALTIME, - LOCALTIMESTAMP, - LOWER, - MATCH, - MAX, - MEMBER, - MERGE, - METHOD, - MIN, - MINUTE, - MOD, - MODIFIES, - MODULE, - MONTH, - MULTISET, - NATIONAL, - NATURAL, - NCHAR, - NCLOB, - NEW, - NO, - NONE, - NORMALIZE, - NOT, - NTH_VALUE, - NTILE, - NULL, - NULLIF, - NUMERIC, - OCTET_LENGTH, - OCCURRENCES_REGEX, - OF, - OFFSET, - OLD, - ON, - ONLY, - OPEN, - OR, - ORDER, - OUT, - OUTER, - OVER, - OVERLAPS, - OVERLAY, - PARAMETER, - PARTITION, - PERCENT, - PERCENT_RANK, - PERCENTILE_CONT, - PERCENTILE_DISC, - PERIOD, - PORTION, - POSITION, - POSITION_REGEX, - POWER, - PRECEDES, - PRECISION, - PREPARE, - PRIMARY, - PROCEDURE, - RANGE, - RANK, - READS, - REAL, - RECURSIVE, - REF, - REFERENCES, - REFERENCING, - REGR_AVGX, - REGR_AVGY, - REGR_COUNT, - REGR_INTERCEPT, - REGR_R2, - REGR_SLOPE, - REGR_SXX, - REGR_SXY, - REGR_SYY, - RELEASE, - RESULT, - RETURN, - RETURNS, - REVOKE, - RIGHT, - ROLLBACK, - ROLLUP, - ROW, - ROW_NUMBER, - ROWS, - SAVEPOINT, - SCOPE, - SCROLL, - SEARCH, - SECOND, - SELECT, - SENSITIVE, - SESSION_USER, - SET, - SIMILAR, - SMALLINT, - SOME, - SPECIFIC, - SPECIFICTYPE, - SQL, - SQLEXCEPTION, - SQLSTATE, - SQLWARNING, - SQRT, - START, - STATIC, - STDDEV_POP, - STDDEV_SAMP, - SUBMULTISET, - SUBSTRING, - SUBSTRING_REGEX, - SUCCEEDS, - SUM, - SYMMETRIC, - SYSTEM, - SYSTEM_TIME, - SYSTEM_USER, - TABLE, - TABLESAMPLE, - THEN, - TIME, - TIMESTAMP, - TIMEZONE_HOUR, - TIMEZONE_MINUTE, - TO, - TRAILING, - TRANSLATE, - TRANSLATE_REGEX, - TRANSLATION, - TREAT, - TRIGGER, - TRUNCATE, - TRIM, - TRIM_ARRAY, - TRUE, - UESCAPE, - UNION, - UNIQUE, - UNKNOWN, - UNNEST, - UPDATE, - UPPER, - USER, - USING, - VALUE, - VALUES, - VALUE_OF, - VAR_POP, - VAR_SAMP, - VARBINARY, - VARCHAR, - VARYING, - VERSIONING, - WHEN, - WHENEVER, - WHERE, - WIDTH_BUCKET, - WINDOW, - WITH, - WITHIN, - WITHOUT, - YEAR, - ]; - } - fn is_identifier_start(&self, ch: char) -> bool { (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') } diff --git a/src/dialect/generic_sql.rs b/src/dialect/generic_sql.rs index 0f18b7234..54275d69f 100644 --- a/src/dialect/generic_sql.rs +++ b/src/dialect/generic_sql.rs @@ -1,21 +1,7 @@ use dialect::Dialect; - -use dialect::keywords::*; pub struct GenericSqlDialect {} impl Dialect for GenericSqlDialect { - fn keywords(&self) -> Vec<&'static str> { - return vec![ - SELECT, FROM, WHERE, LIMIT, ORDER, GROUP, BY, HAVING, UNION, ALL, INSERT, INTO, UPDATE, - DELETE, IN, IS, NULL, SET, CREATE, EXTERNAL, TABLE, ASC, DESC, AND, OR, NOT, AS, - STORED, CSV, PARQUET, LOCATION, WITH, WITHOUT, HEADER, ROW, // SQL types - CHAR, CHARACTER, VARYING, LARGE, OBJECT, VARCHAR, CLOB, BINARY, VARBINARY, BLOB, FLOAT, - REAL, DOUBLE, PRECISION, INT, INTEGER, SMALLINT, BIGINT, NUMERIC, DECIMAL, DEC, - BOOLEAN, DATE, TIME, TIMESTAMP, CASE, WHEN, THEN, ELSE, END, JOIN, LEFT, RIGHT, FULL, - CROSS, OUTER, INNER, NATURAL, ON, USING, LIKE, - ]; - } - fn is_identifier_start(&self, ch: char) -> bool { (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '@' } diff --git a/src/dialect/keywords.rs b/src/dialect/keywords.rs index e46837243..e270e8a9f 100644 --- a/src/dialect/keywords.rs +++ b/src/dialect/keywords.rs @@ -1,12 +1,23 @@ -/// make a listing of keywords -/// with static str and their stringified value +///! This module defines +/// 1) a list of constants for every keyword that +/// can appear in SQLWord::keyword: +/// pub const KEYWORD = "KEYWORD" +/// 2) an `ALL_KEYWORDS` array with every keyword in it +/// This is not a list of *reserved* keywords: some of these can be +/// parsed as identifiers if the parser decides so. This means that +/// new keywords can be added here without affecting the parse result. +/// +/// As a matter of fact, most of these keywords are not used at all +/// and could be removed. +/// 3) a `RESERVED_FOR_TABLE_ALIAS` array with keywords reserved in a +/// "table alias" context. + macro_rules! keyword { ($($ident:ident),*) => { - $(pub static $ident: &'static str = stringify!($ident);)* + $(pub const $ident: &'static str = stringify!($ident);)* } } -/// enumerate all the keywords here for all dialects to support in this project keyword!( ABS, ADD, @@ -180,6 +191,7 @@ keyword!( LOCATION, LOWER, MATCH, + MATERIALIZED, MAX, MEMBER, MERGE, @@ -339,6 +351,7 @@ keyword!( VARCHAR, VARYING, VERSIONING, + VIEW, WHEN, WHENEVER, WHERE, @@ -352,4 +365,369 @@ keyword!( ); /// special case of keyword where the it is an invalid identifier -pub static END_EXEC: &'static str = "END-EXEC"; +pub const END_EXEC: &'static str = "END-EXEC"; + +pub const ALL_KEYWORDS: &'static [&'static str] = &[ + ABS, + ADD, + ASC, + ALL, + ALLOCATE, + ALTER, + AND, + ANY, + ARE, + ARRAY, + ARRAY_AGG, + ARRAY_MAX_CARDINALITY, + AS, + ASENSITIVE, + ASYMMETRIC, + AT, + ATOMIC, + AUTHORIZATION, + AVG, + BEGIN, + BEGIN_FRAME, + BEGIN_PARTITION, + BETWEEN, + BIGINT, + BINARY, + BLOB, + BOOLEAN, + BOTH, + BY, + BYTEA, + CALL, + CALLED, + CARDINALITY, + CASCADED, + CASE, + CAST, + CEIL, + CEILING, + CHAR, + CHAR_LENGTH, + CHARACTER, + CHARACTER_LENGTH, + CHECK, + CLOB, + CLOSE, + COALESCE, + COLLATE, + COLLECT, + COLUMN, + COMMIT, + CONDITION, + CONNECT, + CONSTRAINT, + CONTAINS, + CONVERT, + COPY, + CORR, + CORRESPONDING, + COUNT, + COVAR_POP, + COVAR_SAMP, + CREATE, + CROSS, + CSV, + CUBE, + CUME_DIST, + CURRENT, + CURRENT_CATALOG, + CURRENT_DATE, + CURRENT_DEFAULT_TRANSFORM_GROUP, + CURRENT_PATH, + CURRENT_ROLE, + CURRENT_ROW, + CURRENT_SCHEMA, + CURRENT_TIME, + CURRENT_TIMESTAMP, + CURRENT_TRANSFORM_GROUP_FOR_TYPE, + CURRENT_USER, + CURSOR, + CYCLE, + DATE, + DAY, + DEALLOCATE, + DEC, + DECIMAL, + DECLARE, + DEFAULT, + DELETE, + DENSE_RANK, + DEREF, + DESC, + DESCRIBE, + DETERMINISTIC, + DISCONNECT, + DISTINCT, + DOUBLE, + DROP, + DYNAMIC, + EACH, + ELEMENT, + ELSE, + END, + END_FRAME, + END_PARTITION, + EQUALS, + ESCAPE, + EVERY, + EXCEPT, + EXEC, + EXECUTE, + EXISTS, + EXP, + EXTERNAL, + EXTRACT, + FALSE, + FETCH, + FILTER, + FIRST_VALUE, + FLOAT, + FLOOR, + FOR, + FOREIGN, + FRAME_ROW, + FREE, + FROM, + FULL, + FUNCTION, + FUSION, + GET, + GLOBAL, + GRANT, + GROUP, + GROUPING, + GROUPS, + HAVING, + HEADER, + HOLD, + HOUR, + IDENTITY, + IN, + INDICATOR, + INNER, + INOUT, + INSENSITIVE, + INSERT, + INT, + INTEGER, + INTERSECT, + INTERSECTION, + INTERVAL, + INTO, + IS, + JOIN, + KEY, + LAG, + LANGUAGE, + LARGE, + LAST_VALUE, + LATERAL, + LEAD, + LEADING, + LEFT, + LIKE, + LIKE_REGEX, + LIMIT, + LN, + LOCAL, + LOCALTIME, + LOCALTIMESTAMP, + LOCATION, + LOWER, + MATCH, + MATERIALIZED, + MAX, + MEMBER, + MERGE, + METHOD, + MIN, + MINUTE, + MOD, + MODIFIES, + MODULE, + MONTH, + MULTISET, + NATIONAL, + NATURAL, + NCHAR, + NCLOB, + NEW, + NO, + NONE, + NORMALIZE, + NOT, + NTH_VALUE, + NTILE, + NULL, + NULLIF, + NUMERIC, + OBJECT, + OCTET_LENGTH, + OCCURRENCES_REGEX, + OF, + OFFSET, + OLD, + ON, + ONLY, + OPEN, + OR, + ORDER, + OUT, + OUTER, + OVER, + OVERLAPS, + OVERLAY, + PARAMETER, + PARTITION, + PARQUET, + PERCENT, + PERCENT_RANK, + PERCENTILE_CONT, + PERCENTILE_DISC, + PERIOD, + PORTION, + POSITION, + POSITION_REGEX, + POWER, + PRECEDES, + PRECISION, + PREPARE, + PRIMARY, + PROCEDURE, + RANGE, + RANK, + READS, + REAL, + RECURSIVE, + REF, + REFERENCES, + REFERENCING, + REGCLASS, + REGR_AVGX, + REGR_AVGY, + REGR_COUNT, + REGR_INTERCEPT, + REGR_R2, + REGR_SLOPE, + REGR_SXX, + REGR_SXY, + REGR_SYY, + RELEASE, + RESULT, + RETURN, + RETURNS, + REVOKE, + RIGHT, + ROLLBACK, + ROLLUP, + ROW, + ROW_NUMBER, + ROWS, + SAVEPOINT, + SCOPE, + SCROLL, + SEARCH, + SECOND, + SELECT, + SENSITIVE, + SESSION_USER, + SET, + SIMILAR, + SMALLINT, + SOME, + SPECIFIC, + SPECIFICTYPE, + SQL, + SQLEXCEPTION, + SQLSTATE, + SQLWARNING, + SQRT, + START, + STATIC, + STDDEV_POP, + STDDEV_SAMP, + STDIN, + STORED, + SUBMULTISET, + SUBSTRING, + SUBSTRING_REGEX, + SUCCEEDS, + SUM, + SYMMETRIC, + SYSTEM, + SYSTEM_TIME, + SYSTEM_USER, + TABLE, + TABLESAMPLE, + TEXT, + THEN, + TIME, + TIMESTAMP, + TIMEZONE_HOUR, + TIMEZONE_MINUTE, + TO, + TRAILING, + TRANSLATE, + TRANSLATE_REGEX, + TRANSLATION, + TREAT, + TRIGGER, + TRUNCATE, + TRIM, + TRIM_ARRAY, + TRUE, + UESCAPE, + UNION, + UNIQUE, + UNKNOWN, + UNNEST, + UPDATE, + UPPER, + USER, + USING, + UUID, + VALUE, + VALUES, + VALUE_OF, + VAR_POP, + VAR_SAMP, + VARBINARY, + VARCHAR, + VARYING, + VERSIONING, + VIEW, + WHEN, + WHENEVER, + WHERE, + WIDTH_BUCKET, + WINDOW, + WITH, + WITHIN, + WITHOUT, + YEAR, + ZONE, + END_EXEC, +]; + +/// These keywords can't be used as a table alias, so that `FROM table_name alias` +/// can be parsed unambiguously without looking ahead. +pub const RESERVED_FOR_TABLE_ALIAS: &'static [&'static str] = &[ + // Reserved as both a table and a column alias: + WITH, SELECT, WHERE, GROUP, ORDER, UNION, EXCEPT, INTERSECT, + // Reserved only as a table alias in the `FROM`/`JOIN` clauses: + ON, JOIN, INNER, CROSS, FULL, LEFT, RIGHT, NATURAL, USING, +]; + +/// Can't be used as a column alias, so that `SELECT alias` +/// can be parsed unambiguously without looking ahead. +pub const RESERVED_FOR_COLUMN_ALIAS: &'static [&'static str] = &[ + // Reserved as both a table and a column alias: + WITH, SELECT, WHERE, GROUP, ORDER, UNION, EXCEPT, INTERSECT, + // Reserved only as a column alias in the `SELECT` clause: + FROM, +]; diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index 1a704f000..95ecf7924 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -8,10 +8,16 @@ pub use self::generic_sql::GenericSqlDialect; pub use self::postgresql::PostgreSqlDialect; pub trait Dialect { - /// Get a list of keywords for this dialect - fn keywords(&self) -> Vec<&'static str>; - /// Determine if a character is a valid identifier start character + /// Determine if a character starts a quoted identifier. The default + /// implementation, accepting "double quoted" ids is both ANSI-compliant + /// and appropriate for most dialects (with the notable exception of + /// MySQL, MS SQL, and sqlite). You can accept one of characters listed + /// in `SQLWord::matching_end_quote()` here + fn is_delimited_identifier_start(&self, ch: char) -> bool { + ch == '"' + } + /// Determine if a character is a valid start character for an unquoted identifier fn is_identifier_start(&self, ch: char) -> bool; - /// Determine if a character is a valid identifier character + /// Determine if a character is a valid unquoted identifier character fn is_identifier_part(&self, ch: char) -> bool; } diff --git a/src/dialect/postgresql.rs b/src/dialect/postgresql.rs index 66cb51c19..2b64c1f0a 100644 --- a/src/dialect/postgresql.rs +++ b/src/dialect/postgresql.rs @@ -1,24 +1,8 @@ use dialect::Dialect; -use dialect::keywords::*; - pub struct PostgreSqlDialect {} impl Dialect for PostgreSqlDialect { - fn keywords(&self) -> Vec<&'static str> { - return vec![ - ALTER, ONLY, SELECT, FROM, WHERE, LIMIT, ORDER, GROUP, BY, HAVING, UNION, ALL, INSERT, - INTO, UPDATE, DELETE, IN, IS, NULL, SET, CREATE, EXTERNAL, TABLE, ASC, DESC, AND, OR, - NOT, AS, STORED, CSV, WITH, WITHOUT, ROW, // SQL types - CHAR, CHARACTER, VARYING, LARGE, VARCHAR, CLOB, BINARY, VARBINARY, BLOB, FLOAT, REAL, - DOUBLE, PRECISION, INT, INTEGER, SMALLINT, BIGINT, NUMERIC, DECIMAL, DEC, BOOLEAN, - DATE, TIME, TIMESTAMP, VALUES, DEFAULT, ZONE, REGCLASS, TEXT, BYTEA, TRUE, FALSE, COPY, - STDIN, PRIMARY, KEY, UNIQUE, UUID, ADD, CONSTRAINT, FOREIGN, REFERENCES, CASE, WHEN, - THEN, ELSE, END, JOIN, LEFT, RIGHT, FULL, CROSS, OUTER, INNER, NATURAL, ON, USING, - LIKE, - ]; - } - fn is_identifier_start(&self, ch: char) -> bool { (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '@' } diff --git a/src/sqlast/mod.rs b/src/sqlast/mod.rs index 54b650a84..fe30586b4 100644 --- a/src/sqlast/mod.rs +++ b/src/sqlast/mod.rs @@ -14,32 +14,63 @@ //! SQL Abstract Syntax Tree (AST) types +mod query; mod sql_operator; mod sqltype; mod table_key; mod value; +pub use self::query::{ + Cte, Join, JoinConstraint, JoinOperator, SQLOrderByExpr, SQLQuery, SQLSelect, SQLSelectItem, + SQLSetExpr, SQLSetOperator, TableFactor, +}; pub use self::sqltype::SQLType; pub use self::table_key::{AlterOperation, Key, TableKey}; pub use self::value::Value; pub use self::sql_operator::SQLOperator; -/// SQL Abstract Syntax Tree (AST) +/// Identifier name, in the originally quoted form (e.g. `"id"`) +pub type SQLIdent = String; + +/// Represents a parsed SQL expression, which is a common building +/// block of SQL statements (the part after SELECT, WHERE, etc.) #[derive(Debug, Clone, PartialEq)] pub enum ASTNode { /// Identifier e.g. table name or column name - SQLIdentifier(String), - /// Wildcard e.g. `*` + SQLIdentifier(SQLIdent), + /// Unqualified wildcard (`*`). SQL allows this in limited contexts (such as right + /// after `SELECT` or as part of an aggregate function, e.g. `COUNT(*)`, but we + /// currently accept it in contexts where it doesn't make sense, such as `* + *` SQLWildcard, + /// Qualified wildcard, e.g. `alias.*` or `schema.table.*`. + /// (Same caveats apply to SQLQualifiedWildcard as to SQLWildcard.) + SQLQualifiedWildcard(Vec), /// Multi part identifier e.g. `myschema.dbo.mytable` - SQLCompoundIdentifier(Vec), - /// Assigment e.g. `name = 'Fred'` in an UPDATE statement - SQLAssignment(SQLAssignment), + SQLCompoundIdentifier(Vec), /// `IS NULL` expression SQLIsNull(Box), /// `IS NOT NULL` expression SQLIsNotNull(Box), + /// `[ NOT ] IN (val1, val2, ...)` + SQLInList { + expr: Box, + list: Vec, + negated: bool, + }, + /// `[ NOT ] IN (SELECT ...)` + SQLInSubquery { + expr: Box, + subquery: Box, + negated: bool, + }, + /// [ NOT ] BETWEEN AND + SQLBetween { + expr: Box, + negated: bool, + low: Box, + high: Box, + }, /// Binary expression e.g. `1 + 1` or `foo > bar` SQLBinaryExpr { left: Box, @@ -61,7 +92,8 @@ pub enum ASTNode { /// SQLValue SQLValue(Value), /// Scalar function call e.g. `LEFT(foo, 5)` - SQLFunction { id: String, args: Vec }, + /// TODO: this can be a compound SQLObjectName as well (for UDFs) + SQLFunction { id: SQLIdent, args: Vec }, /// CASE [] WHEN THEN ... [ELSE ] END SQLCase { // TODO: support optional operand for "simple case" @@ -69,71 +101,9 @@ pub enum ASTNode { results: Vec, else_result: Option>, }, - /// SELECT - SQLSelect { - /// projection expressions - projection: Vec, - /// FROM - relation: Option>, - // JOIN - joins: Vec, - /// WHERE - selection: Option>, - /// ORDER BY - order_by: Option>, - /// GROUP BY - group_by: Option>, - /// HAVING - having: Option>, - /// LIMIT - limit: Option>, - }, - /// INSERT - SQLInsert { - /// TABLE - table_name: String, - /// COLUMNS - columns: Vec, - /// VALUES (vector of rows to insert) - values: Vec>, - }, - SQLCopy { - /// TABLE - table_name: String, - /// COLUMNS - columns: Vec, - /// VALUES a vector of values to be copied - values: Vec>, - }, - /// UPDATE - SQLUpdate { - /// TABLE - table_name: String, - /// Column assignments - assignments: Vec, - /// WHERE - selection: Option>, - }, - /// DELETE - SQLDelete { - /// FROM - relation: Option>, - /// WHERE - selection: Option>, - }, - /// CREATE TABLE - SQLCreateTable { - /// Table name - name: String, - /// Optional schema - columns: Vec, - }, - /// ALTER TABLE - SQLAlterTable { - /// Table name - name: String, - operation: AlterOperation, - }, + /// A parenthesized subquery `(SELECT ...)`, used in expression like + /// `SELECT (subquery) AS x` or `WHERE (subquery) = x` + SQLSubquery(Box), } impl ToString for ASTNode { @@ -141,10 +111,45 @@ impl ToString for ASTNode { match self { ASTNode::SQLIdentifier(s) => s.to_string(), ASTNode::SQLWildcard => "*".to_string(), + ASTNode::SQLQualifiedWildcard(q) => q.join(".") + "*", ASTNode::SQLCompoundIdentifier(s) => s.join("."), - ASTNode::SQLAssignment(ass) => ass.to_string(), ASTNode::SQLIsNull(ast) => format!("{} IS NULL", ast.as_ref().to_string()), ASTNode::SQLIsNotNull(ast) => format!("{} IS NOT NULL", ast.as_ref().to_string()), + ASTNode::SQLInList { + expr, + list, + negated, + } => format!( + "{} {}IN ({})", + expr.as_ref().to_string(), + if *negated { "NOT " } else { "" }, + list.iter() + .map(|a| a.to_string()) + .collect::>() + .join(", ") + ), + ASTNode::SQLInSubquery { + expr, + subquery, + negated, + } => format!( + "{} {}IN ({})", + expr.as_ref().to_string(), + if *negated { "NOT " } else { "" }, + subquery.to_string() + ), + ASTNode::SQLBetween { + expr, + negated, + low, + high, + } => format!( + "{} {}BETWEEN {} AND {}", + expr.to_string(), + if *negated { "NOT " } else { "" }, + low.to_string(), + high.to_string() + ), ASTNode::SQLBinaryExpr { left, op, right } => format!( "{} {} {}", left.as_ref().to_string(), @@ -188,67 +193,81 @@ impl ToString for ASTNode { } s + " END" } - ASTNode::SQLSelect { - projection, - relation, - joins, - selection, - order_by, - group_by, - having, - limit, - } => { - let mut s = format!( - "SELECT {}", - projection - .iter() - .map(|p| p.to_string()) - .collect::>() - .join(", ") - ); - if let Some(relation) = relation { - s += &format!(" FROM {}", relation.as_ref().to_string()); - } - for join in joins { - s += &join.to_string(); - } - if let Some(selection) = selection { - s += &format!(" WHERE {}", selection.as_ref().to_string()); - } - if let Some(group_by) = group_by { - s += &format!( - " GROUP BY {}", - group_by - .iter() - .map(|g| g.to_string()) - .collect::>() - .join(", ") - ); - } - if let Some(having) = having { - s += &format!(" HAVING {}", having.as_ref().to_string()); - } - if let Some(order_by) = order_by { - s += &format!( - " ORDER BY {}", - order_by - .iter() - .map(|o| o.to_string()) - .collect::>() - .join(", ") - ); - } - if let Some(limit) = limit { - s += &format!(" LIMIT {}", limit.as_ref().to_string()); - } - s - } - ASTNode::SQLInsert { + ASTNode::SQLSubquery(s) => format!("({})", s.to_string()), + } + } +} + +/// A top-level statement (SELECT, INSERT, CREATE, etc.) +#[derive(Debug, Clone, PartialEq)] +pub enum SQLStatement { + /// SELECT + SQLSelect(SQLQuery), + /// INSERT + SQLInsert { + /// TABLE + table_name: SQLObjectName, + /// COLUMNS + columns: Vec, + /// VALUES (vector of rows to insert) + values: Vec>, + }, + SQLCopy { + /// TABLE + table_name: SQLObjectName, + /// COLUMNS + columns: Vec, + /// VALUES a vector of values to be copied + values: Vec>, + }, + /// UPDATE + SQLUpdate { + /// TABLE + table_name: SQLObjectName, + /// Column assignments + assignments: Vec, + /// WHERE + selection: Option, + }, + /// DELETE + SQLDelete { + /// FROM + table_name: SQLObjectName, + /// WHERE + selection: Option, + }, + /// CREATE VIEW + SQLCreateView { + /// View name + name: SQLObjectName, + query: SQLQuery, + materialized: bool, + }, + /// CREATE TABLE + SQLCreateTable { + /// Table name + name: SQLObjectName, + /// Optional schema + columns: Vec, + }, + /// ALTER TABLE + SQLAlterTable { + /// Table name + name: SQLObjectName, + operation: AlterOperation, + }, +} + +impl ToString for SQLStatement { + fn to_string(&self) -> String { + match self { + SQLStatement::SQLSelect(s) => s.to_string(), + SQLStatement::SQLInsert { table_name, columns, values, } => { - let mut s = format!("INSERT INTO {}", table_name); + let mut s = format!("INSERT INTO {}", table_name.to_string()); if columns.len() > 0 { s += &format!(" ({})", columns.join(", ")); } @@ -268,12 +287,12 @@ impl ToString for ASTNode { } s } - ASTNode::SQLCopy { + SQLStatement::SQLCopy { table_name, columns, values, } => { - let mut s = format!("COPY {}", table_name); + let mut s = format!("COPY {}", table_name.to_string()); if columns.len() > 0 { s += &format!( " ({})", @@ -298,12 +317,12 @@ impl ToString for ASTNode { s += "\n\\."; s } - ASTNode::SQLUpdate { + SQLStatement::SQLUpdate { table_name, assignments, selection, } => { - let mut s = format!("UPDATE {}", table_name); + let mut s = format!("UPDATE {}", table_name.to_string()); if assignments.len() > 0 { s += &format!( "{}", @@ -315,84 +334,80 @@ impl ToString for ASTNode { ); } if let Some(selection) = selection { - s += &format!(" WHERE {}", selection.as_ref().to_string()); + s += &format!(" WHERE {}", selection.to_string()); } s } - ASTNode::SQLDelete { - relation, + SQLStatement::SQLDelete { + table_name, selection, } => { - let mut s = String::from("DELETE"); - if let Some(relation) = relation { - s += &format!(" FROM {}", relation.as_ref().to_string()); - } + let mut s = format!("DELETE FROM {}", table_name.to_string()); if let Some(selection) = selection { - s += &format!(" WHERE {}", selection.as_ref().to_string()); + s += &format!(" WHERE {}", selection.to_string()); } s } - ASTNode::SQLCreateTable { name, columns } => format!( - "CREATE TABLE {} ({})", + SQLStatement::SQLCreateView { name, + query, + materialized, + } => { + let modifier = if *materialized { " MATERIALIZED" } else { "" }; + format!( + "CREATE{} VIEW {} AS {}", + modifier, + name.to_string(), + query.to_string() + ) + } + SQLStatement::SQLCreateTable { name, columns } => format!( + "CREATE TABLE {} ({})", + name.to_string(), columns .iter() .map(|c| c.to_string()) .collect::>() .join(", ") ), - ASTNode::SQLAlterTable { name, operation } => { - format!("ALTER TABLE {} {}", name, operation.to_string()) + SQLStatement::SQLAlterTable { name, operation } => { + format!("ALTER TABLE {} {}", name.to_string(), operation.to_string()) } } } } -/// SQL assignment `foo = expr` as used in SQLUpdate -/// TODO: unify this with the ASTNode SQLAssignment +/// A name of a table, view, custom type, etc., possibly multi-part, i.e. db.schema.obj #[derive(Debug, Clone, PartialEq)] -pub struct SQLAssignment { - id: String, - value: Box, -} +pub struct SQLObjectName(pub Vec); -impl ToString for SQLAssignment { +impl ToString for SQLObjectName { fn to_string(&self) -> String { - format!("SET {} = {}", self.id, self.value.as_ref().to_string()) + self.0.join(".") } } -/// SQL ORDER BY expression +/// SQL assignment `foo = expr` as used in SQLUpdate #[derive(Debug, Clone, PartialEq)] -pub struct SQLOrderByExpr { - pub expr: Box, - pub asc: bool, -} - -impl SQLOrderByExpr { - pub fn new(expr: Box, asc: bool) -> Self { - SQLOrderByExpr { expr, asc } - } +pub struct SQLAssignment { + id: SQLIdent, + value: ASTNode, } -impl ToString for SQLOrderByExpr { +impl ToString for SQLAssignment { fn to_string(&self) -> String { - if self.asc { - format!("{} ASC", self.expr.as_ref().to_string()) - } else { - format!("{} DESC", self.expr.as_ref().to_string()) - } + format!("SET {} = {}", self.id, self.value.to_string()) } } /// SQL column definition #[derive(Debug, Clone, PartialEq)] pub struct SQLColumnDef { - pub name: String, + pub name: SQLIdent, pub data_type: SQLType, pub is_primary: bool, pub is_unique: bool, - pub default: Option>, + pub default: Option, pub allow_null: bool, } @@ -406,7 +421,7 @@ impl ToString for SQLColumnDef { s += " UNIQUE"; } if let Some(ref default) = self.default { - s += &format!(" DEFAULT {}", default.as_ref().to_string()); + s += &format!(" DEFAULT {}", default.to_string()); } if !self.allow_null { s += " NOT NULL"; @@ -414,72 +429,3 @@ impl ToString for SQLColumnDef { s } } - -#[derive(Debug, Clone, PartialEq)] -pub struct Join { - pub relation: ASTNode, - pub join_operator: JoinOperator, -} - -impl ToString for Join { - fn to_string(&self) -> String { - fn prefix(constraint: &JoinConstraint) -> String { - match constraint { - JoinConstraint::Natural => "NATURAL ".to_string(), - _ => "".to_string(), - } - } - fn suffix(constraint: &JoinConstraint) -> String { - match constraint { - JoinConstraint::On(expr) => format!("ON {}", expr.to_string()), - JoinConstraint::Using(attrs) => format!("USING({})", attrs.join(", ")), - _ => "".to_string(), - } - } - match &self.join_operator { - JoinOperator::Inner(constraint) => format!( - " {}JOIN {} {}", - prefix(constraint), - self.relation.to_string(), - suffix(constraint) - ), - JoinOperator::Cross => format!(" CROSS JOIN {}", self.relation.to_string()), - JoinOperator::Implicit => format!(", {}", self.relation.to_string()), - JoinOperator::LeftOuter(constraint) => format!( - " {}LEFT JOIN {} {}", - prefix(constraint), - self.relation.to_string(), - suffix(constraint) - ), - JoinOperator::RightOuter(constraint) => format!( - " {}RIGHT JOIN {} {}", - prefix(constraint), - self.relation.to_string(), - suffix(constraint) - ), - JoinOperator::FullOuter(constraint) => format!( - " {}FULL JOIN {} {}", - prefix(constraint), - self.relation.to_string(), - suffix(constraint) - ), - } - } -} - -#[derive(Debug, Clone, PartialEq)] -pub enum JoinOperator { - Inner(JoinConstraint), - LeftOuter(JoinConstraint), - RightOuter(JoinConstraint), - FullOuter(JoinConstraint), - Implicit, - Cross, -} - -#[derive(Debug, Clone, PartialEq)] -pub enum JoinConstraint { - On(ASTNode), - Using(Vec), - Natural, -} diff --git a/src/sqlast/query.rs b/src/sqlast/query.rs new file mode 100644 index 000000000..69577e557 --- /dev/null +++ b/src/sqlast/query.rs @@ -0,0 +1,309 @@ +use super::*; + +/// The most complete variant of a `SELECT` query expression, optionally +/// including `WITH`, `UNION` / other set operations, and `ORDER BY`. +#[derive(Debug, Clone, PartialEq)] +pub struct SQLQuery { + /// WITH (common table expressions, or CTEs) + pub ctes: Vec, + /// SELECT or UNION / EXCEPT / INTECEPT + pub body: SQLSetExpr, + /// ORDER BY + pub order_by: Option>, + /// LIMIT + pub limit: Option, +} + +impl ToString for SQLQuery { + fn to_string(&self) -> String { + let mut s = String::new(); + if !self.ctes.is_empty() { + s += &format!( + "WITH {} ", + self.ctes + .iter() + .map(|cte| format!("{} AS ({})", cte.alias, cte.query.to_string())) + .collect::>() + .join(", ") + ) + } + s += &self.body.to_string(); + if let Some(ref order_by) = self.order_by { + s += &format!( + " ORDER BY {}", + order_by + .iter() + .map(|o| o.to_string()) + .collect::>() + .join(", ") + ); + } + if let Some(ref limit) = self.limit { + s += &format!(" LIMIT {}", limit.to_string()); + } + s + } +} + +/// A node in a tree, representing a "query body" expression, roughly: +/// `SELECT ... [ {UNION|EXCEPT|INTERSECT} SELECT ...]` +#[derive(Debug, Clone, PartialEq)] +pub enum SQLSetExpr { + /// Restricted SELECT .. FROM .. HAVING (no ORDER BY or set operations) + Select(SQLSelect), + /// Parenthesized SELECT subquery, which may include more set operations + /// in its body and an optional ORDER BY / LIMIT. + Query(Box), + /// UNION/EXCEPT/INTERSECT of two queries + SetOperation { + op: SQLSetOperator, + all: bool, + left: Box, + right: Box, + }, + // TODO: ANSI SQL supports `TABLE` and `VALUES` here. +} + +impl ToString for SQLSetExpr { + fn to_string(&self) -> String { + match self { + SQLSetExpr::Select(s) => s.to_string(), + SQLSetExpr::Query(q) => format!("({})", q.to_string()), + SQLSetExpr::SetOperation { + left, + right, + op, + all, + } => { + let all_str = if *all { " ALL" } else { "" }; + format!( + "{} {}{} {}", + left.to_string(), + op.to_string(), + all_str, + right.to_string() + ) + } + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum SQLSetOperator { + Union, + Except, + Intersect, +} + +impl ToString for SQLSetOperator { + fn to_string(&self) -> String { + match self { + SQLSetOperator::Union => "UNION".to_string(), + SQLSetOperator::Except => "EXCEPT".to_string(), + SQLSetOperator::Intersect => "INTERSECT".to_string(), + } + } +} + +/// A restricted variant of `SELECT` (without CTEs/`ORDER BY`), which may +/// appear either as the only body item of an `SQLQuery`, or as an operand +/// to a set operation like `UNION`. +#[derive(Debug, Clone, PartialEq)] +pub struct SQLSelect { + /// projection expressions + pub projection: Vec, + /// FROM + pub relation: Option, + /// JOIN + pub joins: Vec, + /// WHERE + pub selection: Option, + /// GROUP BY + pub group_by: Option>, + /// HAVING + pub having: Option, +} + +impl ToString for SQLSelect { + fn to_string(&self) -> String { + let mut s = format!( + "SELECT {}", + self.projection + .iter() + .map(|p| p.to_string()) + .collect::>() + .join(", ") + ); + if let Some(ref relation) = self.relation { + s += &format!(" FROM {}", relation.to_string()); + } + for join in &self.joins { + s += &join.to_string(); + } + if let Some(ref selection) = self.selection { + s += &format!(" WHERE {}", selection.to_string()); + } + if let Some(ref group_by) = self.group_by { + s += &format!( + " GROUP BY {}", + group_by + .iter() + .map(|g| g.to_string()) + .collect::>() + .join(", ") + ); + } + if let Some(ref having) = self.having { + s += &format!(" HAVING {}", having.to_string()); + } + s + } +} + +/// A single CTE (used after `WITH`): `alias AS ( query )` +#[derive(Debug, Clone, PartialEq)] +pub struct Cte { + pub alias: SQLIdent, + pub query: SQLQuery, +} + +/// One item of the comma-separated list following `SELECT` +#[derive(Debug, Clone, PartialEq)] +pub enum SQLSelectItem { + /// Any expression, not followed by `[ AS ] alias` + UnnamedExpression(ASTNode), + /// An expression, followed by `[ AS ] alias` + ExpressionWithAlias(ASTNode, SQLIdent), + /// `alias.*` or even `schema.table.*` + QualifiedWildcard(SQLObjectName), + /// An unqualified `*` + Wildcard, +} + +impl ToString for SQLSelectItem { + fn to_string(&self) -> String { + match &self { + SQLSelectItem::UnnamedExpression(expr) => expr.to_string(), + SQLSelectItem::ExpressionWithAlias(expr, alias) => { + format!("{} AS {}", expr.to_string(), alias) + } + SQLSelectItem::QualifiedWildcard(prefix) => format!("{}.*", prefix.to_string()), + SQLSelectItem::Wildcard => "*".to_string(), + } + } +} + +/// A table name or a parenthesized subquery with an optional alias +#[derive(Debug, Clone, PartialEq)] +pub enum TableFactor { + Table { + name: SQLObjectName, + alias: Option, + }, + Derived { + subquery: Box, + alias: Option, + }, +} + +impl ToString for TableFactor { + fn to_string(&self) -> String { + let (base, alias) = match self { + TableFactor::Table { name, alias } => (name.to_string(), alias), + TableFactor::Derived { subquery, alias } => { + (format!("({})", subquery.to_string()), alias) + } + }; + if let Some(alias) = alias { + format!("{} AS {}", base, alias) + } else { + base + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Join { + pub relation: TableFactor, + pub join_operator: JoinOperator, +} + +impl ToString for Join { + fn to_string(&self) -> String { + fn prefix(constraint: &JoinConstraint) -> String { + match constraint { + JoinConstraint::Natural => "NATURAL ".to_string(), + _ => "".to_string(), + } + } + fn suffix(constraint: &JoinConstraint) -> String { + match constraint { + JoinConstraint::On(expr) => format!("ON {}", expr.to_string()), + JoinConstraint::Using(attrs) => format!("USING({})", attrs.join(", ")), + _ => "".to_string(), + } + } + match &self.join_operator { + JoinOperator::Inner(constraint) => format!( + " {}JOIN {} {}", + prefix(constraint), + self.relation.to_string(), + suffix(constraint) + ), + JoinOperator::Cross => format!(" CROSS JOIN {}", self.relation.to_string()), + JoinOperator::Implicit => format!(", {}", self.relation.to_string()), + JoinOperator::LeftOuter(constraint) => format!( + " {}LEFT JOIN {} {}", + prefix(constraint), + self.relation.to_string(), + suffix(constraint) + ), + JoinOperator::RightOuter(constraint) => format!( + " {}RIGHT JOIN {} {}", + prefix(constraint), + self.relation.to_string(), + suffix(constraint) + ), + JoinOperator::FullOuter(constraint) => format!( + " {}FULL JOIN {} {}", + prefix(constraint), + self.relation.to_string(), + suffix(constraint) + ), + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum JoinOperator { + Inner(JoinConstraint), + LeftOuter(JoinConstraint), + RightOuter(JoinConstraint), + FullOuter(JoinConstraint), + Implicit, + Cross, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum JoinConstraint { + On(ASTNode), + Using(Vec), + Natural, +} + +/// SQL ORDER BY expression +#[derive(Debug, Clone, PartialEq)] +pub struct SQLOrderByExpr { + pub expr: ASTNode, + pub asc: Option, +} + +impl ToString for SQLOrderByExpr { + fn to_string(&self) -> String { + match self.asc { + Some(true) => format!("{} ASC", self.expr.to_string()), + Some(false) => format!("{} DESC", self.expr.to_string()), + None => self.expr.to_string(), + } + } +} diff --git a/src/sqlast/sqltype.rs b/src/sqlast/sqltype.rs index c81313ab7..06a092033 100644 --- a/src/sqlast/sqltype.rs +++ b/src/sqlast/sqltype.rs @@ -1,3 +1,5 @@ +use super::SQLObjectName; + /// SQL datatypes for literals in SQL statements #[derive(Debug, Clone, PartialEq)] pub enum SQLType { @@ -15,8 +17,8 @@ pub enum SQLType { Varbinary(usize), /// Large binary object e.g. BLOB(1000) Blob(usize), - /// Decimal type with precision and optional scale e.g. DECIMAL(10,2) - Decimal(usize, Option), + /// Decimal type with optional precision and scale e.g. DECIMAL(10,2) + Decimal(Option, Option), /// Small integer SmallInt, /// Integer @@ -44,7 +46,7 @@ pub enum SQLType { /// Bytea Bytea, /// Custom type such as enums - Custom(String), + Custom(SQLObjectName), /// Arrays Array(Box), } @@ -73,9 +75,13 @@ impl ToString for SQLType { SQLType::Blob(size) => format!("blob({})", size), SQLType::Decimal(precision, scale) => { if let Some(scale) = scale { - format!("numeric({},{})", precision, scale) + format!("numeric({},{})", precision.unwrap(), scale) } else { - format!("numeric({})", precision) + if let Some(precision) = precision { + format!("numeric({})", precision) + } else { + format!("numeric") + } } } SQLType::SmallInt => "smallint".to_string(), diff --git a/src/sqlast/table_key.rs b/src/sqlast/table_key.rs index 9dacc21b3..6b1078e59 100644 --- a/src/sqlast/table_key.rs +++ b/src/sqlast/table_key.rs @@ -1,7 +1,9 @@ +use super::{SQLIdent, SQLObjectName}; + #[derive(Debug, PartialEq, Clone)] pub enum AlterOperation { AddConstraint(TableKey), - RemoveConstraint { name: String }, + RemoveConstraint { name: SQLIdent }, } impl ToString for AlterOperation { @@ -17,8 +19,8 @@ impl ToString for AlterOperation { #[derive(Debug, PartialEq, Clone)] pub struct Key { - pub name: String, - pub columns: Vec, + pub name: SQLIdent, + pub columns: Vec, } #[derive(Debug, PartialEq, Clone)] @@ -28,8 +30,8 @@ pub enum TableKey { Key(Key), ForeignKey { key: Key, - foreign_table: String, - referred_columns: Vec, + foreign_table: SQLObjectName, + referred_columns: Vec, }, } @@ -51,7 +53,7 @@ impl ToString for TableKey { "{} FOREIGN KEY ({}) REFERENCES {}({})", key.name, key.columns.join(", "), - foreign_table, + foreign_table.to_string(), referred_columns.join(", ") ), } diff --git a/src/sqlast/value.rs b/src/sqlast/value.rs index ec11b17e0..a061080a8 100644 --- a/src/sqlast/value.rs +++ b/src/sqlast/value.rs @@ -13,6 +13,8 @@ pub enum Value { Uuid(Uuid), /// 'string value' SingleQuotedString(String), + /// N'string value' + NationalStringLiteral(String), /// Boolean value true or false, Boolean(bool), /// Date value @@ -34,6 +36,7 @@ impl ToString for Value { Value::Double(v) => v.to_string(), Value::Uuid(v) => v.to_string(), Value::SingleQuotedString(v) => format!("'{}'", v), + Value::NationalStringLiteral(v) => format!("N'{}'", v), Value::Boolean(v) => v.to_string(), Value::Date(v) => v.to_string(), Value::Time(v) => v.to_string(), diff --git a/src/sqlparser.rs b/src/sqlparser.rs index 42a39b01e..22d0e70ce 100644 --- a/src/sqlparser.rs +++ b/src/sqlparser.rs @@ -14,12 +14,13 @@ //! SQL Parser +use super::dialect::keywords; use super::dialect::Dialect; use super::sqlast::*; use super::sqltokenizer::*; use chrono::{offset::FixedOffset, DateTime, NaiveDate, NaiveDateTime, NaiveTime}; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum ParserError { TokenizerError(String), ParserError(String), @@ -53,20 +54,71 @@ impl Parser { } /// Parse a SQL statement and produce an Abstract Syntax Tree (AST) - pub fn parse_sql(dialect: &Dialect, sql: String) -> Result { + pub fn parse_sql(dialect: &Dialect, sql: String) -> Result, ParserError> { let mut tokenizer = Tokenizer::new(dialect, &sql); let tokens = tokenizer.tokenize()?; let mut parser = Parser::new(tokens); - parser.parse() + let mut stmts = Vec::new(); + let mut expecting_statement_delimiter = false; + debug!("Parsing sql '{}'...", sql); + loop { + // ignore empty statements (between successive statement delimiters) + while parser.consume_token(&Token::SemiColon) { + expecting_statement_delimiter = false; + } + + if parser.peek_token().is_none() { + break; + } else if expecting_statement_delimiter { + return parser_err!(format!( + "Expected end of statement, found: {}", + parser.peek_token().unwrap().to_string() + )); + } + + let statement = parser.parse_statement()?; + stmts.push(statement); + expecting_statement_delimiter = true; + } + Ok(stmts) + } + + /// Parse a single top-level statement (such as SELECT, INSERT, CREATE, etc.), + /// stopping before the statement separator, if any. + pub fn parse_statement(&mut self) -> Result { + match self.next_token() { + Some(t) => match t { + Token::SQLWord(ref w) if w.keyword != "" => match w.keyword.as_ref() { + "SELECT" | "WITH" => { + self.prev_token(); + Ok(SQLStatement::SQLSelect(self.parse_query()?)) + } + "CREATE" => Ok(self.parse_create()?), + "DELETE" => Ok(self.parse_delete()?), + "INSERT" => Ok(self.parse_insert()?), + "ALTER" => Ok(self.parse_alter()?), + "COPY" => Ok(self.parse_copy()?), + _ => parser_err!(format!( + "Unexpected keyword {:?} at the beginning of a statement", + w.to_string() + )), + }, + unexpected => parser_err!(format!( + "Unexpected {:?} at the beginning of a statement", + unexpected + )), + }, + _ => parser_err!("Unexpected end of file"), + } } /// Parse a new expression - pub fn parse(&mut self) -> Result { - self.parse_expr(0) + pub fn parse_expr(&mut self) -> Result { + self.parse_subexpr(0) } /// Parse tokens until the precedence changes - pub fn parse_expr(&mut self, precedence: u8) -> Result { + pub fn parse_subexpr(&mut self, precedence: u8) -> Result { debug!("parsing expr"); let mut expr = self.parse_prefix()?; debug!("prefix: {:?}", expr); @@ -77,9 +129,7 @@ impl Parser { break; } - if let Some(infix_expr) = self.parse_infix(expr.clone(), next_precedence)? { - expr = infix_expr; - } + expr = self.parse_infix(expr, next_precedence)?; } Ok(expr) } @@ -92,7 +142,7 @@ impl Parser { loop { // stop parsing on `NULL` | `NOT NULL` match self.peek_token() { - Some(Token::Keyword(ref k)) if k == "NOT" || k == "NULL" => break, + Some(Token::SQLWord(ref k)) if k.keyword == "NOT" || k.keyword == "NULL" => break, _ => {} } @@ -102,9 +152,7 @@ impl Parser { break; } - if let Some(infix_expr) = self.parse_infix(expr.clone(), next_precedence)? { - expr = infix_expr; - } + expr = self.parse_infix(expr, next_precedence)?; } Ok(expr) } @@ -113,61 +161,75 @@ impl Parser { pub fn parse_prefix(&mut self) -> Result { match self.next_token() { Some(t) => match t { - Token::Keyword(k) => match k.to_uppercase().as_ref() { - "SELECT" => Ok(self.parse_select()?), - "CREATE" => Ok(self.parse_create()?), - "DELETE" => Ok(self.parse_delete()?), - "INSERT" => Ok(self.parse_insert()?), - "ALTER" => Ok(self.parse_alter()?), - "COPY" => Ok(self.parse_copy()?), + Token::SQLWord(w) => match w.keyword.as_ref() { "TRUE" | "FALSE" | "NULL" => { self.prev_token(); self.parse_sql_value() } "CASE" => self.parse_case_expression(), - "NOT" => Ok(ASTNode::SQLUnary { - operator: SQLOperator::Not, - expr: Box::new(self.parse_expr(0)?), - }), - _ => return parser_err!(format!("No prefix parser for keyword {}", k)), - }, - Token::Mult => Ok(ASTNode::SQLWildcard), - Token::Identifier(id) => { - if "CAST" == id.to_uppercase() { - self.parse_cast_expression() - } else { - match self.peek_token() { - Some(Token::LParen) => self.parse_function(&id), - Some(Token::Period) => { - let mut id_parts: Vec = vec![id]; - while self.peek_token() == Some(Token::Period) { - self.expect_token(&Token::Period)?; - match self.next_token() { - Some(Token::Identifier(id)) => id_parts.push(id), - _ => { - return parser_err!(format!( - "Error parsing compound identifier" - )) - } + "CAST" => self.parse_cast_expression(), + "NOT" => { + let p = self.get_precedence(&Token::make_keyword("NOT"))?; + Ok(ASTNode::SQLUnary { + operator: SQLOperator::Not, + expr: Box::new(self.parse_subexpr(p)?), + }) + } + // another SQLWord: + _ => match self.peek_token() { + Some(Token::LParen) => self.parse_function(w.as_sql_ident()), + Some(Token::Period) => { + let mut id_parts: Vec = vec![w.as_sql_ident()]; + let mut ends_with_wildcard = false; + while self.consume_token(&Token::Period) { + match self.next_token() { + Some(Token::SQLWord(w)) => id_parts.push(w.as_sql_ident()), + Some(Token::Mult) => { + ends_with_wildcard = true; + break; + } + _ => { + return parser_err!(format!( + "Error parsing compound identifier" + )); } } + } + if ends_with_wildcard { + Ok(ASTNode::SQLQualifiedWildcard(id_parts)) + } else { Ok(ASTNode::SQLCompoundIdentifier(id_parts)) } - _ => Ok(ASTNode::SQLIdentifier(id)), } - } + _ => Ok(ASTNode::SQLIdentifier(w.as_sql_ident())), + }, + }, // End of Token::SQLWord + Token::Mult => Ok(ASTNode::SQLWildcard), + tok @ Token::Minus | tok @ Token::Plus => { + let p = self.get_precedence(&tok)?; + Ok(ASTNode::SQLUnary { + operator: self.to_sql_operator(&tok)?, + expr: Box::new(self.parse_subexpr(p)?), + }) } - Token::Number(_) | Token::SingleQuotedString(_) => { + Token::Number(_) + | Token::SingleQuotedString(_) + | Token::NationalStringLiteral(_) => { self.prev_token(); self.parse_sql_value() } Token::LParen => { - let expr = self.parse(); + let expr = if self.parse_keyword("SELECT") || self.parse_keyword("WITH") { + self.prev_token(); + ASTNode::SQLSubquery(Box::new(self.parse_query()?)) + } else { + ASTNode::SQLNested(Box::new(self.parse_expr()?)) + }; self.expect_token(&Token::RParen)?; - expr + Ok(expr) } _ => parser_err!(format!( - "Prefix parser expected a keyword but found {:?}", + "Did not expect {:?} at the beginning of an expression", t )), }, @@ -175,20 +237,17 @@ impl Parser { } } - pub fn parse_function(&mut self, id: &str) -> Result { + pub fn parse_function(&mut self, id: SQLIdent) -> Result { self.expect_token(&Token::LParen)?; if self.consume_token(&Token::RParen) { Ok(ASTNode::SQLFunction { - id: id.to_string(), + id: id, args: vec![], }) } else { let args = self.parse_expr_list()?; self.expect_token(&Token::RParen)?; - Ok(ASTNode::SQLFunction { - id: id.to_string(), - args, - }) + Ok(ASTNode::SQLFunction { id, args }) } } @@ -198,11 +257,11 @@ impl Parser { let mut results = vec![]; let mut else_result = None; loop { - conditions.push(self.parse_expr(0)?); + conditions.push(self.parse_expr()?); self.expect_keyword("THEN")?; - results.push(self.parse_expr(0)?); + results.push(self.parse_expr()?); if self.parse_keywords(vec!["ELSE"]) { - else_result = Some(Box::new(self.parse_expr(0)?)); + else_result = Some(Box::new(self.parse_expr()?)); if self.parse_keywords(vec!["END"]) { break; } else { @@ -229,7 +288,7 @@ impl Parser { /// Parse a SQL CAST function e.g. `CAST(expr AS FLOAT)` pub fn parse_cast_expression(&mut self) -> Result { self.expect_token(&Token::LParen)?; - let expr = self.parse_expr(0)?; + let expr = self.parse_expr()?; self.expect_keyword("AS")?; let data_type = self.parse_data_type()?; self.expect_token(&Token::RParen)?; @@ -239,49 +298,46 @@ impl Parser { }) } - /// Parse a postgresql casting style which is in the form of `expr::datatype` - pub fn parse_pg_cast(&mut self, expr: ASTNode) -> Result { - Ok(ASTNode::SQLCast { - expr: Box::new(expr), - data_type: self.parse_data_type()?, - }) - } - /// Parse an expression infix (typically an operator) - pub fn parse_infix( - &mut self, - expr: ASTNode, - precedence: u8, - ) -> Result, ParserError> { + pub fn parse_infix(&mut self, expr: ASTNode, precedence: u8) -> Result { debug!("parsing infix"); match self.next_token() { Some(tok) => match tok { - Token::Keyword(ref k) if k == "IS" => { + Token::SQLWord(ref k) if k.keyword == "IS" => { if self.parse_keywords(vec!["NULL"]) { - Ok(Some(ASTNode::SQLIsNull(Box::new(expr)))) + Ok(ASTNode::SQLIsNull(Box::new(expr))) } else if self.parse_keywords(vec!["NOT", "NULL"]) { - Ok(Some(ASTNode::SQLIsNotNull(Box::new(expr)))) + Ok(ASTNode::SQLIsNotNull(Box::new(expr))) } else { - parser_err!("Invalid tokens after IS") + parser_err!(format!( + "Expected NULL or NOT NULL after IS, found {:?}", + self.peek_token() + )) } } - Token::Keyword(ref k) if k == "NOT" => { - if self.parse_keywords(vec!["LIKE"]) { - Ok(Some(ASTNode::SQLBinaryExpr { + Token::SQLWord(ref k) if k.keyword == "NOT" => { + if self.parse_keyword("IN") { + self.parse_in(expr, true) + } else if self.parse_keyword("BETWEEN") { + self.parse_between(expr, true) + } else if self.parse_keyword("LIKE") { + Ok(ASTNode::SQLBinaryExpr { left: Box::new(expr), op: SQLOperator::NotLike, - right: Box::new(self.parse_expr(precedence)?), - })) + right: Box::new(self.parse_subexpr(precedence)?), + }) } else { - parser_err!("Invalid tokens after NOT") + parser_err!(format!( + "Expected IN or LIKE after NOT, found {:?}", + self.peek_token() + )) } } - Token::Keyword(_) => Ok(Some(ASTNode::SQLBinaryExpr { - left: Box::new(expr), - op: self.to_sql_operator(&tok)?, - right: Box::new(self.parse_expr(precedence)?), - })), - Token::Eq + Token::SQLWord(ref k) if k.keyword == "IN" => self.parse_in(expr, false), + Token::SQLWord(ref k) if k.keyword == "BETWEEN" => self.parse_between(expr, false), + Token::DoubleColon => self.parse_pg_cast(expr), + Token::SQLWord(_) + | Token::Eq | Token::Neq | Token::Gt | Token::GtEq @@ -291,21 +347,61 @@ impl Parser { | Token::Minus | Token::Mult | Token::Mod - | Token::Div => Ok(Some(ASTNode::SQLBinaryExpr { + | Token::Div => Ok(ASTNode::SQLBinaryExpr { left: Box::new(expr), op: self.to_sql_operator(&tok)?, - right: Box::new(self.parse_expr(precedence)?), - })), - Token::DoubleColon => { - let pg_cast = self.parse_pg_cast(expr)?; - Ok(Some(pg_cast)) - } + right: Box::new(self.parse_subexpr(precedence)?), + }), _ => parser_err!(format!("No infix parser for token {:?}", tok)), }, - None => Ok(None), + // This is not supposed to happen, because of the precedence check + // in parse_subexpr. + None => parser_err!("Unexpected EOF in parse_infix"), } } + /// Parses the parens following the `[ NOT ] IN` operator + pub fn parse_in(&mut self, expr: ASTNode, negated: bool) -> Result { + self.expect_token(&Token::LParen)?; + let in_op = if self.parse_keyword("SELECT") || self.parse_keyword("WITH") { + self.prev_token(); + ASTNode::SQLInSubquery { + expr: Box::new(expr), + subquery: Box::new(self.parse_query()?), + negated, + } + } else { + ASTNode::SQLInList { + expr: Box::new(expr), + list: self.parse_expr_list()?, + negated, + } + }; + self.expect_token(&Token::RParen)?; + Ok(in_op) + } + + /// Parses `BETWEEN AND `, assuming the `BETWEEN` keyword was already consumed + pub fn parse_between(&mut self, expr: ASTNode, negated: bool) -> Result { + let low = self.parse_prefix()?; + self.expect_keyword("AND")?; + let high = self.parse_prefix()?; + Ok(ASTNode::SQLBetween { + expr: Box::new(expr), + negated, + low: Box::new(low), + high: Box::new(high), + }) + } + + /// Parse a postgresql casting style which is in the form of `expr::datatype` + pub fn parse_pg_cast(&mut self, expr: ASTNode) -> Result { + Ok(ASTNode::SQLCast { + expr: Box::new(expr), + data_type: self.parse_data_type()?, + }) + } + /// Convert a token operator to an AST operator pub fn to_sql_operator(&self, tok: &Token) -> Result { match tok { @@ -320,10 +416,10 @@ impl Parser { &Token::Mult => Ok(SQLOperator::Multiply), &Token::Div => Ok(SQLOperator::Divide), &Token::Mod => Ok(SQLOperator::Modulus), - &Token::Keyword(ref k) if k == "AND" => Ok(SQLOperator::And), - &Token::Keyword(ref k) if k == "OR" => Ok(SQLOperator::Or), - //&Token::Keyword(ref k) if k == "NOT" => Ok(SQLOperator::Not), - &Token::Keyword(ref k) if k == "LIKE" => Ok(SQLOperator::Like), + &Token::SQLWord(ref k) if k.keyword == "AND" => Ok(SQLOperator::And), + &Token::SQLWord(ref k) if k.keyword == "OR" => Ok(SQLOperator::Or), + //&Token::SQLWord(ref k) if k.keyword == "NOT" => Ok(SQLOperator::Not), + &Token::SQLWord(ref k) if k.keyword == "LIKE" => Ok(SQLOperator::Like), _ => parser_err!(format!("Unsupported SQL operator {:?}", tok)), } } @@ -342,11 +438,13 @@ impl Parser { debug!("get_precedence() {:?}", tok); match tok { - &Token::Keyword(ref k) if k == "OR" => Ok(5), - &Token::Keyword(ref k) if k == "AND" => Ok(10), - &Token::Keyword(ref k) if k == "NOT" => Ok(15), - &Token::Keyword(ref k) if k == "IS" => Ok(15), - &Token::Keyword(ref k) if k == "LIKE" => Ok(20), + &Token::SQLWord(ref k) if k.keyword == "OR" => Ok(5), + &Token::SQLWord(ref k) if k.keyword == "AND" => Ok(10), + &Token::SQLWord(ref k) if k.keyword == "NOT" => Ok(15), + &Token::SQLWord(ref k) if k.keyword == "IS" => Ok(17), + &Token::SQLWord(ref k) if k.keyword == "IN" => Ok(20), + &Token::SQLWord(ref k) if k.keyword == "BETWEEN" => Ok(20), + &Token::SQLWord(ref k) if k.keyword == "LIKE" => Ok(20), &Token::Eq | &Token::Lt | &Token::LtEq | &Token::Neq | &Token::Gt | &Token::GtEq => { Ok(20) } @@ -444,14 +542,15 @@ impl Parser { /// Look for an expected keyword and consume it if it exists #[must_use] pub fn parse_keyword(&mut self, expected: &'static str) -> bool { + // Ideally, we'd accept a enum variant, not a string, but since + // it's not trivial to maintain the enum without duplicating all + // the keywords three times, we'll settle for a run-time check that + // the string actually represents a known keyword... + assert!(keywords::ALL_KEYWORDS.contains(&expected)); match self.peek_token() { - Some(Token::Keyword(k)) => { - if expected.eq_ignore_ascii_case(k.as_str()) { - self.next_token(); - true - } else { - false - } + Some(Token::SQLWord(ref k)) if expected.eq_ignore_ascii_case(&k.keyword) => { + self.next_token(); + true } _ => false, } @@ -515,77 +614,12 @@ impl Parser { } /// Parse a SQL CREATE statement - pub fn parse_create(&mut self) -> Result { - if self.parse_keywords(vec!["TABLE"]) { - let table_name = self.parse_tablename()?; - // parse optional column list (schema) - let mut columns = vec![]; - if self.consume_token(&Token::LParen) { - loop { - if let Some(Token::Identifier(column_name)) = self.next_token() { - if let Ok(data_type) = self.parse_data_type() { - let is_primary = self.parse_keywords(vec!["PRIMARY", "KEY"]); - let is_unique = self.parse_keyword("UNIQUE"); - let default = if self.parse_keyword("DEFAULT") { - let expr = self.parse_default_expr(0)?; - Some(Box::new(expr)) - } else { - None - }; - let allow_null = if self.parse_keywords(vec!["NOT", "NULL"]) { - false - } else if self.parse_keyword("NULL") { - true - } else { - true - }; - debug!("default: {:?}", default); - - match self.peek_token() { - Some(Token::Comma) => { - self.next_token(); - columns.push(SQLColumnDef { - name: column_name, - data_type: data_type, - allow_null, - is_primary, - is_unique, - default, - }); - } - Some(Token::RParen) => { - self.next_token(); - columns.push(SQLColumnDef { - name: column_name, - data_type: data_type, - allow_null, - is_primary, - is_unique, - default, - }); - break; - } - other => { - return parser_err!( - format!("Expected ',' or ')' after column definition but found {:?}", other) - ); - } - } - } else { - return parser_err!(format!( - "Error parsing data type in column definition near: {:?}", - self.peek_token() - )); - } - } else { - return parser_err!("Error parsing column name"); - } - } - } - Ok(ASTNode::SQLCreateTable { - name: table_name, - columns, - }) + pub fn parse_create(&mut self) -> Result { + if self.parse_keyword("TABLE") { + self.parse_create_table() + } else if self.parse_keyword("MATERIALIZED") || self.parse_keyword("VIEW") { + self.prev_token(); + self.parse_create_view() } else { parser_err!(format!( "Unexpected token after CREATE: {:?}", @@ -594,7 +628,85 @@ impl Parser { } } - pub fn parse_table_key(&mut self, constraint_name: &str) -> Result { + pub fn parse_create_view(&mut self) -> Result { + let materialized = self.parse_keyword("MATERIALIZED"); + self.expect_keyword("VIEW")?; + // Many dialects support `OR REPLACE` | `OR ALTER` right after `CREATE`, but we don't (yet). + // ANSI SQL and Postgres support RECURSIVE here, but we don't support it either. + let name = self.parse_object_name()?; + // Parenthesized "output" columns list could be handled here. + // Some dialects allow WITH here, followed by some keywords (e.g. MS SQL) + // or `(k1=v1, k2=v2, ...)` (Postgres) + self.expect_keyword("AS")?; + let query = self.parse_query()?; + // Optional `WITH [ CASCADED | LOCAL ] CHECK OPTION` is widely supported here. + Ok(SQLStatement::SQLCreateView { + name, + query, + materialized, + }) + } + + pub fn parse_create_table(&mut self) -> Result { + let table_name = self.parse_object_name()?; + // parse optional column list (schema) + let mut columns = vec![]; + if self.consume_token(&Token::LParen) { + loop { + match self.next_token() { + Some(Token::SQLWord(column_name)) => { + let data_type = self.parse_data_type()?; + let is_primary = self.parse_keywords(vec!["PRIMARY", "KEY"]); + let is_unique = self.parse_keyword("UNIQUE"); + let default = if self.parse_keyword("DEFAULT") { + let expr = self.parse_default_expr(0)?; + Some(expr) + } else { + None + }; + let allow_null = if self.parse_keywords(vec!["NOT", "NULL"]) { + false + } else if self.parse_keyword("NULL") { + true + } else { + true + }; + debug!("default: {:?}", default); + + columns.push(SQLColumnDef { + name: column_name.as_sql_ident(), + data_type: data_type, + allow_null, + is_primary, + is_unique, + default, + }); + match self.next_token() { + Some(Token::Comma) => {} + Some(Token::RParen) => { + break; + } + other => { + return parser_err!(format!( + "Expected ',' or ')' after column definition but found {:?}", + other + )); + } + } + } + unexpected => { + return parser_err!(format!("Expected column name, got {:?}", unexpected)); + } + } + } + } + Ok(SQLStatement::SQLCreateTable { + name: table_name, + columns, + }) + } + + pub fn parse_table_key(&mut self, constraint_name: SQLIdent) -> Result { let is_primary_key = self.parse_keywords(vec!["PRIMARY", "KEY"]); let is_unique_key = self.parse_keywords(vec!["UNIQUE", "KEY"]); let is_foreign_key = self.parse_keywords(vec!["FOREIGN", "KEY"]); @@ -602,7 +714,7 @@ impl Parser { let column_names = self.parse_column_names()?; self.expect_token(&Token::RParen)?; let key = Key { - name: constraint_name.to_string(), + name: constraint_name, columns: column_names, }; if is_primary_key { @@ -610,19 +722,16 @@ impl Parser { } else if is_unique_key { Ok(TableKey::UniqueKey(key)) } else if is_foreign_key { - if self.parse_keyword("REFERENCES") { - let foreign_table = self.parse_tablename()?; - self.expect_token(&Token::LParen)?; - let referred_columns = self.parse_column_names()?; - self.expect_token(&Token::RParen)?; - Ok(TableKey::ForeignKey { - key, - foreign_table, - referred_columns, - }) - } else { - parser_err!("Expecting references") - } + self.expect_keyword("REFERENCES")?; + let foreign_table = self.parse_object_name()?; + self.expect_token(&Token::LParen)?; + let referred_columns = self.parse_column_names()?; + self.expect_token(&Token::RParen)?; + Ok(TableKey::ForeignKey { + key, + foreign_table, + referred_columns, + }) } else { parser_err!(format!( "Expecting primary key, unique key, or foreign key, found: {:?}", @@ -631,45 +740,30 @@ impl Parser { } } - pub fn parse_alter(&mut self) -> Result { - if self.parse_keyword("TABLE") { - let _ = self.parse_keyword("ONLY"); - let table_name = self.parse_tablename()?; - let operation: Result = - if self.parse_keywords(vec!["ADD", "CONSTRAINT"]) { - match self.next_token() { - Some(Token::Identifier(ref id)) => { - let table_key = self.parse_table_key(id)?; - Ok(AlterOperation::AddConstraint(table_key)) - } - _ => { - return parser_err!(format!( - "Expecting identifier, found : {:?}", - self.peek_token() - )); - } - } - } else { - return parser_err!(format!( - "Expecting ADD CONSTRAINT, found :{:?}", - self.peek_token() - )); - }; - Ok(ASTNode::SQLAlterTable { - name: table_name, - operation: operation?, - }) - } else { - parser_err!(format!( - "Expecting TABLE after ALTER, found {:?}", - self.peek_token() - )) - } + pub fn parse_alter(&mut self) -> Result { + self.expect_keyword("TABLE")?; + let _ = self.parse_keyword("ONLY"); + let table_name = self.parse_object_name()?; + let operation: Result = + if self.parse_keywords(vec!["ADD", "CONSTRAINT"]) { + let constraint_name = self.parse_identifier()?; + let table_key = self.parse_table_key(constraint_name)?; + Ok(AlterOperation::AddConstraint(table_key)) + } else { + return parser_err!(format!( + "Expecting ADD CONSTRAINT, found :{:?}", + self.peek_token() + )); + }; + Ok(SQLStatement::SQLAlterTable { + name: table_name, + operation: operation?, + }) } /// Parse a copy statement - pub fn parse_copy(&mut self) -> Result { - let table_name = self.parse_tablename()?; + pub fn parse_copy(&mut self) -> Result { + let table_name = self.parse_object_name()?; let columns = if self.consume_token(&Token::LParen) { let column_names = self.parse_column_names()?; self.expect_token(&Token::RParen)?; @@ -681,7 +775,7 @@ impl Parser { self.expect_keyword("STDIN")?; self.expect_token(&Token::SemiColon)?; let values = self.parse_tsv()?; - Ok(ASTNode::SQLCopy { + Ok(SQLStatement::SQLCopy { table_name, columns, values, @@ -717,8 +811,10 @@ impl Parser { return Ok(values); } if let Some(token) = self.next_token() { - if token == Token::Identifier("N".to_string()) { - values.push(None); + if let Token::SQLWord(SQLWord { value: v, .. }) = token { + if v == "N" { + values.push(None); + } } } else { continue; @@ -737,16 +833,21 @@ impl Parser { match self.next_token() { Some(t) => { match t { - Token::Keyword(k) => match k.to_uppercase().as_ref() { + Token::SQLWord(k) => match k.keyword.as_ref() { "TRUE" => Ok(Value::Boolean(true)), "FALSE" => Ok(Value::Boolean(false)), "NULL" => Ok(Value::Null), - _ => return parser_err!(format!("No value parser for keyword {}", k)), + _ => { + return parser_err!(format!( + "No value parser for keyword {}", + k.keyword + )); + } }, //TODO: parse the timestamp here (see parse_timestamp_value()) Token::Number(ref n) if n.contains(".") => match n.parse::() { Ok(n) => Ok(Value::Double(n)), - Err(e) => parser_err!(format!("Could not parse '{}' as i64: {}", n, e)), + Err(e) => parser_err!(format!("Could not parse '{}' as f64: {}", n, e)), }, Token::Number(ref n) => match n.parse::() { Ok(n) => Ok(Value::Long(n)), @@ -755,7 +856,10 @@ impl Parser { Token::SingleQuotedString(ref s) => { Ok(Value::SingleQuotedString(s.to_string())) } - _ => parser_err!(format!("Unsupported value: {:?}", self.peek_token())), + Token::NationalStringLiteral(ref s) => { + Ok(Value::NationalStringLiteral(s.to_string())) + } + _ => parser_err!(format!("Unsupported value: {:?}", t)), } } None => parser_err!("Expecting a value, but found EOF"), @@ -873,7 +977,7 @@ impl Parser { /// Parse a SQL datatype (in the context of a CREATE TABLE statement for example) pub fn parse_data_type(&mut self) -> Result { match self.next_token() { - Some(Token::Keyword(k)) => match k.to_uppercase().as_ref() { + Some(Token::SQLWord(k)) => match k.keyword.as_ref() { "BOOLEAN" => Ok(SQLType::Boolean), "FLOAT" => Ok(SQLType::Float(self.parse_optional_precision()?)), "REAL" => Ok(SQLType::Real), @@ -888,7 +992,7 @@ impl Parser { "INT" | "INTEGER" => Ok(SQLType::Int), "BIGINT" => Ok(SQLType::BigInt), "VARCHAR" => Ok(SQLType::Varchar(self.parse_optional_precision()?)), - "CHARACTER" => { + "CHAR" | "CHARACTER" => { if self.parse_keyword("VARYING") { Ok(SQLType::Varchar(self.parse_optional_precision()?)) } else { @@ -958,71 +1062,100 @@ impl Parser { let (precision, scale) = self.parse_optional_precision_scale()?; Ok(SQLType::Decimal(precision, scale)) } - _ => parser_err!(format!("Invalid data type '{:?}'", k)), + _ => { + self.prev_token(); + let type_name = self.parse_object_name()?; + Ok(SQLType::Custom(type_name)) + } }, - Some(Token::Identifier(_)) => { + other => parser_err!(format!("Invalid data type: '{:?}'", other)), + } + } + + /// Parse `AS identifier` (or simply `identifier` if it's not a reserved keyword) + /// Some examples with aliases: `SELECT 1 foo`, `SELECT COUNT(*) AS cnt`, + /// `SELECT ... FROM t1 foo, t2 bar`, `SELECT ... FROM (...) AS bar` + pub fn parse_optional_alias( + &mut self, + reserved_kwds: &[&str], + ) -> Result, ParserError> { + let after_as = self.parse_keyword("AS"); + let maybe_alias = self.next_token(); + match maybe_alias { + // Accept any identifier after `AS` (though many dialects have restrictions on + // keywords that may appear here). If there's no `AS`: don't parse keywords, + // which may start a construct allowed in this position, to be parsed as aliases. + // (For example, in `FROM t1 JOIN` the `JOIN` will always be parsed as a keyword, + // not an alias.) + Some(Token::SQLWord(ref w)) + if after_as || !reserved_kwds.contains(&w.keyword.as_str()) => + { + Ok(Some(w.as_sql_ident())) + } + ref not_an_ident if after_as => parser_err!(format!( + "Expected an identifier after AS, got {:?}", + not_an_ident + )), + Some(_not_an_ident) => { self.prev_token(); - let type_name = self.parse_tablename()?; // TODO: this actually reads a possibly schema-qualified name of a (custom) type - Ok(SQLType::Custom(type_name)) + Ok(None) // no alias found } - other => parser_err!(format!("Invalid data type: '{:?}'", other)), + None => Ok(None), } } - pub fn parse_compound_identifier(&mut self, separator: &Token) -> Result { + /// Parse one or more identifiers with the specified separator between them + pub fn parse_list_of_ids(&mut self, separator: &Token) -> Result, ParserError> { let mut idents = vec![]; let mut expect_identifier = true; loop { let token = &self.next_token(); match token { - Some(token) => match token { - Token::Identifier(s) => { - if expect_identifier { - expect_identifier = false; - idents.push(s.to_string()); - } else { - self.prev_token(); - break; - } - } - token if token == separator => { - if expect_identifier { - return parser_err!(format!("Expecting identifier, found {:?}", token)); - } else { - expect_identifier = true; - continue; - } - } - _ => { + Some(Token::SQLWord(s)) if expect_identifier => { + expect_identifier = false; + idents.push(s.as_sql_ident()); + } + Some(token) if token == separator && !expect_identifier => { + expect_identifier = true; + continue; + } + _ => { + if token.is_some() { self.prev_token(); - break; } - }, - None => { - self.prev_token(); break; } } } - Ok(ASTNode::SQLCompoundIdentifier(idents)) + if expect_identifier { + parser_err!(format!( + "Expecting identifier, found {:?}", + self.peek_token() + )) + } else { + Ok(idents) + } } - pub fn parse_tablename(&mut self) -> Result { - let identifier = self.parse_compound_identifier(&Token::Period)?; - match identifier { - ASTNode::SQLCompoundIdentifier(idents) => Ok(idents.join(".")), - other => parser_err!(format!("Expecting compound identifier, found: {:?}", other)), - } + /// Parse a possibly qualified, possibly quoted identifier, e.g. + /// `foo` or `myschema."table"` + pub fn parse_object_name(&mut self) -> Result { + Ok(SQLObjectName(self.parse_list_of_ids(&Token::Period)?)) } - pub fn parse_column_names(&mut self) -> Result, ParserError> { - let identifier = self.parse_compound_identifier(&Token::Comma)?; - match identifier { - ASTNode::SQLCompoundIdentifier(idents) => Ok(idents), - other => parser_err!(format!("Expecting compound identifier, found: {:?}", other)), + /// Parse a simple one-word identifier (possibly quoted, possibly a keyword) + pub fn parse_identifier(&mut self) -> Result { + match self.next_token() { + Some(Token::SQLWord(w)) => Ok(w.as_sql_ident()), + unexpected => parser_err!(format!("Expected identifier, found {:?}", unexpected)), } } + /// Parse a comma-separated list of unqualified, possibly quoted identifiers + pub fn parse_column_names(&mut self) -> Result, ParserError> { + Ok(self.parse_list_of_ids(&Token::Comma)?) + } + pub fn parse_precision(&mut self) -> Result { //TODO: error handling Ok(self.parse_optional_precision()?.unwrap()) @@ -1041,7 +1174,7 @@ impl Parser { pub fn parse_optional_precision_scale( &mut self, - ) -> Result<(usize, Option), ParserError> { + ) -> Result<(Option, Option), ParserError> { if self.consume_token(&Token::LParen) { let n = self.parse_literal_int()?; let scale = if self.consume_token(&Token::Comma) { @@ -1050,47 +1183,147 @@ impl Parser { None }; self.expect_token(&Token::RParen)?; - Ok((n as usize, scale)) + Ok((Some(n as usize), scale)) } else { - parser_err!("Expecting `(`") + Ok((None, None)) } } - pub fn parse_delete(&mut self) -> Result { - let relation: Option> = if self.parse_keyword("FROM") { - Some(Box::new(self.parse_expr(0)?)) + pub fn parse_delete(&mut self) -> Result { + self.expect_keyword("FROM")?; + let table_name = self.parse_object_name()?; + let selection = if self.parse_keyword("WHERE") { + Some(self.parse_expr()?) } else { None }; - let selection = if self.parse_keyword("WHERE") { - Some(Box::new(self.parse_expr(0)?)) + Ok(SQLStatement::SQLDelete { + table_name, + selection, + }) + } + + /// Parse a query expression, i.e. a `SELECT` statement optionally + /// preceeded with some `WITH` CTE declarations and optionally followed + /// by `ORDER BY`. Unlike some other parse_... methods, this one doesn't + /// expect the initial keyword to be already consumed + pub fn parse_query(&mut self) -> Result { + let ctes = if self.parse_keyword("WITH") { + // TODO: optional RECURSIVE + self.parse_cte_list()? + } else { + vec![] + }; + + let body = self.parse_query_body(0)?; + + let order_by = if self.parse_keywords(vec!["ORDER", "BY"]) { + Some(self.parse_order_by_expr_list()?) } else { None }; - let _ = self.consume_token(&Token::SemiColon); + let limit = if self.parse_keyword("LIMIT") { + self.parse_limit()? + } else { + None + }; - // parse next token - if let Some(next_token) = self.peek_token() { - parser_err!(format!( - "Unexpected token at end of DELETE: {:?}", - next_token - )) + Ok(SQLQuery { + ctes, + body, + limit, + order_by, + }) + } + + /// Parse one or more (comma-separated) `alias AS (subquery)` CTEs, + /// assuming the initial `WITH` was already consumed. + fn parse_cte_list(&mut self) -> Result, ParserError> { + let mut cte = vec![]; + loop { + let alias = self.parse_identifier()?; + // TODO: Optional `( )` + self.expect_keyword("AS")?; + self.expect_token(&Token::LParen)?; + cte.push(Cte { + alias, + query: self.parse_query()?, + }); + self.expect_token(&Token::RParen)?; + if !self.consume_token(&Token::Comma) { + break; + } + } + return Ok(cte); + } + + /// Parse a "query body", which is an expression with roughly the + /// following grammar: + /// ```text + /// query_body ::= restricted_select | '(' subquery ')' | set_operation + /// restricted_select ::= 'SELECT' [expr_list] [ from ] [ where ] [ groupby_having ] + /// subquery ::= query_body [ order_by_limit ] + /// set_operation ::= query_body { 'UNION' | 'EXCEPT' | 'INTERSECT' } [ 'ALL' ] query_body + /// ``` + fn parse_query_body(&mut self, precedence: u8) -> Result { + // We parse the expression using a Pratt parser, as in `parse_expr()`. + // Start by parsing a restricted SELECT or a `(subquery)`: + let mut expr = if self.parse_keyword("SELECT") { + SQLSetExpr::Select(self.parse_select()?) + } else if self.consume_token(&Token::LParen) { + // CTEs are not allowed here, but the parser currently accepts them + let subquery = self.parse_query()?; + self.expect_token(&Token::RParen)?; + SQLSetExpr::Query(Box::new(subquery)) } else { - Ok(ASTNode::SQLDelete { - relation, - selection, - }) + parser_err!("Expected SELECT or a subquery in the query body!")? + }; + + loop { + // The query can be optionally followed by a set operator: + let next_token = self.peek_token(); + let op = self.parse_set_operator(&next_token); + let next_precedence = match op { + // UNION and EXCEPT have the same binding power and evaluate left-to-right + Some(SQLSetOperator::Union) | Some(SQLSetOperator::Except) => 10, + // INTERSECT has higher precedence than UNION/EXCEPT + Some(SQLSetOperator::Intersect) => 20, + // Unexpected token or EOF => stop parsing the query body + None => break, + }; + if precedence >= next_precedence { + break; + } + self.next_token(); // skip past the set operator + expr = SQLSetExpr::SetOperation { + left: Box::new(expr), + op: op.unwrap(), + all: self.parse_keyword("ALL"), + right: Box::new(self.parse_query_body(next_precedence)?), + }; + } + + Ok(expr) + } + + fn parse_set_operator(&mut self, token: &Option) -> Option { + match token { + Some(Token::SQLWord(w)) if w.keyword == "UNION" => Some(SQLSetOperator::Union), + Some(Token::SQLWord(w)) if w.keyword == "EXCEPT" => Some(SQLSetOperator::Except), + Some(Token::SQLWord(w)) if w.keyword == "INTERSECT" => Some(SQLSetOperator::Intersect), + _ => None, } } - /// Parse a SELECT statement - pub fn parse_select(&mut self) -> Result { - let projection = self.parse_expr_list()?; + /// Parse a restricted `SELECT` statement (no CTEs / `UNION` / `ORDER BY`), + /// assuming the initial `SELECT` was already consumed + pub fn parse_select(&mut self) -> Result { + let projection = self.parse_select_list()?; - let (relation, joins): (Option>, Vec) = if self.parse_keyword("FROM") { - let relation = Some(Box::new(self.parse_expr(0)?)); + let (relation, joins) = if self.parse_keyword("FROM") { + let relation = Some(self.parse_table_factor()?); let joins = self.parse_joins()?; (relation, joins) } else { @@ -1098,8 +1331,7 @@ impl Parser { }; let selection = if self.parse_keyword("WHERE") { - let expr = self.parse_expr(0)?; - Some(Box::new(expr)) + Some(self.parse_expr()?) } else { None }; @@ -1111,41 +1343,32 @@ impl Parser { }; let having = if self.parse_keyword("HAVING") { - Some(Box::new(self.parse_expr(0)?)) - } else { - None - }; - - let order_by = if self.parse_keywords(vec!["ORDER", "BY"]) { - Some(self.parse_order_by_expr_list()?) + Some(self.parse_expr()?) } else { None }; - let limit = if self.parse_keyword("LIMIT") { - self.parse_limit()? - } else { - None - }; - - let _ = self.consume_token(&Token::SemiColon); + Ok(SQLSelect { + projection, + selection, + relation, + joins, + group_by, + having, + }) + } - if let Some(next_token) = self.peek_token() { - parser_err!(format!( - "Unexpected token at end of SELECT: {:?}", - next_token - )) + /// A table name or a parenthesized subquery, followed by optional `[AS] alias` + pub fn parse_table_factor(&mut self) -> Result { + if self.consume_token(&Token::LParen) { + let subquery = Box::new(self.parse_query()?); + self.expect_token(&Token::RParen)?; + let alias = self.parse_optional_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?; + Ok(TableFactor::Derived { subquery, alias }) } else { - Ok(ASTNode::SQLSelect { - projection, - selection, - relation, - joins, - limit, - order_by, - group_by, - having, - }) + let name = self.parse_object_name()?; + let alias = self.parse_optional_alias(keywords::RESERVED_FOR_TABLE_ALIAS)?; + Ok(TableFactor::Table { name, alias }) } } @@ -1153,29 +1376,13 @@ impl Parser { if natural { Ok(JoinConstraint::Natural) } else if self.parse_keyword("ON") { - let constraint = self.parse_expr(0)?; + let constraint = self.parse_expr()?; Ok(JoinConstraint::On(constraint)) } else if self.parse_keyword("USING") { - if self.consume_token(&Token::LParen) { - let attributes = self - .parse_expr_list()? - .into_iter() - .map(|ast_node| match ast_node { - ASTNode::SQLIdentifier(ident) => Ok(ident), - unexpected => { - parser_err!(format!("Expected identifier, found {:?}", unexpected)) - } - }) - .collect::, ParserError>>()?; - - if self.consume_token(&Token::RParen) { - Ok(JoinConstraint::Using(attributes)) - } else { - parser_err!(format!("Expected token ')', found {:?}", self.peek_token())) - } - } else { - parser_err!(format!("Expected token '(', found {:?}", self.peek_token())) - } + self.expect_token(&Token::LParen)?; + let attributes = self.parse_column_names()?; + self.expect_token(&Token::RParen)?; + Ok(JoinConstraint::Using(attributes)) } else { parser_err!(format!( "Unexpected token after JOIN: {:?}", @@ -1190,7 +1397,7 @@ impl Parser { let natural = match &self.peek_token() { Some(Token::Comma) => { self.next_token(); - let relation = self.parse_expr(0)?; + let relation = self.parse_table_factor()?; let join = Join { relation, join_operator: JoinOperator::Implicit, @@ -1198,10 +1405,10 @@ impl Parser { joins.push(join); continue; } - Some(Token::Keyword(kw)) if kw == "CROSS" => { + Some(Token::SQLWord(kw)) if kw.keyword == "CROSS" => { self.next_token(); self.expect_keyword("JOIN")?; - let relation = self.parse_expr(0)?; + let relation = self.parse_table_factor()?; let join = Join { relation, join_operator: JoinOperator::Cross, @@ -1209,7 +1416,7 @@ impl Parser { joins.push(join); continue; } - Some(Token::Keyword(kw)) if kw == "NATURAL" => { + Some(Token::SQLWord(kw)) if kw.keyword == "NATURAL" => { self.next_token(); true } @@ -1218,49 +1425,49 @@ impl Parser { }; let join = match &self.peek_token() { - Some(Token::Keyword(kw)) if kw == "INNER" => { + Some(Token::SQLWord(kw)) if kw.keyword == "INNER" => { self.next_token(); self.expect_keyword("JOIN")?; Join { - relation: self.parse_expr(0)?, + relation: self.parse_table_factor()?, join_operator: JoinOperator::Inner(self.parse_join_constraint(natural)?), } } - Some(Token::Keyword(kw)) if kw == "JOIN" => { + Some(Token::SQLWord(kw)) if kw.keyword == "JOIN" => { self.next_token(); Join { - relation: self.parse_expr(0)?, + relation: self.parse_table_factor()?, join_operator: JoinOperator::Inner(self.parse_join_constraint(natural)?), } } - Some(Token::Keyword(kw)) if kw == "LEFT" => { + Some(Token::SQLWord(kw)) if kw.keyword == "LEFT" => { self.next_token(); let _ = self.parse_keyword("OUTER"); self.expect_keyword("JOIN")?; Join { - relation: self.parse_expr(0)?, + relation: self.parse_table_factor()?, join_operator: JoinOperator::LeftOuter( self.parse_join_constraint(natural)?, ), } } - Some(Token::Keyword(kw)) if kw == "RIGHT" => { + Some(Token::SQLWord(kw)) if kw.keyword == "RIGHT" => { self.next_token(); let _ = self.parse_keyword("OUTER"); self.expect_keyword("JOIN")?; Join { - relation: self.parse_expr(0)?, + relation: self.parse_table_factor()?, join_operator: JoinOperator::RightOuter( self.parse_join_constraint(natural)?, ), } } - Some(Token::Keyword(kw)) if kw == "FULL" => { + Some(Token::SQLWord(kw)) if kw.keyword == "FULL" => { self.next_token(); let _ = self.parse_keyword("OUTER"); self.expect_keyword("JOIN")?; Join { - relation: self.parse_expr(0)?, + relation: self.parse_table_factor()?, join_operator: JoinOperator::FullOuter( self.parse_join_constraint(natural)?, ), @@ -1275,9 +1482,9 @@ impl Parser { } /// Parse an INSERT statement - pub fn parse_insert(&mut self) -> Result { + pub fn parse_insert(&mut self) -> Result { self.expect_keyword("INTO")?; - let table_name = self.parse_tablename()?; + let table_name = self.parse_object_name()?; let columns = if self.consume_token(&Token::LParen) { let column_names = self.parse_column_names()?; self.expect_token(&Token::RParen)?; @@ -1289,7 +1496,7 @@ impl Parser { self.expect_token(&Token::LParen)?; let values = self.parse_expr_list()?; self.expect_token(&Token::RParen)?; - Ok(ASTNode::SQLInsert { + Ok(SQLStatement::SQLInsert { table_name, columns, values: vec![values], @@ -1300,54 +1507,62 @@ impl Parser { pub fn parse_expr_list(&mut self) -> Result, ParserError> { let mut expr_list: Vec = vec![]; loop { - expr_list.push(self.parse_expr(0)?); - if let Some(t) = self.peek_token() { - if t == Token::Comma { - self.next_token(); + expr_list.push(self.parse_expr()?); + match self.peek_token() { + Some(Token::Comma) => self.next_token(), + _ => break, + }; + } + Ok(expr_list) + } + + /// Parse a comma-delimited list of projections after SELECT + pub fn parse_select_list(&mut self) -> Result, ParserError> { + let mut projections: Vec = vec![]; + loop { + let expr = self.parse_expr()?; + if let ASTNode::SQLWildcard = expr { + projections.push(SQLSelectItem::Wildcard); + } else if let ASTNode::SQLQualifiedWildcard(prefix) = expr { + projections.push(SQLSelectItem::QualifiedWildcard(SQLObjectName(prefix))); + } else { + // `expr` is a regular SQL expression and can be followed by an alias + if let Some(alias) = + self.parse_optional_alias(keywords::RESERVED_FOR_COLUMN_ALIAS)? + { + projections.push(SQLSelectItem::ExpressionWithAlias(expr, alias)); } else { - break; + projections.push(SQLSelectItem::UnnamedExpression(expr)); } - } else { - //EOF - break; } + + match self.peek_token() { + Some(Token::Comma) => self.next_token(), + _ => break, + }; } - Ok(expr_list) + Ok(projections) } /// Parse a comma-delimited list of SQL ORDER BY expressions pub fn parse_order_by_expr_list(&mut self) -> Result, ParserError> { let mut expr_list: Vec = vec![]; loop { - let expr = self.parse_expr(0)?; - - // look for optional ASC / DESC specifier - let asc = match self.peek_token() { - Some(Token::Keyword(k)) => match k.to_uppercase().as_ref() { - "ASC" => { - self.next_token(); - true - } - "DESC" => { - self.next_token(); - false - } - _ => true, - }, - Some(Token::Comma) => true, - _ => true, + let expr = self.parse_expr()?; + + let asc = if self.parse_keyword("ASC") { + Some(true) + } else if self.parse_keyword("DESC") { + Some(false) + } else { + None }; - expr_list.push(SQLOrderByExpr::new(Box::new(expr), asc)); + expr_list.push(SQLOrderByExpr { expr, asc }); - if let Some(t) = self.peek_token() { - if t == Token::Comma { - self.next_token(); - } else { - break; - } + if let Some(Token::Comma) = self.peek_token() { + self.next_token(); } else { - // EOF break; } } @@ -1355,12 +1570,18 @@ impl Parser { } /// Parse a LIMIT clause - pub fn parse_limit(&mut self) -> Result>, ParserError> { + pub fn parse_limit(&mut self) -> Result, ParserError> { if self.parse_keyword("ALL") { Ok(None) } else { self.parse_literal_int() - .map(|n| Some(Box::new(ASTNode::SQLValue(Value::Long(n))))) + .map(|n| Some(ASTNode::SQLValue(Value::Long(n)))) } } } + +impl SQLWord { + pub fn as_sql_ident(&self) -> SQLIdent { + self.to_string() + } +} diff --git a/src/sqltokenizer.rs b/src/sqltokenizer.rs index 504088227..83105736c 100644 --- a/src/sqltokenizer.rs +++ b/src/sqltokenizer.rs @@ -21,23 +21,22 @@ use std::iter::Peekable; use std::str::Chars; +use super::dialect::keywords::ALL_KEYWORDS; use super::dialect::Dialect; /// SQL Token enumeration #[derive(Debug, Clone, PartialEq)] pub enum Token { - /// SQL identifier e.g. table or column name - Identifier(String), - /// SQL keyword e.g. Keyword("SELECT") - Keyword(String), + /// A keyword (like SELECT) or an optionally quoted SQL identifier + SQLWord(SQLWord), /// Numeric literal Number(String), /// A character that could not be tokenized Char(char), /// Single quoted string: i.e: 'string' SingleQuotedString(String), - /// Double quoted string: i.e: "string" - DoubleQuotedString(String), + /// "National" string literal: i.e: N'string' + NationalStringLiteral(String), /// Comma Comma, /// Whitespace (space, tab, etc) @@ -93,12 +92,11 @@ pub enum Token { impl ToString for Token { fn to_string(&self) -> String { match self { - Token::Identifier(ref id) => id.to_string(), - Token::Keyword(ref k) => k.to_string(), + Token::SQLWord(ref w) => w.to_string(), Token::Number(ref n) => n.to_string(), Token::Char(ref c) => c.to_string(), Token::SingleQuotedString(ref s) => format!("'{}'", s), - Token::DoubleQuotedString(ref s) => format!("\"{}\"", s), + Token::NationalStringLiteral(ref s) => format!("N'{}'", s), Token::Comma => ",".to_string(), Token::Whitespace(ws) => ws.to_string(), Token::Eq => "=".to_string(), @@ -128,11 +126,72 @@ impl ToString for Token { } } +impl Token { + pub fn make_keyword(keyword: &str) -> Self { + Token::make_word(keyword, None) + } + pub fn make_word(word: &str, quote_style: Option) -> Self { + let word_uppercase = word.to_uppercase(); + //TODO: need to reintroduce FnvHashSet at some point .. iterating over keywords is + // not fast but I want the simplicity for now while I experiment with pluggable + // dialects + let is_keyword = quote_style == None && ALL_KEYWORDS.contains(&word_uppercase.as_str()); + Token::SQLWord(SQLWord { + value: word.to_string(), + quote_style: quote_style, + keyword: if is_keyword { + word_uppercase.to_string() + } else { + "".to_string() + }, + }) + } +} + +/// A keyword (like SELECT) or an optionally quoted SQL identifier +#[derive(Debug, Clone, PartialEq)] +pub struct SQLWord { + /// The value of the token, without the enclosing quotes, and with the + /// escape sequences (if any) processed (TODO: escapes are not handled) + pub value: String, + /// An identifier can be "quoted" (<delimited identifier> in ANSI parlance). + /// The standard and most implementations allow using double quotes for this, + /// but some implementations support other quoting styles as well (e.g. \[MS SQL]) + pub quote_style: Option, + /// If the word was not quoted and it matched one of the known keywords, + /// this will have one of the values from dialect::keywords, otherwise empty + pub keyword: String, +} + +impl ToString for SQLWord { + fn to_string(&self) -> String { + match self.quote_style { + Some(s) if s == '"' || s == '[' || s == '`' => { + format!("{}{}{}", s, self.value, SQLWord::matching_end_quote(s)) + } + None => self.value.clone(), + _ => panic!("Unexpected quote_style!"), + } + } +} +impl SQLWord { + fn matching_end_quote(ch: char) -> char { + match ch { + '"' => '"', // ANSI and most dialects + '[' => ']', // MS SQL + '`' => '`', // MySQL + _ => panic!("unexpected quoting style!"), + } + } +} + #[derive(Debug, Clone, PartialEq)] pub enum Whitespace { Space, Newline, Tab, + SingleLineComment(String), + MultiLineComment(String), } impl ToString for Whitespace { @@ -141,6 +200,8 @@ impl ToString for Whitespace { Whitespace::Space => " ".to_string(), Whitespace::Newline => "\n".to_string(), Whitespace::Tab => "\t".to_string(), + Whitespace::SingleLineComment(s) => format!("--{}", s), + Whitespace::MultiLineComment(s) => format!("/*{}*/", s), } } } @@ -168,13 +229,6 @@ impl<'a> Tokenizer<'a> { } } - fn is_keyword(&self, s: &str) -> bool { - //TODO: need to reintroduce FnvHashSet at some point .. iterating over keywords is - // not fast but I want the simplicity for now while I experiment with pluggable - // dialects - return self.dialect.keywords().contains(&s); - } - /// Tokenize the statement and produce a vector of tokens pub fn tokenize(&mut self) -> Result, TokenizerError> { let mut peekable = self.query.chars().peekable(); @@ -189,11 +243,10 @@ impl<'a> Tokenizer<'a> { } Token::Whitespace(Whitespace::Tab) => self.col += 4, - Token::Identifier(s) => self.col += s.len() as u64, - Token::Keyword(s) => self.col += s.len() as u64, + Token::SQLWord(w) if w.quote_style == None => self.col += w.value.len() as u64, + Token::SQLWord(w) if w.quote_style != None => self.col += w.value.len() as u64 + 2, Token::Number(s) => self.col += s.len() as u64, Token::SingleQuotedString(s) => self.col += s.len() as u64, - Token::DoubleQuotedString(s) => self.col += s.len() as u64, _ => self.col += 1, } @@ -219,63 +272,44 @@ impl<'a> Tokenizer<'a> { chars.next(); Ok(Some(Token::Whitespace(Whitespace::Newline))) } - // identifier or keyword - ch if self.dialect.is_identifier_start(ch) => { - let mut s = String::new(); - chars.next(); // consume - s.push(ch); - while let Some(&ch) = chars.peek() { - if self.dialect.is_identifier_part(ch) { - chars.next(); // consume - s.push(ch); - } else { - break; + 'N' => { + chars.next(); // consume, to check the next char + match chars.peek() { + Some('\'') => { + // N'...' - a + let s = self.tokenize_single_quoted_string(chars); + Ok(Some(Token::NationalStringLiteral(s))) + } + _ => { + // regular identifier starting with an "N" + let s = self.tokenize_word('N', chars); + Ok(Some(Token::make_word(&s, None))) } - } - let upper_str = s.to_uppercase(); - if self.is_keyword(upper_str.as_str()) { - Ok(Some(Token::Keyword(upper_str))) - } else { - Ok(Some(Token::Identifier(s))) } } + // identifier or keyword + ch if self.dialect.is_identifier_start(ch) => { + chars.next(); // consume the first char + let s = self.tokenize_word(ch, chars); + Ok(Some(Token::make_word(&s, None))) + } // string '\'' => { - //TODO: handle escaped quotes in string - //TODO: handle EOF before terminating quote - let mut s = String::new(); - chars.next(); // consume - while let Some(&ch) = chars.peek() { - match ch { - '\'' => { - chars.next(); // consume - break; - } - _ => { - chars.next(); // consume - s.push(ch); - } - } - } + let s = self.tokenize_single_quoted_string(chars); Ok(Some(Token::SingleQuotedString(s))) } - // string - '"' => { + // delimited (quoted) identifier + quote_start if self.dialect.is_delimited_identifier_start(quote_start) => { let mut s = String::new(); - chars.next(); // consume - while let Some(&ch) = chars.peek() { + chars.next(); // consume the opening quote + let quote_end = SQLWord::matching_end_quote(quote_start); + while let Some(ch) = chars.next() { match ch { - '"' => { - chars.next(); // consume - break; - } - _ => { - chars.next(); // consume - s.push(ch); - } + c if c == quote_end => break, + _ => s.push(ch), } } - Ok(Some(Token::DoubleQuotedString(s))) + Ok(Some(Token::make_word(&s, Some(quote_start)))) } // numbers '0'...'9' => { @@ -296,10 +330,45 @@ impl<'a> Tokenizer<'a> { ')' => self.consume_and_return(chars, Token::RParen), ',' => self.consume_and_return(chars, Token::Comma), // operators + '-' => { + chars.next(); // consume the '-' + match chars.peek() { + Some('-') => { + chars.next(); // consume the second '-', starting a single-line comment + let mut s = String::new(); + loop { + match chars.next() { + Some(ch) if ch != '\n' => { + s.push(ch); + } + other => { + if other.is_some() { + s.push('\n'); + } + break Ok(Some(Token::Whitespace( + Whitespace::SingleLineComment(s), + ))); + } + } + } + } + // a regular '-' operator + _ => Ok(Some(Token::Minus)), + } + } + '/' => { + chars.next(); // consume the '/' + match chars.peek() { + Some('*') => { + chars.next(); // consume the '*', starting a multi-line comment + self.tokenize_multiline_comment(chars) + } + // a regular '/' operator + _ => Ok(Some(Token::Div)), + } + } '+' => self.consume_and_return(chars, Token::Plus), - '-' => self.consume_and_return(chars, Token::Minus), '*' => self.consume_and_return(chars, Token::Mult), - '/' => self.consume_and_return(chars, Token::Div), '%' => self.consume_and_return(chars, Token::Mod), '=' => self.consume_and_return(chars, Token::Eq), '.' => self.consume_and_return(chars, Token::Period), @@ -366,6 +435,75 @@ impl<'a> Tokenizer<'a> { } } + /// Tokenize an identifier or keyword, after the first char is already consumed. + fn tokenize_word(&self, first_char: char, chars: &mut Peekable) -> String { + let mut s = String::new(); + s.push(first_char); + while let Some(&ch) = chars.peek() { + if self.dialect.is_identifier_part(ch) { + chars.next(); // consume + s.push(ch); + } else { + break; + } + } + s + } + + /// Read a single quoted string, starting with the opening quote. + fn tokenize_single_quoted_string(&self, chars: &mut Peekable) -> String { + //TODO: handle escaped quotes in string + //TODO: handle newlines in string + //TODO: handle EOF before terminating quote + //TODO: handle 'string' 'string continuation' + let mut s = String::new(); + chars.next(); // consume the opening quote + while let Some(&ch) = chars.peek() { + match ch { + '\'' => { + chars.next(); // consume + break; + } + _ => { + chars.next(); // consume + s.push(ch); + } + } + } + s + } + + fn tokenize_multiline_comment( + &self, + chars: &mut Peekable, + ) -> Result, TokenizerError> { + let mut s = String::new(); + let mut maybe_closing_comment = false; + // TODO: deal with nested comments + loop { + match chars.next() { + Some(ch) => { + if maybe_closing_comment { + if ch == '/' { + break Ok(Some(Token::Whitespace(Whitespace::MultiLineComment(s)))); + } else { + s.push('*'); + } + } + maybe_closing_comment = ch == '*'; + if !maybe_closing_comment { + s.push(ch); + } + } + None => { + break Err(TokenizerError( + "Unexpected EOF while in a multi-line comment".to_string(), + )); + } + } + } + } + fn consume_and_return( &self, chars: &mut Peekable, @@ -389,7 +527,7 @@ mod tests { let tokens = tokenizer.tokenize().unwrap(); let expected = vec![ - Token::Keyword(String::from("SELECT")), + Token::make_keyword("SELECT"), Token::Whitespace(Whitespace::Space), Token::Number(String::from("1")), ]; @@ -405,9 +543,9 @@ mod tests { let tokens = tokenizer.tokenize().unwrap(); let expected = vec![ - Token::Keyword(String::from("SELECT")), + Token::make_keyword("SELECT"), Token::Whitespace(Whitespace::Space), - Token::Identifier(String::from("sqrt")), + Token::make_word("sqrt", None), Token::LParen, Token::Number(String::from("1")), Token::RParen, @@ -424,23 +562,23 @@ mod tests { let tokens = tokenizer.tokenize().unwrap(); let expected = vec![ - Token::Keyword(String::from("SELECT")), + Token::make_keyword("SELECT"), Token::Whitespace(Whitespace::Space), Token::Mult, Token::Whitespace(Whitespace::Space), - Token::Keyword(String::from("FROM")), + Token::make_keyword("FROM"), Token::Whitespace(Whitespace::Space), - Token::Identifier(String::from("customer")), + Token::make_word("customer", None), Token::Whitespace(Whitespace::Space), - Token::Keyword(String::from("WHERE")), + Token::make_keyword("WHERE"), Token::Whitespace(Whitespace::Space), - Token::Identifier(String::from("id")), + Token::make_word("id", None), Token::Whitespace(Whitespace::Space), Token::Eq, Token::Whitespace(Whitespace::Space), Token::Number(String::from("1")), Token::Whitespace(Whitespace::Space), - Token::Keyword(String::from("LIMIT")), + Token::make_keyword("LIMIT"), Token::Whitespace(Whitespace::Space), Token::Number(String::from("5")), ]; @@ -456,17 +594,17 @@ mod tests { let tokens = tokenizer.tokenize().unwrap(); let expected = vec![ - Token::Keyword(String::from("SELECT")), + Token::make_keyword("SELECT"), Token::Whitespace(Whitespace::Space), Token::Mult, Token::Whitespace(Whitespace::Space), - Token::Keyword(String::from("FROM")), + Token::make_keyword("FROM"), Token::Whitespace(Whitespace::Space), - Token::Identifier(String::from("customer")), + Token::make_word("customer", None), Token::Whitespace(Whitespace::Space), - Token::Keyword(String::from("WHERE")), + Token::make_keyword("WHERE"), Token::Whitespace(Whitespace::Space), - Token::Identifier(String::from("salary")), + Token::make_word("salary", None), Token::Whitespace(Whitespace::Space), Token::Neq, Token::Whitespace(Whitespace::Space), @@ -491,7 +629,7 @@ mod tests { Token::Char('ط'), Token::Char('ف'), Token::Char('ى'), - Token::Identifier("h".to_string()), + Token::make_word("h", None), ]; compare(expected, tokens); } @@ -507,20 +645,20 @@ mod tests { let expected = vec![ Token::Whitespace(Whitespace::Newline), Token::Whitespace(Whitespace::Newline), - Token::Keyword("SELECT".into()), + Token::make_keyword("SELECT"), Token::Whitespace(Whitespace::Space), Token::Mult, Token::Whitespace(Whitespace::Space), - Token::Keyword("FROM".into()), + Token::make_keyword("FROM"), Token::Whitespace(Whitespace::Space), - Token::Keyword("TABLE".into()), + Token::make_keyword("table"), Token::Whitespace(Whitespace::Tab), Token::Char('م'), Token::Char('ص'), Token::Char('ط'), Token::Char('ف'), Token::Char('ى'), - Token::Identifier("h".to_string()), + Token::make_word("h", None), ]; compare(expected, tokens); } @@ -533,13 +671,75 @@ mod tests { let tokens = tokenizer.tokenize().unwrap(); let expected = vec![ - Token::Identifier(String::from("a")), + Token::make_word("a", None), Token::Whitespace(Whitespace::Space), - Token::Keyword("IS".to_string()), + Token::make_keyword("IS"), Token::Whitespace(Whitespace::Space), - Token::Keyword("NULL".to_string()), + Token::make_keyword("NULL"), + ]; + + compare(expected, tokens); + } + + #[test] + fn tokenize_comment() { + let sql = String::from("0--this is a comment\n1"); + + let dialect = GenericSqlDialect {}; + let mut tokenizer = Tokenizer::new(&dialect, &sql); + let tokens = tokenizer.tokenize().unwrap(); + let expected = vec![ + Token::Number("0".to_string()), + Token::Whitespace(Whitespace::SingleLineComment( + "this is a comment\n".to_string(), + )), + Token::Number("1".to_string()), ]; + compare(expected, tokens); + } + + #[test] + fn tokenize_comment_at_eof() { + let sql = String::from("--this is a comment"); + let dialect = GenericSqlDialect {}; + let mut tokenizer = Tokenizer::new(&dialect, &sql); + let tokens = tokenizer.tokenize().unwrap(); + let expected = vec![Token::Whitespace(Whitespace::SingleLineComment( + "this is a comment".to_string(), + ))]; + compare(expected, tokens); + } + + #[test] + fn tokenize_multiline_comment() { + let sql = String::from("0/*multi-line\n* /comment*/1"); + + let dialect = GenericSqlDialect {}; + let mut tokenizer = Tokenizer::new(&dialect, &sql); + let tokens = tokenizer.tokenize().unwrap(); + let expected = vec![ + Token::Number("0".to_string()), + Token::Whitespace(Whitespace::MultiLineComment( + "multi-line\n* /comment".to_string(), + )), + Token::Number("1".to_string()), + ]; + compare(expected, tokens); + } + + #[test] + fn tokenize_multiline_comment_with_even_asterisks() { + let sql = String::from("\n/** Comment **/\n"); + + let dialect = GenericSqlDialect {}; + let mut tokenizer = Tokenizer::new(&dialect, &sql); + let tokens = tokenizer.tokenize().unwrap(); + let expected = vec![ + Token::Whitespace(Whitespace::Newline), + Token::Whitespace(Whitespace::MultiLineComment("* Comment *".to_string())), + Token::Whitespace(Whitespace::Newline), + ]; compare(expected, tokens); } diff --git a/tests/sqlparser_ansi.rs b/tests/sqlparser_ansi.rs index 4fec4f49e..73054fb78 100644 --- a/tests/sqlparser_ansi.rs +++ b/tests/sqlparser_ansi.rs @@ -4,25 +4,19 @@ extern crate sqlparser; use sqlparser::dialect::AnsiSqlDialect; use sqlparser::sqlast::*; use sqlparser::sqlparser::*; -use sqlparser::sqltokenizer::*; #[test] fn parse_simple_select() { let sql = String::from("SELECT id, fname, lname FROM customer WHERE id = 1"); - let ast = parse_sql(&sql); - match ast { - ASTNode::SQLSelect { projection, .. } => { + let ast = Parser::parse_sql(&AnsiSqlDialect {}, sql).unwrap(); + assert_eq!(1, ast.len()); + match ast.first().unwrap() { + SQLStatement::SQLSelect(SQLQuery { + body: SQLSetExpr::Select(SQLSelect { projection, .. }), + .. + }) => { assert_eq!(3, projection.len()); } _ => assert!(false), } } - -fn parse_sql(sql: &str) -> ASTNode { - let dialect = AnsiSqlDialect {}; - let mut tokenizer = Tokenizer::new(&dialect, &sql); - let tokens = tokenizer.tokenize().unwrap(); - let mut parser = Parser::new(tokens); - let ast = parser.parse().unwrap(); - ast -} diff --git a/tests/sqlparser_generic.rs b/tests/sqlparser_generic.rs index 5c8679724..b3e418a06 100644 --- a/tests/sqlparser_generic.rs +++ b/tests/sqlparser_generic.rs @@ -8,16 +8,10 @@ use sqlparser::sqltokenizer::*; #[test] fn parse_delete_statement() { - let sql: &str = "DELETE FROM 'table'"; - - match verified(&sql) { - ASTNode::SQLDelete { relation, .. } => { - assert_eq!( - Some(Box::new(ASTNode::SQLValue(Value::SingleQuotedString( - "table".to_string() - )))), - relation - ); + let sql = "DELETE FROM \"table\""; + match verified_stmt(sql) { + SQLStatement::SQLDelete { table_name, .. } => { + assert_eq!(SQLObjectName(vec!["\"table\"".to_string()]), table_name); } _ => assert!(false), @@ -26,23 +20,17 @@ fn parse_delete_statement() { #[test] fn parse_where_delete_statement() { - let sql: &str = "DELETE FROM 'table' WHERE name = 5"; - use self::ASTNode::*; use self::SQLOperator::*; - match verified(&sql) { - ASTNode::SQLDelete { - relation, + let sql = "DELETE FROM foo WHERE name = 5"; + match verified_stmt(sql) { + SQLStatement::SQLDelete { + table_name, selection, .. } => { - assert_eq!( - Some(Box::new(ASTNode::SQLValue(Value::SingleQuotedString( - "table".to_string() - )))), - relation - ); + assert_eq!(SQLObjectName(vec!["foo".to_string()]), table_name); assert_eq!( SQLBinaryExpr { @@ -50,7 +38,7 @@ fn parse_where_delete_statement() { op: Eq, right: Box::new(SQLValue(Value::Long(5))), }, - *selection.unwrap(), + selection.unwrap(), ); } @@ -60,72 +48,91 @@ fn parse_where_delete_statement() { #[test] fn parse_simple_select() { - let sql = String::from("SELECT id, fname, lname FROM customer WHERE id = 1 LIMIT 5"); - match verified(&sql) { - ASTNode::SQLSelect { - projection, limit, .. - } => { - assert_eq!(3, projection.len()); - assert_eq!(Some(Box::new(ASTNode::SQLValue(Value::Long(5)))), limit); - } - _ => assert!(false), - } + let sql = "SELECT id, fname, lname FROM customer WHERE id = 1 LIMIT 5"; + let select = verified_only_select(sql); + assert_eq!(3, select.projection.len()); + let select = verified_query(sql); + assert_eq!(Some(ASTNode::SQLValue(Value::Long(5))), select.limit); } #[test] fn parse_select_wildcard() { - let sql = String::from("SELECT * FROM customer"); - match verified(&sql) { - ASTNode::SQLSelect { projection, .. } => { - assert_eq!(1, projection.len()); - assert_eq!(ASTNode::SQLWildcard, projection[0]); - } - _ => assert!(false), + let sql = "SELECT * FROM foo"; + let select = verified_only_select(sql); + assert_eq!(&SQLSelectItem::Wildcard, only(&select.projection)); + + let sql = "SELECT foo.* FROM foo"; + let select = verified_only_select(sql); + assert_eq!( + &SQLSelectItem::QualifiedWildcard(SQLObjectName(vec!["foo".to_string()])), + only(&select.projection) + ); + + let sql = "SELECT myschema.mytable.* FROM myschema.mytable"; + let select = verified_only_select(sql); + assert_eq!( + &SQLSelectItem::QualifiedWildcard(SQLObjectName(vec![ + "myschema".to_string(), + "mytable".to_string(), + ])), + only(&select.projection) + ); +} + +#[test] +fn parse_column_aliases() { + let sql = "SELECT a.col + 1 AS newname FROM foo AS a"; + let select = verified_only_select(sql); + if let SQLSelectItem::ExpressionWithAlias( + ASTNode::SQLBinaryExpr { + ref op, ref right, .. + }, + ref alias, + ) = only(&select.projection) + { + assert_eq!(&SQLOperator::Plus, op); + assert_eq!(&ASTNode::SQLValue(Value::Long(1)), right.as_ref()); + assert_eq!("newname", alias); + } else { + panic!("Expected ExpressionWithAlias") } + + // alias without AS is parsed correctly: + one_statement_parses_to("SELECT a.col + 1 newname FROM foo AS a", &sql); } #[test] fn parse_select_count_wildcard() { - let sql = String::from("SELECT COUNT(*) FROM customer"); - match verified(&sql) { - ASTNode::SQLSelect { projection, .. } => { - assert_eq!(1, projection.len()); - assert_eq!( - ASTNode::SQLFunction { - id: "COUNT".to_string(), - args: vec![ASTNode::SQLWildcard], - }, - projection[0] - ); - } - _ => assert!(false), - } + let sql = "SELECT COUNT(*) FROM customer"; + let select = verified_only_select(sql); + assert_eq!( + &ASTNode::SQLFunction { + id: "COUNT".to_string(), + args: vec![ASTNode::SQLWildcard], + }, + expr_from_projection(only(&select.projection)) + ); } #[test] fn parse_not() { - let sql = String::from( - "SELECT id FROM customer \ - WHERE NOT salary = ''", - ); - let _ast = verified(&sql); + let sql = "SELECT id FROM customer WHERE NOT salary = ''"; + let _ast = verified_only_select(sql); //TODO: add assertions } #[test] fn parse_select_string_predicate() { - let sql = String::from( - "SELECT id, fname, lname FROM customer \ - WHERE salary != 'Not Provided' AND salary != ''", - ); - let _ast = verified(&sql); + let sql = "SELECT id, fname, lname FROM customer \ + WHERE salary != 'Not Provided' AND salary != ''"; + let _ast = verified_only_select(sql); //TODO: add assertions } #[test] fn parse_projection_nested_type() { - let sql = String::from("SELECT customer.address.state FROM foo"); - let _ast = verified(&sql); + let sql = "SELECT customer.address.state FROM foo"; + let _ast = verified_only_select(sql); //TODO: add assertions } @@ -133,7 +140,7 @@ fn parse_projection_nested_type() { fn parse_compound_expr_1() { use self::ASTNode::*; use self::SQLOperator::*; - let sql = String::from("a + b * c"); + let sql = "a + b * c"; assert_eq!( SQLBinaryExpr { left: Box::new(SQLIdentifier("a".to_string())), @@ -144,7 +151,7 @@ fn parse_compound_expr_1() { right: Box::new(SQLIdentifier("c".to_string())) }) }, - verified(&sql) + verified_expr(sql) ); } @@ -152,7 +159,7 @@ fn parse_compound_expr_1() { fn parse_compound_expr_2() { use self::ASTNode::*; use self::SQLOperator::*; - let sql = String::from("a * b + c"); + let sql = "a * b + c"; assert_eq!( SQLBinaryExpr { left: Box::new(SQLBinaryExpr { @@ -163,144 +170,232 @@ fn parse_compound_expr_2() { op: Plus, right: Box::new(SQLIdentifier("c".to_string())) }, - verified(&sql) + verified_expr(sql) + ); +} + +#[test] +fn parse_unary_math() { + use self::ASTNode::*; + use self::SQLOperator::*; + let sql = "- a + - b"; + assert_eq!( + SQLBinaryExpr { + left: Box::new(SQLUnary { + operator: Minus, + expr: Box::new(SQLIdentifier("a".to_string())), + }), + op: Plus, + right: Box::new(SQLUnary { + operator: Minus, + expr: Box::new(SQLIdentifier("b".to_string())), + }), + }, + verified_expr(sql) ); } #[test] fn parse_is_null() { use self::ASTNode::*; - let sql = String::from("a IS NULL"); + let sql = "a IS NULL"; assert_eq!( SQLIsNull(Box::new(SQLIdentifier("a".to_string()))), - verified(&sql) + verified_expr(sql) ); } #[test] fn parse_is_not_null() { use self::ASTNode::*; - let sql = String::from("a IS NOT NULL"); + let sql = "a IS NOT NULL"; assert_eq!( SQLIsNotNull(Box::new(SQLIdentifier("a".to_string()))), - verified(&sql) + verified_expr(sql) ); } #[test] -fn parse_like() { - let sql = String::from("SELECT * FROM customers WHERE name LIKE '%a'"); - match verified(&sql) { - ASTNode::SQLSelect { selection, .. } => { - assert_eq!( - ASTNode::SQLBinaryExpr { - left: Box::new(ASTNode::SQLIdentifier("name".to_string())), - op: SQLOperator::Like, - right: Box::new(ASTNode::SQLValue(Value::SingleQuotedString( - "%a".to_string() - ))), - }, - *selection.unwrap() - ); - } +fn parse_not_precedence() { + use self::ASTNode::*; + // NOT has higher precedence than OR/AND, so the following must parse as (NOT true) OR true + let sql = "NOT true OR true"; + match verified_expr(sql) { + SQLBinaryExpr { + op: SQLOperator::Or, + .. + } => assert!(true), _ => assert!(false), - } + }; + + // But NOT has lower precedence than comparison operators, so the following parses as NOT (a IS NULL) + let sql = "NOT a IS NULL"; + match verified_expr(sql) { + SQLUnary { + operator: SQLOperator::Not, + .. + } => assert!(true), + _ => assert!(false), + }; +} + +#[test] +fn parse_like() { + let sql = "SELECT * FROM customers WHERE name LIKE '%a'"; + let select = verified_only_select(sql); + assert_eq!( + ASTNode::SQLBinaryExpr { + left: Box::new(ASTNode::SQLIdentifier("name".to_string())), + op: SQLOperator::Like, + right: Box::new(ASTNode::SQLValue(Value::SingleQuotedString( + "%a".to_string() + ))), + }, + select.selection.unwrap() + ); } #[test] fn parse_not_like() { - let sql = String::from("SELECT * FROM customers WHERE name NOT LIKE '%a'"); - match verified(&sql) { - ASTNode::SQLSelect { selection, .. } => { - assert_eq!( - ASTNode::SQLBinaryExpr { - left: Box::new(ASTNode::SQLIdentifier("name".to_string())), - op: SQLOperator::NotLike, - right: Box::new(ASTNode::SQLValue(Value::SingleQuotedString( - "%a".to_string() - ))), - }, - *selection.unwrap() - ); - } - _ => assert!(false), + let sql = "SELECT * FROM customers WHERE name NOT LIKE '%a'"; + let select = verified_only_select(sql); + assert_eq!( + ASTNode::SQLBinaryExpr { + left: Box::new(ASTNode::SQLIdentifier("name".to_string())), + op: SQLOperator::NotLike, + right: Box::new(ASTNode::SQLValue(Value::SingleQuotedString( + "%a".to_string() + ))), + }, + select.selection.unwrap() + ); +} + +#[test] +fn parse_in_list() { + fn chk(negated: bool) { + let sql = &format!( + "SELECT * FROM customers WHERE segment {}IN ('HIGH', 'MED')", + if negated { "NOT " } else { "" } + ); + let select = verified_only_select(sql); + assert_eq!( + ASTNode::SQLInList { + expr: Box::new(ASTNode::SQLIdentifier("segment".to_string())), + list: vec![ + ASTNode::SQLValue(Value::SingleQuotedString("HIGH".to_string())), + ASTNode::SQLValue(Value::SingleQuotedString("MED".to_string())), + ], + negated, + }, + select.selection.unwrap() + ); } + chk(false); + chk(true); } #[test] -fn parse_select_order_by() { - let sql = String::from( - "SELECT id, fname, lname FROM customer WHERE id < 5 ORDER BY lname ASC, fname DESC", +fn parse_in_subquery() { + let sql = "SELECT * FROM customers WHERE segment IN (SELECT segm FROM bar)"; + let select = verified_only_select(sql); + assert_eq!( + ASTNode::SQLInSubquery { + expr: Box::new(ASTNode::SQLIdentifier("segment".to_string())), + subquery: Box::new(verified_query("SELECT segm FROM bar")), + negated: false, + }, + select.selection.unwrap() ); - match verified(&sql) { - ASTNode::SQLSelect { order_by, .. } => { - assert_eq!( - Some(vec![ - SQLOrderByExpr { - expr: Box::new(ASTNode::SQLIdentifier("lname".to_string())), - asc: true, - }, - SQLOrderByExpr { - expr: Box::new(ASTNode::SQLIdentifier("fname".to_string())), - asc: false, - }, - ]), - order_by - ); - } - _ => assert!(false), +} + +#[test] +fn parse_between() { + fn chk(negated: bool) { + let sql = &format!( + "SELECT * FROM customers WHERE age {}BETWEEN 25 AND 32", + if negated { "NOT " } else { "" } + ); + let select = verified_only_select(sql); + assert_eq!( + ASTNode::SQLBetween { + expr: Box::new(ASTNode::SQLIdentifier("age".to_string())), + low: Box::new(ASTNode::SQLValue(Value::Long(25))), + high: Box::new(ASTNode::SQLValue(Value::Long(32))), + negated, + }, + select.selection.unwrap() + ); } + chk(false); + chk(true); +} + +#[test] +fn parse_select_order_by() { + fn chk(sql: &str) { + let select = verified_query(sql); + assert_eq!( + Some(vec![ + SQLOrderByExpr { + expr: ASTNode::SQLIdentifier("lname".to_string()), + asc: Some(true), + }, + SQLOrderByExpr { + expr: ASTNode::SQLIdentifier("fname".to_string()), + asc: Some(false), + }, + SQLOrderByExpr { + expr: ASTNode::SQLIdentifier("id".to_string()), + asc: None, + }, + ]), + select.order_by + ); + } + chk("SELECT id, fname, lname FROM customer WHERE id < 5 ORDER BY lname ASC, fname DESC, id"); + // make sure ORDER is not treated as an alias + chk("SELECT id, fname, lname FROM customer ORDER BY lname ASC, fname DESC, id"); + chk("SELECT 1 AS lname, 2 AS fname, 3 AS id, 4 ORDER BY lname ASC, fname DESC, id"); } #[test] fn parse_select_order_by_limit() { - let sql = String::from( - "SELECT id, fname, lname FROM customer WHERE id < 5 ORDER BY lname ASC, fname DESC LIMIT 2", + let sql = "SELECT id, fname, lname FROM customer WHERE id < 5 \ + ORDER BY lname ASC, fname DESC LIMIT 2"; + let select = verified_query(sql); + assert_eq!( + Some(vec![ + SQLOrderByExpr { + expr: ASTNode::SQLIdentifier("lname".to_string()), + asc: Some(true), + }, + SQLOrderByExpr { + expr: ASTNode::SQLIdentifier("fname".to_string()), + asc: Some(false), + }, + ]), + select.order_by ); - let ast = parse_sql(&sql); - match ast { - ASTNode::SQLSelect { - order_by, limit, .. - } => { - assert_eq!( - Some(vec![ - SQLOrderByExpr { - expr: Box::new(ASTNode::SQLIdentifier("lname".to_string())), - asc: true, - }, - SQLOrderByExpr { - expr: Box::new(ASTNode::SQLIdentifier("fname".to_string())), - asc: false, - }, - ]), - order_by - ); - assert_eq!(Some(Box::new(ASTNode::SQLValue(Value::Long(2)))), limit); - } - _ => assert!(false), - } + assert_eq!(Some(ASTNode::SQLValue(Value::Long(2))), select.limit); } #[test] fn parse_select_group_by() { - let sql = String::from("SELECT id, fname, lname FROM customer GROUP BY lname, fname"); - match verified(&sql) { - ASTNode::SQLSelect { group_by, .. } => { - assert_eq!( - Some(vec![ - ASTNode::SQLIdentifier("lname".to_string()), - ASTNode::SQLIdentifier("fname".to_string()), - ]), - group_by - ); - } - _ => assert!(false), - } + let sql = "SELECT id, fname, lname FROM customer GROUP BY lname, fname"; + let select = verified_only_select(sql); + assert_eq!( + Some(vec![ + ASTNode::SQLIdentifier("lname".to_string()), + ASTNode::SQLIdentifier("fname".to_string()), + ]), + select.group_by + ); } #[test] fn parse_limit_accepts_all() { - parses_to( + one_statement_parses_to( "SELECT id, fname, lname FROM customer WHERE id = 1 LIMIT ALL", "SELECT id, fname, lname FROM customer WHERE id = 1", ); @@ -308,21 +403,16 @@ fn parse_limit_accepts_all() { #[test] fn parse_cast() { - let sql = String::from("SELECT CAST(id AS bigint) FROM customer"); - match verified(&sql) { - ASTNode::SQLSelect { projection, .. } => { - assert_eq!(1, projection.len()); - assert_eq!( - ASTNode::SQLCast { - expr: Box::new(ASTNode::SQLIdentifier("id".to_string())), - data_type: SQLType::BigInt - }, - projection[0] - ); - } - _ => assert!(false), - } - parses_to( + let sql = "SELECT CAST(id AS bigint) FROM customer"; + let select = verified_only_select(sql); + assert_eq!( + &ASTNode::SQLCast { + expr: Box::new(ASTNode::SQLIdentifier("id".to_string())), + data_type: SQLType::BigInt + }, + expr_from_projection(only(&select.projection)) + ); + one_statement_parses_to( "SELECT CAST(id AS BIGINT) FROM customer", "SELECT CAST(id AS bigint) FROM customer", ); @@ -336,16 +426,16 @@ fn parse_create_table() { lat DOUBLE NULL,\ lng DOUBLE NULL)", ); - parses_to( + let ast = one_statement_parses_to( &sql, "CREATE TABLE uk_cities (\ name character varying(100) NOT NULL, \ lat double, \ lng double)", ); - match parse_sql(&sql) { - ASTNode::SQLCreateTable { name, columns } => { - assert_eq!("uk_cities", name); + match ast { + SQLStatement::SQLCreateTable { name, columns } => { + assert_eq!("uk_cities", name.to_string()); assert_eq!(3, columns.len()); let c_name = &columns[0]; @@ -369,66 +459,99 @@ fn parse_create_table() { #[test] fn parse_scalar_function_in_projection() { - let sql = String::from("SELECT sqrt(id) FROM foo"); - match verified(&sql) { - ASTNode::SQLSelect { projection, .. } => { - assert_eq!( - vec![ASTNode::SQLFunction { - id: String::from("sqrt"), - args: vec![ASTNode::SQLIdentifier(String::from("id"))], - }], - projection - ); - } - _ => assert!(false), - } + let sql = "SELECT sqrt(id) FROM foo"; + let select = verified_only_select(sql); + assert_eq!( + &ASTNode::SQLFunction { + id: String::from("sqrt"), + args: vec![ASTNode::SQLIdentifier(String::from("id"))], + }, + expr_from_projection(only(&select.projection)) + ); } #[test] fn parse_aggregate_with_group_by() { - let sql = String::from("SELECT a, COUNT(1), MIN(b), MAX(b) FROM foo GROUP BY a"); - let _ast = verified(&sql); + let sql = "SELECT a, COUNT(1), MIN(b), MAX(b) FROM foo GROUP BY a"; + let _ast = verified_only_select(sql); //TODO: assertions } #[test] fn parse_literal_string() { - let sql = "SELECT 'one'"; - match verified(&sql) { - ASTNode::SQLSelect { ref projection, .. } => { - assert_eq!( - projection[0], - ASTNode::SQLValue(Value::SingleQuotedString("one".to_string())) - ); - } - _ => panic!(), - } + let sql = "SELECT 'one', N'national string'"; + let select = verified_only_select(sql); + assert_eq!(2, select.projection.len()); + assert_eq!( + &ASTNode::SQLValue(Value::SingleQuotedString("one".to_string())), + expr_from_projection(&select.projection[0]) + ); + assert_eq!( + &ASTNode::SQLValue(Value::NationalStringLiteral("national string".to_string())), + expr_from_projection(&select.projection[1]) + ); } #[test] fn parse_simple_math_expr_plus() { let sql = "SELECT a + b, 2 + a, 2.5 + a, a_f + b_f, 2 + a_f, 2.5 + a_f FROM c"; - parse_sql(&sql); + verified_only_select(sql); } #[test] fn parse_simple_math_expr_minus() { let sql = "SELECT a - b, 2 - a, 2.5 - a, a_f - b_f, 2 - a_f, 2.5 - a_f FROM c"; - parse_sql(&sql); + verified_only_select(sql); } #[test] fn parse_select_version() { let sql = "SELECT @@version"; - match verified(&sql) { - ASTNode::SQLSelect { ref projection, .. } => { - assert_eq!( - projection[0], - ASTNode::SQLIdentifier("@@version".to_string()) - ); + let select = verified_only_select(sql); + assert_eq!( + &ASTNode::SQLIdentifier("@@version".to_string()), + expr_from_projection(only(&select.projection)), + ); +} + +#[test] +fn parse_delimited_identifiers() { + // check that quoted identifiers in any position remain quoted after serialization + let select = verified_only_select( + r#"SELECT "alias"."bar baz", "myfun"(), "simple id" AS "column alias" FROM "a table" AS "alias""# + ); + // check FROM + match select.relation.unwrap() { + TableFactor::Table { name, alias } => { + assert_eq!(vec![r#""a table""#.to_string()], name.0); + assert_eq!(r#""alias""#, alias.unwrap()); + } + _ => panic!("Expecting TableFactor::Table"), + } + // check SELECT + assert_eq!(3, select.projection.len()); + assert_eq!( + &ASTNode::SQLCompoundIdentifier(vec![r#""alias""#.to_string(), r#""bar baz""#.to_string()]), + expr_from_projection(&select.projection[0]), + ); + assert_eq!( + &ASTNode::SQLFunction { + id: r#""myfun""#.to_string(), + args: vec![] + }, + expr_from_projection(&select.projection[1]), + ); + match &select.projection[2] { + &SQLSelectItem::ExpressionWithAlias(ref expr, ref alias) => { + assert_eq!(&ASTNode::SQLIdentifier(r#""simple id""#.to_string()), expr); + assert_eq!(r#""column alias""#, alias); } - _ => panic!(), + _ => panic!("Expected ExpressionWithAlias"), } + + verified_stmt(r#"CREATE TABLE "foo" ("bar" "int")"#); + verified_stmt(r#"ALTER TABLE foo ADD CONSTRAINT "bar" PRIMARY KEY (baz)"#); + //TODO verified_stmt(r#"UPDATE foo SET "bar" = 5"#); } #[test] @@ -436,141 +559,102 @@ fn parse_parens() { use self::ASTNode::*; use self::SQLOperator::*; let sql = "(a + b) - (c + d)"; - let ast = parse_sql(&sql); assert_eq!( SQLBinaryExpr { - left: Box::new(SQLBinaryExpr { + left: Box::new(SQLNested(Box::new(SQLBinaryExpr { left: Box::new(SQLIdentifier("a".to_string())), op: Plus, right: Box::new(SQLIdentifier("b".to_string())) - }), + }))), op: Minus, - right: Box::new(SQLBinaryExpr { + right: Box::new(SQLNested(Box::new(SQLBinaryExpr { left: Box::new(SQLIdentifier("c".to_string())), op: Plus, right: Box::new(SQLIdentifier("d".to_string())) - }) + }))) }, - ast + verified_expr(sql) ); } #[test] fn parse_case_expression() { let sql = "SELECT CASE WHEN bar IS NULL THEN 'null' WHEN bar = 0 THEN '=0' WHEN bar >= 0 THEN '>=0' ELSE '<0' END FROM foo"; - let ast = parse_sql(&sql); - assert_eq!(sql, ast.to_string()); - use self::ASTNode::*; + use self::ASTNode::{SQLBinaryExpr, SQLCase, SQLIdentifier, SQLIsNull, SQLValue}; use self::SQLOperator::*; - match ast { - ASTNode::SQLSelect { projection, .. } => { - assert_eq!(1, projection.len()); - assert_eq!( - SQLCase { - conditions: vec![ - SQLIsNull(Box::new(SQLIdentifier("bar".to_string()))), - SQLBinaryExpr { - left: Box::new(SQLIdentifier("bar".to_string())), - op: Eq, - right: Box::new(SQLValue(Value::Long(0))) - }, - SQLBinaryExpr { - left: Box::new(SQLIdentifier("bar".to_string())), - op: GtEq, - right: Box::new(SQLValue(Value::Long(0))) - } - ], - results: vec![ - SQLValue(Value::SingleQuotedString("null".to_string())), - SQLValue(Value::SingleQuotedString("=0".to_string())), - SQLValue(Value::SingleQuotedString(">=0".to_string())) - ], - else_result: Some(Box::new(SQLValue(Value::SingleQuotedString( - "<0".to_string() - )))) + let select = verified_only_select(sql); + assert_eq!( + &SQLCase { + conditions: vec![ + SQLIsNull(Box::new(SQLIdentifier("bar".to_string()))), + SQLBinaryExpr { + left: Box::new(SQLIdentifier("bar".to_string())), + op: Eq, + right: Box::new(SQLValue(Value::Long(0))) }, - projection[0] - ); - } - _ => assert!(false), - } -} - -#[test] -fn parse_select_with_semi_colon() { - let sql = String::from("SELECT id, fname, lname FROM customer WHERE id = 1;"); - let ast = parse_sql(&sql); - match ast { - ASTNode::SQLSelect { projection, .. } => { - assert_eq!(3, projection.len()); - } - _ => assert!(false), - } -} - -#[test] -fn parse_delete_with_semi_colon() { - let sql: &str = "DELETE FROM 'table';"; - - match parse_sql(&sql) { - ASTNode::SQLDelete { relation, .. } => { - assert_eq!( - Some(Box::new(ASTNode::SQLValue(Value::SingleQuotedString( - "table".to_string() - )))), - relation - ); - } - _ => assert!(false), - } + SQLBinaryExpr { + left: Box::new(SQLIdentifier("bar".to_string())), + op: GtEq, + right: Box::new(SQLValue(Value::Long(0))) + } + ], + results: vec![ + SQLValue(Value::SingleQuotedString("null".to_string())), + SQLValue(Value::SingleQuotedString("=0".to_string())), + SQLValue(Value::SingleQuotedString(">=0".to_string())) + ], + else_result: Some(Box::new(SQLValue(Value::SingleQuotedString( + "<0".to_string() + )))) + }, + expr_from_projection(only(&select.projection)), + ); } #[test] fn parse_implicit_join() { let sql = "SELECT * FROM t1, t2"; - - match verified(sql) { - ASTNode::SQLSelect { joins, .. } => { - assert_eq!(joins.len(), 1); - assert_eq!( - joins[0], - Join { - relation: ASTNode::SQLIdentifier("t2".to_string()), - join_operator: JoinOperator::Implicit - } - ) - } - _ => assert!(false), - } + let select = verified_only_select(sql); + assert_eq!( + &Join { + relation: TableFactor::Table { + name: SQLObjectName(vec!["t2".to_string()]), + alias: None, + }, + join_operator: JoinOperator::Implicit + }, + only(&select.joins), + ); } #[test] fn parse_cross_join() { let sql = "SELECT * FROM t1 CROSS JOIN t2"; - - match verified(sql) { - ASTNode::SQLSelect { joins, .. } => { - assert_eq!(joins.len(), 1); - assert_eq!( - joins[0], - Join { - relation: ASTNode::SQLIdentifier("t2".to_string()), - join_operator: JoinOperator::Cross - } - ) - } - _ => assert!(false), - } + let select = verified_only_select(sql); + assert_eq!( + &Join { + relation: TableFactor::Table { + name: SQLObjectName(vec!["t2".to_string()]), + alias: None, + }, + join_operator: JoinOperator::Cross + }, + only(&select.joins), + ); } #[test] fn parse_joins_on() { fn join_with_constraint( relation: impl Into, + alias: Option, f: impl Fn(JoinConstraint) -> JoinOperator, ) -> Join { Join { - relation: ASTNode::SQLIdentifier(relation.into()), + relation: TableFactor::Table { + name: SQLObjectName(vec![relation.into()]), + alias, + }, join_operator: f(JoinConstraint::On(ASTNode::SQLBinaryExpr { left: Box::new(ASTNode::SQLIdentifier("c1".into())), op: SQLOperator::Eq, @@ -578,21 +662,35 @@ fn parse_joins_on() { })), } } + // Test parsing of aliases + assert_eq!( + verified_only_select("SELECT * FROM t1 JOIN t2 AS foo ON c1 = c2").joins, + vec![join_with_constraint( + "t2", + Some("foo".to_string()), + JoinOperator::Inner + )] + ); + one_statement_parses_to( + "SELECT * FROM t1 JOIN t2 foo ON c1 = c2", + "SELECT * FROM t1 JOIN t2 AS foo ON c1 = c2", + ); + // Test parsing of different join operators assert_eq!( - joins_from(verified("SELECT * FROM t1 JOIN t2 ON c1 = c2")), - vec![join_with_constraint("t2", JoinOperator::Inner)] + verified_only_select("SELECT * FROM t1 JOIN t2 ON c1 = c2").joins, + vec![join_with_constraint("t2", None, JoinOperator::Inner)] ); assert_eq!( - joins_from(verified("SELECT * FROM t1 LEFT JOIN t2 ON c1 = c2")), - vec![join_with_constraint("t2", JoinOperator::LeftOuter)] + verified_only_select("SELECT * FROM t1 LEFT JOIN t2 ON c1 = c2").joins, + vec![join_with_constraint("t2", None, JoinOperator::LeftOuter)] ); assert_eq!( - joins_from(verified("SELECT * FROM t1 RIGHT JOIN t2 ON c1 = c2")), - vec![join_with_constraint("t2", JoinOperator::RightOuter)] + verified_only_select("SELECT * FROM t1 RIGHT JOIN t2 ON c1 = c2").joins, + vec![join_with_constraint("t2", None, JoinOperator::RightOuter)] ); assert_eq!( - joins_from(verified("SELECT * FROM t1 FULL JOIN t2 ON c1 = c2")), - vec![join_with_constraint("t2", JoinOperator::FullOuter)] + verified_only_select("SELECT * FROM t1 FULL JOIN t2 ON c1 = c2").joins, + vec![join_with_constraint("t2", None, JoinOperator::FullOuter)] ); } @@ -600,86 +698,316 @@ fn parse_joins_on() { fn parse_joins_using() { fn join_with_constraint( relation: impl Into, + alias: Option, f: impl Fn(JoinConstraint) -> JoinOperator, ) -> Join { Join { - relation: ASTNode::SQLIdentifier(relation.into()), + relation: TableFactor::Table { + name: SQLObjectName(vec![relation.into()]), + alias, + }, join_operator: f(JoinConstraint::Using(vec!["c1".into()])), } } - + // Test parsing of aliases assert_eq!( - joins_from(verified("SELECT * FROM t1 JOIN t2 USING(c1)")), - vec![join_with_constraint("t2", JoinOperator::Inner)] + verified_only_select("SELECT * FROM t1 JOIN t2 AS foo USING(c1)").joins, + vec![join_with_constraint( + "t2", + Some("foo".to_string()), + JoinOperator::Inner + )] + ); + one_statement_parses_to( + "SELECT * FROM t1 JOIN t2 foo USING(c1)", + "SELECT * FROM t1 JOIN t2 AS foo USING(c1)", ); + // Test parsing of different join operators assert_eq!( - joins_from(verified("SELECT * FROM t1 LEFT JOIN t2 USING(c1)")), - vec![join_with_constraint("t2", JoinOperator::LeftOuter)] + verified_only_select("SELECT * FROM t1 JOIN t2 USING(c1)").joins, + vec![join_with_constraint("t2", None, JoinOperator::Inner)] ); assert_eq!( - joins_from(verified("SELECT * FROM t1 RIGHT JOIN t2 USING(c1)")), - vec![join_with_constraint("t2", JoinOperator::RightOuter)] + verified_only_select("SELECT * FROM t1 LEFT JOIN t2 USING(c1)").joins, + vec![join_with_constraint("t2", None, JoinOperator::LeftOuter)] ); assert_eq!( - joins_from(verified("SELECT * FROM t1 FULL JOIN t2 USING(c1)")), - vec![join_with_constraint("t2", JoinOperator::FullOuter)] + verified_only_select("SELECT * FROM t1 RIGHT JOIN t2 USING(c1)").joins, + vec![join_with_constraint("t2", None, JoinOperator::RightOuter)] + ); + assert_eq!( + verified_only_select("SELECT * FROM t1 FULL JOIN t2 USING(c1)").joins, + vec![join_with_constraint("t2", None, JoinOperator::FullOuter)] ); } #[test] fn parse_complex_join() { let sql = "SELECT c1, c2 FROM t1, t4 JOIN t2 ON t2.c = t1.c LEFT JOIN t3 USING(q, c) WHERE t4.c = t1.c"; - assert_eq!(sql, parse_sql(sql).to_string()); + verified_only_select(sql); } #[test] fn parse_join_syntax_variants() { - parses_to( + one_statement_parses_to( "SELECT c1 FROM t1 INNER JOIN t2 USING(c1)", "SELECT c1 FROM t1 JOIN t2 USING(c1)", ); - parses_to( + one_statement_parses_to( "SELECT c1 FROM t1 LEFT OUTER JOIN t2 USING(c1)", "SELECT c1 FROM t1 LEFT JOIN t2 USING(c1)", ); - parses_to( + one_statement_parses_to( "SELECT c1 FROM t1 RIGHT OUTER JOIN t2 USING(c1)", "SELECT c1 FROM t1 RIGHT JOIN t2 USING(c1)", ); - parses_to( + one_statement_parses_to( "SELECT c1 FROM t1 FULL OUTER JOIN t2 USING(c1)", "SELECT c1 FROM t1 FULL JOIN t2 USING(c1)", ); } -fn verified(query: &str) -> ASTNode { - let ast = parse_sql(query); - assert_eq!(query, &ast.to_string()); - ast +#[test] +fn parse_ctes() { + let cte_sqls = vec!["SELECT 1 AS foo", "SELECT 2 AS bar"]; + let with = &format!( + "WITH a AS ({}), b AS ({}) SELECT foo + bar FROM a, b", + cte_sqls[0], cte_sqls[1] + ); + + fn assert_ctes_in_select(expected: &Vec<&str>, sel: &SQLQuery) { + for i in 0..1 { + let Cte { + ref query, + ref alias, + } = sel.ctes[i]; + assert_eq!(expected[i], query.to_string()); + assert_eq!(if i == 0 { "a" } else { "b" }, alias); + } + } + + // Top-level CTE + assert_ctes_in_select(&cte_sqls, &verified_query(with)); + // CTE in a subquery + let sql = &format!("SELECT ({})", with); + let select = verified_only_select(sql); + match expr_from_projection(only(&select.projection)) { + &ASTNode::SQLSubquery(ref subquery) => { + assert_ctes_in_select(&cte_sqls, subquery.as_ref()); + } + _ => panic!("Expected subquery"), + } + // CTE in a derived table + let sql = &format!("SELECT * FROM ({})", with); + let select = verified_only_select(sql); + match select.relation { + Some(TableFactor::Derived { subquery, .. }) => { + assert_ctes_in_select(&cte_sqls, subquery.as_ref()) + } + _ => panic!("Expected derived table"), + } + // CTE in a view + let sql = &format!("CREATE VIEW v AS {}", with); + match verified_stmt(sql) { + SQLStatement::SQLCreateView { query, .. } => assert_ctes_in_select(&cte_sqls, &query), + _ => panic!("Expected CREATE VIEW"), + } + // CTE in a CTE... + let sql = &format!("WITH outer_cte AS ({}) SELECT * FROM outer_cte", with); + let select = verified_query(sql); + assert_ctes_in_select(&cte_sqls, &only(&select.ctes).query); } -fn parses_to(from: &str, to: &str) { - assert_eq!(to, &parse_sql(from).to_string()) +#[test] +fn parse_derived_tables() { + let sql = "SELECT a.x, b.y FROM (SELECT x FROM foo) AS a CROSS JOIN (SELECT y FROM bar) AS b"; + let _ = verified_only_select(sql); + //TODO: add assertions } -fn joins_from(ast: ASTNode) -> Vec { - match ast { - ASTNode::SQLSelect { joins, .. } => joins, +#[test] +fn parse_union() { + // TODO: add assertions + verified_stmt("SELECT 1 UNION SELECT 2"); + verified_stmt("SELECT 1 UNION ALL SELECT 2"); + verified_stmt("SELECT 1 EXCEPT SELECT 2"); + verified_stmt("SELECT 1 EXCEPT ALL SELECT 2"); + verified_stmt("SELECT 1 INTERSECT SELECT 2"); + verified_stmt("SELECT 1 INTERSECT ALL SELECT 2"); + verified_stmt("SELECT 1 UNION SELECT 2 UNION SELECT 3"); + verified_stmt("SELECT 1 EXCEPT SELECT 2 UNION SELECT 3"); // Union[Except[1,2], 3] + verified_stmt("SELECT 1 INTERSECT (SELECT 2 EXCEPT SELECT 3)"); + verified_stmt("WITH cte AS (SELECT 1 AS foo) (SELECT foo FROM cte ORDER BY 1 LIMIT 1)"); + verified_stmt("SELECT 1 UNION (SELECT 2 ORDER BY 1 LIMIT 1)"); + verified_stmt("SELECT 1 UNION SELECT 2 INTERSECT SELECT 3"); // Union[1, Intersect[2,3]] + verified_stmt("SELECT foo FROM tab UNION SELECT bar FROM TAB"); +} + +#[test] +fn parse_multiple_statements() { + fn test_with(sql1: &str, sql2_kw: &str, sql2_rest: &str) { + // Check that a string consisting of two statements delimited by a semicolon + // parses the same as both statements individually: + let res = parse_sql_statements(&(sql1.to_owned() + ";" + sql2_kw + sql2_rest)); + assert_eq!( + vec![ + one_statement_parses_to(&sql1, ""), + one_statement_parses_to(&(sql2_kw.to_owned() + sql2_rest), ""), + ], + res.unwrap() + ); + // Check that extra semicolon at the end is stripped by normalization: + one_statement_parses_to(&(sql1.to_owned() + ";"), sql1); + // Check that forgetting the semicolon results in an error: + let res = parse_sql_statements(&(sql1.to_owned() + " " + sql2_kw + sql2_rest)); + assert_eq!( + ParserError::ParserError("Expected end of statement, found: ".to_string() + sql2_kw), + res.unwrap_err() + ); + } + test_with("SELECT foo", "SELECT", " bar"); + // ensure that SELECT/WITH is not parsed as a table or column alias if ';' + // separating the statements is omitted: + test_with("SELECT foo FROM baz", "SELECT", " bar"); + test_with("SELECT foo", "WITH", " cte AS (SELECT 1 AS s) SELECT bar"); + test_with( + "SELECT foo FROM baz", + "WITH", + " cte AS (SELECT 1 AS s) SELECT bar", + ); + test_with("DELETE FROM foo", "SELECT", " bar"); + test_with("INSERT INTO foo VALUES(1)", "SELECT", " bar"); + test_with("CREATE TABLE foo (baz int)", "SELECT", " bar"); + // Make sure that empty statements do not cause an error: + let res = parse_sql_statements(";;"); + assert_eq!(0, res.unwrap().len()); +} + +#[test] +fn parse_scalar_subqueries() { + use self::ASTNode::*; + let sql = "(SELECT 1) + (SELECT 2)"; + match verified_expr(sql) { + SQLBinaryExpr { + op: SQLOperator::Plus, .. + //left: box SQLSubquery { .. }, + //right: box SQLSubquery { .. }, + } => assert!(true), + _ => assert!(false), + }; +} + +#[test] +fn parse_create_view() { + let sql = "CREATE VIEW myschema.myview AS SELECT foo FROM bar"; + match verified_stmt(sql) { + SQLStatement::SQLCreateView { + name, + query, + materialized, + } => { + assert_eq!("myschema.myview", name.to_string()); + assert_eq!("SELECT foo FROM bar", query.to_string()); + assert!(!materialized); + } + _ => assert!(false), + } +} + +#[test] +fn parse_create_materialized_view() { + let sql = "CREATE MATERIALIZED VIEW myschema.myview AS SELECT foo FROM bar"; + match verified_stmt(sql) { + SQLStatement::SQLCreateView { + name, + query, + materialized, + } => { + assert_eq!("myschema.myview", name.to_string()); + assert_eq!("SELECT foo FROM bar", query.to_string()); + assert!(materialized); + } + _ => assert!(false), + } +} + +#[test] +fn parse_invalid_subquery_without_parens() { + let res = parse_sql_statements("SELECT SELECT 1 FROM bar WHERE 1=1 FROM baz"); + assert_eq!( + ParserError::ParserError("Expected end of statement, found: 1".to_string()), + res.unwrap_err() + ); +} + +fn only<'a, T>(v: &'a Vec) -> &'a T { + assert_eq!(1, v.len()); + v.first().unwrap() +} + +fn verified_query(query: &str) -> SQLQuery { + match verified_stmt(query) { + SQLStatement::SQLSelect(select) => select, _ => panic!("Expected SELECT"), } } -fn parse_sql(sql: &str) -> ASTNode { - let generic_ast = parse_sql_with(sql, &GenericSqlDialect {}); - let pg_ast = parse_sql_with(sql, &PostgreSqlDialect {}); +fn expr_from_projection(item: &SQLSelectItem) -> &ASTNode { + match item { + SQLSelectItem::UnnamedExpression(expr) => expr, + _ => panic!("Expected UnnamedExpression"), + } +} + +fn verified_only_select(query: &str) -> SQLSelect { + match verified_query(query).body { + SQLSetExpr::Select(s) => s, + _ => panic!("Expected SQLSetExpr::Select"), + } +} + +fn verified_stmt(query: &str) -> SQLStatement { + one_statement_parses_to(query, query) +} + +fn verified_expr(query: &str) -> ASTNode { + let ast = parse_sql_expr(query); + assert_eq!(query, &ast.to_string()); + ast +} + +/// Ensures that `sql` parses as a single statement, optionally checking that +/// converting AST back to string equals to `canonical` (unless an empty string +/// is provided). +fn one_statement_parses_to(sql: &str, canonical: &str) -> SQLStatement { + let mut statements = parse_sql_statements(&sql).unwrap(); + assert_eq!(statements.len(), 1); + + let only_statement = statements.pop().unwrap(); + if !canonical.is_empty() { + assert_eq!(canonical, only_statement.to_string()) + } + only_statement +} + +fn parse_sql_statements(sql: &str) -> Result, ParserError> { + let generic_ast = Parser::parse_sql(&GenericSqlDialect {}, sql.to_string()); + let pg_ast = Parser::parse_sql(&PostgreSqlDialect {}, sql.to_string()); + assert_eq!(generic_ast, pg_ast); + generic_ast +} + +fn parse_sql_expr(sql: &str) -> ASTNode { + let generic_ast = parse_sql_expr_with(&GenericSqlDialect {}, &sql.to_string()); + let pg_ast = parse_sql_expr_with(&PostgreSqlDialect {}, &sql.to_string()); assert_eq!(generic_ast, pg_ast); generic_ast } -fn parse_sql_with(sql: &str, dialect: &Dialect) -> ASTNode { +fn parse_sql_expr_with(dialect: &Dialect, sql: &str) -> ASTNode { let mut tokenizer = Tokenizer::new(dialect, &sql); let tokens = tokenizer.tokenize().unwrap(); let mut parser = Parser::new(tokens); - let ast = parser.parse().unwrap(); + let ast = parser.parse_expr().unwrap(); ast } diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 6b6598c67..80e57176c 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -13,34 +13,25 @@ fn test_prev_index() { let sql: &str = "SELECT version()"; let mut parser = parser(sql); assert_eq!(parser.prev_token(), None); - assert_eq!(parser.next_token(), Some(Token::Keyword("SELECT".into()))); - assert_eq!( - parser.next_token(), - Some(Token::Identifier("version".into())) - ); - assert_eq!( - parser.prev_token(), - Some(Token::Identifier("version".into())) - ); - assert_eq!( - parser.peek_token(), - Some(Token::Identifier("version".into())) - ); - assert_eq!(parser.prev_token(), Some(Token::Keyword("SELECT".into()))); + assert_eq!(parser.next_token(), Some(Token::make_keyword("SELECT"))); + assert_eq!(parser.next_token(), Some(Token::make_word("version", None))); + assert_eq!(parser.prev_token(), Some(Token::make_word("version", None))); + assert_eq!(parser.peek_token(), Some(Token::make_word("version", None))); + assert_eq!(parser.prev_token(), Some(Token::make_keyword("SELECT"))); assert_eq!(parser.prev_token(), None); } #[test] fn parse_simple_insert() { let sql = String::from("INSERT INTO customer VALUES(1, 2, 3)"); - match verified(&sql) { - ASTNode::SQLInsert { + match verified_stmt(&sql) { + SQLStatement::SQLInsert { table_name, columns, values, .. } => { - assert_eq!(table_name, "customer"); + assert_eq!(table_name.to_string(), "customer"); assert!(columns.is_empty()); assert_eq!( vec![vec![ @@ -58,14 +49,14 @@ fn parse_simple_insert() { #[test] fn parse_common_insert() { let sql = String::from("INSERT INTO public.customer VALUES(1, 2, 3)"); - match verified(&sql) { - ASTNode::SQLInsert { + match verified_stmt(&sql) { + SQLStatement::SQLInsert { table_name, columns, values, .. } => { - assert_eq!(table_name, "public.customer"); + assert_eq!(table_name.to_string(), "public.customer"); assert!(columns.is_empty()); assert_eq!( vec![vec![ @@ -83,14 +74,14 @@ fn parse_common_insert() { #[test] fn parse_complex_insert() { let sql = String::from("INSERT INTO db.public.customer VALUES(1, 2, 3)"); - match verified(&sql) { - ASTNode::SQLInsert { + match verified_stmt(&sql) { + SQLStatement::SQLInsert { table_name, columns, values, .. } => { - assert_eq!(table_name, "db.public.customer"); + assert_eq!(table_name.to_string(), "db.public.customer"); assert!(columns.is_empty()); assert_eq!( vec![vec![ @@ -108,21 +99,28 @@ fn parse_complex_insert() { #[test] fn parse_invalid_table_name() { let mut parser = parser("db.public..customer"); - let ast = parser.parse_tablename(); + let ast = parser.parse_object_name(); + assert!(ast.is_err()); +} + +#[test] +fn parse_no_table_name() { + let mut parser = parser(""); + let ast = parser.parse_object_name(); assert!(ast.is_err()); } #[test] fn parse_insert_with_columns() { let sql = String::from("INSERT INTO public.customer (id, name, active) VALUES(1, 2, 3)"); - match verified(&sql) { - ASTNode::SQLInsert { + match verified_stmt(&sql) { + SQLStatement::SQLInsert { table_name, columns, values, .. } => { - assert_eq!(table_name, "public.customer"); + assert_eq!(table_name.to_string(), "public.customer"); assert_eq!( columns, vec!["id".to_string(), "name".to_string(), "active".to_string()] @@ -143,8 +141,7 @@ fn parse_insert_with_columns() { #[test] fn parse_insert_invalid() { let sql = String::from("INSERT public.customer (id, name, active) VALUES (1, 2, 3)"); - let mut parser = parser(&sql); - match parser.parse() { + match Parser::parse_sql(&PostgreSqlDialect {}, sql) { Err(_) => {} _ => assert!(false), } @@ -165,9 +162,9 @@ fn parse_create_table_with_defaults() { last_update timestamp without time zone DEFAULT now() NOT NULL, active integer NOT NULL)", ); - match parse_sql(&sql) { - ASTNode::SQLCreateTable { name, columns } => { - assert_eq!("public.customer", name); + match one_statement_parses_to(&sql, "") { + SQLStatement::SQLCreateTable { name, columns } => { + assert_eq!("public.customer", name.to_string()); assert_eq!(10, columns.len()); let c_name = &columns[0]; @@ -206,10 +203,9 @@ fn parse_create_table_from_pg_dump() { release_year public.year, active integer )"); - let ast = parse_sql(&sql); - match ast { - ASTNode::SQLCreateTable { name, columns } => { - assert_eq!("public.customer", name); + match one_statement_parses_to(&sql, "") { + SQLStatement::SQLCreateTable { name, columns } => { + assert_eq!("public.customer", name.to_string()); let c_customer_id = &columns[0]; assert_eq!("customer_id", c_customer_id.name); @@ -228,7 +224,7 @@ fn parse_create_table_from_pg_dump() { let c_create_date1 = &columns[8]; assert_eq!( - Some(Box::new(ASTNode::SQLCast { + Some(ASTNode::SQLCast { expr: Box::new(ASTNode::SQLCast { expr: Box::new(ASTNode::SQLValue(Value::SingleQuotedString( "now".to_string() @@ -236,13 +232,16 @@ fn parse_create_table_from_pg_dump() { data_type: SQLType::Text }), data_type: SQLType::Date - })), + }), c_create_date1.default ); let c_release_year = &columns[10]; assert_eq!( - SQLType::Custom("public.year".to_string()), + SQLType::Custom(SQLObjectName(vec![ + "public".to_string(), + "year".to_string() + ])), c_release_year.data_type ); } @@ -261,9 +260,9 @@ fn parse_create_table_with_inherit() { use_metric boolean DEFAULT true\ )", ); - match verified(&sql) { - ASTNode::SQLCreateTable { name, columns } => { - assert_eq!("bazaar.settings", name); + match verified_stmt(&sql) { + SQLStatement::SQLCreateTable { name, columns } => { + assert_eq!("bazaar.settings", name.to_string()); let c_name = &columns[0]; assert_eq!("settings_id", c_name.name); @@ -290,9 +289,9 @@ fn parse_alter_table_constraint_primary_key() { ALTER TABLE bazaar.address \ ADD CONSTRAINT address_pkey PRIMARY KEY (address_id)", ); - match verified(&sql) { - ASTNode::SQLAlterTable { name, .. } => { - assert_eq!(name, "bazaar.address"); + match verified_stmt(&sql) { + SQLStatement::SQLAlterTable { name, .. } => { + assert_eq!(name.to_string(), "bazaar.address"); } _ => assert!(false), } @@ -303,9 +302,9 @@ fn parse_alter_table_constraint_foreign_key() { let sql = String::from("\ ALTER TABLE public.customer \ ADD CONSTRAINT customer_address_id_fkey FOREIGN KEY (address_id) REFERENCES public.address(address_id)"); - match verified(&sql) { - ASTNode::SQLAlterTable { name, .. } => { - assert_eq!(name, "public.customer"); + match verified_stmt(&sql) { + SQLStatement::SQLAlterTable { name, .. } => { + assert_eq!(name.to_string(), "public.customer"); } _ => assert!(false), } @@ -333,7 +332,7 @@ Kwara & Kogi PHP ₱ USD $ \N Some other value \\."#); - let ast = parse_sql(&sql); + let ast = one_statement_parses_to(&sql, ""); println!("{:#?}", ast); //assert_eq!(sql, ast.to_string()); } @@ -341,7 +340,7 @@ PHP ₱ USD $ #[test] fn parse_timestamps_example() { let sql = "2016-02-15 09:43:33"; - let _ = parse_sql(sql); + let _ = parse_sql_expr(sql); //TODO add assertion //assert_eq!(sql, ast.to_string()); } @@ -349,7 +348,7 @@ fn parse_timestamps_example() { #[test] fn parse_timestamps_with_millis_example() { let sql = "2017-11-02 19:15:42.308637"; - let _ = parse_sql(sql); + let _ = parse_sql_expr(sql); //TODO add assertion //assert_eq!(sql, ast.to_string()); } @@ -357,27 +356,43 @@ fn parse_timestamps_with_millis_example() { #[test] fn parse_example_value() { let sql = "SARAH.LEWIS@sakilacustomer.org"; - let ast = parse_sql(sql); + let ast = parse_sql_expr(sql); assert_eq!(sql, ast.to_string()); } #[test] fn parse_function_now() { let sql = "now()"; - let ast = parse_sql(sql); + let ast = parse_sql_expr(sql); assert_eq!(sql, ast.to_string()); } -fn verified(query: &str) -> ASTNode { - let ast = parse_sql(query); - assert_eq!(query, &ast.to_string()); - ast +fn verified_stmt(query: &str) -> SQLStatement { + one_statement_parses_to(query, query) +} + +/// Ensures that `sql` parses as a single statement, optionally checking that +/// converting AST back to string equals to `canonical` (unless an empty string +/// is provided). +fn one_statement_parses_to(sql: &str, canonical: &str) -> SQLStatement { + let mut statements = parse_sql_statements(&sql).unwrap(); + assert_eq!(statements.len(), 1); + + let only_statement = statements.pop().unwrap(); + if !canonical.is_empty() { + assert_eq!(canonical, only_statement.to_string()) + } + only_statement +} + +fn parse_sql_statements(sql: &str) -> Result, ParserError> { + Parser::parse_sql(&PostgreSqlDialect {}, sql.to_string()) } -fn parse_sql(sql: &str) -> ASTNode { +fn parse_sql_expr(sql: &str) -> ASTNode { debug!("sql: {}", sql); let mut parser = parser(sql); - let ast = parser.parse().unwrap(); + let ast = parser.parse_expr().unwrap(); ast }