Skip to content

Commit c5c1989

Browse files
committed
Server: add CovidcastRow helper class for testing
* no default values * helper functions for creating rows
1 parent f949439 commit c5c1989

File tree

4 files changed

+357
-19
lines changed

4 files changed

+357
-19
lines changed

integrations/server/test_covidcast.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,11 @@ def _insert_placeholder_set_four(self):
7777

7878
def _insert_placeholder_set_five(self):
7979
rows = [
80-
self._make_placeholder_row(time_value=2000_01_01, value=i*1., stderr=i*10., sample_size=i*100., issue=2000_01_03+i)[0]
80+
CovidcastRow(time_value=2000_01_01, value=i*1., stderr=i*10., sample_size=i*100., issue=2000_01_03+i)
8181
for i in [1, 2, 3]
8282
] + [
8383
# different time_values, same issues
84-
self._make_placeholder_row(time_value=2000_01_01+i-3, value=i*1., stderr=i*10., sample_size=i*100., issue=2000_01_03+i-3)[0]
84+
CovidcastRow(time_value=2000_01_01+i-3, value=i*1., stderr=i*10., sample_size=i*100., issue=2000_01_03+i-3)
8585
for i in [4, 5, 6]
8686
]
8787
self._insert_rows(rows)
@@ -254,18 +254,16 @@ def test_time_values_wildcard(self):
254254

255255
# insert placeholder data
256256
rows = self._insert_placeholder_set_three()
257-
expected_time_values = [
258-
self.expected_from_row(r) for r in rows[:3]
259-
]
257+
expected = [row.as_dict(ignore_fields=IGNORE_FIELDS) for row in rows[:3]]
260258

261259
# make the request
262-
response, _ = self.request_based_on_row(rows[0], time_values="*")
260+
response = self.request_based_on_row(rows[0], time_values="*")
263261

264262
self.maxDiff = None
265263
# assert that the right data came back
266264
self.assertEqual(response, {
267265
'result': 1,
268-
'epidata': expected_time_values,
266+
'epidata': expected,
269267
'message': 'success',
270268
})
271269

@@ -274,18 +272,16 @@ def test_issues_wildcard(self):
274272

275273
# insert placeholder data
276274
rows = self._insert_placeholder_set_five()
277-
expected_issues = [
278-
self.expected_from_row(r) for r in rows[:3]
279-
]
275+
expected = [row.as_dict(ignore_fields=IGNORE_FIELDS) for row in rows[:3]]
280276

281277
# make the request
282-
response, _ = self.request_based_on_row(rows[0], issues="*")
278+
response = self.request_based_on_row(rows[0], issues="*")
283279

284280
self.maxDiff = None
285281
# assert that the right data came back
286282
self.assertEqual(response, {
287283
'result': 1,
288-
'epidata': expected_issues,
284+
'epidata': expected,
289285
'message': 'success',
290286
})
291287

