Skip to content

Commit 524b8a7

Browse files
authored
Add a mutable visitor (#782)
* Add a mutable visitor This adds the ability to mutate parsed sql queries. Previously, only visitors taking an immutable reference to the visited structures were allowed. * add utility functions for mutable visits * bump version numbers
1 parent 86d71f2 commit 524b8a7

14 files changed

+428
-150
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ serde = { version = "1.0", features = ["derive"], optional = true }
3333
# of dev-dependencies because of
3434
# https://github.com/rust-lang/cargo/issues/1596
3535
serde_json = { version = "1.0", optional = true }
36-
sqlparser_derive = { version = "0.1", path = "derive", optional = true }
36+
sqlparser_derive = { version = "0.1.1", path = "derive", optional = true }
3737

3838
[dev-dependencies]
3939
simple_logger = "4.0"

derive/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[package]
22
name = "sqlparser_derive"
33
description = "proc macro for sqlparser"
4-
version = "0.1.0"
4+
version = "0.1.1"
55
authors = ["sqlparser-rs authors"]
66
homepage = "https://github.com/sqlparser-rs/sqlparser-rs"
77
documentation = "https://docs.rs/sqlparser_derive/"

derive/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ This crate contains a procedural macro that can automatically derive
66
implementations of the `Visit` trait in the [sqlparser](https://crates.io/crates/sqlparser) crate
77

88
```rust
9-
#[derive(Visit)]
9+
#[derive(Visit, VisitMut)]
1010
struct Foo {
1111
boolean: bool,
1212
bar: Bar,
1313
}
1414

15-
#[derive(Visit)]
15+
#[derive(Visit, VisitMut)]
1616
enum Bar {
1717
A(),
1818
B(String, bool),
@@ -51,7 +51,7 @@ impl Visit for Bar {
5151
Additionally certain types may wish to call a corresponding method on visitor before recursing
5252

5353
```rust
54-
#[derive(Visit)]
54+
#[derive(Visit, VisitMut)]
5555
#[visit(with = "visit_expr")]
5656
enum Expr {
5757
A(),

derive/src/lib.rs

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,58 @@ use syn::{
66
Ident, Index, Lit, Meta, MetaNameValue, NestedMeta,
77
};
88

9+
10+
/// Implementation of `[#derive(Visit)]`
11+
#[proc_macro_derive(VisitMut, attributes(visit))]
12+
pub fn derive_visit_mut(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
13+
derive_visit(input, &VisitType {
14+
visit_trait: quote!(VisitMut),
15+
visitor_trait: quote!(VisitorMut),
16+
modifier: Some(quote!(mut)),
17+
})
18+
}
19+
920
/// Implementation of `[#derive(Visit)]`
1021
#[proc_macro_derive(Visit, attributes(visit))]
11-
pub fn derive_visit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
22+
pub fn derive_visit_immutable(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
23+
derive_visit(input, &VisitType {
24+
visit_trait: quote!(Visit),
25+
visitor_trait: quote!(Visitor),
26+
modifier: None,
27+
})
28+
}
29+
30+
struct VisitType {
31+
visit_trait: TokenStream,
32+
visitor_trait: TokenStream,
33+
modifier: Option<TokenStream>,
34+
}
35+
36+
fn derive_visit(
37+
input: proc_macro::TokenStream,
38+
visit_type: &VisitType,
39+
) -> proc_macro::TokenStream {
1240
// Parse the input tokens into a syntax tree.
1341
let input = parse_macro_input!(input as DeriveInput);
1442
let name = input.ident;
1543

44+
let VisitType { visit_trait, visitor_trait, modifier } = visit_type;
45+
1646
let attributes = Attributes::parse(&input.attrs);
17-
// Add a bound `T: HeapSize` to every type parameter T.
18-
let generics = add_trait_bounds(input.generics);
47+
// Add a bound `T: Visit` to every type parameter T.
48+
let generics = add_trait_bounds(input.generics, visit_type);
1949
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
2050

2151
let (pre_visit, post_visit) = attributes.visit(quote!(self));
22-
let children = visit_children(&input.data);
52+
let children = visit_children(&input.data, visit_type);
2353

2454
let expanded = quote! {
2555
// The generated impl.
26-
impl #impl_generics sqlparser::ast::Visit for #name #ty_generics #where_clause {
27-
fn visit<V: sqlparser::ast::Visitor>(&self, visitor: &mut V) -> ::std::ops::ControlFlow<V::Break> {
56+
impl #impl_generics sqlparser::ast::#visit_trait for #name #ty_generics #where_clause {
57+
fn visit<V: sqlparser::ast::#visitor_trait>(
58+
&#modifier self,
59+
visitor: &mut V
60+
) -> ::std::ops::ControlFlow<V::Break> {
2861
#pre_visit
2962
#children
3063
#post_visit
@@ -92,25 +125,25 @@ impl Attributes {
92125
}
93126

94127
// Add a bound `T: Visit` to every type parameter T.
95-
fn add_trait_bounds(mut generics: Generics) -> Generics {
128+
fn add_trait_bounds(mut generics: Generics, VisitType{visit_trait, ..}: &VisitType) -> Generics {
96129
for param in &mut generics.params {
97130
if let GenericParam::Type(ref mut type_param) = *param {
98-
type_param.bounds.push(parse_quote!(sqlparser::ast::Visit));
131+
type_param.bounds.push(parse_quote!(sqlparser::ast::#visit_trait));
99132
}
100133
}
101134
generics
102135
}
103136

104137
// Generate the body of the visit implementation for the given type
105-
fn visit_children(data: &Data) -> TokenStream {
138+
fn visit_children(data: &Data, VisitType{visit_trait, modifier, ..}: &VisitType) -> TokenStream {
106139
match data {
107140
Data::Struct(data) => match &data.fields {
108141
Fields::Named(fields) => {
109142
let recurse = fields.named.iter().map(|f| {
110143
let name = &f.ident;
111144
let attributes = Attributes::parse(&f.attrs);
112-
let (pre_visit, post_visit) = attributes.visit(quote!(&self.#name));
113-
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(&self.#name, visitor)?; #post_visit)
145+
let (pre_visit, post_visit) = attributes.visit(quote!(&#modifier self.#name));
146+
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#name, visitor)?; #post_visit)
114147
});
115148
quote! {
116149
#(#recurse)*
@@ -121,7 +154,7 @@ fn visit_children(data: &Data) -> TokenStream {
121154
let index = Index::from(i);
122155
let attributes = Attributes::parse(&f.attrs);
123156
let (pre_visit, post_visit) = attributes.visit(quote!(&self.#index));
124-
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(&self.#index, visitor)?; #post_visit)
157+
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#index, visitor)?; #post_visit)
125158
});
126159
quote! {
127160
#(#recurse)*
@@ -140,8 +173,8 @@ fn visit_children(data: &Data) -> TokenStream {
140173
let visit = fields.named.iter().map(|f| {
141174
let name = &f.ident;
142175
let attributes = Attributes::parse(&f.attrs);
143-
let (pre_visit, post_visit) = attributes.visit(quote!(&#name));
144-
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(#name, visitor)?; #post_visit)
176+
let (pre_visit, post_visit) = attributes.visit(name.to_token_stream());
177+
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit)
145178
});
146179

147180
quote!(
@@ -155,8 +188,8 @@ fn visit_children(data: &Data) -> TokenStream {
155188
let visit = fields.unnamed.iter().enumerate().map(|(i, f)| {
156189
let name = format_ident!("_{}", i);
157190
let attributes = Attributes::parse(&f.attrs);
158-
let (pre_visit, post_visit) = attributes.visit(quote!(&#name));
159-
quote_spanned!(f.span() => #pre_visit sqlparser::ast::Visit::visit(#name, visitor)?; #post_visit)
191+
let (pre_visit, post_visit) = attributes.visit(name.to_token_stream());
192+
quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit)
160193
});
161194

162195
quote! {

src/ast/data_type.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use core::fmt;
1818
use serde::{Deserialize, Serialize};
1919

2020
#[cfg(feature = "visitor")]
21-
use sqlparser_derive::Visit;
21+
use sqlparser_derive::{Visit, VisitMut};
2222

2323
use crate::ast::ObjectName;
2424

@@ -27,7 +27,7 @@ use super::value::escape_single_quote_string;
2727
/// SQL data types
2828
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
2929
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
30-
#[cfg_attr(feature = "visitor", derive(Visit))]
30+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
3131
pub enum DataType {
3232
/// Fixed-length character type e.g. CHARACTER(10)
3333
Character(Option<CharacterLength>),
@@ -341,7 +341,7 @@ fn format_datetime_precision_and_tz(
341341
/// guarantee compatibility with the input query we must maintain its exact information.
342342
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
343343
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
344-
#[cfg_attr(feature = "visitor", derive(Visit))]
344+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
345345
pub enum TimezoneInfo {
346346
/// No information about time zone. E.g., TIMESTAMP
347347
None,
@@ -389,7 +389,7 @@ impl fmt::Display for TimezoneInfo {
389389
/// [standard]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#exact-numeric-type
390390
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
391391
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
392-
#[cfg_attr(feature = "visitor", derive(Visit))]
392+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
393393
pub enum ExactNumberInfo {
394394
/// No additional information e.g. `DECIMAL`
395395
None,
@@ -420,7 +420,7 @@ impl fmt::Display for ExactNumberInfo {
420420
/// [1]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#character-length
421421
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
422422
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
423-
#[cfg_attr(feature = "visitor", derive(Visit))]
423+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
424424
pub struct CharacterLength {
425425
/// Default (if VARYING) or maximum (if not VARYING) length
426426
pub length: u64,
@@ -443,7 +443,7 @@ impl fmt::Display for CharacterLength {
443443
/// [1]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#char-length-units
444444
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
445445
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
446-
#[cfg_attr(feature = "visitor", derive(Visit))]
446+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
447447
pub enum CharLengthUnits {
448448
/// CHARACTERS unit
449449
Characters,

src/ast/ddl.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use core::fmt;
2121
use serde::{Deserialize, Serialize};
2222

2323
#[cfg(feature = "visitor")]
24-
use sqlparser_derive::Visit;
24+
use sqlparser_derive::{Visit, VisitMut};
2525

2626
use crate::ast::value::escape_single_quote_string;
2727
use crate::ast::{display_comma_separated, display_separated, DataType, Expr, Ident, ObjectName};
@@ -30,7 +30,7 @@ use crate::tokenizer::Token;
3030
/// An `ALTER TABLE` (`Statement::AlterTable`) operation
3131
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
3232
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
33-
#[cfg_attr(feature = "visitor", derive(Visit))]
33+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
3434
pub enum AlterTableOperation {
3535
/// `ADD <table_constraint>`
3636
AddConstraint(TableConstraint),
@@ -100,7 +100,7 @@ pub enum AlterTableOperation {
100100

101101
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
102102
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
103-
#[cfg_attr(feature = "visitor", derive(Visit))]
103+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
104104
pub enum AlterIndexOperation {
105105
RenameIndex { index_name: ObjectName },
106106
}
@@ -224,7 +224,7 @@ impl fmt::Display for AlterIndexOperation {
224224
/// An `ALTER COLUMN` (`Statement::AlterTable`) operation
225225
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
226226
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
227-
#[cfg_attr(feature = "visitor", derive(Visit))]
227+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
228228
pub enum AlterColumnOperation {
229229
/// `SET NOT NULL`
230230
SetNotNull,
@@ -268,7 +268,7 @@ impl fmt::Display for AlterColumnOperation {
268268
/// `ALTER TABLE ADD <constraint>` statement.
269269
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
270270
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
271-
#[cfg_attr(feature = "visitor", derive(Visit))]
271+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
272272
pub enum TableConstraint {
273273
/// `[ CONSTRAINT <name> ] { PRIMARY KEY | UNIQUE } (<columns>)`
274274
Unique {
@@ -433,7 +433,7 @@ impl fmt::Display for TableConstraint {
433433
/// [1]: https://dev.mysql.com/doc/refman/8.0/en/create-table.html
434434
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
435435
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
436-
#[cfg_attr(feature = "visitor", derive(Visit))]
436+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
437437
pub enum KeyOrIndexDisplay {
438438
/// Nothing to display
439439
None,
@@ -469,7 +469,7 @@ impl fmt::Display for KeyOrIndexDisplay {
469469
/// [3]: https://www.postgresql.org/docs/14/sql-createindex.html
470470
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
471471
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
472-
#[cfg_attr(feature = "visitor", derive(Visit))]
472+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
473473
pub enum IndexType {
474474
BTree,
475475
Hash,
@@ -488,7 +488,7 @@ impl fmt::Display for IndexType {
488488
/// SQL column definition
489489
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
490490
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
491-
#[cfg_attr(feature = "visitor", derive(Visit))]
491+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
492492
pub struct ColumnDef {
493493
pub name: Ident,
494494
pub data_type: DataType,
@@ -524,7 +524,7 @@ impl fmt::Display for ColumnDef {
524524
/// "column options," and we allow any column option to be named.
525525
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
526526
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
527-
#[cfg_attr(feature = "visitor", derive(Visit))]
527+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
528528
pub struct ColumnOptionDef {
529529
pub name: Option<Ident>,
530530
pub option: ColumnOption,
@@ -540,7 +540,7 @@ impl fmt::Display for ColumnOptionDef {
540540
/// TABLE` statement.
541541
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
542542
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
543-
#[cfg_attr(feature = "visitor", derive(Visit))]
543+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
544544
pub enum ColumnOption {
545545
/// `NULL`
546546
Null,
@@ -630,7 +630,7 @@ fn display_constraint_name(name: &'_ Option<Ident>) -> impl fmt::Display + '_ {
630630
/// Used in foreign key constraints in `ON UPDATE` and `ON DELETE` options.
631631
#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
632632
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
633-
#[cfg_attr(feature = "visitor", derive(Visit))]
633+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
634634
pub enum ReferentialAction {
635635
Restrict,
636636
Cascade,

src/ast/helpers/stmt_create_table.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use alloc::{boxed::Box, format, string::String, vec, vec::Vec};
55
use serde::{Deserialize, Serialize};
66

77
#[cfg(feature = "visitor")]
8-
use sqlparser_derive::Visit;
8+
use sqlparser_derive::{Visit, VisitMut};
99

1010
use crate::ast::{
1111
ColumnDef, FileFormat, HiveDistributionStyle, HiveFormat, ObjectName, OnCommit, Query,
@@ -43,7 +43,7 @@ use crate::parser::ParserError;
4343
/// [1]: crate::ast::Statement::CreateTable
4444
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
4545
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
46-
#[cfg_attr(feature = "visitor", derive(Visit))]
46+
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
4747
pub struct CreateTableBuilder {
4848
pub or_replace: bool,
4949
pub temporary: bool,

0 commit comments

Comments
 (0)