Skip to content

Add support for MSSQL's XQuery methods #1500

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,23 @@ pub enum Expr {
},
/// Scalar function call e.g. `LEFT(foo, 5)`
Function(Function),
/// Arbitrary expr method call
///
/// Syntax:
///
/// `<arbitrary-expr>.<function-call>.<function-call-expr>...`
///
/// > `arbitrary-expr` can be any expression including a function call.
///
/// Example:
///
/// ```sql
/// SELECT (SELECT ',' + name FROM sys.objects FOR XML PATH(''), TYPE).value('.','NVARCHAR(MAX)')
/// SELECT CONVERT(XML,'<Book>abc</Book>').value('.','NVARCHAR(MAX)').value('.','NVARCHAR(MAX)')
/// ```
///
/// (mssql): https://learn.microsoft.com/en-us/sql/t-sql/xml/xml-data-type-methods?view=sql-server-ver16
Method(Method),
/// `CASE [<operand>] WHEN <condition> THEN <result> ... [ELSE <result>] END`
///
/// Note we only recognize a complete single expression as `<condition>`,
Expand Down Expand Up @@ -1464,6 +1481,7 @@ impl fmt::Display for Expr {
write!(f, " '{}'", &value::escape_single_quote_string(value))
}
Expr::Function(fun) => write!(f, "{fun}"),
Expr::Method(method) => write!(f, "{method}"),
Expr::Case {
operand,
conditions,
Expand Down Expand Up @@ -5593,6 +5611,27 @@ impl fmt::Display for FunctionArgumentClause {
}
}

/// A method call
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct Method {
pub expr: Box<Expr>,
// always non-empty
pub method_chain: Vec<Function>,
}

impl fmt::Display for Method {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{}.{}",
self.expr,
display_separated(&self.method_chain, ".")
)
}
}

#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
Expand Down
9 changes: 9 additions & 0 deletions src/dialect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,15 @@ pub trait Dialect: Debug + Any {
false
}

/// Returns true if the dialect supports method calls, for example:
///
/// ```sql
/// SELECT (SELECT ',' + name FROM sys.objects FOR XML PATH(''), TYPE).value('.','NVARCHAR(MAX)')
/// ```
fn supports_methods(&self) -> bool {
false
}

/// Returns true if the dialect supports multiple variable assignment
/// using parentheses in a `SET` variable declaration.
///
Expand Down
4 changes: 4 additions & 0 deletions src/dialect/mssql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,8 @@ impl Dialect for MsSqlDialect {
fn supports_try_convert(&self) -> bool {
true
}

fn supports_methods(&self) -> bool {
true
}
}
38 changes: 38 additions & 0 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1258,6 +1258,7 @@ impl<'a> Parser<'a> {
}
};
self.expect_token(&Token::RParen)?;
let expr = self.try_parse_method(expr)?;
if !self.consume_token(&Token::Period) {
Ok(expr)
} else {
Expand Down Expand Up @@ -1288,6 +1289,8 @@ impl<'a> Parser<'a> {
_ => self.expected("an expression", next_token),
}?;

let expr = self.try_parse_method(expr)?;

if self.parse_keyword(Keyword::COLLATE) {
Ok(Expr::Collate {
expr: Box::new(expr),
Expand Down Expand Up @@ -1345,6 +1348,41 @@ impl<'a> Parser<'a> {
})
}

/// Parses method call expression
fn try_parse_method(&mut self, expr: Expr) -> Result<Expr, ParserError> {
if !self.dialect.supports_methods() {
return Ok(expr);
}
let method_chain = self.maybe_parse(|p| {
let mut method_chain = Vec::new();
while p.consume_token(&Token::Period) {
let tok = p.next_token();
let name = match tok.token {
Token::Word(word) => word.to_ident(),
_ => return p.expected("identifier", tok),
};
let func = match p.parse_function(ObjectName(vec![name]))? {
Expr::Function(func) => func,
_ => return p.expected("function", p.peek_token()),
};
method_chain.push(func);
}
if !method_chain.is_empty() {
Ok(method_chain)
} else {
p.expected("function", p.peek_token())
}
})?;
if let Some(method_chain) = method_chain {
Ok(Expr::Method(Method {
expr: Box::new(expr),
method_chain,
}))
} else {
Ok(expr)
}
}

pub fn parse_function(&mut self, name: ObjectName) -> Result<Expr, ParserError> {
self.expect_token(&Token::LParen)?;

Expand Down
70 changes: 70 additions & 0 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11387,6 +11387,76 @@ fn test_try_convert() {
dialects.verified_expr("TRY_CONVERT('foo', VARCHAR(MAX))");
}

#[test]
fn parse_method_select() {
let dialects = all_dialects_where(|d| d.supports_methods());
let _ = dialects.verified_only_select(
"SELECT LEFT('abc', 1).value('.', 'NVARCHAR(MAX)').value('.', 'NVARCHAR(MAX)') AS T",
);
let _ = dialects.verified_only_select("SELECT STUFF((SELECT ',' + name FROM sys.objects FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)'), 1, 1, '') AS T");
let _ = dialects
.verified_only_select("SELECT CAST(column AS XML).value('.', 'NVARCHAR(MAX)') AS T");

// `CONVERT` support
let dialects = all_dialects_where(|d| {
d.supports_methods() && d.supports_try_convert() && d.convert_type_before_value()
});
let _ = dialects.verified_only_select("SELECT CONVERT(XML, '<Book>abc</Book>').value('.', 'NVARCHAR(MAX)').value('.', 'NVARCHAR(MAX)') AS T");
}

#[test]
fn parse_method_expr() {
let dialects = all_dialects_where(|d| d.supports_methods());
let expr = dialects
.verified_expr("LEFT('abc', 1).value('.', 'NVARCHAR(MAX)').value('.', 'NVARCHAR(MAX)')");
match expr {
Expr::Method(Method { expr, method_chain }) => {
assert!(matches!(*expr, Expr::Function(_)));
assert!(matches!(
method_chain[..],
[Function { .. }, Function { .. }]
));
}
_ => unreachable!(),
}
let expr = dialects.verified_expr(
"(SELECT ',' + name FROM sys.objects FOR XML PATH(''), TYPE).value('.', 'NVARCHAR(MAX)')",
);
match expr {
Expr::Method(Method { expr, method_chain }) => {
assert!(matches!(*expr, Expr::Subquery(_)));
assert!(matches!(method_chain[..], [Function { .. }]));
}
_ => unreachable!(),
}
let expr = dialects.verified_expr("CAST(column AS XML).value('.', 'NVARCHAR(MAX)')");
match expr {
Expr::Method(Method { expr, method_chain }) => {
assert!(matches!(*expr, Expr::Cast { .. }));
assert!(matches!(method_chain[..], [Function { .. }]));
}
_ => unreachable!(),
}

// `CONVERT` support
let dialects = all_dialects_where(|d| {
d.supports_methods() && d.supports_try_convert() && d.convert_type_before_value()
});
let expr = dialects.verified_expr(
"CONVERT(XML, '<Book>abc</Book>').value('.', 'NVARCHAR(MAX)').value('.', 'NVARCHAR(MAX)')",
);
match expr {
Expr::Method(Method { expr, method_chain }) => {
assert!(matches!(*expr, Expr::Convert { .. }));
assert!(matches!(
method_chain[..],
[Function { .. }, Function { .. }]
));
}
_ => unreachable!(),
}
}

#[test]
fn test_show_dbs_schemas_tables_views() {
verified_stmt("SHOW DATABASES");
Expand Down
Loading