Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 0031ee2

Browse files
committedJun 25, 2021
refactor: add more type annotations
1 parent 5c5428d commit 0031ee2

File tree

1 file changed

+53
-96
lines changed

1 file changed

+53
-96
lines changed
 

‎src/acquisition/covidcast/signal_dash_data_generator.py

Lines changed: 53 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pandas as pd
1010

1111
from dataclasses import dataclass
12-
from typing import List
12+
from typing import Dict, List
1313

1414
# first party
1515
import covidcast
@@ -19,6 +19,7 @@
1919

2020
LOOKBACK_DAYS_FOR_COVERAGE = 28
2121

22+
2223
@dataclass
2324
class DashboardSignal:
2425
"""Container class for information about dashboard signals."""
@@ -54,20 +55,16 @@ class DashboardSignalStatus:
5455
class Database:
5556
"""Storage for dashboard data."""
5657

57-
DATABASE_NAME = 'epidata'
58-
SIGNAL_TABLE_NAME = 'dashboard_signal'
59-
STATUS_TABLE_NAME = 'dashboard_signal_status'
60-
COVERAGE_TABLE_NAME = 'dashboard_signal_coverage'
58+
DATABASE_NAME = "epidata"
59+
SIGNAL_TABLE_NAME = "dashboard_signal"
60+
STATUS_TABLE_NAME = "dashboard_signal_status"
61+
COVERAGE_TABLE_NAME = "dashboard_signal_coverage"
6162

6263
def __init__(self, connector_impl=mysql.connector):
6364
"""Establish a connection to the database."""
6465

6566
u, p = secrets.db.epi
66-
self._connection = connector_impl.connect(
67-
host=secrets.db.host,
68-
user=u,
69-
password=p,
70-
database=Database.DATABASE_NAME)
67+
self._connection = connector_impl.connect(host=secrets.db.host, user=u, password=p, database=Database.DATABASE_NAME)
7168
self._cursor = self._connection.cursor()
7269

7370
def rowcount(self) -> int:
@@ -76,97 +73,85 @@ def rowcount(self) -> int:
7673

7774
def write_status(self, status_list: List[DashboardSignalStatus]) -> None:
7875
"""Write the provided status to the database."""
79-
insert_statement = f'''INSERT INTO `{Database.STATUS_TABLE_NAME}`
76+
insert_statement = f"""INSERT INTO `{Database.STATUS_TABLE_NAME}`
8077
(`signal_id`, `date`, `latest_issue`, `latest_time_value`)
8178
VALUES
8279
(%s, %s, %s, %s)
8380
ON DUPLICATE KEY UPDATE
8481
`latest_issue`=VALUES(`latest_issue`),
8582
`latest_time_value`=VALUES(`latest_time_value`)
86-
'''
87-
status_as_tuples = [
88-
(x.signal_id, x.date, x.latest_issue, x.latest_time_value)
89-
for x in status_list]
83+
"""
84+
status_as_tuples = [(x.signal_id, x.date, x.latest_issue, x.latest_time_value) for x in status_list]
9085
self._cursor.executemany(insert_statement, status_as_tuples)
9186

92-
latest_status_dates = {}
87+
latest_status_dates: Dict[int, datetime.date] = {}
9388
for x in status_list:
9489
latest_status_date = latest_status_dates.get(x.signal_id)
9590
if not latest_status_date or x.date > latest_status_date:
96-
latest_status_dates.update({x.signal_id: x.date})
91+
latest_status_dates[x.signal_id] = x.date
9792
latest_status_tuples = [(v, k) for k, v in latest_status_dates.items()]
9893

99-
update_statement = f'''UPDATE `{Database.SIGNAL_TABLE_NAME}`
94+
update_statement = f"""UPDATE `{Database.SIGNAL_TABLE_NAME}`
10095
SET `latest_status_update` = GREATEST(`latest_status_update`, %s)
10196
WHERE `id` = %s
102-
'''
97+
"""
10398
self._cursor.executemany(update_statement, latest_status_tuples)
10499

105100
self._connection.commit()
106101

107-
def write_coverage(
108-
self, coverage_list: List[DashboardSignalCoverage]) -> None:
102+
def write_coverage(self, coverage_list: List[DashboardSignalCoverage]) -> None:
109103
"""Write the provided coverage to the database."""
110-
insert_statement = f'''INSERT INTO `{Database.COVERAGE_TABLE_NAME}`
104+
insert_statement = f"""INSERT INTO `{Database.COVERAGE_TABLE_NAME}`
111105
(`signal_id`, `date`, `geo_type`, `count`)
112106
VALUES
113107
(%s, %s, %s, %s)
114108
ON DUPLICATE KEY UPDATE `count` = VALUES(`count`)
115-
'''
116-
coverage_as_tuples = [
117-
(x.signal_id, x.date, x.geo_type, x.count)
118-
for x in coverage_list]
109+
"""
110+
coverage_as_tuples = [(x.signal_id, x.date, x.geo_type, x.count) for x in coverage_list]
119111
self._cursor.executemany(insert_statement, coverage_as_tuples)
120112

