@@ -26,6 +26,10 @@ pub(super) enum ParsedSQLStatement {
26
26
Statement ( PreparedStatement ) ,
27
27
StaticSimpleSelect ( serde_json:: Map < String , serde_json:: Value > ) ,
28
28
Error ( anyhow:: Error ) ,
29
+ SetVariable {
30
+ variable : StmtParam ,
31
+ value : PreparedStatement ,
32
+ } ,
29
33
}
30
34
31
35
impl ParsedSqlFile {
@@ -40,20 +44,15 @@ impl ParsedSqlFile {
40
44
statements. push ( match parsed {
41
45
ParsedStatement :: StaticSimpleSelect ( s) => ParsedSQLStatement :: StaticSimpleSelect ( s) ,
42
46
ParsedStatement :: Error ( e) => ParsedSQLStatement :: Error ( e) ,
43
- ParsedStatement :: StmtWithParams { query, params } => {
44
- let param_types = get_param_types ( & params) ;
45
- match db. prepare_with ( & query, & param_types) . await {
46
- Ok ( statement) => {
47
- log:: debug!( "Successfully prepared SQL statement '{query}'" ) ;
48
- ParsedSQLStatement :: Statement ( PreparedStatement {
49
- statement,
50
- parameters : params,
51
- } )
52
- }
53
- Err ( err) => {
54
- log:: warn!( "Failed to prepare {query:?}: {err:#}" ) ;
55
- ParsedSQLStatement :: Error ( err)
47
+ ParsedStatement :: StmtWithParams ( stmt_with_params) => {
48
+ prepare_query_with_params ( db, stmt_with_params) . await
49
+ }
50
+ ParsedStatement :: SetVariable { variable, value } => {
51
+ match prepare_query_with_params ( db, value) . await {
52
+ ParsedSQLStatement :: Statement ( value) => {
53
+ ParsedSQLStatement :: SetVariable { variable, value }
56
54
}
55
+ err => err,
57
56
}
58
57
}
59
58
} ) ;
@@ -71,19 +70,45 @@ impl ParsedSqlFile {
71
70
}
72
71
}
73
72
73
+ async fn prepare_query_with_params (
74
+ db : & Database ,
75
+ StmtWithParams { query, params } : StmtWithParams ,
76
+ ) -> ParsedSQLStatement {
77
+ let param_types = get_param_types ( & params) ;
78
+ match db. prepare_with ( & query, & param_types) . await {
79
+ Ok ( statement) => {
80
+ log:: debug!( "Successfully prepared SQL statement '{query}'" ) ;
81
+ ParsedSQLStatement :: Statement ( PreparedStatement {
82
+ statement,
83
+ parameters : params,
84
+ } )
85
+ }
86
+ Err ( err) => {
87
+ log:: warn!( "Failed to prepare {query:?}: {err:#}" ) ;
88
+ ParsedSQLStatement :: Error ( err)
89
+ }
90
+ }
91
+ }
92
+
74
93
#[ async_trait( ? Send ) ]
75
94
impl AsyncFromStrWithState for ParsedSqlFile {
76
95
async fn from_str_with_state ( app_state : & AppState , source : & str ) -> anyhow:: Result < Self > {
77
96
Ok ( ParsedSqlFile :: new ( & app_state. db , source) . await )
78
97
}
79
98
}
80
99
100
+ struct StmtWithParams {
101
+ query : String ,
102
+ params : Vec < StmtParam > ,
103
+ }
104
+
81
105
enum ParsedStatement {
82
- StmtWithParams {
83
- query : String ,
84
- params : Vec < StmtParam > ,
85
- } ,
106
+ StmtWithParams ( StmtWithParams ) ,
86
107
StaticSimpleSelect ( serde_json:: Map < String , serde_json:: Value > ) ,
108
+ SetVariable {
109
+ variable : StmtParam ,
110
+ value : StmtWithParams ,
111
+ } ,
87
112
Error ( anyhow:: Error ) ,
88
113
}
89
114
@@ -113,8 +138,16 @@ fn parse_single_statement(parser: &mut Parser<'_>, db_kind: AnyKind) -> Option<P
113
138
return Some ( ParsedStatement :: StaticSimpleSelect ( static_statement) ) ;
114
139
}
115
140
let params = ParameterExtractor :: extract_parameters ( & mut stmt, db_kind) ;
116
- let query = stmt. to_string ( ) ;
117
- Some ( ParsedStatement :: StmtWithParams { query, params } )
141
+ if let Some ( ( variable, query) ) = extract_set_variable ( & mut stmt) {
142
+ return Some ( ParsedStatement :: SetVariable {
143
+ variable,
144
+ value : StmtWithParams { query, params } ,
145
+ } ) ;
146
+ }
147
+ Some ( ParsedStatement :: StmtWithParams ( StmtWithParams {
148
+ query : stmt. to_string ( ) ,
149
+ params,
150
+ } ) )
118
151
}
119
152
120
153
fn syntax_error ( err : ParserError , parser : & mut Parser ) -> ParsedStatement {
@@ -226,6 +259,24 @@ fn extract_static_simple_select(
226
259
Some ( map)
227
260
}
228
261
262
+ fn extract_set_variable ( stmt : & mut Statement ) -> Option < ( StmtParam , String ) > {
263
+ if let Statement :: SetVariable {
264
+ variable : ObjectName ( name) ,
265
+ value,
266
+ local : false ,
267
+ hivevar : false ,
268
+ } = stmt
269
+ {
270
+ if let ( [ ident] , [ value] ) = ( name. as_mut_slice ( ) , value. as_mut_slice ( ) ) {
271
+ if let Some ( variable) = extract_ident_param ( ident) {
272
+ let query = format ! ( "SELECT {value}" ) ;
273
+ return Some ( ( variable, query) ) ;
274
+ }
275
+ }
276
+ }
277
+ None
278
+ }
279
+
229
280
struct ParameterExtractor {
230
281
db_kind : AnyKind ,
231
282
parameters : Vec < StmtParam > ,
@@ -378,17 +429,24 @@ pub fn make_placeholder(db_kind: AnyKind, arg_number: usize) -> String {
378
429
DEFAULT_PLACEHOLDER . to_string ( )
379
430
}
380
431
432
+ fn extract_ident_param ( Ident { value, quote_style } : & mut Ident ) -> Option < StmtParam > {
433
+ if quote_style. is_none ( ) && value. starts_with ( '$' ) || value. starts_with ( ':' ) {
434
+ let name = std:: mem:: take ( value) ;
435
+ Some ( map_param ( name) )
436
+ } else {
437
+ None
438
+ }
439
+ }
440
+
381
441
impl VisitorMut for ParameterExtractor {
382
442
type Break = ( ) ;
383
443
fn pre_visit_expr ( & mut self , value : & mut Expr ) -> ControlFlow < Self :: Break > {
384
444
match value {
385
- Expr :: Identifier ( Ident {
386
- value : var_name,
387
- quote_style : None ,
388
- } ) if var_name. starts_with ( '$' ) || var_name. starts_with ( ':' ) => {
389
- let name = std:: mem:: take ( var_name) ;
390
- * value = self . make_placeholder ( ) ;
391
- self . parameters . push ( map_param ( name) ) ;
445
+ Expr :: Identifier ( ident) => {
446
+ if let Some ( param) = extract_ident_param ( ident) {
447
+ * value = self . make_placeholder ( ) ;
448
+ self . parameters . push ( param) ;
449
+ }
392
450
}
393
451
Expr :: Value ( Value :: Placeholder ( param) ) if !self . is_own_placeholder ( param) =>
394
452
// this check is to avoid recursively replacing placeholders in the form of '?', or '$1', '$2', which we emit ourselves
0 commit comments