Skip to content

Commit 92c6e7f

Browse files
authored
Support relation visitor to visit the Option field (#1556)
1 parent 6291afb commit 92c6e7f

File tree

4 files changed

+89
-8
lines changed

4 files changed

+89
-8
lines changed

derive/README.md

+49
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,55 @@ visitor.post_visit_expr(<is null operand>)
151151
visitor.post_visit_expr(<is null expr>)
152152
```
153153

154+
If the field is a `Option` and add `#[with = "visit_xxx"]` to the field, the generated code
155+
will try to access the field only if it is `Some`:
156+
157+
```rust
158+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
159+
pub struct ShowStatementIn {
160+
pub clause: ShowStatementInClause,
161+
pub parent_type: Option<ShowStatementInParentType>,
162+
#[cfg_attr(feature = "visitor", visit(with = "visit_relation"))]
163+
pub parent_name: Option<ObjectName>,
164+
}
165+
```
166+
167+
This will generate
168+
169+
```rust
170+
impl sqlparser::ast::Visit for ShowStatementIn {
171+
fn visit<V: sqlparser::ast::Visitor>(
172+
&self,
173+
visitor: &mut V,
174+
) -> ::std::ops::ControlFlow<V::Break> {
175+
sqlparser::ast::Visit::visit(&self.clause, visitor)?;
176+
sqlparser::ast::Visit::visit(&self.parent_type, visitor)?;
177+
if let Some(value) = &self.parent_name {
178+
visitor.pre_visit_relation(value)?;
179+
sqlparser::ast::Visit::visit(value, visitor)?;
180+
visitor.post_visit_relation(value)?;
181+
}
182+
::std::ops::ControlFlow::Continue(())
183+
}
184+
}
185+
186+
impl sqlparser::ast::VisitMut for ShowStatementIn {
187+
fn visit<V: sqlparser::ast::VisitorMut>(
188+
&mut self,
189+
visitor: &mut V,
190+
) -> ::std::ops::ControlFlow<V::Break> {
191+
sqlparser::ast::VisitMut::visit(&mut self.clause, visitor)?;
192+
sqlparser::ast::VisitMut::visit(&mut self.parent_type, visitor)?;
193+
if let Some(value) = &mut self.parent_name {
194+
visitor.pre_visit_relation(value)?;
195+
sqlparser::ast::VisitMut::visit(value, visitor)?;
196+
visitor.post_visit_relation(value)?;
197+
}
198+
::std::ops::ControlFlow::Continue(())
199+
}
200+
}
201+
```
202+
154203
## Releasing
155204

156205
This crate's release is not automated. Instead it is released manually as needed

derive/src/lib.rs

+29-7
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,8 @@
1818
use proc_macro2::TokenStream;
1919
use quote::{format_ident, quote, quote_spanned, ToTokens};
2020
use syn::spanned::Spanned;
21-
use syn::{
22-
parse::{Parse, ParseStream},
23-
parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, GenericParam, Generics,
24-
Ident, Index, LitStr, Meta, Token,
25-
};
21+
use syn::{parse::{Parse, ParseStream}, parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, GenericParam, Generics, Ident, Index, LitStr, Meta, Token, Type, TypePath};
22+
use syn::{Path, PathArguments};
2623

2724
/// Implementation of `[#derive(Visit)]`
2825
#[proc_macro_derive(VisitMut, attributes(visit))]
@@ -182,9 +179,21 @@ fn visit_children(
182179
Fields::Named(fields) => {
183180
let recurse = fields.named.iter().map(|f| {
184181
let name = &f.ident;
182+
let is_option = is_option(&f.ty);
185183
let attributes = Attributes::parse(&f.attrs);
186-
let (pre_visit, post_visit) = attributes.visit(quote!(&#modifier self.#name));
187-
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#name, visitor)?; #post_visit)
184+
if is_option && attributes.with.is_some() {
185+
let (pre_visit, post_visit) = attributes.visit(quote!(value));
186+
quote_spanned!(f.span() =>
187+
if let Some(value) = &#modifier self.#name {
188+
#pre_visit sqlparser::ast::#visit_trait::visit(value, visitor)?; #post_visit
189+
}
190+
)
191+
} else {
192+
let (pre_visit, post_visit) = attributes.visit(quote!(&#modifier self.#name));
193+
quote_spanned!(f.span() =>
194+
#pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#name, visitor)?; #post_visit
195+
)
196+
}
188197
});
189198
quote! {
190199
#(#recurse)*
@@ -256,3 +265,16 @@ fn visit_children(
256265
Data::Union(_) => unimplemented!(),
257266
}
258267
}
268+
269+
fn is_option(ty: &Type) -> bool {
270+
if let Type::Path(TypePath { path: Path { segments, .. }, .. }) = ty {
271+
if let Some(segment) = segments.last() {
272+
if segment.ident == "Option" {
273+
if let PathArguments::AngleBracketed(args) = &segment.arguments {
274+
return args.args.len() == 1;
275+
}
276+
}
277+
}
278+
}
279+
false
280+
}

src/ast/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -7653,6 +7653,7 @@ impl fmt::Display for ShowStatementInParentType {
76537653
pub struct ShowStatementIn {
76547654
pub clause: ShowStatementInClause,
76557655
pub parent_type: Option<ShowStatementInParentType>,
7656+
#[cfg_attr(feature = "visitor", visit(with = "visit_relation"))]
76567657
pub parent_name: Option<ObjectName>,
76577658
}
76587659

src/ast/visitor.rs

+10-1
Original file line numberDiff line numberDiff line change
@@ -876,7 +876,16 @@ mod tests {
876876
"POST: QUERY: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
877877
"POST: STATEMENT: SELECT * FROM monthly_sales PIVOT(SUM(a.amount) FOR a.MONTH IN ('JAN', 'FEB', 'MAR', 'APR')) AS p (c, d) ORDER BY EMPID",
878878
]
879-
)
879+
),
880+
(
881+
"SHOW COLUMNS FROM t1",
882+
vec![
883+
"PRE: STATEMENT: SHOW COLUMNS FROM t1",
884+
"PRE: RELATION: t1",
885+
"POST: RELATION: t1",
886+
"POST: STATEMENT: SHOW COLUMNS FROM t1",
887+
],
888+
),
880889
];
881890
for (sql, expected) in tests {
882891
let actual = do_visit(sql);

0 commit comments

Comments
 (0)