Skip to content

Commit 0c99d1b

Browse files
committed
add utility functions for mutable visits
1 parent d4c0f35 commit 0c99d1b

File tree

1 file changed

+115
-0
lines changed

1 file changed

+115
-0
lines changed

src/ast/visitor.rs

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,14 @@ impl<E, F: FnMut(&ObjectName) -> ControlFlow<E>> Visitor for RelationVisitor<F>
298298
}
299299
}
300300

301+
impl<E, F: FnMut(&mut ObjectName) -> ControlFlow<E>> VisitorMut for RelationVisitor<F> {
302+
type Break = E;
303+
304+
fn pre_visit_relation(&mut self, relation: &mut ObjectName) -> ControlFlow<Self::Break> {
305+
self.0(relation)
306+
}
307+
}
308+
301309
/// Invokes the provided closure on all relations (e.g. table names) present in `v`
302310
///
303311
/// # Example
@@ -335,6 +343,36 @@ where
335343
ControlFlow::Continue(())
336344
}
337345

346+
/// Invokes the provided closure on all relations (e.g. table names) present in `v`
347+
///
348+
/// # Example
349+
/// ```
350+
/// # use sqlparser::parser::Parser;
351+
/// # use sqlparser::dialect::GenericDialect;
352+
/// # use sqlparser::ast::{ObjectName, visit_relations_mut};
353+
/// # use core::ops::ControlFlow;
354+
/// let sql = "SELECT a FROM foo";
355+
/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql)
356+
/// .unwrap();
357+
///
358+
/// // visit statements, renaming table foo to bar
359+
/// visit_relations_mut(&mut statements, |table| {
360+
/// table.0[0].value = table.0[0].value.replace("foo", "bar");
361+
/// ControlFlow::<()>::Continue(())
362+
/// });
363+
///
364+
/// assert_eq!(statements[0].to_string(), "SELECT a FROM bar");
365+
/// ```
366+
pub fn visit_relations_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
367+
where
368+
V: VisitMut,
369+
F: FnMut(&mut ObjectName) -> ControlFlow<E>,
370+
{
371+
let mut visitor = RelationVisitor(f);
372+
v.visit(&mut visitor)?;
373+
ControlFlow::Continue(())
374+
}
375+
338376
struct ExprVisitor<F>(F);
339377

340378
impl<E, F: FnMut(&Expr) -> ControlFlow<E>> Visitor for ExprVisitor<F> {
@@ -345,6 +383,14 @@ impl<E, F: FnMut(&Expr) -> ControlFlow<E>> Visitor for ExprVisitor<F> {
345383
}
346384
}
347385

386+
impl<E, F: FnMut(&mut Expr) -> ControlFlow<E>> VisitorMut for ExprVisitor<F> {
387+
type Break = E;
388+
389+
fn pre_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow<Self::Break> {
390+
self.0(expr)
391+
}
392+
}
393+
348394
/// Invokes the provided closure on all expressions (e.g. `1 + 2`) present in `v`
349395
///
350396
/// # Example
@@ -384,6 +430,36 @@ where
384430
ControlFlow::Continue(())
385431
}
386432

433+
/// Invokes the provided closure on all expressions present in `v`
434+
///
435+
/// # Example
436+
/// ```
437+
/// # use sqlparser::parser::Parser;
438+
/// # use sqlparser::dialect::GenericDialect;
439+
/// # use sqlparser::ast::{Expr, visit_expressions_mut, visit_statements_mut};
440+
/// # use core::ops::ControlFlow;
441+
/// let sql = "SELECT (SELECT y FROM z LIMIT 9) FROM t LIMIT 3";
442+
/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
443+
///
444+
/// // Remove all select limits in sub-queries
445+
/// visit_expressions_mut(&mut statements, |expr| {
446+
/// if let Expr::Subquery(q) = expr {
447+
/// q.limit = None
448+
/// }
449+
/// ControlFlow::<()>::Continue(())
450+
/// });
451+
///
452+
/// assert_eq!(statements[0].to_string(), "SELECT (SELECT y FROM z) FROM t LIMIT 3");
453+
/// ```
454+
pub fn visit_expressions_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
455+
where
456+
V: VisitMut,
457+
F: FnMut(&mut Expr) -> ControlFlow<E>,
458+
{
459+
v.visit(&mut ExprVisitor(f))?;
460+
ControlFlow::Continue(())
461+
}
462+
387463
struct StatementVisitor<F>(F);
388464

389465
impl<E, F: FnMut(&Statement) -> ControlFlow<E>> Visitor for StatementVisitor<F> {
@@ -394,6 +470,14 @@ impl<E, F: FnMut(&Statement) -> ControlFlow<E>> Visitor for StatementVisitor<F>
394470
}
395471
}
396472

473+
impl<E, F: FnMut(&mut Statement) -> ControlFlow<E>> VisitorMut for StatementVisitor<F> {
474+
type Break = E;
475+
476+
fn pre_visit_statement(&mut self, statement: &mut Statement) -> ControlFlow<Self::Break> {
477+
self.0(statement)
478+
}
479+
}
480+
397481
/// Invokes the provided closure on all statements (e.g. `SELECT`, `CREATE TABLE`, etc) present in `v`
398482
///
399483
/// # Example
@@ -431,6 +515,37 @@ where
431515
ControlFlow::Continue(())
432516
}
433517

518+
/// Invokes the provided closure on all statements (e.g. `SELECT`, `CREATE TABLE`, etc) present in `v`
519+
///
520+
/// # Example
521+
/// ```
522+
/// # use sqlparser::parser::Parser;
523+
/// # use sqlparser::dialect::GenericDialect;
524+
/// # use sqlparser::ast::{Statement, visit_statements_mut};
525+
/// # use core::ops::ControlFlow;
526+
/// let sql = "SELECT x FROM foo LIMIT 9+$limit; SELECT * FROM t LIMIT f()";
527+
/// let mut statements = Parser::parse_sql(&GenericDialect{}, sql).unwrap();
528+
///
529+
/// // Remove all select limits in outer statements (not in sub-queries)
530+
/// visit_statements_mut(&mut statements, |stmt| {
531+
/// if let Statement::Query(q) = stmt {
532+
/// q.limit = None
533+
/// }
534+
/// ControlFlow::<()>::Continue(())
535+
/// });
536+
///
537+
/// assert_eq!(statements[0].to_string(), "SELECT x FROM foo");
538+
/// assert_eq!(statements[1].to_string(), "SELECT * FROM t");
539+
/// ```
540+
pub fn visit_statements_mut<V, E, F>(v: &mut V, f: F) -> ControlFlow<E>
541+
where
542+
V: VisitMut,
543+
F: FnMut(&mut Statement) -> ControlFlow<E>,
544+
{
545+
v.visit(&mut StatementVisitor(f))?;
546+
ControlFlow::Continue(())
547+
}
548+
434549
#[cfg(test)]
435550
mod tests {
436551
use super::*;

0 commit comments

Comments
 (0)