Skip to content

Commit dec134a

Browse files
committed
Limit correlations:
* add test magic so as_pandas can be called
1 parent 0f22008 commit dec134a

File tree

3 files changed

+21
-28
lines changed

3 files changed

+21
-28
lines changed

src/server/_pandas.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,19 @@
22
import pandas as pd
33

44
from sqlalchemy import text
5+
from sqlalchemy.engine.base import Engine
56

67
from ._common import engine
78
from ._config import MAX_RESULTS
8-
from ._printer import create_printer, APrinter
9+
from ._printer import create_printer
910
from ._query import filter_fields, limit_query
1011
from ._exceptions import DatabaseErrorException
1112

1213

13-
def as_pandas(query: str, params: Dict[str, Any], parse_dates: Optional[Dict[str, str]] = None, limit_rows = MAX_RESULTS+1) -> pd.DataFrame:
14+
def as_pandas(query: str, params: Dict[str, Any], db_engine: Engine = engine, parse_dates: Optional[Dict[str, str]] = None, limit_rows = MAX_RESULTS+1) -> pd.DataFrame:
1415
try:
1516
query = limit_query(query, limit_rows)
16-
return pd.read_sql_query(text(str(query)), engine, params=params, parse_dates=parse_dates)
17+
return pd.read_sql_query(text(str(query)), db_engine, params=params, parse_dates=parse_dates)
1718
except Exception as e:
1819
raise DatabaseErrorException(str(e))
1920

src/server/_query.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,13 +233,13 @@ def parse_result(
233233

234234

235235
def limit_query(query: str, limit: int) -> str:
236-
# limit rows + 1 for detecting whether we would have more
237236
full_query = f"{query} LIMIT {limit}"
238237
return full_query
239238

240239

241240
def run_query(p: APrinter, query_tuple: Tuple[str, Dict[str, Any]]):
242241
query, params = query_tuple
242+
# limit rows + 1 for detecting whether we would have more
243243
full_query = text(limit_query(query, p.remaining_rows + 1))
244244
app.logger.info("full_query: %s, params: %s", full_query, params)
245245
return db.execution_options(stream_results=True).execute(full_query, **params)

tests/server/test_pandas.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,13 @@
66
import unittest
77

88
import pandas as pd
9-
from pandas.core.frame import DataFrame
10-
from sqlalchemy import text
11-
import mysql.connector
9+
from sqlalchemy import create_engine
1210

1311
# from flask.testing import FlaskClient
1412
from delphi_utils import Nans
13+
from delphi.epidata.server.main import app
1514
from delphi.epidata.server._pandas import as_pandas
16-
from delphi.epidata.server._query import limit_query
15+
from delphi.epidata.server._query import QueryBuilder
1716

1817
# py3tester coverage target
1918
__test_target__ = "delphi.epidata.server._query"
@@ -111,27 +110,26 @@ class UnitTests(unittest.TestCase):
111110

112111
def setUp(self):
113112
"""Perform per-test setup."""
113+
app.config["TESTING"] = True
114+
app.config["WTF_CSRF_ENABLED"] = False
115+
app.config["DEBUG"] = False
114116

115117
# connect to the `epidata` database and clear the `covidcast` table
116-
cnx = mysql.connector.connect(user="user", password="pass", host="delphi_database_epidata", database="epidata")
117-
cur = cnx.cursor()
118-
cur.execute("truncate table covidcast")
119-
cur.execute('update covidcast_meta_cache set timestamp = 0, epidata = ""')
120-
cnx.commit()
121-
cur.close()
118+
engine = create_engine('mysql://user:pass@delphi_database_epidata/epidata')
119+
cnx = engine.connect()
120+
cnx.execute("truncate table covidcast")
121+
cnx.execute('update covidcast_meta_cache set timestamp = 0, epidata = ""')
122122

123123
# make connection and cursor available to test cases
124124
self.cnx = cnx
125-
self.cur = cnx.cursor()
126125

127126
def tearDown(self):
128127
"""Perform per-test teardown."""
129-
self.cur.close()
130128
self.cnx.close()
131129

132130
def _insert_rows(self, rows: Iterable[CovidcastRow]):
133131
sql = ",\n".join((str(r) for r in rows))
134-
self.cur.execute(
132+
self.cnx.execute(
135133
f"""
136134
INSERT INTO
137135
`covidcast` (`id`, `source`, `signal`, `time_type`, `geo_type`,
@@ -143,7 +141,6 @@ def _insert_rows(self, rows: Iterable[CovidcastRow]):
143141
{sql}
144142
"""
145143
)
146-
self.cnx.commit()
147144
return rows
148145

149146
def _rows_to_df(self, rows: Iterable[CovidcastRow]) -> pd.DataFrame:
@@ -160,17 +157,12 @@ def test_as_pandas(self):
160157
rows = [CovidcastRow(time_value=20200401 + i, value=float(i)) for i in range(10)]
161158
self._insert_rows(rows)
162159

163-
with self.subTest("simple"):
164-
query = """select * from `covidcast`"""
165-
params = {}
166-
parse_dates = None
167-
engine = self.cnx
168-
df = pd.read_sql_query(str(query), engine, params=params, parse_dates=parse_dates)
169-
df = df.astype({"is_latest_issue": bool, "is_wip": bool})
160+
with app.test_request_context('/correlation'):
161+
q = QueryBuilder("covidcast", "t")
162+
163+
df = as_pandas(str(q), params={}, db_engine=self.cnx, parse_dates=None).astype({"is_latest_issue": bool, "is_wip": bool})
170164
expected_df = self._rows_to_df(rows)
171165
pd.testing.assert_frame_equal(df, expected_df)
172-
query = limit_query(query, 5)
173-
df = pd.read_sql_query(str(query), engine, params=params, parse_dates=parse_dates)
174-
df = df.astype({"is_latest_issue": bool, "is_wip": bool})
166+
df = as_pandas(str(q), params={}, db_engine=self.cnx, parse_dates=None, limit_rows=5).astype({"is_latest_issue": bool, "is_wip": bool})
175167
expected_df = self._rows_to_df(rows[:5])
176168
pd.testing.assert_frame_equal(df, expected_df)

0 commit comments

Comments
 (0)