Skip to content

Commit 9a87256

Browse files
committed
Fix DoubleColonCast skipping AT TIME ZONE apache#1266
1 parent eb36bd7 commit 9a87256

File tree

3 files changed

+110
-74
lines changed

3 files changed

+110
-74
lines changed

src/ast/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ pub enum CastKind {
424424
/// See <https://cloud.google.com/bigquery/docs/reference/standard-sql/functions-and-operators#safe_casting>.
425425
SafeCast,
426426
/// `<expr> :: <datatype>`
427-
DoubleColon,
427+
DoubleColonCast,
428428
}
429429

430430
/// An SQL expression of any type.
@@ -1073,7 +1073,7 @@ impl fmt::Display for Expr {
10731073
write!(f, "SAFE_CAST({expr} AS {data_type})")
10741074
}
10751075
}
1076-
CastKind::DoubleColon => {
1076+
CastKind::DoubleColonCast => {
10771077
write!(f, "{expr}::{data_type}")
10781078
}
10791079
},

src/parser/mod.rs

Lines changed: 87 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use core::fmt;
2424

2525
use log::debug;
2626

27+
use recursion::RecursionCounter;
2728
use IsLateral::*;
2829
use IsOptional::*;
2930

@@ -114,6 +115,7 @@ mod recursion {
114115
Self { remaining_depth }
115116
}
116117
}
118+
117119
impl Drop for DepthGuard {
118120
fn drop(&mut self) {
119121
let old_value = self.remaining_depth.get();
@@ -143,8 +145,6 @@ mod recursion {
143145
pub struct DepthGuard {}
144146
}
145147

146-
use recursion::RecursionCounter;
147-
148148
#[derive(PartialEq, Eq)]
149149
pub enum IsOptional {
150150
Optional,
@@ -995,17 +995,17 @@ impl<'a> Parser<'a> {
995995
| Keyword::CURRENT_USER
996996
| Keyword::SESSION_USER
997997
| Keyword::USER
998-
if dialect_of!(self is PostgreSqlDialect | GenericDialect) =>
999-
{
1000-
Ok(Expr::Function(Function {
1001-
name: ObjectName(vec![w.to_ident()]),
1002-
args: FunctionArguments::None,
1003-
null_treatment: None,
1004-
filter: None,
1005-
over: None,
1006-
within_group: vec![],
1007-
}))
1008-
}
998+
if dialect_of!(self is PostgreSqlDialect | GenericDialect) =>
999+
{
1000+
Ok(Expr::Function(Function {
1001+
name: ObjectName(vec![w.to_ident()]),
1002+
args: FunctionArguments::None,
1003+
null_treatment: None,
1004+
filter: None,
1005+
over: None,
1006+
within_group: vec![],
1007+
}))
1008+
}
10091009
Keyword::CURRENT_TIMESTAMP
10101010
| Keyword::CURRENT_TIME
10111011
| Keyword::CURRENT_DATE
@@ -1019,18 +1019,18 @@ impl<'a> Parser<'a> {
10191019
Keyword::TRY_CAST => self.parse_cast_expr(CastKind::TryCast),
10201020
Keyword::SAFE_CAST => self.parse_cast_expr(CastKind::SafeCast),
10211021
Keyword::EXISTS
1022-
// Support parsing Databricks has a function named `exists`.
1023-
if !dialect_of!(self is DatabricksDialect)
1024-
|| matches!(
1022+
// Support parsing Databricks has a function named `exists`.
1023+
if !dialect_of!(self is DatabricksDialect)
1024+
|| matches!(
10251025
self.peek_nth_token(1).token,
10261026
Token::Word(Word {
10271027
keyword: Keyword::SELECT | Keyword::WITH,
10281028
..
10291029
})
10301030
) =>
1031-
{
1032-
self.parse_exists_expr(false)
1033-
}
1031+
{
1032+
self.parse_exists_expr(false)
1033+
}
10341034
Keyword::EXTRACT => self.parse_extract_expr(),
10351035
Keyword::CEIL => self.parse_ceil_floor_expr(true),
10361036
Keyword::FLOOR => self.parse_ceil_floor_expr(false),
@@ -1047,21 +1047,21 @@ impl<'a> Parser<'a> {
10471047
self.parse_array_expr(true)
10481048
}
10491049
Keyword::ARRAY
1050-
if self.peek_token() == Token::LParen
1051-
&& !dialect_of!(self is ClickHouseDialect | DatabricksDialect) =>
1052-
{
1053-
self.expect_token(&Token::LParen)?;
1054-
let query = self.parse_boxed_query()?;
1055-
self.expect_token(&Token::RParen)?;
1056-
Ok(Expr::Function(Function {
1057-
name: ObjectName(vec![w.to_ident()]),
1058-
args: FunctionArguments::Subquery(query),
1059-
filter: None,
1060-
null_treatment: None,
1061-
over: None,
1062-
within_group: vec![],
1063-
}))
1064-
}
1050+
if self.peek_token() == Token::LParen
1051+
&& !dialect_of!(self is ClickHouseDialect | DatabricksDialect) =>
1052+
{
1053+
self.expect_token(&Token::LParen)?;
1054+
let query = self.parse_boxed_query()?;
1055+
self.expect_token(&Token::RParen)?;
1056+
Ok(Expr::Function(Function {
1057+
name: ObjectName(vec![w.to_ident()]),
1058+
args: FunctionArguments::Subquery(query),
1059+
filter: None,
1060+
null_treatment: None,
1061+
over: None,
1062+
within_group: vec![],
1063+
}))
1064+
}
10651065
Keyword::NOT => self.parse_not(),
10661066
Keyword::MATCH if dialect_of!(self is MySqlDialect | GenericDialect) => {
10671067
self.parse_match_against()
@@ -1129,13 +1129,13 @@ impl<'a> Parser<'a> {
11291129
Token::SingleQuotedString(_)
11301130
| Token::DoubleQuotedString(_)
11311131
| Token::HexStringLiteral(_)
1132-
if w.value.starts_with('_') =>
1133-
{
1134-
Ok(Expr::IntroducedString {
1135-
introducer: w.value,
1136-
value: self.parse_introduced_string_value()?,
1137-
})
1138-
}
1132+
if w.value.starts_with('_') =>
1133+
{
1134+
Ok(Expr::IntroducedString {
1135+
introducer: w.value,
1136+
value: self.parse_introduced_string_value()?,
1137+
})
1138+
}
11391139
Token::Arrow if self.dialect.supports_lambda_functions() => {
11401140
self.expect_token(&Token::Arrow)?;
11411141
return Ok(Expr::Lambda(LambdaFunction {
@@ -1222,7 +1222,7 @@ impl<'a> Parser<'a> {
12221222
return parser_err!(
12231223
format!("Expected identifier, found: {tok}"),
12241224
tok.location
1225-
)
1225+
);
12261226
}
12271227
};
12281228
Ok(Expr::CompositeAccess {
@@ -1528,19 +1528,23 @@ impl<'a> Parser<'a> {
15281528
pub fn parse_optional_cast_format(&mut self) -> Result<Option<CastFormat>, ParserError> {
15291529
if self.parse_keyword(Keyword::FORMAT) {
15301530
let value = self.parse_value()?;
1531-
if self.parse_keywords(&[Keyword::AT, Keyword::TIME, Keyword::ZONE]) {
1532-
Ok(Some(CastFormat::ValueAtTimeZone(
1533-
value,
1534-
self.parse_value()?,
1535-
)))
1536-
} else {
1537-
Ok(Some(CastFormat::Value(value)))
1531+
match self.parse_optional_time_zone()? {
1532+
Some(tz) => Ok(Some(CastFormat::ValueAtTimeZone(value, tz))),
1533+
None => Ok(Some(CastFormat::Value(value))),
15381534
}
15391535
} else {
15401536
Ok(None)
15411537
}
15421538
}
15431539

1540+
pub fn parse_optional_time_zone(&mut self) -> Result<Option<Value>, ParserError> {
1541+
if self.parse_keywords(&[Keyword::AT, Keyword::TIME, Keyword::ZONE]) {
1542+
self.parse_value().map(Some)
1543+
} else {
1544+
Ok(None)
1545+
}
1546+
}
1547+
15441548
/// mssql-like convert function
15451549
fn parse_mssql_convert(&mut self) -> Result<Expr, ParserError> {
15461550
self.expect_token(&Token::LParen)?;
@@ -2458,8 +2462,6 @@ impl<'a> Parser<'a> {
24582462
}
24592463
}
24602464
Keyword::AT => {
2461-
// if self.parse_keyword(Keyword::TIME) {
2462-
// self.expect_keyword(Keyword::ZONE)?;
24632465
if self.parse_keywords(&[Keyword::TIME, Keyword::ZONE]) {
24642466
let time_zone = self.next_token();
24652467
match time_zone.token {
@@ -2534,12 +2536,35 @@ impl<'a> Parser<'a> {
25342536
),
25352537
}
25362538
} else if Token::DoubleColon == tok {
2537-
Ok(Expr::Cast {
2538-
kind: CastKind::DoubleColon,
2539+
let data_type = self.parse_data_type()?;
2540+
2541+
let cast_expr = Expr::Cast {
2542+
kind: CastKind::DoubleColonCast,
25392543
expr: Box::new(expr),
2540-
data_type: self.parse_data_type()?,
2544+
data_type: data_type.clone(),
25412545
format: None,
2542-
})
2546+
};
2547+
2548+
match data_type {
2549+
DataType::Date
2550+
| DataType::Datetime(_)
2551+
| DataType::Timestamp(_, _)
2552+
| DataType::Time(_, _) => {
2553+
let value = self.parse_optional_time_zone()?;
2554+
match value {
2555+
Some(Value::SingleQuotedString(tz)) => Ok(Expr::AtTimeZone {
2556+
timestamp: Box::new(cast_expr),
2557+
time_zone: tz,
2558+
}),
2559+
None => Ok(cast_expr),
2560+
_ => Err(ParserError::ParserError(format!(
2561+
"Expected Token::SingleQuotedString after AT TIME ZONE, but found: {}",
2562+
value.unwrap()
2563+
))),
2564+
}
2565+
}
2566+
_ => Ok(cast_expr),
2567+
}
25432568
} else if Token::ExclamationMark == tok {
25442569
// PostgreSQL factorial operation
25452570
Ok(Expr::UnaryOp {
@@ -2738,16 +2763,6 @@ impl<'a> Parser<'a> {
27382763
})
27392764
}
27402765

2741-
/// Parse a postgresql casting style which is in the form of `expr::datatype`
2742-
pub fn parse_pg_cast(&mut self, expr: Expr) -> Result<Expr, ParserError> {
2743-
Ok(Expr::Cast {
2744-
kind: CastKind::DoubleColon,
2745-
expr: Box::new(expr),
2746-
data_type: self.parse_data_type()?,
2747-
format: None,
2748-
})
2749-
}
2750-
27512766
// use https://www.postgresql.org/docs/7.0/operators.htm#AEN2026 as a reference
27522767
// higher number = higher precedence
27532768
const MUL_DIV_MOD_OP_PREC: u8 = 40;
@@ -2967,7 +2982,7 @@ impl<'a> Parser<'a> {
29672982
token => {
29682983
return token
29692984
.cloned()
2970-
.unwrap_or_else(|| TokenWithLocation::wrap(Token::EOF))
2985+
.unwrap_or_else(|| TokenWithLocation::wrap(Token::EOF));
29712986
}
29722987
}
29732988
}
@@ -6968,12 +6983,12 @@ impl<'a> Parser<'a> {
69686983
Token::EOF => {
69696984
return Err(ParserError::ParserError(
69706985
"Empty input when parsing identifier".to_string(),
6971-
))?
6986+
))?;
69726987
}
69736988
token => {
69746989
return Err(ParserError::ParserError(format!(
69756990
"Unexpected token in identifier: {token}"
6976-
)))?
6991+
)))?;
69776992
}
69786993
};
69796994

@@ -6986,19 +7001,19 @@ impl<'a> Parser<'a> {
69867001
Token::EOF => {
69877002
return Err(ParserError::ParserError(
69887003
"Trailing period in identifier".to_string(),
6989-
))?
7004+
))?;
69907005
}
69917006
token => {
69927007
return Err(ParserError::ParserError(format!(
69937008
"Unexpected token following period in identifier: {token}"
6994-
)))?
7009+
)))?;
69957010
}
69967011
},
69977012
Token::EOF => break,
69987013
token => {
69997014
return Err(ParserError::ParserError(format!(
70007015
"Unexpected token in identifier: {token}"
7001-
)))?
7016+
)))?;
70027017
}
70037018
}
70047019
}
@@ -8371,7 +8386,7 @@ impl<'a> Parser<'a> {
83718386
_ => {
83728387
return Err(ParserError::ParserError(format!(
83738388
"expected OUTER, SEMI, ANTI or JOIN after {kw:?}"
8374-
)))
8389+
)));
83758390
}
83768391
}
83778392
}

