Skip to content

Commit 9ab7840

Browse files
authored
Merge pull request #1072 from cmu-delphi/ds/refactor-csv-importer
[Refactor] Further cleans of `csv_importer`
2 parents c21b940 + 57aa137 commit 9ab7840

File tree

6 files changed

+230
-186
lines changed

6 files changed

+230
-186
lines changed

src/acquisition/covidcast/csv_importer.py

+58-17
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,26 @@
11
"""Collects and reads covidcast data from a set of local CSV files."""
22

33
# standard library
4-
from dataclasses import dataclass
5-
from datetime import date
6-
import glob
74
import os
85
import re
6+
from dataclasses import dataclass
7+
from datetime import date
8+
from glob import glob
9+
from typing import Iterator, NamedTuple, Optional, Tuple
910

1011
# third party
11-
import pandas as pd
1212
import epiweeks as epi
13+
import pandas as pd
1314

1415
# first party
1516
from delphi_utils import Nans
1617
from delphi.utils.epiweek import delta_epiweeks
17-
from .logger import get_structured_logger
18+
from delphi.epidata.acquisition.covidcast.database import CovidcastRow
19+
from delphi.epidata.acquisition.covidcast.logger import get_structured_logger
20+
21+
DFRow = NamedTuple('DFRow', [('geo_id', str), ('value', float), ('stderr', float), ('sample_size', float), ('missing_value', int), ('missing_stderr', int), ('missing_sample_size', int)])
22+
PathDetails = NamedTuple('PathDetails', [('issue', int), ('lag', int), ('source', str), ('signal', str), ('time_type', str), ('time_value', int), ('geo_type', str)])
23+
1824

1925
@dataclass
2026
class CsvRowValue:
@@ -27,6 +33,7 @@ class CsvRowValue:
2733
missing_stderr: int
2834
missing_sample_size: int
2935

36+
3037
class CsvImporter:
3138
"""Finds and parses covidcast CSV files."""
3239

@@ -60,6 +67,7 @@ class CsvImporter:
6067
"missing_sample_size": "Int64"
6168
}
6269

70+
6371
@staticmethod
6472
def is_sane_day(value):
6573
"""Return whether `value` is a sane (maybe not valid) YYYYMMDD date.
@@ -76,6 +84,7 @@ def is_sane_day(value):
7684
return False
7785
return date(year=year,month=month,day=day)
7886

87+
7988
@staticmethod
8089
def is_sane_week(value):
8190
"""Return whether `value` is a sane (maybe not valid) YYYYWW epiweek.
@@ -91,22 +100,24 @@ def is_sane_week(value):
91100
return False
92101
return value
93102

103+
94104
@staticmethod
95-
def find_issue_specific_csv_files(scan_dir, glob=glob):
105+
def find_issue_specific_csv_files(scan_dir):
96106
logger = get_structured_logger('find_issue_specific_csv_files')
97-
for path in sorted(glob.glob(os.path.join(scan_dir, '*'))):
107+
for path in sorted(glob(os.path.join(scan_dir, '*'))):
98108
issuedir_match = CsvImporter.PATTERN_ISSUE_DIR.match(path.lower())
99109
if issuedir_match and os.path.isdir(path):
100110
issue_date_value = int(issuedir_match.group(2))
101111
issue_date = CsvImporter.is_sane_day(issue_date_value)
102112
if issue_date:
103113
logger.info(event='processing csv files from issue', detail=issue_date, file=path)
104-
yield from CsvImporter.find_csv_files(path, issue=(issue_date, epi.Week.fromdate(issue_date)), glob=glob)
114+
yield from CsvImporter.find_csv_files(path, issue=(issue_date, epi.Week.fromdate(issue_date)))
105115
else:
106116
logger.warning(event='invalid issue directory day', detail=issue_date_value, file=path)
107117

