Skip to content

Databricks: support for lambda functions #1257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,59 @@ pub enum Expr {
OuterJoin(Box<Expr>),
/// A reference to the prior level in a CONNECT BY clause.
Prior(Box<Expr>),
/// A lambda function.
///
/// Syntax:
/// ```plaintext
/// param -> expr | (param1, ...) -> expr
/// ```
///
/// See <https://docs.databricks.com/en/sql/language-manual/sql-ref-lambda-functions.html>.
Lambda(LambdaFunction),
}

/// A lambda function.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct LambdaFunction {
/// The parameters to the lambda function.
pub params: OneOrManyWithParens<Ident>,
/// The body of the lambda function.
pub body: Box<Expr>,
}

impl fmt::Display for LambdaFunction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} -> {}", self.params, self.body)
}
}

/// Encapsulates the common pattern in SQL where either one unparenthesized item
/// such as an identifier or expression is permitted, or multiple of the same
/// item in a parenthesized list.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum OneOrManyWithParens<T> {
/// A single `T`, unparenthesized.
One(T),
/// One or more `T`s, parenthesized.
Many(Vec<T>),
}

impl<T> fmt::Display for OneOrManyWithParens<T>
where
T: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
OneOrManyWithParens::One(value) => write!(f, "{value}"),
OneOrManyWithParens::Many(values) => {
write!(f, "({})", display_comma_separated(values))
}
}
}
}

