Skip to content

feat: mysql no-escape mode #870

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ $ cargo run --feature json_example --example cli FILENAME.sql [--dialectname]
"--postgres" => Box::new(PostgreSqlDialect {}),
"--ms" => Box::new(MsSqlDialect {}),
"--mysql" => Box::new(MySqlDialect {}),
"--mysql-no-escape" => Box::new(MySqlNoEscapeDialect {}),
"--snowflake" => Box::new(SnowflakeDialect {}),
"--hive" => Box::new(HiveDialect {}),
"--redshift" => Box::new(RedshiftSqlDialect {}),
Expand Down
20 changes: 15 additions & 5 deletions src/ast/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,21 @@ pub struct EscapeQuotedString<'a> {

impl<'a> fmt::Display for EscapeQuotedString<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
for c in self.string.chars() {
if c == self.quote {
write!(f, "{q}{q}", q = self.quote)?;
} else {
write!(f, "{c}")?;
let mut peekable_chars = self.string.chars().peekable();
while let Some(&ch) = peekable_chars.peek() {
let quote = self.quote;
match ch {
char if char == quote => {
write!(f, "{char}{char}", char = self.quote)?;
peekable_chars.next();
if peekable_chars.peek().map(|c| *c == quote).unwrap_or(false) {
peekable_chars.next();
}
}
_ => {
write!(f, "{ch}")?;
peekable_chars.next();
}
}
}
Ok(())
Expand Down
4 changes: 4 additions & 0 deletions src/dialect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub use self::generic::GenericDialect;
pub use self::hive::HiveDialect;
pub use self::mssql::MsSqlDialect;
pub use self::mysql::MySqlDialect;
pub use self::mysql::MySqlNoEscapeDialect;
pub use self::postgresql::PostgreSqlDialect;
pub use self::redshift::RedshiftSqlDialect;
pub use self::snowflake::SnowflakeDialect;
Expand Down Expand Up @@ -154,6 +155,7 @@ pub fn dialect_from_str(dialect_name: impl AsRef<str>) -> Option<Box<dyn Dialect
match dialect_name.to_lowercase().as_str() {
"generic" => Some(Box::new(GenericDialect)),
"mysql" => Some(Box::new(MySqlDialect {})),
"mysql-no-escape" => Some(Box::new(MySqlNoEscapeDialect {})),
"postgresql" | "postgres" => Some(Box::new(PostgreSqlDialect {})),
"hive" => Some(Box::new(HiveDialect {})),
"sqlite" => Some(Box::new(SQLiteDialect {})),
Expand Down Expand Up @@ -201,6 +203,8 @@ mod tests {
assert!(parse_dialect("generic").is::<GenericDialect>());
assert!(parse_dialect("mysql").is::<MySqlDialect>());
assert!(parse_dialect("MySql").is::<MySqlDialect>());
assert!(parse_dialect("mysql-no-escape").is::<MySqlNoEscapeDialect>());
assert!(parse_dialect("MySql-No-Escape").is::<MySqlNoEscapeDialect>());
assert!(parse_dialect("postgresql").is::<PostgreSqlDialect>());
assert!(parse_dialect("postgres").is::<PostgreSqlDialect>());
assert!(parse_dialect("hive").is::<HiveDialect>());
Expand Down
26 changes: 26 additions & 0 deletions src/dialect/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,29 @@ impl Dialect for MySqlDialect {
ch == '`'
}
}

/// [MySQL](https://www.mysql.com/)
/// You should use it if you don't want to escape queries when both parsing and serializing them.
#[derive(Debug)]
pub struct MySqlNoEscapeDialect {}

impl Dialect for MySqlNoEscapeDialect {
fn is_identifier_start(&self, ch: char) -> bool {
// See https://dev.mysql.com/doc/refman/8.0/en/identifiers.html.
// Identifiers which begin with a digit are recognized while tokenizing numbers,
// so they can be distinguished from exponent numeric literals.
ch.is_alphabetic()
|| ch == '_'
|| ch == '$'
|| ch == '@'
|| ('\u{0080}'..='\u{ffff}').contains(&ch)
}

fn is_identifier_part(&self, ch: char) -> bool {
self.is_identifier_start(ch) || ch.is_ascii_digit()
}

fn is_delimited_identifier_start(&self, ch: char) -> bool {
ch == '`'
}
}
30 changes: 17 additions & 13 deletions src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ impl<'a> Parser<'a> {
}
Keyword::ARRAY_AGG => self.parse_array_agg_expr(),
Keyword::NOT => self.parse_not(),
Keyword::MATCH if dialect_of!(self is MySqlDialect | GenericDialect) => {
Keyword::MATCH if dialect_of!(self is MySqlDialect | MySqlNoEscapeDialect | GenericDialect) => {
self.parse_match_against()
}
// Here `w` is a word, check if it's a part of a multi-part
Expand Down Expand Up @@ -3566,7 +3566,7 @@ impl<'a> Parser<'a> {
self.expect_token(&Token::RParen)?;
Ok(Some(ColumnOption::Check(expr)))
} else if self.parse_keyword(Keyword::AUTO_INCREMENT)
&& dialect_of!(self is MySqlDialect | GenericDialect)
&& dialect_of!(self is MySqlDialect | MySqlNoEscapeDialect | GenericDialect)
{
// Support AUTO_INCREMENT for MySQL
Ok(Some(ColumnOption::DialectSpecific(vec![
Expand All @@ -3580,7 +3580,7 @@ impl<'a> Parser<'a> {
Token::make_keyword("AUTOINCREMENT"),
])))
} else if self.parse_keywords(&[Keyword::ON, Keyword::UPDATE])
&& dialect_of!(self is MySqlDialect | GenericDialect)
&& dialect_of!(self is MySqlDialect | MySqlNoEscapeDialect | GenericDialect)
{
let expr = self.parse_expr()?;
Ok(Some(ColumnOption::OnUpdate(expr)))
Expand Down Expand Up @@ -3716,7 +3716,7 @@ impl<'a> Parser<'a> {
}
Token::Word(w)
if (w.keyword == Keyword::INDEX || w.keyword == Keyword::KEY)
&& dialect_of!(self is GenericDialect | MySqlDialect) =>
&& dialect_of!(self is GenericDialect | MySqlDialect | MySqlNoEscapeDialect) =>
{
let display_as_key = w.keyword == Keyword::KEY;

Expand All @@ -3741,7 +3741,7 @@ impl<'a> Parser<'a> {
}
Token::Word(w)
if (w.keyword == Keyword::FULLTEXT || w.keyword == Keyword::SPATIAL)
&& dialect_of!(self is GenericDialect | MySqlDialect) =>
&& dialect_of!(self is GenericDialect | MySqlDialect | MySqlNoEscapeDialect) =>
{
if let Some(name) = name {
return self.expected(
Expand Down Expand Up @@ -3900,7 +3900,7 @@ impl<'a> Parser<'a> {
cascade,
}
} else if self.parse_keywords(&[Keyword::PRIMARY, Keyword::KEY])
&& dialect_of!(self is MySqlDialect | GenericDialect)
&& dialect_of!(self is MySqlDialect | MySqlNoEscapeDialect | GenericDialect)
{
AlterTableOperation::DropPrimaryKey
} else {
Expand Down Expand Up @@ -4995,7 +4995,7 @@ impl<'a> Parser<'a> {
offset = Some(self.parse_offset()?)
}

if dialect_of!(self is GenericDialect | MySqlDialect)
if dialect_of!(self is GenericDialect | MySqlDialect | MySqlNoEscapeDialect)
&& limit.is_some()
&& offset.is_none()
&& self.consume_token(&Token::Comma)
Expand Down Expand Up @@ -5088,7 +5088,7 @@ impl<'a> Parser<'a> {
self.expect_token(&Token::RParen)?;
SetExpr::Query(Box::new(subquery))
} else if self.parse_keyword(Keyword::VALUES) {
let is_mysql = dialect_of!(self is MySqlDialect);
let is_mysql = dialect_of!(self is MySqlDialect | MySqlNoEscapeDialect);
SetExpr::Values(self.parse_values(is_mysql)?)
} else if self.parse_keyword(Keyword::TABLE) {
SetExpr::Table(Box::new(self.parse_as_table()?))
Expand Down Expand Up @@ -5365,7 +5365,7 @@ impl<'a> Parser<'a> {
};

if variable.to_string().eq_ignore_ascii_case("NAMES")
&& dialect_of!(self is MySqlDialect | GenericDialect)
&& dialect_of!(self is MySqlDialect | MySqlNoEscapeDialect | GenericDialect)
{
if self.parse_keyword(Keyword::DEFAULT) {
return Ok(Statement::SetNamesDefault {});
Expand Down Expand Up @@ -5458,7 +5458,7 @@ impl<'a> Parser<'a> {
} else if self.parse_keyword(Keyword::COLLATION) {
Ok(self.parse_show_collation()?)
} else if self.parse_keyword(Keyword::VARIABLES)
&& dialect_of!(self is MySqlDialect | GenericDialect)
&& dialect_of!(self is MySqlDialect | MySqlNoEscapeDialect | GenericDialect)
{
// TODO: Support GLOBAL|SESSION
Ok(Statement::ShowVariables {
Expand Down Expand Up @@ -6133,7 +6133,7 @@ impl<'a> Parser<'a> {
// Hive lets you put table here regardless
let table = self.parse_keyword(Keyword::TABLE);
let table_name = self.parse_object_name()?;
let is_mysql = dialect_of!(self is MySqlDialect);
let is_mysql = dialect_of!(self is MySqlDialect | MySqlNoEscapeDialect);
let columns = self.parse_parenthesized_column_list(Optional, is_mysql)?;

let partitioned = if self.parse_keyword(Keyword::PARTITION) {
Expand Down Expand Up @@ -6762,7 +6762,7 @@ impl<'a> Parser<'a> {
"INSERT in MATCHED merge clause".to_string(),
));
}
let is_mysql = dialect_of!(self is MySqlDialect);
let is_mysql = dialect_of!(self is MySqlDialect | MySqlNoEscapeDialect);
let columns = self.parse_parenthesized_column_list(Optional, is_mysql)?;
self.expect_keyword(Keyword::VALUES)?;
let values = self.parse_values(is_mysql)?;
Expand Down Expand Up @@ -7333,7 +7333,11 @@ mod tests {
}

let dialect = TestedDialects {
dialects: vec![Box::new(GenericDialect {}), Box::new(MySqlDialect {})],
dialects: vec![
Box::new(GenericDialect {}),
Box::new(MySqlDialect {}),
Box::new(MySqlNoEscapeDialect {}),
],
options: None,
};

Expand Down
12 changes: 12 additions & 0 deletions src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,25 @@ pub fn all_dialects() -> TestedDialects {
Box::new(HiveDialect {}),
Box::new(RedshiftSqlDialect {}),
Box::new(MySqlDialect {}),
Box::new(MySqlNoEscapeDialect {}),
Box::new(BigQueryDialect {}),
Box::new(SQLiteDialect {}),
],
options: None,
}
}

pub fn all_dialects_other_than_mysqlnoescape() -> TestedDialects {
Copy link
Contributor Author

@canalun canalun May 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the functions, "XX_other_than_mysqlnoescape", including this are supposed to be used for tests regarding escape.
I newly defined this kind of functions, because some test utils assume all the dialects return the same output and these utils are used widely now.

let mut all_dialects = all_dialects();
let index_of_mysqlnoescape = all_dialects
.dialects
.iter()
.position(|dialect| dialect.is::<MySqlNoEscapeDialect>())
.unwrap();
all_dialects.dialects.remove(index_of_mysqlnoescape);
all_dialects
}

pub fn assert_eq_vec<T: ToString>(expected: &[&str], actual: &[T]) {
assert_eq!(
expected,
Expand Down
81 changes: 58 additions & 23 deletions src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use sqlparser_derive::{Visit, VisitMut};

use crate::ast::DollarQuotedString;
use crate::dialect::{BigQueryDialect, GenericDialect, SnowflakeDialect};
use crate::dialect::{Dialect, MySqlDialect};
use crate::dialect::{Dialect, MySqlDialect, MySqlNoEscapeDialect};
use crate::keywords::{Keyword, ALL_KEYWORDS, ALL_KEYWORDS_INDEX};

/// SQL Token enumeration
Expand Down Expand Up @@ -636,7 +636,7 @@ impl<'a> Tokenizer<'a> {
let error_loc = chars.location();
chars.next(); // consume the opening quote
let quote_end = Word::matching_end_quote(quote_start);
let (s, last_char) = parse_quoted_ident(chars, quote_end);
let (s, last_char) = self.parse_quoted_ident(chars, quote_end);

if last_char == Some(quote_end) {
Ok(Some(Token::make_word(&s, Some(quote_start))))
Expand Down Expand Up @@ -705,7 +705,9 @@ impl<'a> Tokenizer<'a> {

// mysql dialect supports identifiers that start with a numeric prefix,
// as long as they aren't an exponent number.
if dialect_of!(self is MySqlDialect) && exponent_part.is_empty() {
if dialect_of!(self is MySqlDialect | MySqlNoEscapeDialect)
&& exponent_part.is_empty()
{
let word =
peeking_take_while(chars, |ch| self.dialect.is_identifier_part(ch));

Expand Down Expand Up @@ -1112,6 +1114,10 @@ impl<'a> Tokenizer<'a> {
chars.next(); // consume
if chars.peek().map(|c| *c == quote_style).unwrap_or(false) {
s.push(ch);
if dialect_of!(self is MySqlNoEscapeDialect) {
// In no-escape mode, the given query has to be saved completely
s.push(ch);
}
chars.next();
} else {
return Ok(s);
Expand All @@ -1120,7 +1126,7 @@ impl<'a> Tokenizer<'a> {
'\\' => {
// consume
chars.next();
// slash escaping is specific to MySQL dialect
// slash escaping is specific to MySQL dialect.
if dialect_of!(self is MySqlDialect) {
if let Some(next) = chars.peek() {
// See https://dev.mysql.com/doc/refman/8.0/en/string-literals.html#character-escape-sequences
Expand All @@ -1137,6 +1143,13 @@ impl<'a> Tokenizer<'a> {
s.push(n);
chars.next(); // consume next
}
} else if dialect_of!(self is MySqlNoEscapeDialect) {
// In no-escape mode, the given query has to be saved completely including backslashes.
if let Some(next) = chars.peek() {
s.push(ch);
s.push(*next);
chars.next(); // consume next
}
} else {
s.push(ch);
}
Expand Down Expand Up @@ -1183,6 +1196,29 @@ impl<'a> Tokenizer<'a> {
}
}

fn parse_quoted_ident(&self, chars: &mut State, quote_end: char) -> (String, Option<char>) {
let mut last_char = None;
let mut s = String::new();
while let Some(ch) = chars.next() {
if ch == quote_end {
if chars.peek() == Some(&quote_end) {
chars.next();
s.push(ch);
if dialect_of!(self is MySqlNoEscapeDialect) {
// In no-escape mode, the given query has to be saved completely
s.push(ch);
}
} else {
last_char = Some(quote_end);
break;
}
} else {
s.push(ch);
}
}
(s, last_char)
}

#[allow(clippy::unnecessary_wraps)]
fn consume_and_return(
&self,
Expand Down Expand Up @@ -1210,25 +1246,6 @@ fn peeking_take_while(chars: &mut State, mut predicate: impl FnMut(char) -> bool
s
}

fn parse_quoted_ident(chars: &mut State, quote_end: char) -> (String, Option<char>) {
let mut last_char = None;
let mut s = String::new();
while let Some(ch) = chars.next() {
if ch == quote_end {
if chars.peek() == Some(&quote_end) {
chars.next();
s.push(ch);
} else {
last_char = Some(quote_end);
break;
}
} else {
s.push(ch);
}
}
(s, last_char)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -1870,6 +1887,24 @@ mod tests {
compare(expected, tokens);
}

#[test]
fn tokenize_quoted_identifier_with_no_escape() {
let sql = r#" "a "" b" "a """ "c """"" "#;
let dialect = MySqlNoEscapeDialect {};
let mut tokenizer = Tokenizer::new(&dialect, sql);
let tokens = tokenizer.tokenize().unwrap();
let expected = vec![
Token::Whitespace(Whitespace::Space),
Token::DoubleQuotedString(String::from(r#"a "" b"#)),
Token::Whitespace(Whitespace::Space),
Token::DoubleQuotedString(String::from(r#"a """#)),
Token::Whitespace(Whitespace::Space),
Token::DoubleQuotedString(String::from(r#"c """""#)),
Token::Whitespace(Whitespace::Space),
];
compare(expected, tokens);
}

#[test]
fn tokenize_with_location() {
let sql = "SELECT a,\n b";
Expand Down
Loading