Skip to content

Commit f98a2f9

Browse files
canalunalamb
andauthored
feat: mysql no-escape mode (#870)
Co-authored-by: Andrew Lamb <[email protected]>
1 parent eb28848 commit f98a2f9

File tree

7 files changed

+485
-121
lines changed

7 files changed

+485
-121
lines changed

src/ast/value.rs

+47-6
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ impl fmt::Display for Value {
7171
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
7272
match self {
7373
Value::Number(v, l) => write!(f, "{}{long}", v, long = if *l { "L" } else { "" }),
74-
Value::DoubleQuotedString(v) => write!(f, "\"{v}\""),
74+
Value::DoubleQuotedString(v) => write!(f, "\"{}\"", escape_double_quote_string(v)),
7575
Value::SingleQuotedString(v) => write!(f, "'{}'", escape_single_quote_string(v)),
7676
Value::DollarQuotedString(v) => write!(f, "{v}"),
7777
Value::EscapedStringLiteral(v) => write!(f, "E'{}'", escape_escaped_string(v)),
@@ -187,12 +187,49 @@ pub struct EscapeQuotedString<'a> {
187187

188188
impl<'a> fmt::Display for EscapeQuotedString<'a> {
189189
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
190-
for c in self.string.chars() {
191-
if c == self.quote {
192-
write!(f, "{q}{q}", q = self.quote)?;
193-
} else {
194-
write!(f, "{c}")?;
190+
// EscapeQuotedString doesn't know which mode of escape was
191+
// chosen by the user. So this code must to correctly display
192+
// strings without knowing if the strings are already escaped
193+
// or not.
194+
//
195+
// If the quote symbol in the string is repeated twice, OR, if
196+
// the quote symbol is after backslash, display all the chars
197+
// without any escape. However, if the quote symbol is used
198+
// just between usual chars, `fmt()` should display it twice."
199+
//
200+
// The following table has examples
201+
//
202+
// | original query | mode | AST Node | serialized |
203+
// | ------------- | --------- | -------------------------------------------------- | ------------ |
204+
// | `"A""B""A"` | no-escape | `DoubleQuotedString(String::from("A\"\"B\"\"A"))` | `"A""B""A"` |
205+
// | `"A""B""A"` | default | `DoubleQuotedString(String::from("A\"B\"A"))` | `"A""B""A"` |
206+
// | `"A\"B\"A"` | no-escape | `DoubleQuotedString(String::from("A\\\"B\\\"A"))` | `"A\"B\"A"` |
207+
// | `"A\"B\"A"` | default | `DoubleQuotedString(String::from("A\"B\"A"))` | `"A""B""A"` |
208+
let quote = self.quote;
209+
let mut previous_char = char::default();
210+
let mut peekable_chars = self.string.chars().peekable();
211+
while let Some(&ch) = peekable_chars.peek() {
212+
match ch {
213+
char if char == quote => {
214+
if previous_char == '\\' {
215+
write!(f, "{char}")?;
216+
peekable_chars.next();
217+
continue;
218+
}
219+
peekable_chars.next();
220+
if peekable_chars.peek().map(|c| *c == quote).unwrap_or(false) {
221+
write!(f, "{char}{char}")?;
222+
peekable_chars.next();
223+
} else {
224+
write!(f, "{char}{char}")?;
225+
}
226+
}
227+
_ => {
228+
write!(f, "{ch}")?;
229+
peekable_chars.next();
230+
}
195231
}
232+
previous_char = ch;
196233
}
197234
Ok(())
198235
}
@@ -206,6 +243,10 @@ pub fn escape_single_quote_string(s: &str) -> EscapeQuotedString<'_> {
206243
escape_quoted_string(s, '\'')
207244
}
208245

246+
pub fn escape_double_quote_string(s: &str) -> EscapeQuotedString<'_> {
247+
escape_quoted_string(s, '\"')
248+
}
249+
209250
pub struct EscapeEscapedStringLiteral<'a>(&'a str);
210251

211252
impl<'a> fmt::Display for EscapeEscapedStringLiteral<'a> {

src/ast/visitor.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -632,8 +632,7 @@ mod tests {
632632

633633
fn do_visit(sql: &str) -> Vec<String> {
634634
let dialect = GenericDialect {};
635-
let mut tokenizer = Tokenizer::new(&dialect, sql);
636-
let tokens = tokenizer.tokenize().unwrap();
635+
let tokens = Tokenizer::new(&dialect, sql).tokenize().unwrap();
637636
let s = Parser::new(&dialect)
638637
.with_tokens(tokens)
639638
.parse_statement()

src/parser.rs

+58-10
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,52 @@ impl std::error::Error for ParserError {}
195195
// By default, allow expressions up to this deep before erroring
196196
const DEFAULT_REMAINING_DEPTH: usize = 50;
197197

198-
#[derive(Debug, Default, Clone, PartialEq, Eq)]
198+
/// Options that control how the [`Parser`] parses SQL text
199+
#[derive(Debug, Clone, PartialEq, Eq)]
199200
pub struct ParserOptions {
200201
pub trailing_commas: bool,
202+
/// Controls how literal values are unescaped. See
203+
/// [`Tokenizer::with_unescape`] for more details.
204+
pub unescape: bool,
205+
}
206+
207+
impl Default for ParserOptions {
208+
fn default() -> Self {
209+
Self {
210+
trailing_commas: false,
211+
unescape: true,
212+
}
213+
}
214+
}
215+
216+
impl ParserOptions {
217+
/// Create a new [`ParserOptions`]
218+
pub fn new() -> Self {
219+
Default::default()
220+
}
221+
222+
/// Set if trailing commas are allowed.
223+
///
224+
/// If this option is `false` (the default), the following SQL will
225+
/// not parse. If the option is `true`, the SQL will parse.
226+
///
227+
/// ```sql
228+
/// SELECT
229+
/// foo,
230+
/// bar,
231+
/// FROM baz
232+
/// ```
233+
pub fn with_trailing_commas(mut self, trailing_commas: bool) -> Self {
234+
self.trailing_commas = trailing_commas;
235+
self
236+
}
237+
238+
/// Set if literal values are unescaped. Defaults to true. See
239+
/// [`Tokenizer::with_unescape`] for more details.
240+
pub fn with_unescape(mut self, unescape: bool) -> Self {
241+
self.unescape = unescape;
242+
self
243+
}
201244
}
202245

203246
pub struct Parser<'a> {
@@ -206,8 +249,9 @@ pub struct Parser<'a> {
206249
index: usize,
207250
/// The current dialect to use
208251
dialect: &'a dyn Dialect,
209-
/// Additional options that allow you to mix & match behavior otherwise
210-
/// constrained to certain dialects (e.g. trailing commas)
252+
/// Additional options that allow you to mix & match behavior
253+
/// otherwise constrained to certain dialects (e.g. trailing
254+
/// commas) and/or format of parse (e.g. unescaping)
211255
options: ParserOptions,
212256
/// ensure the stack does not overflow by limiting recursion depth
213257
recursion_counter: RecursionCounter,
@@ -267,17 +311,20 @@ impl<'a> Parser<'a> {
267311
/// Specify additional parser options
268312
///
269313
///
270-
/// [`Parser`] supports additional options ([`ParserOptions`]) that allow you to
271-
/// mix & match behavior otherwise constrained to certain dialects (e.g. trailing
272-
/// commas).
314+
/// [`Parser`] supports additional options ([`ParserOptions`])
315+
/// that allow you to mix & match behavior otherwise constrained
316+
/// to certain dialects (e.g. trailing commas).
273317
///
274318
/// Example:
275319
/// ```
276320
/// # use sqlparser::{parser::{Parser, ParserError, ParserOptions}, dialect::GenericDialect};
277321
/// # fn main() -> Result<(), ParserError> {
278322
/// let dialect = GenericDialect{};
323+
/// let options = ParserOptions::new()
324+
/// .with_trailing_commas(true)
325+
/// .with_unescape(false);
279326
/// let result = Parser::new(&dialect)
280-
/// .with_options(ParserOptions { trailing_commas: true })
327+
/// .with_options(options)
281328
/// .try_with_sql("SELECT a, b, COUNT(*), FROM foo GROUP BY a, b,")?
282329
/// .parse_statements();
283330
/// assert!(matches!(result, Ok(_)));
@@ -317,8 +364,9 @@ impl<'a> Parser<'a> {
317364
/// See example on [`Parser::new()`] for an example
318365
pub fn try_with_sql(self, sql: &str) -> Result<Self, ParserError> {
319366
debug!("Parsing sql '{}'...", sql);
320-
let mut tokenizer = Tokenizer::new(self.dialect, sql);
321-
let tokens = tokenizer.tokenize()?;
367+
let tokens = Tokenizer::new(self.dialect, sql)
368+
.with_unescape(self.options.unescape)
369+
.tokenize()?;
322370
Ok(self.with_tokens(tokens))
323371
}
324372

@@ -3654,7 +3702,7 @@ impl<'a> Parser<'a> {
36543702
self.expect_token(&Token::RParen)?;
36553703
Ok(Some(ColumnOption::Check(expr)))
36563704
} else if self.parse_keyword(Keyword::AUTO_INCREMENT)
3657-
&& dialect_of!(self is MySqlDialect | GenericDialect)
3705+
&& dialect_of!(self is MySqlDialect | GenericDialect)
36583706
{
36593707
// Support AUTO_INCREMENT for MySQL
36603708
Ok(Some(ColumnOption::DialectSpecific(vec![

0 commit comments

Comments
 (0)