@@ -5,7 +5,11 @@ use futures_util::StreamExt;
5
5
use sqlparser:: ast:: {
6
6
CopyLegacyCsvOption , CopyLegacyOption , CopyOption , CopySource , CopyTarget , Statement ,
7
7
} ;
8
- use sqlx:: { any:: AnyArguments , AnyConnection , Arguments , Executor } ;
8
+ use sqlx:: {
9
+ any:: { AnyArguments , AnyKind } ,
10
+ AnyConnection , Arguments , Executor ,
11
+ } ;
12
+ use tokio:: io:: AsyncRead ;
9
13
10
14
use crate :: webserver:: http_request_info:: RequestInfo ;
11
15
@@ -152,7 +156,16 @@ pub(super) async fn run_csv_import(
152
156
let file = tokio:: fs:: File :: open ( file_path)
153
157
. await
154
158
. with_context ( || "opening csv" ) ?;
155
- let insert_stmt = create_insert_stmt ( db, csv_import) ;
159
+ let buffered = tokio:: io:: BufReader :: new ( file) ;
160
+ run_csv_import_on_path ( db, csv_import, buffered) . await
161
+ }
162
+
163
+ async fn run_csv_import_on_path (
164
+ db : & mut AnyConnection ,
165
+ csv_import : & CsvImport ,
166
+ file : impl AsyncRead + Unpin + Send ,
167
+ ) -> anyhow:: Result < ( ) > {
168
+ let insert_stmt = create_insert_stmt ( db. kind ( ) , csv_import) ;
156
169
log:: debug!( "CSV data insert statement: {insert_stmt}" ) ;
157
170
let mut reader = make_csv_reader ( csv_import, file) ;
158
171
let col_idxs = compute_column_indices ( & mut reader, csv_import) . await ?;
@@ -164,8 +177,8 @@ pub(super) async fn run_csv_import(
164
177
Ok ( ( ) )
165
178
}
166
179
167
- async fn compute_column_indices (
168
- reader : & mut csv_async:: AsyncReader < tokio :: fs :: File > ,
180
+ async fn compute_column_indices < R : AsyncRead + Unpin + Send > (
181
+ reader : & mut csv_async:: AsyncReader < R > ,
169
182
csv_import : & CsvImport ,
170
183
) -> anyhow:: Result < Vec < usize > > {
171
184
let mut col_idxs = Vec :: with_capacity ( csv_import. columns . len ( ) ) ;
@@ -189,16 +202,17 @@ async fn compute_column_indices(
189
202
Ok ( col_idxs)
190
203
}
191
204
192
- fn create_insert_stmt ( db : & mut AnyConnection , csv_import : & CsvImport ) -> String {
193
- let kind = db. kind ( ) ;
205
+ fn create_insert_stmt ( kind : AnyKind , csv_import : & CsvImport ) -> String {
194
206
let columns = csv_import. columns . join ( ", " ) ;
195
207
let placeholders = csv_import
196
208
. columns
197
209
. iter ( )
198
210
. enumerate ( )
199
- . map ( |( i, _) | make_placeholder ( kind, i) )
211
+ . map ( |( i, _) | make_placeholder ( kind, i + 1 ) )
200
212
. fold ( String :: new ( ) , |mut acc, f| {
201
- acc. push_str ( ", " ) ;
213
+ if !acc. is_empty ( ) {
214
+ acc. push_str ( ", " ) ;
215
+ }
202
216
acc. push_str ( & f) ;
203
217
acc
204
218
} ) ;
@@ -225,10 +239,10 @@ async fn process_csv_record(
225
239
Ok ( ( ) )
226
240
}
227
241
228
- fn make_csv_reader (
242
+ fn make_csv_reader < R : AsyncRead + Unpin + Send > (
229
243
csv_import : & CsvImport ,
230
- file : tokio :: fs :: File ,
231
- ) -> csv_async:: AsyncReader < tokio :: fs :: File > {
244
+ file : R ,
245
+ ) -> csv_async:: AsyncReader < R > {
232
246
let delimiter = csv_import
233
247
. delimiter
234
248
. and_then ( |c| u8:: try_from ( c) . ok ( ) )
@@ -246,3 +260,60 @@ fn make_csv_reader(
246
260
. escape ( escape)
247
261
. create_reader ( file)
248
262
}
263
+
264
+ #[ test]
265
+ fn test_make_statement ( ) {
266
+ let csv_import = CsvImport {
267
+ query : "COPY my_table (col1, col2) FROM 'my_file.csv' WITH (DELIMITER ';', HEADER)" . into ( ) ,
268
+ table_name : "my_table" . into ( ) ,
269
+ columns : vec ! [ "col1" . into( ) , "col2" . into( ) ] ,
270
+ delimiter : Some ( ';' ) ,
271
+ quote : None ,
272
+ header : Some ( true ) ,
273
+ null_str : None ,
274
+ escape : None ,
275
+ uploaded_file : "my_file.csv" . into ( ) ,
276
+ } ;
277
+ let insert_stmt = create_insert_stmt ( AnyKind :: Postgres , & csv_import) ;
278
+ assert_eq ! (
279
+ insert_stmt,
280
+ "INSERT INTO my_table (col1, col2) VALUES ($1, $2)"
281
+ ) ;
282
+ }
283
+
284
+ #[ actix_web:: test]
285
+ async fn test_end_to_end ( ) {
286
+ use sqlx:: ConnectOptions ;
287
+
288
+ let mut copy_stmt = sqlparser:: parser:: Parser :: parse_sql (
289
+ & sqlparser:: dialect:: GenericDialect { } ,
290
+ "COPY my_table (col1, col2) FROM 'my_file.csv' WITH (DELIMITER ';', HEADER)" ,
291
+ )
292
+ . unwrap ( )
293
+ . into_iter ( )
294
+ . next ( )
295
+ . unwrap ( ) ;
296
+ let csv_import = extract_csv_copy_statement ( & mut copy_stmt) . unwrap ( ) ;
297
+ let mut conn = "sqlite::memory:"
298
+ . parse :: < sqlx:: any:: AnyConnectOptions > ( )
299
+ . unwrap ( )
300
+ . connect ( )
301
+ . await
302
+ . unwrap ( ) ;
303
+ conn. execute ( "CREATE TABLE my_table (col1 TEXT, col2 TEXT)" )
304
+ . await
305
+ . unwrap ( ) ;
306
+ let csv = "col2;col1\n a;b\n c;d" ; // order is different from the table
307
+ let file = csv. as_bytes ( ) ;
308
+ run_csv_import_on_path ( & mut conn, & csv_import, file)
309
+ . await
310
+ . unwrap ( ) ;
311
+ let rows: Vec < ( String , String ) > = sqlx:: query_as ( "SELECT * FROM my_table" )
312
+ . fetch_all ( & mut conn)
313
+ . await
314
+ . unwrap ( ) ;
315
+ assert_eq ! (
316
+ rows,
317
+ vec![ ( "b" . into( ) , "a" . into( ) ) , ( "d" . into( ) , "c" . into( ) ) ]
318
+ ) ;
319
+ }
0 commit comments