118+
108119
@staticmethod
109-
def find_csv_files(scan_dir, issue=(date.today(), epi.Week.fromdate(date.today())), glob=glob):
120+
def find_csv_files(scan_dir, issue=(date.today(), epi.Week.fromdate(date.today()))):
110121
"""Recursively search for and yield covidcast-format CSV files.
111122
112123
scan_dir: the directory to scan (recursively)
@@ -122,11 +133,11 @@ def find_csv_files(scan_dir, issue=(date.today(), epi.Week.fromdate(date.today()
122133
issue_value=-1
123134
lag_value=-1
124135

125-
for path in sorted(glob.glob(os.path.join(scan_dir, '*', '*'))):
126-
136+
for path in sorted(glob(os.path.join(scan_dir, '*', '*'))):
137+
# safe to ignore this file
127138
if not path.lower().endswith('.csv'):
128-
# safe to ignore this file
129139
continue
140+
130141
# match a daily or weekly naming pattern
131142
daily_match = CsvImporter.PATTERN_DAILY.match(path.lower())
132143
weekly_match = CsvImporter.PATTERN_WEEKLY.match(path.lower())
@@ -174,14 +185,16 @@ def find_csv_files(scan_dir, issue=(date.today(), epi.Week.fromdate(date.today()
174185
yield (path, None)
175186
continue
176187

177-
yield (path, (source, signal, time_type, geo_type, time_value, issue_value, lag_value))
188+
yield (path, PathDetails(issue_value, lag_value, source, signal, time_type, time_value, geo_type))
189+
178190

179191
@staticmethod
180192
def is_header_valid(columns):
181193
"""Return whether the given pandas columns contains the required fields."""
182194

183195
return set(columns) >= CsvImporter.REQUIRED_COLUMNS
184196

197+
185198
@staticmethod
186199
def floaty_int(value: str) -> int:
187200
"""Cast a string to an int, even if it looks like a float.
@@ -195,6 +208,7 @@ def floaty_int(value: str) -> int:
195208
raise ValueError('not an int: "%s"' % str(value))
196209
return int(float_value)
197210

211+
198212
@staticmethod
199213
def maybe_apply(func, quantity):
200214
"""Apply the given function to the given quantity if not null-ish."""
@@ -205,6 +219,7 @@ def maybe_apply(func, quantity):
205219
else:
206220
return func(quantity)
207221

222+
208223
@staticmethod
209224
def validate_quantity(row, attr_quantity):
210225
"""Take a row and validate a given associated quantity (e.g., val, se, stderr).
@@ -218,6 +233,7 @@ def validate_quantity(row, attr_quantity):
218233
# val was a string or another data
219234
return "Error"
220235

236+
221237
@staticmethod
222238
def validate_missing_code(row, attr_quantity, attr_name, filepath=None, logger=None):
223239
"""Take a row and validate the missing code associated with
@@ -250,8 +266,9 @@ def validate_missing_code(row, attr_quantity, attr_name, filepath=None, logger=N
250266

251267
return missing_entry
252268

269+
253270
@staticmethod
254-
def extract_and_check_row(row, geo_type, filepath=None):
271+
def extract_and_check_row(row: DFRow, geo_type: str, filepath: Optional[str] = None) -> Tuple[Optional[CsvRowValue], Optional[str]]:
255272
"""Extract and return `CsvRowValue` from a CSV row, with sanity checks.
256273
257274
Also returns the name of the field which failed sanity check, or None.
@@ -331,8 +348,9 @@ def extract_and_check_row(row, geo_type, filepath=None):
331348
# return extracted and validated row values
332349
return (CsvRowValue(geo_id, value, stderr, sample_size, missing_value, missing_stderr, missing_sample_size), None)
333350

351+
334352
@staticmethod
335-
def load_csv(filepath, geo_type):
353+
def load_csv(filepath: str, details: PathDetails) -> Iterator[Optional[CovidcastRow]]:
336354
"""Load, validate, and yield data as `RowValues` from a CSV file.
337355
338356
filepath: the CSV file to be loaded
@@ -357,9 +375,32 @@ def load_csv(filepath, geo_type):
357375
table.rename(columns={"val": "value", "se": "stderr", "missing_val": "missing_value", "missing_se": "missing_stderr"}, inplace=True)
358376

359377
for row in table.itertuples(index=False):
360-
row_values, error = CsvImporter.extract_and_check_row(row, geo_type, filepath)
378+
csv_row_values, error = CsvImporter.extract_and_check_row(row, details.geo_type, filepath)
379+
361380
if error:
362381
logger.warning(event = 'invalid value for row', detail=(str(row), error), file=filepath)
363382
yield None
364383
continue
365-
yield row_values
384+
385+
yield CovidcastRow(
386+
details.source,
387+
details.signal,
388+
details.time_type,
389+
details.geo_type,
390+
details.time_value,
391+
csv_row_values.geo_value,
392+
csv_row_values.value,
393+
csv_row_values.stderr,
394+
csv_row_values.sample_size,
395+
csv_row_values.missing_value,
396+
csv_row_values.missing_stderr,
397+
csv_row_values.missing_sample_size,
398+
details.issue,
399+
details.lag,
400+
# These four fields are unused by database acquisition
401+
# TODO: These will be used when CovidcastRow is updated.
402+
# id=None,
403+
# direction=None,
404+
# direction_updated_timestamp=0,
405+
# value_updated_timestamp=0,
406+
)

src/acquisition/covidcast/csv_to_database.py

+40-39
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
import argparse
55
import os
66
import time
7+
from logging import Logger
8+
from typing import Callable, Iterable, Optional, Tuple
79

810
# first party
9-
from delphi.epidata.acquisition.covidcast.csv_importer import CsvImporter
10-
from delphi.epidata.acquisition.covidcast.database import Database, CovidcastRow, DBLoadStateException
11+
from delphi.epidata.acquisition.covidcast.csv_importer import CsvImporter, PathDetails
12+
from delphi.epidata.acquisition.covidcast.database import Database, DBLoadStateException
1113
from delphi.epidata.acquisition.covidcast.file_archiver import FileArchiver
1214
from delphi.epidata.acquisition.covidcast.logger import get_structured_logger
1315

@@ -28,17 +30,19 @@ def get_argument_parser():
2830
help="filename for log output (defaults to stdout)")
2931
return parser
3032

