Skip to content

Commit 23e42b8

Browse files
committed
visit_query
1 parent 4cdaa40 commit 23e42b8

File tree

4 files changed

+81
-17
lines changed

4 files changed

+81
-17
lines changed

derive/src/lib.rs

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use quote::{format_ident, quote, quote_spanned, ToTokens};
33
use syn::spanned::Spanned;
44
use syn::{
55
parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, GenericParam, Generics,
6-
Ident, Index, Lit, Meta, MetaNameValue, NestedMeta,
6+
Ident, Index, Lit, Meta, MetaNameValue, NestedMeta, Type
77
};
88

99

@@ -48,7 +48,7 @@ fn derive_visit(
4848
let generics = add_trait_bounds(input.generics, visit_type);
4949
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
5050

51-
let (pre_visit, post_visit) = attributes.visit(quote!(self));
51+
let (pre_visit, post_visit) = attributes.visit(quote!(self), false);
5252
let children = visit_children(&input.data, visit_type);
5353

5454
let expanded = quote! {
@@ -111,19 +111,48 @@ impl Attributes {
111111
}
112112

113113
/// Returns the pre and post visit token streams
114-
fn visit(&self, s: TokenStream) -> (Option<TokenStream>, Option<TokenStream>) {
114+
fn visit(&self, s: TokenStream, is_option: bool) -> (Option<TokenStream>, Option<TokenStream>) {
115115
let pre_visit = self.with.as_ref().map(|m| {
116116
let m = format_ident!("pre_{}", m);
117-
quote!(visitor.#m(#s)?;)
117+
if is_option {
118+
quote! {
119+
if let Some(f) = #s {
120+
visitor.#m(f)?;
121+
}
122+
}
123+
} else {
124+
quote!(visitor.#m(#s)?;)
125+
}
118126
});
119127
let post_visit = self.with.as_ref().map(|m| {
120128
let m = format_ident!("post_{}", m);
121-
quote!(visitor.#m(#s)?;)
129+
if is_option {
130+
quote! {
131+
if let Some(f) = #s {
132+
visitor.#m(f)?;
133+
}
134+
}
135+
} else {
136+
quote!(visitor.#m(#s)?;)
137+
}
122138
});
123139
(pre_visit, post_visit)
124140
}
125141
}
126142

143+
fn is_option(mut ty: &Type) -> bool {
144+
while let Type::Group(group) = ty {
145+
ty = &group.elem;
146+
}
147+
let Type::Path(ty) = &ty else {
148+
return false;
149+
};
150+
let Some(seg) = ty.path.segments.last() else {
151+
return false;
152+
};
153+
seg.ident == "Option"
154+
}
155+
127156
// Add a bound `T: Visit` to every type parameter T.
128157
fn add_trait_bounds(mut generics: Generics, VisitType{visit_trait, ..}: &VisitType) -> Generics {
129158
for param in &mut generics.params {
@@ -142,7 +171,7 @@ fn visit_children(data: &Data, VisitType{visit_trait, modifier, ..}: &VisitType)
142171
let recurse = fields.named.iter().map(|f| {
143172
let name = &f.ident;
144173
let attributes = Attributes::parse(&f.attrs);
145-
let (pre_visit, post_visit) = attributes.visit(quote!(&#modifier self.#name));
174+
let (pre_visit, post_visit) = attributes.visit(quote!(&#modifier self.#name), is_option(&f.ty));
146175
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#name, visitor)?; #post_visit)
147176
});
148177
quote! {
@@ -153,7 +182,7 @@ fn visit_children(data: &Data, VisitType{visit_trait, modifier, ..}: &VisitType)
153182
let recurse = fields.unnamed.iter().enumerate().map(|(i, f)| {
154183
let index = Index::from(i);
155184
let attributes = Attributes::parse(&f.attrs);
156-
let (pre_visit, post_visit) = attributes.visit(quote!(&self.#index));
185+
let (pre_visit, post_visit) = attributes.visit(quote!(&self.#index), is_option(&f.ty));
157186
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#index, visitor)?; #post_visit)
158187
});
159188
quote! {
@@ -173,7 +202,7 @@ fn visit_children(data: &Data, VisitType{visit_trait, modifier, ..}: &VisitType)
173202
let visit = fields.named.iter().map(|f| {
174203
let name = &f.ident;
175204
let attributes = Attributes::parse(&f.attrs);
176-
let (pre_visit, post_visit) = attributes.visit(name.to_token_stream());
205+
let (pre_visit, post_visit) = attributes.visit(name.to_token_stream(), is_option(&f.ty));
177206
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit)
178207
});
179208

@@ -188,7 +217,7 @@ fn visit_children(data: &Data, VisitType{visit_trait, modifier, ..}: &VisitType)
188217
let visit = fields.unnamed.iter().enumerate().map(|(i, f)| {
189218
let name = format_ident!("_{}", i);
190219
let attributes = Attributes::parse(&f.attrs);
191-
let (pre_visit, post_visit) = attributes.visit(name.to_token_stream());
220+
let (pre_visit, post_visit) = attributes.visit(name.to_token_stream(), is_option(&f.ty));
192221
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit)
193222
});
194223

src/ast/mod.rs

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ pub enum Expr {
408408
/// `[ NOT ] IN (SELECT ...)`
409409
InSubquery {
410410
expr: Box<Expr>,
411+
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
411412
subquery: Box<Query>,
412413
negated: bool,
413414
},
@@ -600,12 +601,16 @@ pub enum Expr {
600601
},
601602
/// An exists expression `[ NOT ] EXISTS(SELECT ...)`, used in expressions like
602603
/// `WHERE [ NOT ] EXISTS (SELECT ...)`.
603-
Exists { subquery: Box<Query>, negated: bool },
604+
Exists {
605+
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
606+
subquery: Box<Query>,
607+
negated: bool,
608+
},
604609
/// A parenthesized subquery `(SELECT ...)`, used in expression like
605610
/// `SELECT (subquery) AS x` or `WHERE (subquery) = x`
606-
Subquery(Box<Query>),
611+
Subquery(#[cfg_attr(feature = "visitor", visit(with = "visit_query"))] Box<Query>),
607612
/// An array subquery constructor, e.g. `SELECT ARRAY(SELECT 1 UNION SELECT 2)`
608-
ArraySubquery(Box<Query>),
613+
ArraySubquery(#[cfg_attr(feature = "visitor", visit(with = "visit_query"))] Box<Query>),
609614
/// The `LISTAGG` function `SELECT LISTAGG(...) WITHIN GROUP (ORDER BY ...)`
610615
ListAgg(ListAgg),
611616
/// The `ARRAY_AGG` function `SELECT ARRAY_AGG(... ORDER BY ...)`
@@ -1368,7 +1373,7 @@ pub enum Statement {
13681373
partition_action: Option<AddDropSync>,
13691374
},
13701375
/// SELECT
1371-
Query(Box<Query>),
1376+
Query(#[cfg_attr(feature = "visitor", visit(with = "visit_query"))] Box<Query>),
13721377
/// INSERT
13731378
Insert {
13741379
/// Only for Sqlite
@@ -1385,6 +1390,7 @@ pub enum Statement {
13851390
/// Overwrite (Hive)
13861391
overwrite: bool,
13871392
/// A SQL query that specifies what to insert
1393+
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
13881394
source: Box<Query>,
13891395
/// partitioned insert (Hive)
13901396
partitioned: Option<Vec<Expr>>,
@@ -1402,6 +1408,7 @@ pub enum Statement {
14021408
local: bool,
14031409
path: String,
14041410
file_format: Option<FileFormat>,
1411+
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
14051412
source: Box<Query>,
14061413
},
14071414
Copy {
@@ -1480,6 +1487,7 @@ pub enum Statement {
14801487
/// View name
14811488
name: ObjectName,
14821489
columns: Vec<Ident>,
1490+
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
14831491
query: Box<Query>,
14841492
with_options: Vec<SqlOption>,
14851493
cluster_by: Vec<Ident>,
@@ -1510,6 +1518,7 @@ pub enum Statement {
15101518
with_options: Vec<SqlOption>,
15111519
file_format: Option<FileFormat>,
15121520
location: Option<String>,
1521+
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
15131522
query: Option<Box<Query>>,
15141523
without_rowid: bool,
15151524
like: Option<ObjectName>,
@@ -1598,6 +1607,7 @@ pub enum Statement {
15981607
#[cfg_attr(feature = "visitor", visit(with = "visit_relation"))]
15991608
name: ObjectName,
16001609
columns: Vec<Ident>,
1610+
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
16011611
query: Box<Query>,
16021612
with_options: Vec<SqlOption>,
16031613
},
@@ -1665,6 +1675,7 @@ pub enum Statement {
16651675
/// Some(true) = WITH HOLD, specifies that the cursor can continue to be used after the transaction that created it successfully commits
16661676
/// Some(false) = WITHOUT HOLD, specifies that the cursor cannot be used outside of the transaction that created it
16671677
hold: Option<bool>,
1678+
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
16681679
query: Box<Query>,
16691680
},
16701681
/// FETCH - retrieve rows from a query using a cursor
@@ -1969,6 +1980,7 @@ pub enum Statement {
19691980
/// Table confs
19701981
options: Vec<SqlOption>,
19711982
/// Cache table as a Query
1983+
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
19721984
query: Option<Query>,
19731985
},
19741986
/// UNCACHE TABLE [ IF EXISTS ] <table_name>
@@ -4278,7 +4290,7 @@ pub enum CopySource {
42784290
/// are copied.
42794291
columns: Vec<Ident>,
42804292
},
4281-
Query(Box<Query>),
4293+
Query(#[cfg_attr(feature = "visitor", visit(with = "visit_query"))] Box<Query>),
42824294
}
42834295

42844296
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
@@ -4795,7 +4807,7 @@ impl fmt::Display for MacroArg {
47954807
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
47964808
pub enum MacroDefinition {
47974809
Expr(Expr),
4798-
Table(Query),
4810+
Table(#[cfg_attr(feature = "visitor", visit(with = "visit_query"))] Query),
47994811
}
48004812

48014813
impl fmt::Display for MacroDefinition {

src/ast/query.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ use crate::ast::*;
2626
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
2727
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
2828
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
29+
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
2930
pub struct Query {
3031
/// WITH (common table expressions, or CTEs)
3132
pub with: Option<With>,
@@ -86,7 +87,7 @@ pub enum SetExpr {
8687
Select(Box<Select>),
8788
/// Parenthesized SELECT subquery, which may include more set operations
8889
/// in its body and an optional ORDER BY / LIMIT.
89-
Query(Box<Query>),
90+
Query(#[cfg_attr(feature = "visitor", visit(with = "visit_query"))] Box<Query>),
9091
/// UNION/EXCEPT/INTERSECT of two queries
9192
SetOperation {
9293
op: SetOperator,
@@ -377,6 +378,7 @@ impl fmt::Display for With {
377378
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
378379
pub struct Cte {
379380
pub alias: TableAlias,
381+
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
380382
pub query: Box<Query>,
381383
pub from: Option<Ident>,
382384
}
@@ -687,6 +689,7 @@ pub enum TableFactor {
687689
},
688690
Derived {
689691
lateral: bool,
692+
#[cfg_attr(feature = "visitor", visit(with = "visit_query"))]
690693
subquery: Box<Query>,
691694
alias: Option<TableAlias>,
692695
},

src/ast/visitor.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
//! Recursive visitors for ast Nodes. See [`Visitor`] for more details.
1414
15-
use crate::ast::{Expr, ObjectName, Statement, TableFactor};
15+
use crate::ast::{Expr, ObjectName, Query, Statement, TableFactor};
1616
use core::ops::ControlFlow;
1717

1818
/// A type that can be visited by a [`Visitor`]. See [`Visitor`] for
@@ -179,6 +179,16 @@ pub trait Visitor {
179179
/// Type returned when the recursion returns early.
180180
type Break;
181181

182+
/// Invoked for any queries that appear in the AST before visiting children
183+
fn pre_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
184+
ControlFlow::Continue(())
185+
}
186+
187+
/// Invoked for any queries that appear in the AST after visiting children
188+
fn post_visit_query(&mut self, _query: &Query) -> ControlFlow<Self::Break> {
189+
ControlFlow::Continue(())
190+
}
191+
182192
/// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
183193
fn pre_visit_relation(&mut self, _relation: &ObjectName) -> ControlFlow<Self::Break> {
184194
ControlFlow::Continue(())
@@ -267,6 +277,16 @@ pub trait VisitorMut {
267277
/// Type returned when the recursion returns early.
268278
type Break;
269279

280+
/// Invoked for any queries that appear in the AST before visiting children
281+
fn pre_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
282+
ControlFlow::Continue(())
283+
}
284+
285+
/// Invoked for any queries that appear in the AST after visiting children
286+
fn post_visit_query(&mut self, _query: &mut Query) -> ControlFlow<Self::Break> {
287+
ControlFlow::Continue(())
288+
}
289+
270290
/// Invoked for any relations (e.g. tables) that appear in the AST before visiting children
271291
fn pre_visit_relation(&mut self, _relation: &mut ObjectName) -> ControlFlow<Self::Break> {
272292
ControlFlow::Continue(())

0 commit comments

Comments
 (0)