diff --git a/src/ast/mod.rs b/src/ast/mod.rs index e19a12b5e..480f0d232 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -793,6 +793,59 @@ pub enum Expr { OuterJoin(Box), /// A reference to the prior level in a CONNECT BY clause. Prior(Box), + /// A lambda function. + /// + /// Syntax: + /// ```plaintext + /// param -> expr | (param1, ...) -> expr + /// ``` + /// + /// See . + Lambda(LambdaFunction), +} + +/// A lambda function. +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct LambdaFunction { + /// The parameters to the lambda function. + pub params: OneOrManyWithParens, + /// The body of the lambda function. + pub body: Box, +} + +impl fmt::Display for LambdaFunction { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{} -> {}", self.params, self.body) + } +} + +/// Encapsulates the common pattern in SQL where either one unparenthesized item +/// such as an identifier or expression is permitted, or multiple of the same +/// item in a parenthesized list. +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum OneOrManyWithParens { + /// A single `T`, unparenthesized. + One(T), + /// One or more `T`s, parenthesized. + Many(Vec), +} + +impl fmt::Display for OneOrManyWithParens +where + T: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + OneOrManyWithParens::One(value) => write!(f, "{value}"), + OneOrManyWithParens::Many(values) => { + write!(f, "({})", display_comma_separated(values)) + } + } + } } impl fmt::Display for CastFormat { @@ -1241,6 +1294,7 @@ impl fmt::Display for Expr { write!(f, "{expr} (+)") } Expr::Prior(expr) => write!(f, "PRIOR {expr}"), + Expr::Lambda(lambda) => write!(f, "{lambda}"), } } } diff --git a/src/dialect/databricks.rs b/src/dialect/databricks.rs index 63d0e9827..929ec26cd 100644 --- a/src/dialect/databricks.rs +++ b/src/dialect/databricks.rs @@ -29,4 +29,8 @@ impl Dialect for DatabricksDialect { fn supports_group_by_expr(&self) -> bool { true } + + fn supports_lambda_functions(&self) -> bool { + true + } } diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index a04390570..74d0077c8 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -209,6 +209,14 @@ pub trait Dialect: Debug + Any { fn supports_dictionary_syntax(&self) -> bool { false } + /// Returns true if the dialect supports lambda functions, for example: + /// + /// ```sql + /// SELECT transform(array(1, 2, 3), x -> x + 1); -- returns [2,3,4] + /// ``` + fn supports_lambda_functions(&self) -> bool { + false + } /// Returns true if the dialect has a CONVERT function which accepts a type first /// and an expression second, e.g. `CONVERT(varchar, 1)` fn convert_type_before_value(&self) -> bool { diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 7b92fdf7f..fbc07eeed 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -1018,7 +1018,19 @@ impl<'a> Parser<'a> { Keyword::CAST => self.parse_cast_expr(CastKind::Cast), Keyword::TRY_CAST => self.parse_cast_expr(CastKind::TryCast), Keyword::SAFE_CAST => self.parse_cast_expr(CastKind::SafeCast), - Keyword::EXISTS => self.parse_exists_expr(false), + Keyword::EXISTS + // Support parsing Databricks has a function named `exists`. + if !dialect_of!(self is DatabricksDialect) + || matches!( + self.peek_nth_token(1).token, + Token::Word(Word { + keyword: Keyword::SELECT | Keyword::WITH, + .. + }) + ) => + { + self.parse_exists_expr(false) + } Keyword::EXTRACT => self.parse_extract_expr(), Keyword::CEIL => self.parse_ceil_floor_expr(true), Keyword::FLOOR => self.parse_ceil_floor_expr(false), @@ -1036,7 +1048,7 @@ impl<'a> Parser<'a> { } Keyword::ARRAY if self.peek_token() == Token::LParen - && !dialect_of!(self is ClickHouseDialect) => + && !dialect_of!(self is ClickHouseDialect | DatabricksDialect) => { self.expect_token(&Token::LParen)?; let query = self.parse_boxed_query()?; @@ -1124,6 +1136,13 @@ impl<'a> Parser<'a> { value: self.parse_introduced_string_value()?, }) } + Token::Arrow if self.dialect.supports_lambda_functions() => { + self.expect_token(&Token::Arrow)?; + return Ok(Expr::Lambda(LambdaFunction { + params: OneOrManyWithParens::One(w.to_ident()), + body: Box::new(self.parse_expr()?), + })); + } _ => Ok(Expr::Identifier(w.to_ident())), }, }, // End of Token::Word @@ -1182,6 +1201,8 @@ impl<'a> Parser<'a> { if self.parse_keyword(Keyword::SELECT) || self.parse_keyword(Keyword::WITH) { self.prev_token(); Expr::Subquery(self.parse_boxed_query()?) + } else if let Some(lambda) = self.try_parse_lambda() { + return Ok(lambda); } else { let exprs = self.parse_comma_separated(Parser::parse_expr)?; match exprs.len() { @@ -1231,6 +1252,22 @@ impl<'a> Parser<'a> { } } + fn try_parse_lambda(&mut self) -> Option { + if !self.dialect.supports_lambda_functions() { + return None; + } + self.maybe_parse(|p| { + let params = p.parse_comma_separated(|p| p.parse_identifier(false))?; + p.expect_token(&Token::RParen)?; + p.expect_token(&Token::Arrow)?; + let expr = p.parse_expr()?; + Ok(Expr::Lambda(LambdaFunction { + params: OneOrManyWithParens::Many(params), + body: Box::new(expr), + })) + }) + } + pub fn parse_function(&mut self, name: ObjectName) -> Result { self.expect_token(&Token::LParen)?; diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index e08f5c4e3..a17daac5d 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -1380,7 +1380,11 @@ fn pg_and_generic() -> TestedDialects { fn parse_json_ops_without_colon() { use self::BinaryOperator::*; let binary_ops = [ - ("->", Arrow, all_dialects()), + ( + "->", + Arrow, + all_dialects_except(|d| d.supports_lambda_functions()), + ), ("->>", LongArrow, all_dialects()), ("#>", HashArrow, pg_and_generic()), ("#>>", HashLongArrow, pg_and_generic()), @@ -6174,7 +6178,8 @@ fn parse_exists_subquery() { verified_stmt("SELECT * FROM t WHERE EXISTS (WITH u AS (SELECT 1) SELECT * FROM u)"); verified_stmt("SELECT EXISTS (SELECT 1)"); - let res = parse_sql_statements("SELECT EXISTS ("); + let res = all_dialects_except(|d| d.is::()) + .parse_sql_statements("SELECT EXISTS ("); assert_eq!( ParserError::ParserError( "Expected SELECT, VALUES, or a subquery in the query body, found: EOF".to_string() @@ -6182,7 +6187,8 @@ fn parse_exists_subquery() { res.unwrap_err(), ); - let res = parse_sql_statements("SELECT EXISTS (NULL)"); + let res = all_dialects_except(|d| d.is::()) + .parse_sql_statements("SELECT EXISTS (NULL)"); assert_eq!( ParserError::ParserError( "Expected SELECT, VALUES, or a subquery in the query body, found: NULL".to_string() diff --git a/tests/sqlparser_databricks.rs b/tests/sqlparser_databricks.rs index 20795463b..8f0579fc9 100644 --- a/tests/sqlparser_databricks.rs +++ b/tests/sqlparser_databricks.rs @@ -1,5 +1,6 @@ use sqlparser::ast::*; use sqlparser::dialect::DatabricksDialect; +use sqlparser::parser::ParserError; use test_utils::*; #[macro_use] @@ -28,3 +29,98 @@ fn test_databricks_identifiers() { SelectItem::UnnamedExpr(Expr::Value(Value::DoubleQuotedString("Ä".to_owned()))) ); } + +#[test] +fn test_databricks_exists() { + // exists is a function in databricks + assert_eq!( + databricks().verified_expr("exists(array(1, 2, 3), x -> x IS NULL)"), + call( + "exists", + [ + call( + "array", + [ + Expr::Value(number("1")), + Expr::Value(number("2")), + Expr::Value(number("3")) + ] + ), + Expr::Lambda(LambdaFunction { + params: OneOrManyWithParens::One(Ident::new("x")), + body: Box::new(Expr::IsNull(Box::new(Expr::Identifier(Ident::new("x"))))) + }) + ] + ), + ); + + let res = databricks().parse_sql_statements("SELECT EXISTS ("); + assert_eq!( + // TODO: improve this error message... + ParserError::ParserError("Expected an expression:, found: EOF".to_string()), + res.unwrap_err(), + ); +} + +#[test] +fn test_databricks_lambdas() { + #[rustfmt::skip] + let sql = concat!( + "SELECT array_sort(array('Hello', 'World'), ", + "(p1, p2) -> CASE WHEN p1 = p2 THEN 0 ", + "WHEN reverse(p1) < reverse(p2) THEN -1 ", + "ELSE 1 END)", + ); + pretty_assertions::assert_eq!( + SelectItem::UnnamedExpr(call( + "array_sort", + [ + call( + "array", + [ + Expr::Value(Value::SingleQuotedString("Hello".to_owned())), + Expr::Value(Value::SingleQuotedString("World".to_owned())) + ] + ), + Expr::Lambda(LambdaFunction { + params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]), + body: Box::new(Expr::Case { + operand: None, + conditions: vec![ + Expr::BinaryOp { + left: Box::new(Expr::Identifier(Ident::new("p1"))), + op: BinaryOperator::Eq, + right: Box::new(Expr::Identifier(Ident::new("p2"))) + }, + Expr::BinaryOp { + left: Box::new(call( + "reverse", + [Expr::Identifier(Ident::new("p1"))] + )), + op: BinaryOperator::Lt, + right: Box::new(call( + "reverse", + [Expr::Identifier(Ident::new("p2"))] + )) + } + ], + results: vec![ + Expr::Value(number("0")), + Expr::UnaryOp { + op: UnaryOperator::Minus, + expr: Box::new(Expr::Value(number("1"))) + } + ], + else_result: Some(Box::new(Expr::Value(number("1")))) + }) + }) + ] + )), + databricks().verified_only_select(sql).projection[0] + ); + + databricks().verified_expr( + "map_zip_with(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2))", + ); + databricks().verified_expr("transform(array(1, 2, 3), x -> x + 1)"); +}