@@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize};
6
6
use test_strategy:: Arbitrary ;
7
7
8
8
use crate :: {
9
- ast:: * , AstConversionError , Dialect , DialectDisplay , FromDialect , IntoDialect , TryFromDialect ,
9
+ ast:: * , AstConversionError , Dialect , DialectDisplay , IntoDialect , TryFromDialect ,
10
10
TryIntoDialect ,
11
11
} ;
12
12
@@ -36,10 +36,12 @@ impl TryFromDialect<sqlparser::ast::Statement> for SetStatement {
36
36
let name = variables
37
37
. into_iter ( )
38
38
. exactly_one ( )
39
- . map ( |mut object_name| object_name. 0 . pop ( ) . unwrap ( ) )
40
- . expect ( "Snowflake-style multiple variables not supported" )
39
+ . map_err ( |_| failed_err ! ( "Missing variable name" ) ) ?
40
+ . 0
41
+ . pop ( )
42
+ . unwrap ( )
41
43
. into_dialect ( dialect) ;
42
- let value: SetPostgresParameterValue = value. into_dialect ( dialect) ;
44
+ let value: SetPostgresParameterValue = value. try_into_dialect ( dialect) ? ;
43
45
let scope = if local {
44
46
Some ( PostgresParameterScope :: Local )
45
47
} else {
@@ -52,17 +54,18 @@ impl TryFromDialect<sqlparser::ast::Statement> for SetStatement {
52
54
} ) )
53
55
}
54
56
Dialect :: MySQL => {
55
- let name = variables
56
- . into_iter ( )
57
- . exactly_one ( )
58
- . expect ( "Snowflake-style multiple variables not supported" ) ;
57
+ let name = variables. into_iter ( ) . exactly_one ( ) . map_err ( |_| {
58
+ unsupported_err ! ( "Only single variable assignment supported" )
59
+ } ) ?;
59
60
Ok ( Self :: Variable ( SetVariables {
60
61
variables : vec ! [ (
61
62
name. try_into( ) ?,
62
63
value
63
64
. into_iter( )
64
65
. exactly_one( )
65
- . expect( "Multiple variable assignments not supported" )
66
+ . map_err( |_| {
67
+ unsupported_err!( "Only single variable assignment supported" )
68
+ } ) ?
66
69
. try_into_dialect( dialect) ?,
67
70
) ] ,
68
71
} ) )
@@ -123,23 +126,28 @@ pub enum SetPostgresParameterValue {
123
126
Value ( PostgresParameterValue ) ,
124
127
}
125
128
126
- impl FromDialect < Vec < sqlparser:: ast:: Expr > > for SetPostgresParameterValue {
127
- fn from_dialect ( value : Vec < sqlparser:: ast:: Expr > , dialect : Dialect ) -> Self {
129
+ impl TryFromDialect < Vec < sqlparser:: ast:: Expr > > for SetPostgresParameterValue {
130
+ fn try_from_dialect (
131
+ value : Vec < sqlparser:: ast:: Expr > ,
132
+ dialect : Dialect ,
133
+ ) -> Result < Self , AstConversionError > {
128
134
if value. len ( ) == 1 {
129
135
if let sqlparser:: ast:: Expr :: Identifier ( sqlparser:: ast:: Ident { value, .. } ) = & value[ 0 ]
130
136
{
131
137
if value. eq_ignore_ascii_case ( "DEFAULT" ) {
132
- return Self :: Default ;
138
+ return Ok ( Self :: Default ) ;
133
139
}
134
140
}
135
141
}
136
- let values = value. into_iter ( ) . map ( |expr| expr. into_dialect ( dialect) ) ;
142
+ let values = value. into_iter ( ) . map ( |expr| expr. try_into_dialect ( dialect) ) ;
137
143
if values. len ( ) == 1 {
138
- Self :: Value ( PostgresParameterValue :: Single (
139
- values. exactly_one ( ) . unwrap ( ) ,
140
- ) )
144
+ Ok ( Self :: Value ( PostgresParameterValue :: Single (
145
+ values. exactly_one ( ) . unwrap ( ) ? ,
146
+ ) ) )
141
147
} else {
142
- Self :: Value ( PostgresParameterValue :: List ( values. collect ( ) ) )
148
+ Ok ( Self :: Value ( PostgresParameterValue :: List (
149
+ values. try_collect ( ) ?,
150
+ ) ) )
143
151
}
144
152
}
145
153
}
@@ -160,14 +168,17 @@ pub enum PostgresParameterValueInner {
160
168
Literal ( Literal ) ,
161
169
}
162
170
163
- impl FromDialect < sqlparser:: ast:: Expr > for PostgresParameterValueInner {
164
- fn from_dialect ( value : sqlparser:: ast:: Expr , dialect : Dialect ) -> Self {
171
+ impl TryFromDialect < sqlparser:: ast:: Expr > for PostgresParameterValueInner {
172
+ fn try_from_dialect (
173
+ value : sqlparser:: ast:: Expr ,
174
+ dialect : Dialect ,
175
+ ) -> Result < Self , AstConversionError > {
165
176
match value {
166
- sqlparser:: ast:: Expr :: Value ( value) => Self :: Literal ( value. into ( ) ) ,
177
+ sqlparser:: ast:: Expr :: Value ( value) => Ok ( Self :: Literal ( value. try_into ( ) ? ) ) ,
167
178
sqlparser:: ast:: Expr :: Identifier ( ident) => {
168
- Self :: Identifier ( ident. into_dialect ( dialect) )
179
+ Ok ( Self :: Identifier ( ident. into_dialect ( dialect) ) )
169
180
}
170
- _ => unimplemented ! ( "unsupported postgres parameter value {value:?}" ) ,
181
+ _ => unsupported ! ( "unsupported Postgres parameter value {value:?}" ) ,
171
182
}
172
183
}
173
184
}
@@ -332,18 +343,20 @@ impl From<sqlparser::ast::Ident> for Variable {
332
343
}
333
344
}
334
345
335
- impl TryFrom < sqlparser:: ast:: ObjectName > for Variable {
346
+ impl TryFrom < Vec < sqlparser:: ast:: Ident > > for Variable {
336
347
type Error = AstConversionError ;
337
348
338
- fn try_from ( mut value : sqlparser:: ast:: ObjectName ) -> Result < Self , Self :: Error > {
339
- let name = match value. 0 . pop ( ) . unwrap ( ) {
340
- // XXX(mvzink): We lowercase across the board (even ignoring dialect) just to match nom-sql
341
- sqlparser:: ast:: ObjectNamePart :: Identifier ( ident) => ident. value . to_lowercase ( ) ,
342
- } ;
343
- if value. 0 . is_empty ( ) {
349
+ fn try_from ( mut value : Vec < sqlparser:: ast:: Ident > ) -> Result < Self , Self :: Error > {
350
+ // XXX(mvzink): We lowercase across the board (even ignoring dialect) just to match nom-sql
351
+ let name = value
352
+ . pop ( )
353
+ . ok_or_else ( || failed_err ! ( "Empty variable name" ) ) ?
354
+ . value
355
+ . to_lowercase ( ) ;
356
+ if value. is_empty ( ) {
344
357
Ok ( name. into ( ) )
345
- } else if value. 0 . len ( ) == 1 {
346
- let scope = value. 0 . pop ( ) . unwrap ( ) . into ( ) ;
358
+ } else if value. len ( ) == 1 {
359
+ let scope = value. pop ( ) . unwrap ( ) . value . as_str ( ) . into ( ) ;
347
360
Ok ( Self {
348
361
scope,
349
362
name : name. into ( ) ,
@@ -354,6 +367,21 @@ impl TryFrom<sqlparser::ast::ObjectName> for Variable {
354
367
}
355
368
}
356
369
370
+ impl TryFrom < sqlparser:: ast:: ObjectName > for Variable {
371
+ type Error = AstConversionError ;
372
+
373
+ fn try_from ( value : sqlparser:: ast:: ObjectName ) -> Result < Self , Self :: Error > {
374
+ value
375
+ . 0
376
+ . into_iter ( )
377
+ . map ( |part| match part {
378
+ sqlparser:: ast:: ObjectNamePart :: Identifier ( ident) => ident,
379
+ } )
380
+ . collect :: < Vec < _ > > ( )
381
+ . try_into ( )
382
+ }
383
+ }
384
+
357
385
#[ derive( Clone , Debug , Eq , Hash , PartialEq , Serialize , Deserialize , Arbitrary ) ]
358
386
pub struct SetVariables {
359
387
/// A list of variables and their assigned values
0 commit comments