Skip to content

Support OVER clause for window/analytic functions and qualified function names #50

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ println!("AST: {:?}", ast);
This outputs

```rust
AST: [SQLSelect(SQLSelect { projection: [SQLIdentifier("a"), SQLIdentifier("b"), SQLValue(Long(123)), SQLFunction { id: "myfunc", args: [SQLIdentifier("b")] }], relation: Some(Table { name: SQLObjectName(["table_1"]), alias: None }), joins: [], selection: Some(SQLBinaryExpr { left: SQLBinaryExpr { left: SQLIdentifier("a"), op: Gt, right: SQLIdentifier("b") }, op: And, right: SQLBinaryExpr { left: SQLIdentifier("b"), op: Lt, right: SQLValue(Long(100)) } }), order_by: Some([SQLOrderByExpr { expr: SQLIdentifier("a"), asc: Some(false) }, SQLOrderByExpr { expr: SQLIdentifier("b"), asc: None }]), group_by: None, having: None, limit: None })]
AST: [SQLSelect(SQLQuery { ctes: [], body: Select(SQLSelect { distinct: false, projection: [UnnamedExpression(SQLIdentifier("a")), UnnamedExpression(SQLIdentifier("b")), UnnamedExpression(SQLValue(Long(123))), UnnamedExpression(SQLFunction { name: SQLObjectName(["myfunc"]), args: [SQLIdentifier("b")], over: None })], relation: Some(Table { name: SQLObjectName(["table_1"]), alias: None }), joins: [], selection: Some(SQLBinaryExpr { left: SQLBinaryExpr { left: SQLIdentifier("a"), op: Gt, right: SQLIdentifier("b") }, op: And, right: SQLBinaryExpr { left: SQLIdentifier("b"), op: Lt, right: SQLValue(Long(100)) } }), group_by: None, having: None }), order_by: Some([SQLOrderByExpr { expr: SQLIdentifier("a"), asc: Some(false) }, SQLOrderByExpr { expr: SQLIdentifier("b"), asc: None }]), limit: None })]
```

