Skip to content

Commit d9d69a2

Browse files
authored
Databricks: support for lambda functions (#1257)
1 parent a86c58b commit d9d69a2

File tree

6 files changed

+210
-5
lines changed

6 files changed

+210
-5
lines changed

src/ast/mod.rs

+54
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,59 @@ pub enum Expr {
793793
OuterJoin(Box<Expr>),
794794
/// A reference to the prior level in a CONNECT BY clause.
795795
Prior(Box<Expr>),
796+
/// A lambda function.
797+
///
798+
/// Syntax:
799+
/// ```plaintext
800+
/// param -> expr | (param1, ...) -> expr
801+
/// ```
802+
///
803+
/// See <https://docs.databricks.com/en/sql/language-manual/sql-ref-lambda-functions.html>.
804+
Lambda(LambdaFunction),
805+
}
806+
807+
/// A lambda function.
808+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
809+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
810+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
811+
pub struct LambdaFunction {
812+
/// The parameters to the lambda function.
813+
pub params: OneOrManyWithParens<Ident>,
814+
/// The body of the lambda function.
815+
pub body: Box<Expr>,
816+
}
817+
818+
impl fmt::Display for LambdaFunction {
819+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
820+
write!(f, "{} -> {}", self.params, self.body)
821+
}
822+
}
823+
824+
/// Encapsulates the common pattern in SQL where either one unparenthesized item
825+
/// such as an identifier or expression is permitted, or multiple of the same
826+
/// item in a parenthesized list.
827+
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
828+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
829+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
830+
pub enum OneOrManyWithParens<T> {
831+
/// A single `T`, unparenthesized.
832+
One(T),
833+
/// One or more `T`s, parenthesized.
834+
Many(Vec<T>),
835+
}
836+
837+
impl<T> fmt::Display for OneOrManyWithParens<T>
838+
where
839+
T: fmt::Display,
840+
{
841+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
842+
match self {
843+
OneOrManyWithParens::One(value) => write!(f, "{value}"),
844+
OneOrManyWithParens::Many(values) => {
845+
write!(f, "({})", display_comma_separated(values))
846+
}
847+
}
848+
}
796849
}
797850

798851
impl fmt::Display for CastFormat {
@@ -1241,6 +1294,7 @@ impl fmt::Display for Expr {
12411294
write!(f, "{expr} (+)")
12421295
}
12431296
Expr::Prior(expr) => write!(f, "PRIOR {expr}"),
1297+
Expr::Lambda(lambda) => write!(f, "{lambda}"),
12441298
}
12451299
}
12461300
}

src/dialect/databricks.rs

+4
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,8 @@ impl Dialect for DatabricksDialect {
2929
fn supports_group_by_expr(&self) -> bool {
3030
true
3131
}
32+
33+
fn supports_lambda_functions(&self) -> bool {
34+
true
35+
}
3236
}

src/dialect/mod.rs