31-
def collect_files(data_dir, specific_issue_date,csv_importer_impl=CsvImporter):
33+
34+
def collect_files(data_dir: str, specific_issue_date: bool):
3235
"""Fetch path and data profile details for each file to upload."""
3336
logger= get_structured_logger('collect_files')
3437
if specific_issue_date:
35-
results = list(csv_importer_impl.find_issue_specific_csv_files(data_dir))
38+
results = list(CsvImporter.find_issue_specific_csv_files(data_dir))
3639
else:
37-
results = list(csv_importer_impl.find_csv_files(os.path.join(data_dir, 'receiving')))
40+
results = list(CsvImporter.find_csv_files(os.path.join(data_dir, 'receiving')))
3841
logger.info(f'found {len(results)} files')
3942
return results
4043

41-
def make_handlers(data_dir, specific_issue_date, file_archiver_impl=FileArchiver):
44+
45+
def make_handlers(data_dir: str, specific_issue_date: bool):
4246
if specific_issue_date:
4347
# issue-specific uploads are always one-offs, so we can leave all
4448
# files in place without worrying about cleaning up
@@ -47,7 +51,7 @@ def handle_failed(path_src, filename, source, logger):
4751

4852
def handle_successful(path_src, filename, source, logger):
4953
logger.info(event='archiving as successful',file=filename)
50-
file_archiver_impl.archive_inplace(path_src, filename)
54+
FileArchiver.archive_inplace(path_src, filename)
5155
else:
5256
# normal automation runs require some shuffling to remove files
5357
# from receiving and place them in the archive
@@ -59,22 +63,24 @@ def handle_failed(path_src, filename, source, logger):
5963
logger.info(event='archiving as failed - ', detail=source, file=filename)
6064
path_dst = os.path.join(archive_failed_dir, source)
6165
compress = False
62-
file_archiver_impl.archive_file(path_src, path_dst, filename, compress)
66+
FileArchiver.archive_file(path_src, path_dst, filename, compress)
6367

6468
# helper to archive a successful file with compression
6569
def handle_successful(path_src, filename, source, logger):
6670
logger.info(event='archiving as successful',file=filename)
6771
path_dst = os.path.join(archive_successful_dir, source)
6872
compress = True
69-
file_archiver_impl.archive_file(path_src, path_dst, filename, compress)
73+
FileArchiver.archive_file(path_src, path_dst, filename, compress)
74+
7075
return handle_successful, handle_failed
7176