impl fmt::Display for CastFormat {
Expand Down Expand Up @@ -1241,6 +1294,7 @@ impl fmt::Display for Expr {
write!(f, "{expr} (+)")
}
Expr::Prior(expr) => write!(f, "PRIOR {expr}"),
Expr::Lambda(lambda) => write!(f, "{lambda}"),
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions src/dialect/databricks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,8 @@ impl Dialect for DatabricksDialect {
fn supports_group_by_expr(&self) -> bool {
true
}

fn supports_lambda_functions(&self) -> bool {
true
}
}
8 changes: 8 additions & 0 deletions src/dialect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,14 @@ pub trait Dialect: Debug + Any {
fn supports_dictionary_syntax(&self) -> bool {
false
}
/// Returns true if the dialect supports lambda functions, for example:
///
/// ```sql
/// SELECT transform(array(1, 2, 3), x -> x + 1); -- returns [2,3,4]
/// ```
fn supports_lambda_functions(&self) -> bool {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I was unsure of here is whether to enable this for the generic dialect. Our policy is to enable everything for generic, but this syntax conflicts with the -> operator from dialects like Postgres. So supporting this would be a breaking change of somewhat dubious benefit, since we both gain and lose a feature in generic. Thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this makes sense to not support via GenericDialect when there is a conflict

🤔 that might be a good thing to add to the policy https://github.com/sqlparser-rs/sqlparser-rs?tab=readme-ov-file#new-syntax

false
}
/// Returns true if the dialect has a CONVERT function which accepts a type first
/// and an expression second, e.g. `CONVERT(varchar, 1)`
fn convert_type_before_value(&self) -> bool {
Expand Down
41 changes: 39 additions & 2 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,19 @@ impl<'a> Parser<'a> {
Keyword::CAST => self.parse_cast_expr(CastKind::Cast),
Keyword::TRY_CAST => self.parse_cast_expr(CastKind::TryCast),
Keyword::SAFE_CAST => self.parse_cast_expr(CastKind::SafeCast),
Keyword::EXISTS => self.parse_exists_expr(false),
Keyword::EXISTS
// Support parsing Databricks has a function named `exists`.
if !dialect_of!(self is DatabricksDialect)
|| matches!(
self.peek_nth_token(1).token,
Token::Word(Word {
keyword: Keyword::SELECT | Keyword::WITH,
..
})
) =>
{
self.parse_exists_expr(false)
}
Keyword::EXTRACT => self.parse_extract_expr(),
Keyword::CEIL => self.parse_ceil_floor_expr(true),
Keyword::FLOOR => self.parse_ceil_floor_expr(false),
Expand All @@ -1036,7 +1048,7 @@ impl<'a> Parser<'a> {
}
Keyword::ARRAY
if self.peek_token() == Token::LParen
&& !dialect_of!(self is ClickHouseDialect) =>
&& !dialect_of!(self is ClickHouseDialect | DatabricksDialect) =>
{
self.expect_token(&Token::LParen)?;
let query = self.parse_boxed_query()?;
Expand Down Expand Up @@ -1124,6 +1136,13 @@ impl<'a> Parser<'a> {
value: self.parse_introduced_string_value()?,
})
}
Token::Arrow if self.dialect.supports_lambda_functions() => {
self.expect_token(&Token::Arrow)?;
return Ok(Expr::Lambda(LambdaFunction {
params: OneOrManyWithParens::One(w.to_ident()),
body: Box::new(self.parse_expr()?),
}));
}
_ => Ok(Expr::Identifier(w.to_ident())),
},
}, // End of Token::Word
Expand Down Expand Up @@ -1182,6 +1201,8 @@ impl<'a> Parser<'a> {
if self.parse_keyword(Keyword::SELECT) || self.parse_keyword(Keyword::WITH) {
self.prev_token();
Expr::Subquery(self.parse_boxed_query()?)
} else if let Some(lambda) = self.try_parse_lambda() {
return Ok(lambda);
} else {
let exprs = self.parse_comma_separated(Parser::parse_expr)?;
match exprs.len() {
Expand Down Expand Up @@ -1231,6 +1252,22 @@ impl<'a> Parser<'a> {
}
}

fn try_parse_lambda(&mut self) -> Option<Expr> {
if !self.dialect.supports_lambda_functions() {
return None;
}
self.maybe_parse(|p| {
let params = p.parse_comma_separated(|p| p.parse_identifier(false))?;
p.expect_token(&Token::RParen)?;
p.expect_token(&Token::Arrow)?;
let expr = p.parse_expr()?;
Ok(Expr::Lambda(LambdaFunction {
params: OneOrManyWithParens::Many(params),
body: Box::new(expr),
}))
})
}

pub fn parse_function(&mut self, name: ObjectName) -> Result<Expr, ParserError> {
self.expect_token(&Token::LParen)?;

Expand Down
12 changes: 9 additions & 3 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1380,7 +1380,11 @@ fn pg_and_generic() -> TestedDialects {
fn parse_json_ops_without_colon() {
use self::BinaryOperator::*;
let binary_ops = [
("->", Arrow, all_dialects()),
(
"->",
Arrow,
all_dialects_except(|d| d.supports_lambda_functions()),
),
("->>", LongArrow, all_dialects()),
("#>", HashArrow, pg_and_generic()),
("#>>", HashLongArrow, pg_and_generic()),
Expand Down Expand Up @@ -6174,15 +6178,17 @@ fn parse_exists_subquery() {
verified_stmt("SELECT * FROM t WHERE EXISTS (WITH u AS (SELECT 1) SELECT * FROM u)");
verified_stmt("SELECT EXISTS (SELECT 1)");

let res = parse_sql_statements("SELECT EXISTS (");
let res = all_dialects_except(|d| d.is::<DatabricksDialect>())
.parse_sql_statements("SELECT EXISTS (");
assert_eq!(
ParserError::ParserError(
"Expected SELECT, VALUES, or a subquery in the query body, found: EOF".to_string()
),
res.unwrap_err(),
);

let res = parse_sql_statements("SELECT EXISTS (NULL)");
let res = all_dialects_except(|d| d.is::<DatabricksDialect>())
.parse_sql_statements("SELECT EXISTS (NULL)");
assert_eq!(
ParserError::ParserError(
"Expected SELECT, VALUES, or a subquery in the query body, found: NULL".to_string()
Expand Down
96 changes: 96 additions & 0 deletions tests/sqlparser_databricks.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use sqlparser::ast::*;
use sqlparser::dialect::DatabricksDialect;
use sqlparser::parser::ParserError;
use test_utils::*;

#[macro_use]
Expand Down Expand Up @@ -28,3 +29,98 @@ fn test_databricks_identifiers() {
SelectItem::UnnamedExpr(Expr::Value(Value::DoubleQuotedString("Ä".to_owned())))
);
}

#[test]
fn test_databricks_exists() {
// exists is a function in databricks
assert_eq!(
databricks().verified_expr("exists(array(1, 2, 3), x -> x IS NULL)"),
call(
"exists",
[
call(
"array",
[
Expr::Value(number("1")),
Expr::Value(number("2")),
Expr::Value(number("3"))
]
),
Expr::Lambda(LambdaFunction {
params: OneOrManyWithParens::One(Ident::new("x")),
body: Box::new(Expr::IsNull(Box::new(Expr::Identifier(Ident::new("x")))))
})
]
),
);

let res = databricks().parse_sql_statements("SELECT EXISTS (");
assert_eq!(
// TODO: improve this error message...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was the todo meant for this PR or as a general improvement/followup with error reporting?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the latter

ParserError::ParserError("Expected an expression:, found: EOF".to_string()),
res.unwrap_err(),
);
}

#[test]
fn test_databricks_lambdas() {
#[rustfmt::skip]
let sql = concat!(
"SELECT array_sort(array('Hello', 'World'), ",
"(p1, p2) -> CASE WHEN p1 = p2 THEN 0 ",
"WHEN reverse(p1) < reverse(p2) THEN -1 ",
"ELSE 1 END)",
);
pretty_assertions::assert_eq!(
SelectItem::UnnamedExpr(call(
"array_sort",
[
call(
"array",
[
Expr::Value(Value::SingleQuotedString("Hello".to_owned())),
Expr::Value(Value::SingleQuotedString("World".to_owned()))
]
),
Expr::Lambda(LambdaFunction {
params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]),
body: Box::new(Expr::Case {
operand: None,
conditions: vec![
Expr::BinaryOp {
left: Box::new(Expr::Identifier(Ident::new("p1"))),
op: BinaryOperator::Eq,
right: Box::new(Expr::Identifier(Ident::new("p2")))
},
Expr::BinaryOp {
left: Box::new(call(
"reverse",
[Expr::Identifier(Ident::new("p1"))]
)),
op: BinaryOperator::Lt,
right: Box::new(call(
"reverse",
[Expr::Identifier(Ident::new("p2"))]
))
}
],
results: vec![
Expr::Value(number("0")),
Expr::UnaryOp {
op: UnaryOperator::Minus,
expr: Box::new(Expr::Value(number("1")))
}
],
else_result: Some(Box::new(Expr::Value(number("1"))))
})
})
]
)),
databricks().verified_only_select(sql).projection[0]
);

databricks().verified_expr(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty neat!

"map_zip_with(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2))",
);
databricks().verified_expr("transform(array(1, 2, 3), x -> x + 1)");
}
Loading