Skip to content

Commit 4eee3ba

Browse files
authored
Merge pull request #840 from cmu-delphi/krivard/delete_csvs
Delete rows specified by CSV
2 parents 92c73be + d9a45bb commit 4eee3ba

File tree

4 files changed

+286
-1
lines changed

4 files changed

+286
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
geo_id,value,stderr,sample_size,issue,time_value,geo_type,signal,source
2+
d_nonlatest,0,0,0,1,0,geo,sig,src
3+
d_latest, 0,0,0,3,0,geo,sig,src
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
"""Integration tests for covidcast's batch deletions."""
2+
3+
# standard library
4+
from collections import namedtuple
5+
import unittest
6+
from os import path
7+
8+
# third party
9+
import mysql.connector
10+
11+
# first party
12+
from delphi_utils import Nans
13+
from delphi.epidata.client.delphi_epidata import Epidata
14+
import delphi.operations.secrets as secrets
15+
16+
# py3tester coverage target (equivalent to `import *`)
17+
__test_target__ = 'delphi.epidata.acquisition.covidcast.database'
18+
19+
Example = namedtuple("example", "given expected")
20+
21+
class DeleteBatch(unittest.TestCase):
22+
"""Tests batch deletions"""
23+
24+
25+
def setUp(self):
26+
"""Perform per-test setup."""
27+
28+
# connect to the `epidata` database and clear the `covidcast` table
29+
cnx = mysql.connector.connect(
30+
user='user',
31+
password='pass',
32+
host='delphi_database_epidata',
33+
database='epidata')
34+
cur = cnx.cursor()
35+
cur.execute('truncate table covidcast')
36+
cnx.commit()
37+
cur.close()
38+
39+
# make connection and cursor available to test cases
40+
self.cnx = cnx
41+
self.cur = cnx.cursor()
42+
43+
# use the local instance of the epidata database
44+
secrets.db.host = 'delphi_database_epidata'
45+
secrets.db.epi = ('user', 'pass')
46+
47+
# use the local instance of the Epidata API
48+
Epidata.BASE_URL = 'http://delphi_web_epidata/epidata/api.php'
49+
50+
# will use secrets as set above
51+
from delphi.epidata.acquisition.covidcast.database import Database
52+
self.database = Database()
53+
self.database.connect()
54+
55+
def tearDown(self):
56+
"""Perform per-test teardown."""
57+
self.cur.close()
58+
self.cnx.close()
59+
60+
@unittest.skip("Database user would require FILE privileges")
61+
def test_delete_from_file(self):
62+
self._test_delete_batch(path.join(path.dirname(__file__), "delete_batch.csv"))
63+
64+
def test_delete_from_tuples(self):
65+
with open(path.join(path.dirname(__file__), "delete_batch.csv")) as f:
66+
rows=[]
67+
for line in f:
68+
rows.append(line.strip().split(","))
69+
rows = [r + ["day"] for r in rows[1:]]
70+
self._test_delete_batch(rows)
71+
72+
def _test_delete_batch(self, cc_deletions):
73+
# load sample data
74+
rows = [
75+
# geo_value issue is_latest
76+
["d_nonlatest", 1, 0],
77+
["d_nonlatest", 2, 1],
78+
["d_latest", 1, 0],
79+
["d_latest", 2, 0],
80+
["d_latest", 3, 1]
81+
]
82+
for time_value in [0, 1]:
83+
self.cur.executemany(f'''
84+
INSERT INTO covidcast
85+
(`geo_value`, `issue`, `is_latest_issue`, `time_value`,
86+
`source`, `signal`, `time_type`, `geo_type`,
87+
value_updated_timestamp, direction_updated_timestamp, value, stderr, sample_size, lag, direction)
88+
VALUES
89+
(%s, %s, %s, {time_value},
90+
"src", "sig", "day", "geo",
91+
0, 0, 0, 0, 0, 0, 0)
92+
''', rows)
93+
self.cnx.commit()
94+
95+
# delete entries
96+
self.database.delete_batch(cc_deletions)
97+
98+
# verify remaining data is still there
99+
self.cur.execute("select * from covidcast")
100+
result = list(self.cur)
101+
self.assertEqual(len(result), 2*len(rows)-2)
102+
103+
examples = [
104+
# verify deletions are gone
105+
Example(
106+
'select * from covidcast where time_value=0 and geo_value="d_nonlatest" and issue=1',
107+
[]
108+
),
109+
Example(
110+
'select * from covidcast where time_value=0 and geo_value="d_latest" and issue=3',
111+
[]
112+
),
113+
# verify is_latest_issue flag was corrected
114+
Example(
115+
'select geo_value, issue from covidcast where time_value=0 and is_latest_issue=1',
116+
[('d_nonlatest', 2),
117+
('d_latest', 2)]
118+
)
119+
]
120+
121+
for ex in examples:
122+
self.cur.execute(ex.given)
123+
result = list(self.cur)
124+
self.assertEqual(result, ex.expected, ex.given)