## Design
Expand Down
6 changes: 6 additions & 0 deletions src/dialect/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ keyword!(
FIRST_VALUE,
FLOAT,
FLOOR,
FOLLOWING,
FOR,
FOREIGN,
FRAME_ROW,
Expand Down Expand Up @@ -246,6 +247,7 @@ keyword!(
POSITION_REGEX,
POWER,
PRECEDES,
PRECEDING,
PRECISION,
PREPARE,
PRIMARY,
Expand Down Expand Up @@ -333,6 +335,7 @@ keyword!(
TRIM_ARRAY,
TRUE,
UESCAPE,
UNBOUNDED,
UNION,
UNIQUE,
UNKNOWN,
Expand Down Expand Up @@ -488,6 +491,7 @@ pub const ALL_KEYWORDS: &'static [&'static str] = &[
FIRST_VALUE,
FLOAT,
FLOOR,
FOLLOWING,
FOR,
FOREIGN,
FRAME_ROW,
Expand Down Expand Up @@ -595,6 +599,7 @@ pub const ALL_KEYWORDS: &'static [&'static str] = &[
POSITION_REGEX,
POWER,
PRECEDES,
PRECEDING,
PRECISION,
PREPARE,
PRIMARY,
Expand Down Expand Up @@ -682,6 +687,7 @@ pub const ALL_KEYWORDS: &'static [&'static str] = &[
TRIM_ARRAY,
TRUE,
UESCAPE,
UNBOUNDED,
UNION,
UNIQUE,
UNKNOWN,
Expand Down
189 changes: 140 additions & 49 deletions src/sqlast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ pub use self::value::Value;

pub use self::sql_operator::SQLOperator;

/// Like `vec.join(", ")`, but for any types implementing ToString.
fn comma_separated_string<T: ToString>(vec: &[T]) -> String {
vec.iter()
.map(T::to_string)
.collect::<Vec<String>>()
.join(", ")
}

/// Identifier name, in the originally quoted form (e.g. `"id"`)
pub type SQLIdent = String;

Expand All @@ -46,7 +54,7 @@ pub enum ASTNode {
/// Qualified wildcard, e.g. `alias.*` or `schema.table.*`.
/// (Same caveats apply to SQLQualifiedWildcard as to SQLWildcard.)
SQLQualifiedWildcard(Vec<SQLIdent>),
/// Multi part identifier e.g. `myschema.dbo.mytable`
/// Multi-part identifier, e.g. `table_alias.column` or `schema.table.col`
SQLCompoundIdentifier(Vec<SQLIdent>),
/// `IS NULL` expression
SQLIsNull(Box<ASTNode>),
Expand Down Expand Up @@ -92,8 +100,11 @@ pub enum ASTNode {
/// SQLValue
SQLValue(Value),
/// Scalar function call e.g. `LEFT(foo, 5)`
/// TODO: this can be a compound SQLObjectName as well (for UDFs)
SQLFunction { id: SQLIdent, args: Vec<ASTNode> },
SQLFunction {
name: SQLObjectName,
args: Vec<ASTNode>,
over: Option<SQLWindowSpec>,
},
/// CASE [<operand>] WHEN <condition> THEN <result> ... [ELSE <result>] END
SQLCase {
// TODO: support optional operand for "simple case"
Expand Down Expand Up @@ -123,10 +134,7 @@ impl ToString for ASTNode {
"{} {}IN ({})",
expr.as_ref().to_string(),
if *negated { "NOT " } else { "" },
list.iter()
.map(|a| a.to_string())
.collect::<Vec<String>>()
.join(", ")
comma_separated_string(list)
),
ASTNode::SQLInSubquery {
expr,
Expand Down Expand Up @@ -166,14 +174,13 @@ impl ToString for ASTNode {
format!("{} {}", operator.to_string(), expr.as_ref().to_string())
}
ASTNode::SQLValue(v) => v.to_string(),
ASTNode::SQLFunction { id, args } => format!(
"{}({})",
id,
args.iter()
.map(|a| a.to_string())
.collect::<Vec<String>>()
.join(", ")
),
ASTNode::SQLFunction { name, args, over } => {
let mut s = format!("{}({})", name.to_string(), comma_separated_string(args));
if let Some(o) = over {
s += &format!(" OVER ({})", o.to_string())
}
s
}
ASTNode::SQLCase {
conditions,
results,
Expand All @@ -198,6 +205,116 @@ impl ToString for ASTNode {
}
}

/// A window specification (i.e. `OVER (PARTITION BY .. ORDER BY .. etc.)`)
#[derive(Debug, Clone, PartialEq)]
pub struct SQLWindowSpec {
pub partition_by: Vec<ASTNode>,
pub order_by: Vec<SQLOrderByExpr>,
pub window_frame: Option<SQLWindowFrame>,
}

impl ToString for SQLWindowSpec {
fn to_string(&self) -> String {
let mut clauses = vec![];
if !self.partition_by.is_empty() {
clauses.push(format!(
"PARTITION BY {}",
comma_separated_string(&self.partition_by)
))
};
if !self.order_by.is_empty() {
clauses.push(format!(
"ORDER BY {}",
comma_separated_string(&self.order_by)
))
};
if let Some(window_frame) = &self.window_frame {
if let Some(end_bound) = &window_frame.end_bound {
clauses.push(format!(
"{} BETWEEN {} AND {}",
window_frame.units.to_string(),
window_frame.start_bound.to_string(),
end_bound.to_string()
));
} else {
clauses.push(format!(
"{} {}",
window_frame.units.to_string(),
window_frame.start_bound.to_string()
));
}
}
clauses.join(" ")
}
}

/// Specifies the data processed by a window function, e.g.
/// `RANGE UNBOUNDED PRECEDING` or `ROWS BETWEEN 5 PRECEDING AND CURRENT ROW`.
#[derive(Debug, Clone, PartialEq)]
pub struct SQLWindowFrame {
pub units: SQLWindowFrameUnits,
pub start_bound: SQLWindowFrameBound,
/// The right bound of the `BETWEEN .. AND` clause.
pub end_bound: Option<SQLWindowFrameBound>,
// TBD: EXCLUDE
}

#[derive(Debug, Clone, PartialEq)]
pub enum SQLWindowFrameUnits {
Rows,
Range,
Groups,
}

impl ToString for SQLWindowFrameUnits {
fn to_string(&self) -> String {
match self {
SQLWindowFrameUnits::Rows => "ROWS".to_string(),
SQLWindowFrameUnits::Range => "RANGE".to_string(),
SQLWindowFrameUnits::Groups => "GROUPS".to_string(),
}
}
}

impl FromStr for SQLWindowFrameUnits {
type Err = ParserError;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"ROWS" => Ok(SQLWindowFrameUnits::Rows),
"RANGE" => Ok(SQLWindowFrameUnits::Range),
"GROUPS" => Ok(SQLWindowFrameUnits::Groups),
_ => Err(ParserError::ParserError(format!(
"Expected ROWS, RANGE, or GROUPS, found: {}",
s
))),
}
}
}

#[derive(Debug, Clone, PartialEq)]
pub enum SQLWindowFrameBound {
/// "CURRENT ROW"
CurrentRow,
/// "<N> PRECEDING" or "UNBOUNDED PRECEDING"
Preceding(Option<u64>),
/// "<N> FOLLOWING" or "UNBOUNDED FOLLOWING". This can only appear in
/// SQLWindowFrame::end_bound.
Following(Option<u64>),
}

impl ToString for SQLWindowFrameBound {
fn to_string(&self) -> String {
match self {
SQLWindowFrameBound::CurrentRow => "CURRENT ROW".to_string(),
SQLWindowFrameBound::Preceding(None) => "UNBOUNDED PRECEDING".to_string(),
SQLWindowFrameBound::Following(None) => "UNBOUNDED FOLLOWING".to_string(),
SQLWindowFrameBound::Preceding(Some(n)) => format!("{} PRECEDING", n),
SQLWindowFrameBound::Following(Some(n)) => format!("{} FOLLOWING", n),
}
}
}

/// A top-level statement (SELECT, INSERT, CREATE, etc.)
#[derive(Debug, Clone, PartialEq)]
pub enum SQLStatement {
Expand Down Expand Up @@ -279,11 +396,7 @@ impl ToString for SQLStatement {
" VALUES({})",
values
.iter()
.map(|row| row
.iter()
.map(|c| c.to_string())
.collect::<Vec<String>>()
.join(", "))
.map(|row| comma_separated_string(row))
.collect::<Vec<String>>()
.join(", ")
);
Expand All @@ -296,15 +409,8 @@ impl ToString for SQLStatement {
values,
} => {
let mut s = format!("COPY {}", table_name.to_string());
if columns.len() > 0 {
s += &format!(
" ({})",
columns
.iter()
.map(|c| c.to_string())
.collect::<Vec<String>>()
.join(", ")
);
if !columns.is_empty() {
s += &format!(" ({})", comma_separated_string(columns));
}
s += " FROM stdin; ";
if values.len() > 0 {
Expand All @@ -326,15 +432,8 @@ impl ToString for SQLStatement {
selection,
} => {
let mut s = format!("UPDATE {}", table_name.to_string());
if assignments.len() > 0 {
s += &format!(
"{}",
assignments
.iter()
.map(|ass| ass.to_string())
.collect::<Vec<String>>()
.join(", ")
);
if !assignments.is_empty() {
s += &comma_separated_string(assignments);
}
if let Some(selection) = selection {
s += &format!(" WHERE {}", selection.to_string());
Expand Down Expand Up @@ -373,12 +472,8 @@ impl ToString for SQLStatement {
} if *external => format!(
"CREATE EXTERNAL TABLE {} ({}) STORED AS {} LOCATION '{}'",
name.to_string(),
columns
.iter()
.map(|c| c.to_string())
.collect::<Vec<String>>()
.join(", "),
file_format.as_ref().map(|f| f.to_string()).unwrap(),
comma_separated_string(columns),
file_format.as_ref().unwrap().to_string(),
location.as_ref().unwrap()
),
SQLStatement::SQLCreateTable {
Expand All @@ -390,11 +485,7 @@ impl ToString for SQLStatement {
} => format!(
"CREATE TABLE {} ({})",
name.to_string(),
columns
.iter()
.map(|c| c.to_string())
.collect::<Vec<String>>()
.join(", ")
comma_separated_string(columns)
),
SQLStatement::SQLAlterTable { name, operation } => {
format!("ALTER TABLE {} {}", name.to_string(), operation.to_string())
Expand Down
28 changes: 5 additions & 23 deletions src/sqlast/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,7 @@ impl ToString for SQLQuery {
}
s += &self.body.to_string();
if let Some(ref order_by) = self.order_by {
s += &format!(
" ORDER BY {}",
order_by
.iter()
.map(|o| o.to_string())
.collect::<Vec<String>>()
.join(", ")
);
s += &format!(" ORDER BY {}", comma_separated_string(order_by));
}
if let Some(ref limit) = self.limit {
s += &format!(" LIMIT {}", limit.to_string());
Expand Down Expand Up @@ -130,11 +123,7 @@ impl ToString for SQLSelect {
let mut s = format!(
"SELECT{} {}",
if self.distinct { " DISTINCT" } else { "" },
self.projection
.iter()
.map(|p| p.to_string())
.collect::<Vec<String>>()
.join(", ")
comma_separated_string(&self.projection)
);
if let Some(ref relation) = self.relation {
s += &format!(" FROM {}", relation.to_string());
Expand All @@ -146,14 +135,7 @@ impl ToString for SQLSelect {
s += &format!(" WHERE {}", selection.to_string());
}
if let Some(ref group_by) = self.group_by {
s += &format!(
" GROUP BY {}",
group_by
.iter()
.map(|g| g.to_string())
.collect::<Vec<String>>()
.join(", ")
);
s += &format!(" GROUP BY {}", comma_separated_string(group_by));
}
if let Some(ref having) = self.having {
s += &format!(" HAVING {}", having.to_string());
Expand All @@ -175,7 +157,7 @@ pub enum SQLSelectItem {
/// Any expression, not followed by `[ AS ] alias`
UnnamedExpression(ASTNode),
/// An expression, followed by `[ AS ] alias`
ExpressionWithAlias(ASTNode, SQLIdent),
ExpressionWithAlias { expr: ASTNode, alias: SQLIdent },
/// `alias.*` or even `schema.table.*`
QualifiedWildcard(SQLObjectName),
/// An unqualified `*`
Expand All @@ -186,7 +168,7 @@ impl ToString for SQLSelectItem {
fn to_string(&self) -> String {
match &self {
SQLSelectItem::UnnamedExpression(expr) => expr.to_string(),
SQLSelectItem::ExpressionWithAlias(expr, alias) => {
SQLSelectItem::ExpressionWithAlias { expr, alias } => {
format!("{} AS {}", expr.to_string(), alias)
}
SQLSelectItem::QualifiedWildcard(prefix) => format!("{}.*", prefix.to_string()),
Expand Down
Loading