121-
latest_coverage_dates = {}
122-
oldest_coverage_dates = {}
113+
latest_coverage_dates: Dict[int, datetime.date] = {}
114+
oldest_coverage_dates: Dict[int, datetime.date] = {}
123115
for x in coverage_list:
124116
latest_coverage_date = latest_coverage_dates.get(x.signal_id)
125117
oldest_coverage_date = oldest_coverage_dates.get(x.signal_id)
126118
if not latest_coverage_date or x.date > latest_coverage_date:
127-
latest_coverage_dates.update({x.signal_id: x.date})
119+
latest_coverage_dates[x.signal_id] = x.date
128120
if not oldest_coverage_date or x.date < oldest_coverage_date:
129-
oldest_coverage_dates.update({x.signal_id: x.date})
121+
oldest_coverage_dates[x.signal_id] = x.date
130122

131123
latest_coverage_tuples = [(v, k) for k, v in latest_coverage_dates.items()]
132124
oldest_coverage_tuples = [(v, k) for k, v in oldest_coverage_dates.items()]
133125

134-
update_statement = f'''UPDATE `{Database.SIGNAL_TABLE_NAME}`
126+
update_statement = f"""UPDATE `{Database.SIGNAL_TABLE_NAME}`
135127
SET `latest_coverage_update` = GREATEST(`latest_coverage_update`, %s)
136128
WHERE `id` = %s
137-
'''
129+
"""
138130
self._cursor.executemany(update_statement, latest_coverage_tuples)
139131

140-
delete_statement = f'''DELETE FROM `{Database.COVERAGE_TABLE_NAME}`
132+
delete_statement = f"""DELETE FROM `{Database.COVERAGE_TABLE_NAME}`
141133
WHERE `date` < %s
142134
AND `signal_id` = %s
143-
'''
135+
"""
144136
self._cursor.executemany(delete_statement, oldest_coverage_tuples)
145137

146138
self._connection.commit()
147139

148140
def get_enabled_signals(self) -> List[DashboardSignal]:
149141
"""Retrieve all enabled signals from the database"""
150-
select_statement = f'''SELECT `id`,
142+
select_statement = f"""SELECT `id`,
151143
`name`,
152144
`source`,
153145
`covidcast_signal`,
154-
`latest_coverage_update`,
146+
`latest_coverage_update`,
155147
`latest_status_update`
156148
FROM `{Database.SIGNAL_TABLE_NAME}`
157149
WHERE `enabled`
158-
'''
150+
"""
159151
self._cursor.execute(select_statement)
160-
enabled_signals = []
152+
enabled_signals: List[DashboardSignal] = []
161153
for result in self._cursor.fetchall():
162-
enabled_signals.append(
163-
DashboardSignal(
164-
db_id=result[0],
165-
name=result[1],
166-
source=result[2],
167-
covidcast_signal=result[3],
168-
latest_coverage_update=result[4],
169-
latest_status_update=result[5]))
154+
enabled_signals.append(DashboardSignal(db_id=result[0], name=result[1], source=result[2], covidcast_signal=result[3], latest_coverage_update=result[4], latest_status_update=result[5]))
170155
return enabled_signals
171156

172157

@@ -177,46 +162,31 @@ def get_argument_parser():
177162
return parser
178163

179164

180-
def get_latest_issue_from_metadata(dashboard_signal, metadata):
165+
def get_latest_issue_from_metadata(dashboard_signal: DashboardSignal, metadata: pd.DataFrame) -> datetime.date:
181166
"""Get the most recent issue date for the signal."""
182-
df_for_source = metadata[(metadata.data_source == dashboard_signal.source) & (
183-
metadata.signal == dashboard_signal.covidcast_signal)]
167+
df_for_source: pd.DataFrame = metadata[(metadata.data_source == dashboard_signal.source) & (metadata.signal == dashboard_signal.covidcast_signal)]
184168
max_issue = df_for_source["max_issue"].max()
185169
return pd.to_datetime(str(max_issue), format="%Y%m%d").date()
186170

187171

188-
def get_latest_time_value_from_metadata(dashboard_signal, metadata):
172+
def get_latest_time_value_from_metadata(dashboard_signal: DashboardSignal, metadata: pd.DataFrame) -> datetime.date:
189173
"""Get the most recent date with data for the signal."""
190-
df_for_source = metadata[(metadata.data_source == dashboard_signal.source) & (
191-
metadata.signal == dashboard_signal.covidcast_signal)]
174+
df_for_source: pd.DataFrame = metadata[(metadata.data_source == dashboard_signal.source) & (metadata.signal == dashboard_signal.covidcast_signal)]
192175
return df_for_source["max_time"].max().date()
193176

194177

