Skip to content

Commit 43419b8

Browse files
committed
Flask server: limit correlations
* add row limit to the request before forming dataframe * add tests
1 parent 471763b commit 43419b8

File tree

3 files changed

+156
-4
lines changed

3 files changed

+156
-4
lines changed

src/server/_pandas.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
from sqlalchemy import text
55

66
from ._common import engine
7+
from ._config import MAX_RESULTS
78
from ._printer import create_printer, APrinter
8-
from ._query import filter_fields
9+
from ._query import filter_fields, limit_query
910
from ._exceptions import DatabaseErrorException
1011

1112

12-
def as_pandas(query: str, params: Dict[str, Any], parse_dates: Optional[Dict[str, str]] = None) -> pd.DataFrame:
13+
def as_pandas(query: str, params: Dict[str, Any], parse_dates: Optional[Dict[str, str]] = None, limit_rows = MAX_RESULTS+1) -> pd.DataFrame:
1314
try:
15+
query = limit_query(query, limit_rows)
1416
return pd.read_sql_query(text(str(query)), engine, params=params, parse_dates=parse_dates)
1517
except Exception as e:
1618
raise DatabaseErrorException(str(e))

src/server/_query.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,10 +232,15 @@ def parse_result(
232232
return [parse_row(row, fields_string, fields_int, fields_float) for row in db.execute(text(query), **params)]
233233

234234

235+
def limit_query(query: str, limit: int) -> str:
236+
# limit rows + 1 for detecting whether we would have more
237+
full_query = f"{query} LIMIT {limit}"
238+
return full_query
239+
240+
235241
def run_query(p: APrinter, query_tuple: Tuple[str, Dict[str, Any]]):
236242
query, params = query_tuple
237-
# limit rows + 1 for detecting whether we would have more
238-
full_query = text(f"{query} LIMIT {p.remaining_rows + 1}")
243+
full_query = text(limit_query(query, p.remaining_rows + 1))
239244
app.logger.info("full_query: %s, params: %s", full_query, params)
240245
return db.execution_options(stream_results=True).execute(full_query, **params)
241246

tests/server/test_pandas.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""Unit tests for pandas helper."""
2+
3+
# standard library
4+
from dataclasses import dataclass
5+
from typing import Any, Dict, Iterable
6+
import unittest
7+
8+
# from flask.testing import FlaskClient
9+
import mysql.connector
10+
from delphi_utils import Nans
11+
12+
from delphi.epidata.server._pandas import as_pandas
13+
14+
# py3tester coverage target
15+
__test_target__ = "delphi.epidata.server._query"
16+
17+
18+
@dataclass
19+
class CovidcastRow:
20+
id: int = 0
21+
source: str = "src"
22+
signal: str = "sig"
23+
time_type: str = "day"
24+
geo_type: str = "county"
25+
time_value: int = 20200411
26+
geo_value: str = "01234"
27+
value_updated_timestamp: int = 20200202
28+
value: float = 10.0
29+
stderr: float = 0
30+
sample_size: float = 10
31+
direction_updated_timestamp: int = 20200202
32+
direction: int = 0
33+
issue: int = 20200202
34+
lag: int = 0
35+
is_latest_issue: bool = True
36+
is_wip: bool = False
37+
missing_value: int = Nans.NOT_MISSING
38+
missing_stderr: int = Nans.NOT_MISSING
39+
missing_sample_size: int = Nans.NOT_MISSING
40+
41+
def __str__(self):
42+
return f"""(
43+
{self.id},
44+
'{self.source}',
45+
'{self.signal}',
46+
'{self.time_type}',
47+
'{self.geo_type}',
48+
{self.time_value},
49+
'{self.geo_value}',
50+
{self.value_updated_timestamp},
51+
{self.value},
52+
{self.stderr},
53+
{self.sample_size},
54+
{self.direction_updated_timestamp},
55+
{self.direction},
56+
{self.issue},
57+
{self.lag},
58+
{self.is_latest_issue},
59+
{self.is_wip},
60+
{self.missing_value},
61+
{self.missing_stderr},
62+
{self.missing_sample_size}
63+
)"""
64+
65+
@staticmethod
66+
def from_json(json: Dict[str, Any]) -> "CovidcastRow":
67+
return CovidcastRow(
68+
source=json["source"],
69+
signal=json["signal"],
70+
time_type=json["time_type"],
71+
geo_type=json["geo_type"],
72+
geo_value=json["geo_value"],
73+
direction=json["direction"],
74+
issue=json["issue"],
75+
lag=json["lag"],
76+
value=json["value"],
77+
stderr=json["stderr"],
78+
sample_size=json["sample_size"],
79+
missing_value=json["missing_value"],
80+
missing_stderr=json["missing_stderr"],
81+
missing_sample_size=json["missing_sample_size"],
82+
)
83+
84+
@property
85+
def signal_pair(self):
86+
return f"{self.source}:{self.signal}"
87+
88+
@property
89+
def geo_pair(self):
90+
return f"{self.geo_type}:{self.geo_value}"
91+
92+
@property
93+
def time_pair(self):
94+
return f"{self.time_type}:{self.time_value}"
95+
96+
97+
class UnitTests(unittest.TestCase):
98+
"""Basic unit tests."""
99+
100+
def setUp(self):
101+
"""Perform per-test setup."""
102+
103+
# connect to the `epidata` database and clear the `covidcast` table
104+
cnx = mysql.connector.connect(user="user", password="pass", host="delphi_database_epidata", database="epidata")
105+
cur = cnx.cursor()
106+
cur.execute("truncate table covidcast")
107+
cur.execute('update covidcast_meta_cache set timestamp = 0, epidata = ""')
108+
cnx.commit()
109+
cur.close()
110+
111+
# make connection and cursor available to test cases
112+
self.cnx = cnx
113+
self.cur = cnx.cursor()
114+
115+
def tearDown(self):
116+
"""Perform per-test teardown."""
117+
self.cur.close()
118+
self.cnx.close()
119+
120+
def _insert_rows(self, rows: Iterable[CovidcastRow]):
121+
sql = ",\n".join((str(r) for r in rows))
122+
self.cur.execute(
123+
f"""
124+
INSERT INTO
125+
`covidcast` (`id`, `source`, `signal`, `time_type`, `geo_type`,
126+
`time_value`, `geo_value`, `value_updated_timestamp`,
127+
`value`, `stderr`, `sample_size`, `direction_updated_timestamp`,
128+
`direction`, `issue`, `lag`, `is_latest_issue`, `is_wip`,`missing_value`,
129+
`missing_stderr`,`missing_sample_size`)
130+
VALUES
131+
{sql}
132+
"""
133+
)
134+
self.cnx.commit()
135+
return rows
136+
137+
def test_as_pandas(self):
138+
rows = [CovidcastRow(time_value=20200401 + i, value=i) for i in range(10)]
139+
self._insert_rows(rows)
140+
141+
with self.subTest("simple"):
142+
query = "select * from `covidcast`"
143+
out = as_pandas(query, limit_rows=5)
144+
self.assertEqual(len(out["epidata"]), 5)
145+

0 commit comments

Comments
 (0)