src/acquisition/covidcast/database.py

+98-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ class Database:
6969

7070
def connect(self, connector_impl=mysql.connector):
7171
"""Establish a connection to the database."""
72-
7372
u, p = secrets.db.epi
7473
self._connector_impl = connector_impl
7574
self._connection = self._connector_impl.connect(
@@ -247,6 +246,104 @@ def insert_or_update_batch(self, cc_rows, batch_size=2**20, commit_partial=False
247246
self._cursor.execute(drop_tmp_table_sql)
248247
return total
249248

249+
def delete_batch(self, cc_deletions):
250+
"""
251+
Remove rows specified by a csv file or list of tuples.
252+
253+
If cc_deletions is a filename, the file should include a header row and use the following field order:
254+
- geo_id
255+
- value (ignored)
256+
- stderr (ignored)
257+
- sample_size (ignored)
258+
- issue (YYYYMMDD format)
259+
- time_value (YYYYMMDD format)
260+
- geo_type
261+
- signal
262+
- source
263+
264+
If cc_deletions is a list of tuples, the tuples should use the following field order (=same as above, plus time_type):
265+
- geo_id
266+
- value (ignored)
267+
- stderr (ignored)
268+
- sample_size (ignored)
269+
- issue (YYYYMMDD format)
270+
- time_value (YYYYMMDD format)
271+
- geo_type
272+
- signal
273+
- source
274+
- time_type
275+
"""
276+
tmp_table_name = "tmp_delete_table"
277+
create_tmp_table_sql = f'''
278+
CREATE OR REPLACE TABLE {tmp_table_name} LIKE covidcast;
279+
'''
280+
281+
amend_tmp_table_sql = f'''
282+
ALTER TABLE {tmp_table_name} ADD COLUMN covidcast_id bigint unsigned;
283+
'''
284+
285+
load_tmp_table_infile_sql = f'''
286+
LOAD DATA INFILE "{cc_deletions}"
287+
INTO TABLE {tmp_table_name}
288+
FIELDS TERMINATED BY ","
289+
IGNORE 1 LINES
290+
(`geo_value`, `value`, `stderr`, `sample_size`, `issue`, `time_value`, `geo_type`, `signal`, `source`)
291+
SET time_type="day";
292+
'''
293+
294+
load_tmp_table_insert_sql = f'''
295+
INSERT INTO {tmp_table_name}
296+
(`geo_value`, `value`, `stderr`, `sample_size`, `issue`, `time_value`, `geo_type`, `signal`, `source`, `time_type`,
297+
`value_updated_timestamp`, `direction_updated_timestamp`, `lag`, `direction`, `is_latest_issue`)
298+
VALUES
299+
(%s, %s, %s, %s, %s, %s, %s, %s, %s, %s,
300+
0, 0, 0, 0, 0)
301+
'''
302+
303+
add_id_sql = f'''
304+
UPDATE {tmp_table_name} d INNER JOIN covidcast c USING
305+
(`source`, `signal`, `time_type`, `geo_type`, `time_value`, `geo_value`, `issue`)
306+
SET d.covidcast_id=c.id, d.is_latest_issue=c.is_latest_issue;
307+
'''
308+
309+
delete_sql = f'''
310+
DELETE c FROM {tmp_table_name} d INNER JOIN covidcast c WHERE d.covidcast_id=c.id;
311+
'''
312+
313+
fix_latest_issue_sql = f'''
314+
UPDATE
315+
(SELECT `source`, `signal`, `time_type`, `geo_type`, `time_value`, `geo_value`, MAX(`issue`) AS `issue`
316+
FROM
317+
(SELECT DISTINCT `source`, `signal`, `time_type`, `geo_type`, `time_value`, `geo_value`
318+
FROM {tmp_table_name} WHERE `is_latest_issue`=1) AS was_latest
319+
LEFT JOIN covidcast c
320+
USING (`source`, `signal`, `time_type`, `geo_type`, `time_value`, `geo_value`)
321+
GROUP BY `source`, `signal`, `time_type`, `geo_type`, `time_value`, `geo_value`
322+
) AS TMP
323+
LEFT JOIN `covidcast`
324+
USING (`source`, `signal`, `time_type`, `geo_type`, `time_value`, `geo_value`, `issue`)
325+
SET `covidcast`.`is_latest_issue`=1;
326+
'''
327+
328+
drop_tmp_table_sql = f'DROP TABLE {tmp_table_name}'
329+
try:
330+
self._cursor.execute(create_tmp_table_sql)
331+
self._cursor.execute(amend_tmp_table_sql)
332+
if isinstance(cc_deletions, str):
333+
self._cursor.execute(load_tmp_table_infile_sql)
334+
elif isinstance(cc_deletions, list):
335+
self._cursor.executemany(load_tmp_table_insert_sql, cc_deletions)
336+
else:
337+
raise Exception(f"Bad deletions argument: need a filename or a list of tuples; got a {type(cc_deletions)}")
338+
self._cursor.execute(add_id_sql)
339+
self._cursor.execute(delete_sql)
340+
self._cursor.execute(fix_latest_issue_sql)
341+
self._connection.commit()
342+
except Exception as e:
343+
raise e
344+
finally:
345+
self._cursor.execute(drop_tmp_table_sql)
346+
250347
def compute_covidcast_meta(self, table_name='covidcast', use_index=True):
251348
"""Compute and return metadata on all non-WIP COVIDcast signals."""
252349
logger = get_structured_logger("compute_covidcast_meta")
+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""Deletes large numbers of rows from covidcast based on a CSV"""
2+
3+
# standard library
4+
import argparse
5+
import os
6+
import time
7+
8+
# first party
9+
from delphi.epidata.acquisition.covidcast.database import Database
10+
from delphi.epidata.acquisition.covidcast.logger import get_structured_logger
11+
12+
13+
def get_argument_parser():
14+
"""Define command line arguments."""
15+
16+
parser = argparse.ArgumentParser()
17+
parser.add_argument(
18+
'--deletion_dir',
19+
help='directory where deletion CSVs are stored')
20+
parser.add_argument(
21+
'--log_file',
22+
help="filename for log output (defaults to stdout)")
23+
return parser
24+
25+
def handle_file(deletion_file, database):
26+
logger.info("Deleting from csv file", filename=deletion_file)
27+
rows = []
28+
with open(deletion_file) as f:
29+
for line in f:
30+
rows.append(line.strip().split(","))
31+
rows = rows[1:]
32+
try:
33+
n = database.delete_batch(rows)
34+
logger.info("Deleted database rows", row_count=n)
35+
return n
36+
except Exception as e:
37+
logger.exception('Exception while deleting rows:', e)
38+
database.rollback()
39+
return 0
40+
41+
def main(args):
42+
"""Delete rows from covidcast."""
43+
44+
logger = get_structured_logger("csv_deletion", filename=args.log_file)
45+
start_time = time.time()
46+
database = Database()
47+
database.connect()
48+
all_n = 0
49+
50+
try:
51+
for deletion_file in sorted(glob.glob(os.path.join(args.deletion_dir, '*.csv'))):
52+
all_n += handle_file(deletion_file)
53+
finally:
54+
database.disconnect(True)
55+
56+
logger.info(
57+
"Deleted CSVs from database",
58+
total_runtime_in_seconds=round(time.time() - start_time, 2), row_count=all_n)
59+
60+
if __name__ == '__main__':
61+
main(get_argument_parser().parse_args())

0 commit comments

Comments
 (0)