+241
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
from dataclasses import asdict, dataclass, fields
2+
from datetime import date
3+
from typing import Any, ClassVar, Dict, Iterable, List, Optional
4+
5+
import pandas as pd
6+
from delphi_utils import Nans
7+
8+
from delphi.epidata.server.utils.dates import day_to_time_value, time_value_to_day
9+
from delphi.epidata.server.endpoints.covidcast_utils.model import PANDAS_DTYPES
10+
11+
12+
@dataclass
13+
class CovidcastRow:
14+
"""A container for the values of a single covidcast database row.
15+
16+
Used for:
17+
- inserting rows into the database
18+
- creating test rows with default fields for testing
19+
- converting from and to formats (dict, csv, df, kwargs)
20+
- creating consistent views, with consistent data types (dict, csv, df)
21+
22+
The rows are specified in 'v4_schema.sql'. The datatypes are made to match database. When writing to Pandas, the dtypes match the JIT model.py schema.
23+
"""
24+
25+
# Arguments.
26+
source: str
27+
signal: str
28+
time_type: str
29+
geo_type: str
30+
time_value: int
31+
geo_value: str
32+
value: float
33+
stderr: float
34+
sample_size: float
35+
# The following three fields are Nans enums from delphi_utils.nans.
36+
missing_value: int
37+
missing_stderr: int
38+
missing_sample_size: int
39+
issue: Optional[int]
40+
lag: Optional[int]
41+
# The following four fields are only the database, but are not ingested at acquisition and not returned by the API.
42+
id: Optional[int]
43+
direction: Optional[int]
44+
direction_updated_timestamp: int
45+
value_updated_timestamp: int
46+
47+
# Classvars.
48+
_db_row_ignore_fields: ClassVar = []
49+
_api_row_ignore_fields: ClassVar = ["id", "direction_updated_timestamp", "value_updated_timestamp"]
50+
_api_row_compatibility_ignore_fields: ClassVar = ["id", "direction_updated_timestamp", "value_updated_timestamp", "source", "time_type", "geo_type"]
51+
_pandas_dtypes: ClassVar = PANDAS_DTYPES
52+
53+
@staticmethod
54+
def make_default_row(**kwargs) -> "CovidcastRow":
55+
default_args = {
56+
"source": "src",
57+
"signal": "sig",
58+
"time_type": "day",
59+
"geo_type": "county",
60+
"time_value": 20200202,
61+
"geo_value": "01234",
62+
"value": 10.0,
63+
"stderr": 10.0,
64+
"sample_size": 10.0,
65+
"missing_value": Nans.NOT_MISSING.value,
66+
"missing_stderr": Nans.NOT_MISSING.value,
67+
"missing_sample_size": Nans.NOT_MISSING.value,
68+
"issue": 20200202,
69+
"lag": 0,
70+
"id": None,
71+
"direction": None,
72+
"direction_updated_timestamp": 0,
73+
"value_updated_timestamp": 20200202,
74+
}
75+
default_args.update(kwargs)
76+
return CovidcastRow(**default_args)
77+
78+
def __post_init__(self):
79+
# Convert time values to ints by default.
80+
self.time_value = day_to_time_value(self.time_value) if isinstance(self.time_value, date) else self.time_value
81+
self.issue = day_to_time_value(self.issue) if isinstance(self.issue, date) else self.issue
82+
self.value_updated_timestamp = day_to_time_value(self.value_updated_timestamp) if isinstance(self.value_updated_timestamp, date) else self.value_updated_timestamp
83+
84+
def _sanity_check_fields(self, extra_checks: bool = True):
85+
if self.issue and self.issue < self.time_value:
86+
self.issue = self.time_value
87+
88+
if self.issue:
89+
self.lag = (time_value_to_day(self.issue) - time_value_to_day(self.time_value)).days
90+
else:
91+
self.lag = None
92+
93+
# This sanity checking is already done in CsvImporter, but it's here so the testing class gets it too.
94+
if pd.isna(self.value) and self.missing_value == Nans.NOT_MISSING:
95+
self.missing_value = Nans.NOT_APPLICABLE.value if extra_checks else Nans.OTHER.value
96+
97+
if pd.isna(self.stderr) and self.missing_stderr == Nans.NOT_MISSING:
98+
self.missing_stderr = Nans.NOT_APPLICABLE.value if extra_checks else Nans.OTHER.value
99+
100+
if pd.isna(self.sample_size) and self.missing_sample_size == Nans.NOT_MISSING:
101+
self.missing_sample_size = Nans.NOT_APPLICABLE.value if extra_checks else Nans.OTHER.value
102+
103+
return self
104+
105+
def as_dict(self, ignore_fields: Optional[List[str]] = None) -> dict:
106+
d = asdict(self)
107+
if ignore_fields:
108+
for key in ignore_fields:
109+
del d[key]
110+
return d
111+
112+
def as_dataframe(self, ignore_fields: Optional[List[str]] = None) -> pd.DataFrame:
113+
df = pd.DataFrame.from_records([self.as_dict(ignore_fields=ignore_fields)])
114+
# This is to mirror the types in model.py.
115+
df = set_df_dtypes(df, self._pandas_dtypes)
116+
return df
117+
118+
@property
119+
def api_row_df(self) -> pd.DataFrame:
120+
"""Returns a dataframe view into the row with the fields returned by the API server."""
121+
return self.as_dataframe(ignore_fields=self._api_row_ignore_fields)
122+
123+
@property
124+
def api_compatibility_row_df(self) -> pd.DataFrame:
125+
"""Returns a dataframe view into the row with the fields returned by the old API server (the PHP server)."""
126+
return self.as_dataframe(ignore_fields=self._api_row_compatibility_ignore_fields)
127+
128+
@property
129+
def db_row_df(self) -> pd.DataFrame:
130+
"""Returns a dataframe view into the row with the fields returned by an all-field database query."""
131+
return self.as_dataframe(ignore_fields=self._db_row_ignore_fields)
132+
133+
@property
134+
def signal_pair(self):
135+
return f"{self.source}:{self.signal}"
136+
137+
@property
138+
def geo_pair(self):
139+
return f"{self.geo_type}:{self.geo_value}"
140+
141+
@property
142+
def time_pair(self):
143+
return f"{self.time_type}:{self.time_value}"
144+
145+
146+
def covidcast_rows_from_args(sanity_check: bool = True, test_mode: bool = True, **kwargs: Dict[str, Iterable]) -> List[CovidcastRow]:
147+
"""A convenience constructor.
148+
149+
Handy for constructing batches of test cases.
150+
151+
Example:
152+
covidcast_rows_from_args(value=[1, 2, 3], time_value=[1, 2, 3]) will yield
153+
[CovidcastRow.make_default_row(value=1, time_value=1), CovidcastRow.make_default_row(value=2, time_value=2), CovidcastRow.make_default_row(value=3, time_value=3)]
154+
with all the defaults from CovidcastRow.
155+
"""
156+
# If any iterables were passed instead of lists, convert them to lists.
157+
kwargs = {key: list(value) for key, value in kwargs.items()}
158+
# All the arg values must be lists of the same length.
159+
assert len(set(len(lst) for lst in kwargs.values())) == 1
160+
161+
if sanity_check:
162+
return [CovidcastRow.make_default_row(**_kwargs)._sanity_check_fields(extra_checks=test_mode) for _kwargs in transpose_dict(kwargs)]
163+
else:
164+
return [CovidcastRow.make_default_row(**_kwargs) for _kwargs in transpose_dict(kwargs)]
165+
166+
167+
def covidcast_rows_from_records(records: Iterable[dict], sanity_check: bool = False) -> List[CovidcastRow]:
168+
"""A convenience constructor.
169+
170+
Default is different from from_args, because from_records is usually called on faux-API returns in tests,
171+
where we don't want any values getting default filled in.
172+
173+
You can use csv.DictReader before this to read a CSV file.
174+
"""
175+
records = list(records)
176+
return [CovidcastRow.make_default_row(**record) if not sanity_check else CovidcastRow.make_default_row(**record)._sanity_check_fields() for record in records]
177+
178+
179+
def covidcast_rows_as_dicts(rows: Iterable[CovidcastRow], ignore_fields: Optional[List[str]] = None) -> List[dict]:
180+
return [row.as_dict(ignore_fields=ignore_fields) for row in rows]
181+
182+
183+
def covidcast_rows_as_dataframe(rows: Iterable[CovidcastRow], ignore_fields: Optional[List[str]] = None) -> pd.DataFrame:
184+
if ignore_fields is None:
185+
ignore_fields = []
186+
187+
columns = [field.name for field in fields(CovidcastRow) if field.name not in ignore_fields]
188+
189+
if rows:
190+
df = pd.concat([row.as_dataframe(ignore_fields=ignore_fields) for row in rows], ignore_index=True)
191+
return df[columns]
192+
else:
193+
return pd.DataFrame(columns=columns)
194+
195+
196+
def covidcast_rows_as_api_row_df(rows: Iterable[CovidcastRow]) -> pd.DataFrame:
197+
return covidcast_rows_as_dataframe(rows, ignore_fields=CovidcastRow._api_row_ignore_fields)
198+
199+
200+
def covidcast_rows_as_api_compatibility_row_df(rows: Iterable[CovidcastRow]) -> pd.DataFrame:
201+
return covidcast_rows_as_dataframe(rows, ignore_fields=CovidcastRow._api_row_compatibility_ignore_fields)
202+
203+
204+
def covidcast_rows_as_db_row_df(rows: Iterable[CovidcastRow]) -> pd.DataFrame:
205+
return covidcast_rows_as_dataframe(rows, ignore_fields=CovidcastRow._db_row_ignore_fields)
206+
207+
208+
def transpose_dict(d: Dict[Any, List[Any]]) -> List[Dict[Any, Any]]:
209+
"""Given a dictionary whose values are lists of the same length, turn it into a list of dictionaries whose values are the individual list entries.
210+
211+
Example:
212+
>>> transpose_dict(dict([["a", [2, 4, 6]], ["b", [3, 5, 7]], ["c", [10, 20, 30]]]))
213+
[{"a": 2, "b": 3, "c": 10}, {"a": 4, "b": 5, "c": 20}, {"a": 6, "b": 7, "c": 30}]
214+
"""
215+
return [dict(zip(d.keys(), values)) for values in zip(*d.values())]
216+
217+
218+
def check_valid_dtype(dtype):
219+
try:
220+
pd.api.types.pandas_dtype(dtype)
221+
except TypeError:
222+
raise ValueError(f"Invalid dtype {dtype}")
223+
224+
225+
def set_df_dtypes(df: pd.DataFrame, dtypes: Dict[str, Any]) -> pd.DataFrame:
226+
"""Set the dataframe column datatypes."""
227+
[check_valid_dtype(d) for d in dtypes.values()]
228+
229+
df = df.copy()
230+
for k, v in dtypes.items():
231+
if k in df.columns:
232+
df[k] = df[k].astype(v)
233+
return df
234+
235+
236+
def assert_frame_equal_no_order(df1: pd.DataFrame, df2: pd.DataFrame, index: List[str], **kwargs: Any) -> None:
237+
"""Assert that two DataFrames are equal, ignoring the order of rows."""
238+
# Remove any existing index. If it wasn't named, drop it. Set a new index and sort it.
239+
df1 = df1.reset_index().drop(columns="index").set_index(index).sort_index()
240+
df2 = df2.reset_index().drop(columns="index").set_index(index).sort_index()
241+
pd.testing.assert_frame_equal(df1, df2, **kwargs)

src/acquisition/covidcast/database.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,18 @@
22
33
See src/ddl/covidcast.sql for an explanation of each field.
44
"""
5+
import threading
6+
from math import ceil
7+
from multiprocessing import cpu_count
8+
from queue import Queue, Empty
9+
from typing import List
510

611
# third party
712
import json
813
import mysql.connector
9-
import numpy as np
10-
from math import ceil
11-
12-
from queue import Queue, Empty
13-
import threading
14-
from multiprocessing import cpu_count
1514

1615
# first party
1716
import delphi.operations.secrets as secrets
18-
1917
from delphi.epidata.acquisition.covidcast.logger import get_structured_logger
2018

2119
class CovidcastRow():

0 commit comments

Comments
 (0)