Skip to content

Commit 87220a0

Browse files
committed
Acquisition: update csv_importer
* rename RowValues to CsvRowValue * make CsvRowValue into a dataclass * remove argument in load_csv only used for test mocks; correctly mock test * update tests
1 parent 04ed491 commit 87220a0

File tree

3 files changed

+45
-54
lines changed

3 files changed

+45
-54
lines changed

src/acquisition/covidcast/csv_importer.py

+21-26
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,31 @@
11
"""Collects and reads covidcast data from a set of local CSV files."""
22

33
# standard library
4+
from dataclasses import dataclass
45
from datetime import date
56
import glob
67
import os
78
import re
89

910
# third party
10-
import pandas
11+
import pandas as pd
1112
import epiweeks as epi
1213

1314
# first party
1415
from delphi_utils import Nans
1516
from delphi.utils.epiweek import delta_epiweeks
16-
from delphi.epidata.acquisition.covidcast.logger import get_structured_logger
17+
from .logger import get_structured_logger
18+
19+
@dataclass
20+
class CsvRowValue:
21+
"""A container for the values of a single validated covidcast CSV row."""
22+
geo_value: str
23+
value: float
24+
stderr: float
25+
sample_size: float
26+
missing_value: int
27+
missing_stderr: int
28+
missing_sample_size: int
1729

1830
class CsvImporter:
1931
"""Finds and parses covidcast CSV files."""
@@ -37,6 +49,7 @@ class CsvImporter:
3749
MIN_YEAR = 2019
3850
MAX_YEAR = 2030
3951

52+
# The datatypes expected by pandas.read_csv. Int64 is like float in that it can handle both numbers and nans.
4053
DTYPES = {
4154
"geo_id": str,
4255
"val": float,
@@ -47,20 +60,6 @@ class CsvImporter:
4760
"missing_sample_size": "Int64"
4861
}
4962