195-
def get_coverage(dashboard_signal: DashboardSignal,
196-
metadata) -> List[DashboardSignalCoverage]:
178+
def get_coverage(dashboard_signal: DashboardSignal, metadata: pd.DataFrame) -> List[DashboardSignalCoverage]:
197179
"""Get the most recent coverage for the signal."""
198-
latest_time_value = get_latest_time_value_from_metadata(
199-
dashboard_signal, metadata)
200-
start_day = latest_time_value - datetime.timedelta(days = LOOKBACK_DAYS_FOR_COVERAGE)
201-
latest_data = covidcast.signal(
202-
dashboard_signal.source,
203-
dashboard_signal.covidcast_signal,
204-
end_day = latest_time_value,
205-
start_day = start_day)
206-
latest_data_without_megacounties = latest_data[~latest_data['geo_value'].str.endswith(
207-
'000')]
208-
count_by_geo_type_df = latest_data_without_megacounties.groupby(
209-
['geo_type', 'data_source', 'time_value', 'signal']).size().to_frame(
210-
'count').reset_index()
211-
212-
signal_coverage_list = []
213-
180+
latest_time_value = get_latest_time_value_from_metadata(dashboard_signal, metadata)
181+
start_day = latest_time_value - datetime.timedelta(days=LOOKBACK_DAYS_FOR_COVERAGE)
182+
latest_data: pd.DataFrame = covidcast.signal(dashboard_signal.source, dashboard_signal.covidcast_signal, end_day=latest_time_value, start_day=start_day)
183+
latest_data_without_megacounties: pd.DataFrame = latest_data[~latest_data["geo_value"].str.endswith("000")]
184+
count_by_geo_type_df = latest_data_without_megacounties.groupby(["geo_type", "data_source", "time_value", "signal"]).size().to_frame("count").reset_index()
185+
186+
signal_coverage_list: List[DashboardSignalCoverage] = []
187+
214188
for _, row in count_by_geo_type_df.iterrows():
215-
signal_coverage = DashboardSignalCoverage(
216-
signal_id=dashboard_signal.db_id,
217-
date=row['time_value'].date(),
218-
geo_type=row['geo_type'],
219-
count=row['count'])
189+
signal_coverage = DashboardSignalCoverage(signal_id=dashboard_signal.db_id, date=row["time_value"].date(), geo_type=row["geo_type"], count=row["count"])
220190
signal_coverage_list.append(signal_coverage)
221191

222192
return signal_coverage_list
@@ -231,36 +201,25 @@ def main(args):
231201
if args:
232202
log_file = args.log_file
233203

234-
logger = get_structured_logger(
235-
"signal_dash_data_generator",
236-
filename=log_file, log_exceptions=False)
204+
logger = get_structured_logger("signal_dash_data_generator", filename=log_file, log_exceptions=False)
237205
start_time = time.time()
238206

239207
database = Database()
240208

241209
signals_to_generate = database.get_enabled_signals()
242-
logger.info("Starting generating dashboard data.", enabled_signals=[
243-
signal.name for signal in signals_to_generate])
210+
logger.info("Starting generating dashboard data.", enabled_signals=[signal.name for signal in signals_to_generate])
244211

245-
metadata = covidcast.metadata()
212+
metadata: pd.DataFrame = covidcast.metadata()
246213

247214
signal_status_list: List[DashboardSignalStatus] = []
248215
coverage_list: List[DashboardSignalCoverage] = []
249216

250217
for dashboard_signal in signals_to_generate:
251-
latest_issue = get_latest_issue_from_metadata(
252-
dashboard_signal,
253-
metadata)
254-
latest_time_value = get_latest_time_value_from_metadata(
255-
dashboard_signal, metadata)
218+
latest_issue = get_latest_issue_from_metadata(dashboard_signal, metadata)
219+
latest_time_value = get_latest_time_value_from_metadata(dashboard_signal, metadata)
256220
latest_coverage = get_coverage(dashboard_signal, metadata)
257221

258-
signal_status_list.append(
259-
DashboardSignalStatus(
260-
signal_id=dashboard_signal.db_id,
261-
date=datetime.date.today(),
262-
latest_issue=latest_issue,
263-
latest_time_value=latest_time_value))
222+
signal_status_list.append(DashboardSignalStatus(signal_id=dashboard_signal.db_id, date=datetime.date.today(), latest_issue=latest_issue, latest_time_value=latest_time_value))
264223
coverage_list.extend(latest_coverage)
265224

266225
try:
@@ -275,12 +234,10 @@ def main(args):
275234
except mysql.connector.Error as exception:
276235
logger.exception(exception)
277236

278-
logger.info(
279-
"Generated signal dashboard data",
280-
total_runtime_in_seconds=round(time.time() - start_time, 2))
237+
logger.info("Generated signal dashboard data", total_runtime_in_seconds=round(time.time() - start_time, 2))
281238
return True
282239

283240

284-
if __name__ == '__main__':
241+
if __name__ == "__main__":
285242
if not main(get_argument_parser().parse_args()):
286243
sys.exit(1)

0 commit comments

Comments
 (0)
Please sign in to comment.