tests/sqlparser_common.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7022,6 +7022,27 @@ fn parse_set_variable() {
70227022
one_statement_parses_to("SET SOMETHING TO '1'", "SET SOMETHING = '1'");
70237023
}
70247024

7025+
#[test]
7026+
fn parse_double_colon_cast_at_timezone() {
7027+
let sql = "SELECT '2001-01-01T00:00:00.000Z'::TIMESTAMP AT TIME ZONE 'Europe/Brussels' FROM t";
7028+
let select = verified_only_select(sql);
7029+
7030+
assert_eq!(
7031+
&Expr::AtTimeZone {
7032+
timestamp: Box::new(Expr::Cast {
7033+
kind: CastKind::DoubleColonCast,
7034+
expr: Box::new(Expr::Value(Value::SingleQuotedString(
7035+
"2001-01-01T00:00:00.000Z".to_string()
7036+
),)),
7037+
data_type: DataType::Timestamp(None, TimezoneInfo::None),
7038+
format: None
7039+
}),
7040+
time_zone: "Europe/Brussels".to_string()
7041+
},
7042+
expr_from_projection(only(&select.projection)),
7043+
);
7044+
}
7045+
70257046
#[test]
70267047
fn parse_set_time_zone() {
70277048
match verified_stmt("SET TIMEZONE = 'UTC'") {

0 commit comments

Comments
 (0)