+8
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,14 @@ pub trait Dialect: Debug + Any {
209209
fn supports_dictionary_syntax(&self) -> bool {
210210
false
211211
}
212+
/// Returns true if the dialect supports lambda functions, for example:
213+
///
214+
/// ```sql
215+
/// SELECT transform(array(1, 2, 3), x -> x + 1); -- returns [2,3,4]
216+
/// ```
217+
fn supports_lambda_functions(&self) -> bool {
218+
false
219+
}
212220
/// Returns true if the dialect has a CONVERT function which accepts a type first
213221
/// and an expression second, e.g. `CONVERT(varchar, 1)`
214222
fn convert_type_before_value(&self) -> bool {

src/parser/mod.rs

+39-2
Original file line numberDiff line numberDiff line change
@@ -1018,7 +1018,19 @@ impl<'a> Parser<'a> {
10181018
Keyword::CAST => self.parse_cast_expr(CastKind::Cast),
10191019
Keyword::TRY_CAST => self.parse_cast_expr(CastKind::TryCast),
10201020
Keyword::SAFE_CAST => self.parse_cast_expr(CastKind::SafeCast),
1021-
Keyword::EXISTS => self.parse_exists_expr(false),
1021+
Keyword::EXISTS
1022+
// Support parsing Databricks has a function named `exists`.
1023+
if !dialect_of!(self is DatabricksDialect)
1024+
|| matches!(
1025+
self.peek_nth_token(1).token,
1026+
Token::Word(Word {
1027+
keyword: Keyword::SELECT | Keyword::WITH,
1028+
..
1029+
})
1030+
) =>
1031+
{
1032+
self.parse_exists_expr(false)
1033+
}
10221034
Keyword::EXTRACT => self.parse_extract_expr(),
10231035
Keyword::CEIL => self.parse_ceil_floor_expr(true),
10241036
Keyword::FLOOR => self.parse_ceil_floor_expr(false),
@@ -1036,7 +1048,7 @@ impl<'a> Parser<'a> {
10361048
}
10371049
Keyword::ARRAY
10381050
if self.peek_token() == Token::LParen
1039-
&& !dialect_of!(self is ClickHouseDialect) =>
1051+
&& !dialect_of!(self is ClickHouseDialect | DatabricksDialect) =>
10401052
{
10411053
self.expect_token(&Token::LParen)?;
10421054
let query = self.parse_boxed_query()?;
@@ -1124,6 +1136,13 @@ impl<'a> Parser<'a> {
11241136
value: self.parse_introduced_string_value()?,
11251137
})
11261138
}
1139+
Token::Arrow if self.dialect.supports_lambda_functions() => {
1140+
self.expect_token(&Token::Arrow)?;
1141+
return Ok(Expr::Lambda(LambdaFunction {
1142+
params: OneOrManyWithParens::One(w.to_ident()),
1143+
body: Box::new(self.parse_expr()?),
1144+
}));
1145+
}
11271146
_ => Ok(Expr::Identifier(w.to_ident())),
11281147
},
11291148
}, // End of Token::Word
@@ -1182,6 +1201,8 @@ impl<'a> Parser<'a> {
11821201
if self.parse_keyword(Keyword::SELECT) || self.parse_keyword(Keyword::WITH) {
11831202
self.prev_token();
11841203
Expr::Subquery(self.parse_boxed_query()?)
1204+
} else if let Some(lambda) = self.try_parse_lambda() {
1205+
return Ok(lambda);
11851206
} else {
11861207
let exprs = self.parse_comma_separated(Parser::parse_expr)?;
11871208
match exprs.len() {
@@ -1231,6 +1252,22 @@ impl<'a> Parser<'a> {
12311252
}
12321253
}
12331254

1255+
fn try_parse_lambda(&mut self) -> Option<Expr> {
1256+
if !self.dialect.supports_lambda_functions() {
1257+
return None;
1258+
}
1259+
self.maybe_parse(|p| {
1260+
let params = p.parse_comma_separated(|p| p.parse_identifier(false))?;
1261+
p.expect_token(&Token::RParen)?;
1262+
p.expect_token(&Token::Arrow)?;
1263+
let expr = p.parse_expr()?;
1264+
Ok(Expr::Lambda(LambdaFunction {
1265+
params: OneOrManyWithParens::Many(params),
1266+
body: Box::new(expr),
1267+
}))
1268+
})
1269+
}
1270+
12341271
pub fn parse_function(&mut self, name: ObjectName) -> Result<Expr, ParserError> {
12351272
self.expect_token(&Token::LParen)?;
12361273

tests/sqlparser_common.rs

+9-3
Original file line numberDiff line numberDiff line change
@@ -1380,7 +1380,11 @@ fn pg_and_generic() -> TestedDialects {
13801380
fn parse_json_ops_without_colon() {
13811381
use self::BinaryOperator::*;
13821382
let binary_ops = [
1383-
("->", Arrow, all_dialects()),
1383+
(
1384+
"->",
1385+
Arrow,
1386+
all_dialects_except(|d| d.supports_lambda_functions()),
1387+
),
13841388
("->>", LongArrow, all_dialects()),
13851389
("#>", HashArrow, pg_and_generic()),
13861390
("#>>", HashLongArrow, pg_and_generic()),
@@ -6174,15 +6178,17 @@ fn parse_exists_subquery() {
61746178
verified_stmt("SELECT * FROM t WHERE EXISTS (WITH u AS (SELECT 1) SELECT * FROM u)");
61756179
verified_stmt("SELECT EXISTS (SELECT 1)");
61766180

6177-
let res = parse_sql_statements("SELECT EXISTS (");
6181+
let res = all_dialects_except(|d| d.is::<DatabricksDialect>())
6182+
.parse_sql_statements("SELECT EXISTS (");
61786183
assert_eq!(
61796184
ParserError::ParserError(
61806185
"Expected SELECT, VALUES, or a subquery in the query body, found: EOF".to_string()
61816186
),
61826187
res.unwrap_err(),
61836188
);
61846189

6185-
let res = parse_sql_statements("SELECT EXISTS (NULL)");
6190+
let res = all_dialects_except(|d| d.is::<DatabricksDialect>())
6191+
.parse_sql_statements("SELECT EXISTS (NULL)");
61866192
assert_eq!(
61876193
ParserError::ParserError(
61886194
"Expected SELECT, VALUES, or a subquery in the query body, found: NULL".to_string()

tests/sqlparser_databricks.rs

+96
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use sqlparser::ast::*;
22
use sqlparser::dialect::DatabricksDialect;
3+
use sqlparser::parser::ParserError;
34
use test_utils::*;
45

56
#[macro_use]
@@ -28,3 +29,98 @@ fn test_databricks_identifiers() {
2829
SelectItem::UnnamedExpr(Expr::Value(Value::DoubleQuotedString("Ä".to_owned())))
2930
);
3031
}
32+
33+
#[test]
34+
fn test_databricks_exists() {
35+
// exists is a function in databricks
36+
assert_eq!(
37+
databricks().verified_expr("exists(array(1, 2, 3), x -> x IS NULL)"),
38+
call(
39+
"exists",
40+
[
41+
call(
42+
"array",
43+
[
44+
Expr::Value(number("1")),
45+
Expr::Value(number("2")),
46+
Expr::Value(number("3"))
47+
]
48+
),
49+
Expr::Lambda(LambdaFunction {
50+
params: OneOrManyWithParens::One(Ident::new("x")),
51+
body: Box::new(Expr::IsNull(Box::new(Expr::Identifier(Ident::new("x")))))
52+
})
53+
]
54+
),
55+
);
56+
57+
let res = databricks().parse_sql_statements("SELECT EXISTS (");
58+
assert_eq!(
59+
// TODO: improve this error message...
60+
ParserError::ParserError("Expected an expression:, found: EOF".to_string()),
61+
res.unwrap_err(),
62+
);
63+
}
64+
65+
#[test]
66+
fn test_databricks_lambdas() {
67+
#[rustfmt::skip]
68+
let sql = concat!(
69+
"SELECT array_sort(array('Hello', 'World'), ",
70+
"(p1, p2) -> CASE WHEN p1 = p2 THEN 0 ",
71+
"WHEN reverse(p1) < reverse(p2) THEN -1 ",
72+
"ELSE 1 END)",
73+
);
74+
pretty_assertions::assert_eq!(
75+
SelectItem::UnnamedExpr(call(
76+
"array_sort",
77+
[
78+
call(
79+
"array",
80+
[
81+
Expr::Value(Value::SingleQuotedString("Hello".to_owned())),
82+
Expr::Value(Value::SingleQuotedString("World".to_owned()))
83+
]
84+
),
85+
Expr::Lambda(LambdaFunction {
86+
params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]),
87+
body: Box::new(Expr::Case {
88+
operand: None,
89+
conditions: vec![
90+
Expr::BinaryOp {
91+
left: Box::new(Expr::Identifier(Ident::new("p1"))),
92+
op: BinaryOperator::Eq,
93+
right: Box::new(Expr::Identifier(Ident::new("p2")))
94+
},
95+
Expr::BinaryOp {
96+
left: Box::new(call(
97+
"reverse",
98+
[Expr::Identifier(Ident::new("p1"))]
99+
)),
100+
op: BinaryOperator::Lt,
101+
right: Box::new(call(
102+
"reverse",
103+
[Expr::Identifier(Ident::new("p2"))]
104+
))
105+
}
106+
],
107+
results: vec![
108+
Expr::Value(number("0")),
109+
Expr::UnaryOp {
110+
op: UnaryOperator::Minus,
111+
expr: Box::new(Expr::Value(number("1")))
112+
}
113+
],
114+
else_result: Some(Box::new(Expr::Value(number("1"))))
115+
})
116+
})
117+
]
118+
)),
119+
databricks().verified_only_select(sql).projection[0]
120+
);
121+
122+
databricks().verified_expr(
123+
"map_zip_with(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2))",
124+
);
125+
databricks().verified_expr("transform(array(1, 2, 3), x -> x + 1)");
126+
}

0 commit comments

Comments
 (0)