Skip to content

Commit 86aa044

Browse files
authored
add {pre,post}_visit_query to Visitor (#1044)
1 parent 640b939 commit 86aa044

File tree

3 files changed

+140
-15
lines changed

3 files changed

+140
-15
lines changed

derive/README.md

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,33 +48,86 @@ impl Visit for Bar {
4848
}
4949
```
5050

51-
Additionally certain types may wish to call a corresponding method on visitor before recursing
51+
Some types may wish to call a corresponding method on the visitor:
5252

5353
```rust
5454
#[derive(Visit, VisitMut)]
5555
#[visit(with = "visit_expr")]
5656
enum Expr {
57-
A(),
58-
B(String, #[cfg_attr(feature = "visitor", visit(with = "visit_relation"))] ObjectName, bool),
57+
IsNull(Box<Expr>),
58+
..
5959
}
6060
```
6161

62-
Will generate
62+
This will result in the following sequence of visitor calls when an `IsNull`
63+
expression is visited
64+
65+
```
66+
visitor.pre_visit_expr(<is null expr>)
67+
visitor.pre_visit_expr(<is null operand>)
68+
visitor.post_visit_expr(<is null operand>)
69+
visitor.post_visit_expr(<is null expr>)
70+
```
71+
72+
For some types it is only appropriate to call a particular visitor method in
73+
some contexts. For example, not every `ObjectName` refers to a relation.
74+
75+
In these cases, the `visit` attribute can be used on the field for which we'd
76+
like to call the method:
6377

6478
```rust
65-
impl Visit for Bar {
79+
#[derive(Visit, VisitMut)]
80+
#[visit(with = "visit_table_factor")]
81+
pub enum TableFactor {
82+
Table {
83+
#[visit(with = "visit_relation")]
84+
name: ObjectName,
85+
alias: Option<TableAlias>,
86+
},
87+
..
88+
}
89+
```
90+
91+
This will generate
92+
93+
```rust
94+
impl Visit for TableFactor {
6695
fn visit<V: Visitor>(&self, visitor: &mut V) -> ControlFlow<V::Break> {
67-
visitor.visit_expr(self)?;
96+
visitor.pre_visit_table_factor(self)?;
6897
match self {
69-
Self::A() => {}
70-
Self::B(_1, _2, _3) => {
71-
_1.visit(visitor)?;
72-
visitor.visit_relation(_3)?;
73-
_2.visit(visitor)?;
74-
_3.visit(visitor)?;
98+
Self::Table { name, alias } => {
99+
visitor.pre_visit_relation(name)?;
100+
alias.visit(name)?;
101+
visitor.post_visit_relation(name)?;
102+
alias.visit(visitor)?;
75103
}
76104
}
105+
visitor.post_visit_table_factor(self)?;
77106
ControlFlow::Continue(())
78107
}
79108
}
80109
```
110+
111+
Note that annotating both the type and the field is incorrect as it will result
112+
in redundant calls to the method. For example
113+
114+
```rust
115+
#[derive(Visit, VisitMut)]
116+
#[visit(with = "visit_expr")]
117+
enum Expr {
118+
IsNull(#[visit(with = "visit_expr")] Box<Expr>),
119+
..
120+
}
121+
```
122+
123+
will result in these calls to the visitor
124+
125+
126+
```
127+
visitor.pre_visit_expr(<is null expr>)
128+
visitor.pre_visit_expr(<is null operand>)
129+
visitor.pre_visit_expr(<is null operand>)
130+
visitor.post_visit_expr(<is null operand>)
131+
visitor.post_visit_expr(<is null operand>)
132+
visitor.post_visit_expr(<is null expr>)
133+
```

src/ast/query.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ use crate::ast::*;
2626
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
2727
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
2828
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
29+
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
2930
pub struct Query {
3031
/// WITH (common table expressions, or CTEs)
3132
pub with: Option<With>,
@@ -739,7 +740,6 @@ pub enum TableFactor {
739740
/// For example `FROM monthly_sales PIVOT(sum(amount) FOR MONTH IN ('JAN', 'FEB'))`
740741
/// See <https://docs.snowflake.com/en/sql-reference/constructs/pivot>
741742
Pivot {
742-
#[cfg_attr(feature = "visitor", visit(with = "visit_table_factor"))]
743743
table: Box<TableFactor>,
744744
aggregate_function: Expr, // Function expression
745745
value_column: Vec<Ident>,
@@ -755,7 +755,6 @@ pub enum TableFactor {
755755
///
756756
/// See <https://docs.snowflake.com/en/sql-reference/constructs/unpivot>.
757757
Unpivot {
758-
#[cfg_attr(feature = "visitor", visit(with = "visit_table_factor"))]
759758
table: Box<TableFactor>,
760759
value: Ident,
761760
name: Ident,

src/ast/visitor.rs

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
//! Recursive visitors for ast Nodes. See [`Visitor`] for more details.
1414
15-
use crate::ast::{Expr, ObjectName, Statement, TableFactor};
15+
use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor};
1616
use core::ops::ControlFlow;
1717

1818
/// A type that can be visited by a [`Visitor`]. See [`Visitor`] for
@@ -179,6 +179,16 @@ pub trait Visitor {
179179
/// Type returned when the recursion returns early.
180180
type Break;
181181

182+
/// Invoked for any queries that appear in the AST before visiting children
183+
fn pre_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
184+
ControlFlow::Continue(())
185+
}
186+
187+
/// Invoked for any queries that appear in the AST after visiting children
188+
fn post_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
189+
ControlFlow::Continue(())
190+
}
191+
182192
/// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
183193
fn pre_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
184194
ControlFlow::Continue(())
@@ -267,6 +277,16 @@ pub trait VisitorMut {
267277
/// Type returned when the recursion returns early.
268278
type Break;
269279

280+
/// Invoked for any queries that appear in the AST before visiting children
281+
fn pre_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
282+
ControlFlow::Continue(())
283+
}
284+
285+
/// Invoked for any queries that appear in the AST after visiting children
286+
fn post_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
287+
ControlFlow::Continue(())
288+
}
289+
270290
/// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
271291
fn pre_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
272292
ControlFlow::Continue(())
@@ -626,6 +646,18 @@ mod tests {
626646
impl Visitor for TestVisitor {
627647
type Break = ();
628648

649+
/// Invoked for any queries that appear in the AST before visiting children
650+
fn pre_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
651+
self.visited.push(format!("PRE: QUERY: {query}"));
652+
ControlFlow::Continue(())
653+
}
654+
655+
/// Invoked for any queries that appear in the AST after visiting children
656+
fn post_visit_query(&mut self, query: &Query) -> ControlFlow<Self::Break> {
657+
self.visited.push(format!("POST: QUERY: {query}"));
658+
ControlFlow::Continue(())
659+
}
660+
629661
fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<Self::Break> {
630662
self.visited.push(format!("PRE: RELATION: {relation}"));
631663
ControlFlow::Continue(())
@@ -695,17 +727,20 @@ mod tests {
695727
"SELECT * from table_name as my_table",
696728
vec![
697729
"PRE: STATEMENT: SELECT * FROM table_name AS my_table",
730+
"PRE: QUERY: SELECT * FROM table_name AS my_table",
698731
"PRE: TABLE FACTOR: table_name AS my_table",
699732
"PRE: RELATION: table_name",
700733
"POST: RELATION: table_name",
701734
"POST: TABLE FACTOR: table_name AS my_table",
735+
"POST: QUERY: SELECT * FROM table_name AS my_table",
702736
"POST: STATEMENT: SELECT * FROM table_name AS my_table",
703737
],
704738
),
705739
(
706740
"SELECT * from t1 join t2 on t1.id = t2.t1_id",
707741
vec![
708742
"PRE: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
743+
"PRE: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
709744
"PRE: TABLE FACTOR: t1",
710745
"PRE: RELATION: t1",
711746
"POST: RELATION: t1",
@@ -720,70 +755,108 @@ mod tests {
720755
"PRE: EXPR: t2.t1_id",
721756
"POST: EXPR: t2.t1_id",
722757
"POST: EXPR: t1.id = t2.t1_id",
758+
"POST: QUERY: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
723759
"POST: STATEMENT: SELECT * FROM t1 JOIN t2 ON t1.id = t2.t1_id",
724760
],
725761
),
726762
(
727763
"SELECT * from t1 where EXISTS(SELECT column from t2)",
728764
vec![
729765
"PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
766+
"PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
730767
"PRE: TABLE FACTOR: t1",
731768
"PRE: RELATION: t1",
732769
"POST: RELATION: t1",
733770
"POST: TABLE FACTOR: t1",
734771
"PRE: EXPR: EXISTS (SELECT column FROM t2)",
772+
"PRE: QUERY: SELECT column FROM t2",
735773
"PRE: EXPR: column",
736774
"POST: EXPR: column",
737775
"PRE: TABLE FACTOR: t2",
738776
"PRE: RELATION: t2",
739777
"POST: RELATION: t2",
740778
"POST: TABLE FACTOR: t2",
779+
"POST: QUERY: SELECT column FROM t2",
741780
"POST: EXPR: EXISTS (SELECT column FROM t2)",
781+
"POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
742782
"POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
743783
],
744784
),
745785
(
746786
"SELECT * from t1 where EXISTS(SELECT column from t2)",
747787
vec![
748788
"PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
789+
"PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
749790
"PRE: TABLE FACTOR: t1",
750791
"PRE: RELATION: t1",
751792
"POST: RELATION: t1",
752793
"POST: TABLE FACTOR: t1",
753794
"PRE: EXPR: EXISTS (SELECT column FROM t2)",
795+
"PRE: QUERY: SELECT column FROM t2",
754796
"PRE: EXPR: column",
755797
"POST: EXPR: column",
756798
"PRE: TABLE FACTOR: t2",
757799
"PRE: RELATION: t2",
758800
"POST: RELATION: t2",
759801
"POST: TABLE FACTOR: t2",
802+
"POST: QUERY: SELECT column FROM t2",
760803
"POST: EXPR: EXISTS (SELECT column FROM t2)",
804+
"POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
761805
"POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2)",
762806
],
763807
),
764808
(
765809
"SELECT * from t1 where EXISTS(SELECT column from t2) UNION SELECT * from t3",
766810
vec![
767811
"PRE: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
812+
"PRE: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
768813
"PRE: TABLE FACTOR: t1",
769814
"PRE: RELATION: t1",
770815
"POST: RELATION: t1",
771816
"POST: TABLE FACTOR: t1",
772817
"PRE: EXPR: EXISTS (SELECT column FROM t2)",
818+
"PRE: QUERY: SELECT column FROM t2",
773819
"PRE: EXPR: column",
774820
"POST: EXPR: column",
775821
"PRE: TABLE FACTOR: t2",
776822
"PRE: RELATION: t2",
777823
"POST: RELATION: t2",
778824
"POST: TABLE FACTOR: t2",
825+
"POST: QUERY: SELECT column FROM t2",
779826
"POST: EXPR: EXISTS (SELECT column FROM t2)",
780827
"PRE: TABLE FACTOR: t3",
781828
"PRE: RELATION: t3",
782829
"POST: RELATION: t3",
783830
"POST: TABLE FACTOR: t3",
831+
"POST: QUERY: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
784832
"POST: STATEMENT: SELECT * FROM t1 WHERE EXISTS (SELECT column FROM t2) UNION SELECT * FROM t3",
785833
],
786834
),
835+
(
836+
concat!(
837+
"SELECT * FROM monthly_sales ",
838+
"PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ",
839+
"ORDER BY EMPID"
840+
),
841+
vec![
842+
"PRE: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
843+
"PRE: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
844+
"PRE: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
845+
"PRE: TABLE FACTOR: monthly_sales",
846+
"PRE: RELATION: monthly_sales",
847+
"POST: RELATION: monthly_sales",
848+
"POST: TABLE FACTOR: monthly_sales",
849+
"PRE: EXPR: SUM(a.amount)",
850+
"PRE: EXPR: a.amount",
851+
"POST: EXPR: a.amount",
852+
"POST: EXPR: SUM(a.amount)",
853+
"POST: TABLE FACTOR: monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d)",
854+
"PRE: EXPR: EMPID",
855+
"POST: EXPR: EMPID",
856+
"POST: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
857+
"POST: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
858+
]
859+
)
787860
];
788861
for (sql, expected) in tests {
789862
let actual = do_visit(sql);

0 commit comments

Comments
 (0)