50-
# NOTE: this should be a Python 3.7+ `dataclass`, but the server is on 3.4
51-
# See https://docs.python.org/3/library/dataclasses.html
52-
class RowValues:
53-
"""A container for the values of a single covidcast row."""
54-
55-
def __init__(self, geo_value, value, stderr, sample_size, missing_value, missing_stderr, missing_sample_size):
56-
self.geo_value = geo_value
57-
self.value = value
58-
self.stderr = stderr
59-
self.sample_size = sample_size
60-
self.missing_value = missing_value
61-
self.missing_stderr = missing_stderr
62-
self.missing_sample_size = missing_sample_size
63-
6463
@staticmethod
6564
def is_sane_day(value):
6665
"""Return whether `value` is a sane (maybe not valid) YYYYMMDD date.
@@ -184,7 +183,7 @@ def is_header_valid(columns):
184183
return set(columns) >= CsvImporter.REQUIRED_COLUMNS
185184

186185
@staticmethod
187-
def floaty_int(value):
186+
def floaty_int(value: str) -> int:
188187
"""Cast a string to an int, even if it looks like a float.
189188
190189
For example, "-1" and "-1.0" should both result in -1. Non-integer floats
@@ -253,7 +252,7 @@ def validate_missing_code(row, attr_quantity, attr_name, filepath=None, logger=N
253252

254253
@staticmethod
255254
def extract_and_check_row(row, geo_type, filepath=None):
256-
"""Extract and return `RowValues` from a CSV row, with sanity checks.
255+
"""Extract and return `CsvRowValue` from a CSV row, with sanity checks.
257256
258257
Also returns the name of the field which failed sanity check, or None.
259258
@@ -330,14 +329,10 @@ def extract_and_check_row(row, geo_type, filepath=None):
330329
missing_sample_size = CsvImporter.validate_missing_code(row, sample_size, "sample_size", filepath)
331330

332331
# return extracted and validated row values
333-
row_values = CsvImporter.RowValues(
334-
geo_id, value, stderr, sample_size,
335-
missing_value, missing_stderr, missing_sample_size
336-
)
337-
return (row_values, None)
332+
return (CsvRowValue(geo_id, value, stderr, sample_size, missing_value, missing_stderr, missing_sample_size), None)
338333

339334
@staticmethod
340-
def load_csv(filepath, geo_type, pandas=pandas):
335+
def load_csv(filepath, geo_type):
341336
"""Load, validate, and yield data as `RowValues` from a CSV file.
342337
343338
filepath: the CSV file to be loaded
@@ -349,10 +344,10 @@ def load_csv(filepath, geo_type, pandas=pandas):
349344
logger = get_structured_logger('load_csv')
350345

351346
try:
352-
table = pandas.read_csv(filepath, dtype=CsvImporter.DTYPES)
347+
table = pd.read_csv(filepath, dtype=CsvImporter.DTYPES)
353348
except ValueError as e:
354349
logger.warning(event='Failed to open CSV with specified dtypes, switching to str', detail=str(e), file=filepath)
355-
table = pandas.read_csv(filepath, dtype='str')
350+
table = pd.read_csv(filepath, dtype='str')
356351

357352
if not CsvImporter.is_header_valid(table.columns):
358353
logger.warning(event='invalid header', detail=table.columns, file=filepath)

tests/acquisition/covidcast/test_csv_importer.py

+23-26
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,15 @@
55
from unittest.mock import MagicMock
66
from unittest.mock import patch
77
from datetime import date
8-
import math
98
import numpy as np
10-
import os
119

1210
# third party
13-
import pandas
11+
import pandas as pd
1412
import epiweeks as epi
1513

1614
from delphi_utils import Nans
17-
from delphi.epidata.acquisition.covidcast.csv_importer import CsvImporter
1815
from delphi.utils.epiweek import delta_epiweeks
16+
from delphi.epidata.acquisition.covidcast.csv_importer import CsvImporter, CsvRowValue
1917

2018
# py3tester coverage target
2119
__test_target__ = 'delphi.epidata.acquisition.covidcast.csv_importer'
@@ -208,37 +206,38 @@ def make_row(
208206
self.assertEqual(error, field)
209207

210208
success_cases = [
211-
(make_row(), CsvImporter.RowValues('vi', 1.23, 4.56, 100.5, Nans.NOT_MISSING, Nans.NOT_MISSING, Nans.NOT_MISSING)),
212-
(make_row(value=None, stderr=np.nan, sample_size='', missing_value=str(float(Nans.DELETED)), missing_stderr=str(float(Nans.DELETED)), missing_sample_size=str(float(Nans.DELETED))), CsvImporter.RowValues('vi', None, None, None, Nans.DELETED, Nans.DELETED, Nans.DELETED)),
213-
(make_row(stderr='', sample_size='NA', missing_stderr=str(float(Nans.OTHER)), missing_sample_size=str(float(Nans.OTHER))), CsvImporter.RowValues('vi', 1.23, None, None, Nans.NOT_MISSING, Nans.OTHER, Nans.OTHER)),
214-
(make_row(sample_size=None, missing_value='missing_value', missing_stderr=str(float(Nans.OTHER)), missing_sample_size=str(float(Nans.NOT_MISSING))), CsvImporter.RowValues('vi', 1.23, 4.56, None, Nans.NOT_MISSING, Nans.NOT_MISSING, Nans.OTHER)),
209+
(make_row(), CsvRowValue('vi', 1.23, 4.56, 100.5, Nans.NOT_MISSING, Nans.NOT_MISSING, Nans.NOT_MISSING)),
210+
(make_row(value=None, stderr=np.nan, sample_size='', missing_value=str(float(Nans.DELETED)), missing_stderr=str(float(Nans.DELETED)), missing_sample_size=str(float(Nans.DELETED))), CsvRowValue('vi', None, None, None, Nans.DELETED, Nans.DELETED, Nans.DELETED)),
211+
(make_row(stderr='', sample_size='NA', missing_stderr=str(float(Nans.OTHER)), missing_sample_size=str(float(Nans.OTHER))), CsvRowValue('vi', 1.23, None, None, Nans.NOT_MISSING, Nans.OTHER, Nans.OTHER)),
212+
(make_row(sample_size=None, missing_value='missing_value', missing_stderr=str(float(Nans.OTHER)), missing_sample_size=str(float(Nans.NOT_MISSING))), CsvRowValue('vi', 1.23, 4.56, None, Nans.NOT_MISSING, Nans.NOT_MISSING, Nans.OTHER)),
215213
]
216214

217215
for ((geo_type, row), field) in success_cases:
218216
values, error = CsvImporter.extract_and_check_row(row, geo_type)
219217
self.assertIsNone(error)
220-
self.assertIsInstance(values, CsvImporter.RowValues)
218+
self.assertIsInstance(values, CsvRowValue)
221219
self.assertEqual(values.geo_value, field.geo_value)
222220
self.assertEqual(values.value, field.value)
223221
self.assertEqual(values.stderr, field.stderr)
224222
self.assertEqual(values.sample_size, field.sample_size)
225223

226-
def test_load_csv_with_invalid_header(self):
224+
@patch("pandas.read_csv")
225+
def test_load_csv_with_invalid_header(self, mock_read_csv):
227226
"""Bail loading a CSV when the header is invalid."""
228227

229228
data = {'foo': [1, 2, 3]}
230-
mock_pandas = MagicMock()
231-
mock_pandas.read_csv.return_value = pandas.DataFrame(data=data)
232229
filepath = 'path/name.csv'
233230
geo_type = 'state'
234231

235-
rows = list(CsvImporter.load_csv(filepath, geo_type, pandas=mock_pandas))
232+
mock_read_csv.return_value = pd.DataFrame(data)
233+
rows = list(CsvImporter.load_csv(filepath, geo_type))
236234

237-
self.assertTrue(mock_pandas.read_csv.called)
238-
self.assertTrue(mock_pandas.read_csv.call_args[0][0], filepath)
235+
self.assertTrue(mock_read_csv.called)
236+
self.assertTrue(mock_read_csv.call_args[0][0], filepath)
239237
self.assertEqual(rows, [None])
240238

241-
def test_load_csv_with_valid_header(self):
239+
@patch("pandas.read_csv")
240+
def test_load_csv_with_valid_header(self, mock_read_csv):
242241
"""Yield sanity checked `RowValues` from a valid CSV file."""
243242

244243
# one invalid geo_id, but otherwise valid
@@ -248,15 +247,14 @@ def test_load_csv_with_valid_header(self):
248247
'se': ['2.1', '2.2', '2.3', '2.4'],
249248
'sample_size': ['301', '302', '303', '304'],
250249
}
251-
mock_pandas = MagicMock()
252-
mock_pandas.read_csv.return_value = pandas.DataFrame(data=data)
253250
filepath = 'path/name.csv'
254251
geo_type = 'state'
255252

256-
rows = list(CsvImporter.load_csv(filepath, geo_type, pandas=mock_pandas))
253+
mock_read_csv.return_value = pd.DataFrame(data=data)
254+
rows = list(CsvImporter.load_csv(filepath, geo_type))
257255

258-
self.assertTrue(mock_pandas.read_csv.called)
259-
self.assertTrue(mock_pandas.read_csv.call_args[0][0], filepath)
256+
self.assertTrue(mock_read_csv.called)
257+
self.assertTrue(mock_read_csv.call_args[0][0], filepath)
260258
self.assertEqual(len(rows), 4)
261259

262260
self.assertEqual(rows[0].geo_value, 'ca')
@@ -286,15 +284,14 @@ def test_load_csv_with_valid_header(self):
286284
'missing_stderr': [Nans.NOT_MISSING, Nans.REGION_EXCEPTION, Nans.NOT_MISSING, Nans.NOT_MISSING] + [None],
287285
'missing_sample_size': [Nans.NOT_MISSING] * 2 + [Nans.REGION_EXCEPTION] * 2 + [None]
288286
}
289-
mock_pandas = MagicMock()
290-
mock_pandas.read_csv.return_value = pandas.DataFrame(data=data)
291287
filepath = 'path/name.csv'
292288
geo_type = 'state'
293289

294-
rows = list(CsvImporter.load_csv(filepath, geo_type, pandas=mock_pandas))
290+
mock_read_csv.return_value = pd.DataFrame(data)
291+
rows = list(CsvImporter.load_csv(filepath, geo_type))
295292

296-
self.assertTrue(mock_pandas.read_csv.called)
297-
self.assertTrue(mock_pandas.read_csv.call_args[0][0], filepath)
293+
self.assertTrue(mock_read_csv.called)
294+
self.assertTrue(mock_read_csv.call_args[0][0], filepath)
298295
self.assertEqual(len(rows), 5)
299296

300297
self.assertEqual(rows[0].geo_value, 'ca')

tests/acquisition/covidcast/test_csv_to_database.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
import unittest
77
from unittest.mock import MagicMock
88

9-
from delphi.epidata.acquisition.covidcast.csv_to_database import get_argument_parser, main, \
10-
collect_files, upload_archive, make_handlers
9+
from delphi.epidata.acquisition.covidcast.csv_to_database import get_argument_parser, main, collect_files, upload_archive, make_handlers
1110

1211
# py3tester coverage target
1312
__test_target__ = 'delphi.epidata.acquisition.covidcast.csv_to_database'

0 commit comments

Comments
 (0)