77+
7278
def upload_archive(
73-
path_details,
74-
database,
75-
handlers,
76-
logger,
77-
csv_importer_impl=CsvImporter):
79+
path_details: Iterable[Tuple[str, Optional[PathDetails]]],
80+
database: Database,
81+
handlers: Tuple[Callable],
82+
logger: Logger
83+
):
7884
"""Upload CSVs to the database and archive them using the specified handlers.
7985
8086
:path_details: output from CsvImporter.find*_csv_files
@@ -89,20 +95,16 @@ def upload_archive(
8995
total_modified_row_count = 0
9096
# iterate over each file
9197
for path, details in path_details:
92-
logger.info(event='handling',dest=path)
98+
logger.info(event='handling', dest=path)
9399
path_src, filename = os.path.split(path)
94100

101+
# file path or name was invalid, source is unknown
95102
if not details:
96-
# file path or name was invalid, source is unknown
97103
archive_as_failed(path_src, filename, 'unknown',logger)
98104
continue
99105

100-
(source, signal, time_type, geo_type, time_value, issue, lag) = details
101-
102-
csv_rows = csv_importer_impl.load_csv(path, geo_type)
103-
104-
cc_rows = CovidcastRow.fromCsvRows(csv_rows, source, signal, time_type, geo_type, time_value, issue, lag)
105-
rows_list = list(cc_rows)
106+
csv_rows = CsvImporter.load_csv(path, details)
107+
rows_list = list(csv_rows)
106108
all_rows_valid = rows_list and all(r is not None for r in rows_list)
107109
if all_rows_valid:
108110
try:
@@ -111,12 +113,13 @@ def upload_archive(
111113
logger.info(
112114
"Inserted database rows",
113115
row_count = modified_row_count,
114-
source = source,
115-
signal = signal,
116-
geo_type = geo_type,
117-
time_value = time_value,
118-
issue = issue,
119-
lag = lag)
116+
source = details.source,
117+
signal = details.signal,
118+
geo_type = details.geo_type,
119+
time_value = details.time_value,
120+
issue = details.issue,
121+
lag = details.lag
122+
)
120123
if modified_row_count is None or modified_row_count: # else would indicate zero rows inserted
121124
total_modified_row_count += (modified_row_count if modified_row_count else 0)
122125
database.commit()
@@ -131,40 +134,37 @@ def upload_archive(
131134

132135
# archive the current file based on validation results
133136
if all_rows_valid:
134-
archive_as_successful(path_src, filename, source, logger)
137+
archive_as_successful(path_src, filename, details.source, logger)
135138
else:
136-
archive_as_failed(path_src, filename, source,logger)
139+
archive_as_failed(path_src, filename, details.source, logger)
137140

138141
return total_modified_row_count
139142

140143

141-
def main(
142-
args,
143-
database_impl=Database,
144-
collect_files_impl=collect_files,
145-
upload_archive_impl=upload_archive):
144+
def main(args):
146145
"""Find, parse, and upload covidcast signals."""
147146

148147
logger = get_structured_logger("csv_ingestion", filename=args.log_file)
149148
start_time = time.time()
150149

151150
# shortcut escape without hitting db if nothing to do
152-
path_details = collect_files_impl(args.data_dir, args.specific_issue_date)
151+
path_details = collect_files(args.data_dir, args.specific_issue_date)
153152
if not path_details:
154153
logger.info('nothing to do; exiting...')
155154
return
156155

157156
logger.info("Ingesting CSVs", csv_count = len(path_details))
158157

159-
database = database_impl()
158+
database = Database()
160159
database.connect()
161160

162161
try:
163-
modified_row_count = upload_archive_impl(
162+
modified_row_count = upload_archive(
164163
path_details,
165164
database,
166165
make_handlers(args.data_dir, args.specific_issue_date),
167-
logger)
166+
logger
167+
)
168168
logger.info("Finished inserting/updating database rows", row_count = modified_row_count)
169169
finally:
170170
database.do_analyze()
@@ -175,5 +175,6 @@ def main(
175175
"Ingested CSVs into database",
176176
total_runtime_in_seconds=round(time.time() - start_time, 2))
177177

178+
178179
if __name__ == '__main__':
179180
main(get_argument_parser().parse_args())

0 commit comments

Comments
 (0)