Skip to content

Commit e5f3ccd

Browse files
committed
add csv import tests
1 parent a2a6e04 commit e5f3ccd

File tree

1 file changed

+82
-11
lines changed

1 file changed

+82
-11
lines changed

src/webserver/database/csv_import.rs

Lines changed: 82 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@ use futures_util::StreamExt;
55
use sqlparser::ast::{
66
CopyLegacyCsvOption, CopyLegacyOption, CopyOption, CopySource, CopyTarget, Statement,
77
};
8-
use sqlx::{any::AnyArguments, AnyConnection, Arguments, Executor};
8+
use sqlx::{
9+
any::{AnyArguments, AnyKind},
10+
AnyConnection, Arguments, Executor,
11+
};
12+
use tokio::io::AsyncRead;
913

1014
use crate::webserver::http_request_info::RequestInfo;
1115

@@ -152,7 +156,16 @@ pub(super) async fn run_csv_import(
152156
let file = tokio::fs::File::open(file_path)
153157
.await
154158
.with_context(|| "opening csv")?;
155-
let insert_stmt = create_insert_stmt(db, csv_import);
159+
let buffered = tokio::io::BufReader::new(file);
160+
run_csv_import_on_path(db, csv_import, buffered).await
161+
}
162+
163+
async fn run_csv_import_on_path(
164+
db: &mut AnyConnection,
165+
csv_import: &CsvImport,
166+
file: impl AsyncRead + Unpin + Send,
167+
) -> anyhow::Result<()> {
168+
let insert_stmt = create_insert_stmt(db.kind(), csv_import);
156169
log::debug!("CSV data insert statement: {insert_stmt}");
157170
let mut reader = make_csv_reader(csv_import, file);
158171
let col_idxs = compute_column_indices(&mut reader, csv_import).await?;
@@ -164,8 +177,8 @@ pub(super) async fn run_csv_import(
164177
Ok(())
165178
}
166179

167-
async fn compute_column_indices(
168-
reader: &mut csv_async::AsyncReader<tokio::fs::File>,
180+
async fn compute_column_indices<R: AsyncRead + Unpin + Send>(
181+
reader: &mut csv_async::AsyncReader<R>,
169182
csv_import: &CsvImport,
170183
) -> anyhow::Result<Vec<usize>> {
171184
let mut col_idxs = Vec::with_capacity(csv_import.columns.len());
@@ -189,16 +202,17 @@ async fn compute_column_indices(
189202
Ok(col_idxs)
190203
}
191204

192-
fn create_insert_stmt(db: &mut AnyConnection, csv_import: &CsvImport) -> String {
193-
let kind = db.kind();
205+
fn create_insert_stmt(kind: AnyKind, csv_import: &CsvImport) -> String {
194206
let columns = csv_import.columns.join(", ");
195207
let placeholders = csv_import
196208
.columns
197209
.iter()
198210
.enumerate()
199-
.map(|(i, _)| make_placeholder(kind, i))
211+
.map(|(i, _)| make_placeholder(kind, i + 1))
200212
.fold(String::new(), |mut acc, f| {
201-
acc.push_str(", ");
213+
if !acc.is_empty() {
214+
acc.push_str(", ");
215+
}
202216
acc.push_str(&f);
203217
acc
204218
});
@@ -225,10 +239,10 @@ async fn process_csv_record(
225239
Ok(())
226240
}
227241

228-
fn make_csv_reader(
242+
fn make_csv_reader<R: AsyncRead + Unpin + Send>(
229243
csv_import: &CsvImport,
230-
file: tokio::fs::File,
231-
) -> csv_async::AsyncReader<tokio::fs::File> {
244+
file: R,
245+
) -> csv_async::AsyncReader<R> {
232246
let delimiter = csv_import
233247
.delimiter
234248
.and_then(|c| u8::try_from(c).ok())
@@ -246,3 +260,60 @@ fn make_csv_reader(
246260
.escape(escape)
247261
.create_reader(file)
248262
}
263+
264+
#[test]
265+
fn test_make_statement() {
266+
let csv_import = CsvImport {
267+
query: "COPY my_table (col1, col2) FROM 'my_file.csv' WITH (DELIMITER ';', HEADER)".into(),
268+
table_name: "my_table".into(),
269+
columns: vec!["col1".into(), "col2".into()],
270+
delimiter: Some(';'),
271+
quote: None,
272+
header: Some(true),
273+
null_str: None,
274+
escape: None,
275+
uploaded_file: "my_file.csv".into(),
276+
};
277+
let insert_stmt = create_insert_stmt(AnyKind::Postgres, &csv_import);
278+
assert_eq!(
279+
insert_stmt,
280+
"INSERT INTO my_table (col1, col2) VALUES ($1, $2)"
281+
);
282+
}
283+
284+
#[actix_web::test]
285+
async fn test_end_to_end() {
286+
use sqlx::ConnectOptions;
287+
288+
let mut copy_stmt = sqlparser::parser::Parser::parse_sql(
289+
&sqlparser::dialect::GenericDialect {},
290+
"COPY my_table (col1, col2) FROM 'my_file.csv' WITH (DELIMITER ';', HEADER)",
291+
)
292+
.unwrap()
293+
.into_iter()
294+
.next()
295+
.unwrap();
296+
let csv_import = extract_csv_copy_statement(&mut copy_stmt).unwrap();
297+
let mut conn = "sqlite::memory:"
298+
.parse::<sqlx::any::AnyConnectOptions>()
299+
.unwrap()
300+
.connect()
301+
.await
302+
.unwrap();
303+
conn.execute("CREATE TABLE my_table (col1 TEXT, col2 TEXT)")
304+
.await
305+
.unwrap();
306+
let csv = "col2;col1\na;b\nc;d"; // order is different from the table
307+
let file = csv.as_bytes();
308+
run_csv_import_on_path(&mut conn, &csv_import, file)
309+
.await
310+
.unwrap();
311+
let rows: Vec<(String, String)> = sqlx::query_as("SELECT * FROM my_table")
312+
.fetch_all(&mut conn)
313+
.await
314+
.unwrap();
315+
assert_eq!(
316+
rows,
317+
vec![("b".into(), "a".into()), ("d".into(), "c".into())]
318+
);
319+
}

0 commit comments

Comments
 (0)