From 43419b888f1b791b267ab934f39aafea3d03d1be Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Thu, 2 Sep 2021 17:52:24 -0700 Subject: [PATCH 1/3] Flask server: limit correlations * add row limit to the request before forming dataframe * add tests --- src/server/_pandas.py | 6 +- src/server/_query.py | 9 ++- tests/server/test_pandas.py | 145 ++++++++++++++++++++++++++++++++++++ 3 files changed, 156 insertions(+), 4 deletions(-) create mode 100644 tests/server/test_pandas.py diff --git a/src/server/_pandas.py b/src/server/_pandas.py index 8ab10047b..7e56708d3 100644 --- a/src/server/_pandas.py +++ b/src/server/_pandas.py @@ -4,13 +4,15 @@ from sqlalchemy import text from ._common import engine +from ._config import MAX_RESULTS from ._printer import create_printer, APrinter -from ._query import filter_fields +from ._query import filter_fields, limit_query from ._exceptions import DatabaseErrorException -def as_pandas(query: str, params: Dict[str, Any], parse_dates: Optional[Dict[str, str]] = None) -> pd.DataFrame: +def as_pandas(query: str, params: Dict[str, Any], parse_dates: Optional[Dict[str, str]] = None, limit_rows = MAX_RESULTS+1) -> pd.DataFrame: try: + query = limit_query(query, limit_rows) return pd.read_sql_query(text(str(query)), engine, params=params, parse_dates=parse_dates) except Exception as e: raise DatabaseErrorException(str(e)) diff --git a/src/server/_query.py b/src/server/_query.py index 1a915b85e..42346e954 100644 --- a/src/server/_query.py +++ b/src/server/_query.py @@ -232,10 +232,15 @@ def parse_result( return [parse_row(row, fields_string, fields_int, fields_float) for row in db.execute(text(query), **params)] +def limit_query(query: str, limit: int) -> str: + # limit rows + 1 for detecting whether we would have more + full_query = f"{query} LIMIT {limit}" + return full_query + + def run_query(p: APrinter, query_tuple: Tuple[str, Dict[str, Any]]): query, params = query_tuple - # limit rows + 1 for detecting whether we would have more - full_query = text(f"{query} LIMIT {p.remaining_rows + 1}") + full_query = text(limit_query(query, p.remaining_rows + 1)) app.logger.info("full_query: %s, params: %s", full_query, params) return db.execution_options(stream_results=True).execute(full_query, **params) diff --git a/tests/server/test_pandas.py b/tests/server/test_pandas.py new file mode 100644 index 000000000..a4b3a435f --- /dev/null +++ b/tests/server/test_pandas.py @@ -0,0 +1,145 @@ +"""Unit tests for pandas helper.""" + +# standard library +from dataclasses import dataclass +from typing import Any, Dict, Iterable +import unittest + +# from flask.testing import FlaskClient +import mysql.connector +from delphi_utils import Nans + +from delphi.epidata.server._pandas import as_pandas + +# py3tester coverage target +__test_target__ = "delphi.epidata.server._query" + + +@dataclass +class CovidcastRow: + id: int = 0 + source: str = "src" + signal: str = "sig" + time_type: str = "day" + geo_type: str = "county" + time_value: int = 20200411 + geo_value: str = "01234" + value_updated_timestamp: int = 20200202 + value: float = 10.0 + stderr: float = 0 + sample_size: float = 10 + direction_updated_timestamp: int = 20200202 + direction: int = 0 + issue: int = 20200202 + lag: int = 0 + is_latest_issue: bool = True + is_wip: bool = False + missing_value: int = Nans.NOT_MISSING + missing_stderr: int = Nans.NOT_MISSING + missing_sample_size: int = Nans.NOT_MISSING + + def __str__(self): + return f"""( + {self.id}, + '{self.source}', + '{self.signal}', + '{self.time_type}', + '{self.geo_type}', + {self.time_value}, + '{self.geo_value}', + {self.value_updated_timestamp}, + {self.value}, + {self.stderr}, + {self.sample_size}, + {self.direction_updated_timestamp}, + {self.direction}, + {self.issue}, + {self.lag}, + {self.is_latest_issue}, + {self.is_wip}, + {self.missing_value}, + {self.missing_stderr}, + {self.missing_sample_size} + )""" + + @staticmethod + def from_json(json: Dict[str, Any]) -> "CovidcastRow": + return CovidcastRow( + source=json["source"], + signal=json["signal"], + time_type=json["time_type"], + geo_type=json["geo_type"], + geo_value=json["geo_value"], + direction=json["direction"], + issue=json["issue"], + lag=json["lag"], + value=json["value"], + stderr=json["stderr"], + sample_size=json["sample_size"], + missing_value=json["missing_value"], + missing_stderr=json["missing_stderr"], + missing_sample_size=json["missing_sample_size"], + ) + + @property + def signal_pair(self): + return f"{self.source}:{self.signal}" + + @property + def geo_pair(self): + return f"{self.geo_type}:{self.geo_value}" + + @property + def time_pair(self): + return f"{self.time_type}:{self.time_value}" + + +class UnitTests(unittest.TestCase): + """Basic unit tests.""" + + def setUp(self): + """Perform per-test setup.""" + + # connect to the `epidata` database and clear the `covidcast` table + cnx = mysql.connector.connect(user="user", password="pass", host="delphi_database_epidata", database="epidata") + cur = cnx.cursor() + cur.execute("truncate table covidcast") + cur.execute('update covidcast_meta_cache set timestamp = 0, epidata = ""') + cnx.commit() + cur.close() + + # make connection and cursor available to test cases + self.cnx = cnx + self.cur = cnx.cursor() + + def tearDown(self): + """Perform per-test teardown.""" + self.cur.close() + self.cnx.close() + + def _insert_rows(self, rows: Iterable[CovidcastRow]): + sql = ",\n".join((str(r) for r in rows)) + self.cur.execute( + f""" + INSERT INTO + `covidcast` (`id`, `source`, `signal`, `time_type`, `geo_type`, + `time_value`, `geo_value`, `value_updated_timestamp`, + `value`, `stderr`, `sample_size`, `direction_updated_timestamp`, + `direction`, `issue`, `lag`, `is_latest_issue`, `is_wip`,`missing_value`, + `missing_stderr`,`missing_sample_size`) + VALUES + {sql} + """ + ) + self.cnx.commit() + return rows + + def test_as_pandas(self): + rows = [CovidcastRow(time_value=20200401 + i, value=i) for i in range(10)] + self._insert_rows(rows) + + with self.subTest("simple"): + query = "select * from `covidcast`" + out = as_pandas(query, limit_rows=5) + self.assertEqual(len(out["epidata"]), 5) + From 0f220084735bbaecd57fb8e0e3e16e086670df48 Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Fri, 3 Sep 2021 11:46:14 -0700 Subject: [PATCH 2/3] Limit correlations: update tests * wrote test to ensure row limits work --- tests/server/test_pandas.py | 53 +++++++++++++++++++++++++++++-------- 1 file changed, 42 insertions(+), 11 deletions(-) diff --git a/tests/server/test_pandas.py b/tests/server/test_pandas.py index a4b3a435f..210e7a857 100644 --- a/tests/server/test_pandas.py +++ b/tests/server/test_pandas.py @@ -1,15 +1,19 @@ """Unit tests for pandas helper.""" # standard library -from dataclasses import dataclass +from dataclasses import dataclass, astuple from typing import Any, Dict, Iterable import unittest -# from flask.testing import FlaskClient +import pandas as pd +from pandas.core.frame import DataFrame +from sqlalchemy import text import mysql.connector -from delphi_utils import Nans +# from flask.testing import FlaskClient +from delphi_utils import Nans from delphi.epidata.server._pandas import as_pandas +from delphi.epidata.server._query import limit_query # py3tester coverage target __test_target__ = "delphi.epidata.server._query" @@ -26,14 +30,14 @@ class CovidcastRow: geo_value: str = "01234" value_updated_timestamp: int = 20200202 value: float = 10.0 - stderr: float = 0 - sample_size: float = 10 + stderr: float = 0. + sample_size: float = 10. direction_updated_timestamp: int = 20200202 direction: int = 0 issue: int = 20200202 lag: int = 0 is_latest_issue: bool = True - is_wip: bool = False + is_wip: bool = True missing_value: int = Nans.NOT_MISSING missing_stderr: int = Nans.NOT_MISSING missing_sample_size: int = Nans.NOT_MISSING @@ -93,6 +97,14 @@ def geo_pair(self): def time_pair(self): return f"{self.time_type}:{self.time_value}" + @property + def astuple(self): + return astuple(self)[1:] + + @property + def aslist(self): + return list(self.astuple) + class UnitTests(unittest.TestCase): """Basic unit tests.""" @@ -134,12 +146,31 @@ def _insert_rows(self, rows: Iterable[CovidcastRow]): self.cnx.commit() return rows + def _rows_to_df(self, rows: Iterable[CovidcastRow]) -> pd.DataFrame: + columns = [ + 'id', 'source', 'signal', 'time_type', 'geo_type', 'time_value', + 'geo_value', 'value_updated_timestamp', 'value', 'stderr', + 'sample_size', 'direction_updated_timestamp', 'direction', 'issue', + 'lag', 'is_latest_issue', 'is_wip', 'missing_value', 'missing_stderr', + 'missing_sample_size' + ] + return pd.DataFrame.from_records([[i] + row.aslist for i, row in enumerate(rows, start=1)], columns=columns) + def test_as_pandas(self): - rows = [CovidcastRow(time_value=20200401 + i, value=i) for i in range(10)] + rows = [CovidcastRow(time_value=20200401 + i, value=float(i)) for i in range(10)] self._insert_rows(rows) with self.subTest("simple"): - query = "select * from `covidcast`" - out = as_pandas(query, limit_rows=5) - self.assertEqual(len(out["epidata"]), 5) - + query = """select * from `covidcast`""" + params = {} + parse_dates = None + engine = self.cnx + df = pd.read_sql_query(str(query), engine, params=params, parse_dates=parse_dates) + df = df.astype({"is_latest_issue": bool, "is_wip": bool}) + expected_df = self._rows_to_df(rows) + pd.testing.assert_frame_equal(df, expected_df) + query = limit_query(query, 5) + df = pd.read_sql_query(str(query), engine, params=params, parse_dates=parse_dates) + df = df.astype({"is_latest_issue": bool, "is_wip": bool}) + expected_df = self._rows_to_df(rows[:5]) + pd.testing.assert_frame_equal(df, expected_df) From dec134a8ffbfea7aa3c8ecdb88365f15640a1f6c Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Sun, 12 Sep 2021 02:42:53 -0700 Subject: [PATCH 3/3] Limit correlations: * add test magic so as_pandas can be called --- src/server/_pandas.py | 7 ++++--- src/server/_query.py | 2 +- tests/server/test_pandas.py | 40 +++++++++++++++---------------------- 3 files changed, 21 insertions(+), 28 deletions(-) diff --git a/src/server/_pandas.py b/src/server/_pandas.py index 7e56708d3..54f8f99dc 100644 --- a/src/server/_pandas.py +++ b/src/server/_pandas.py @@ -2,18 +2,19 @@ import pandas as pd from sqlalchemy import text +from sqlalchemy.engine.base import Engine from ._common import engine from ._config import MAX_RESULTS -from ._printer import create_printer, APrinter +from ._printer import create_printer from ._query import filter_fields, limit_query from ._exceptions import DatabaseErrorException -def as_pandas(query: str, params: Dict[str, Any], parse_dates: Optional[Dict[str, str]] = None, limit_rows = MAX_RESULTS+1) -> pd.DataFrame: +def as_pandas(query: str, params: Dict[str, Any], db_engine: Engine = engine, parse_dates: Optional[Dict[str, str]] = None, limit_rows = MAX_RESULTS+1) -> pd.DataFrame: try: query = limit_query(query, limit_rows) - return pd.read_sql_query(text(str(query)), engine, params=params, parse_dates=parse_dates) + return pd.read_sql_query(text(str(query)), db_engine, params=params, parse_dates=parse_dates) except Exception as e: raise DatabaseErrorException(str(e)) diff --git a/src/server/_query.py b/src/server/_query.py index 42346e954..2e95d081a 100644 --- a/src/server/_query.py +++ b/src/server/_query.py @@ -233,13 +233,13 @@ def parse_result( def limit_query(query: str, limit: int) -> str: - # limit rows + 1 for detecting whether we would have more full_query = f"{query} LIMIT {limit}" return full_query def run_query(p: APrinter, query_tuple: Tuple[str, Dict[str, Any]]): query, params = query_tuple + # limit rows + 1 for detecting whether we would have more full_query = text(limit_query(query, p.remaining_rows + 1)) app.logger.info("full_query: %s, params: %s", full_query, params) return db.execution_options(stream_results=True).execute(full_query, **params) diff --git a/tests/server/test_pandas.py b/tests/server/test_pandas.py index 210e7a857..99cd69d19 100644 --- a/tests/server/test_pandas.py +++ b/tests/server/test_pandas.py @@ -6,14 +6,13 @@ import unittest import pandas as pd -from pandas.core.frame import DataFrame -from sqlalchemy import text -import mysql.connector +from sqlalchemy import create_engine # from flask.testing import FlaskClient from delphi_utils import Nans +from delphi.epidata.server.main import app from delphi.epidata.server._pandas import as_pandas -from delphi.epidata.server._query import limit_query +from delphi.epidata.server._query import QueryBuilder # py3tester coverage target __test_target__ = "delphi.epidata.server._query" @@ -111,27 +110,26 @@ class UnitTests(unittest.TestCase): def setUp(self): """Perform per-test setup.""" + app.config["TESTING"] = True + app.config["WTF_CSRF_ENABLED"] = False + app.config["DEBUG"] = False # connect to the `epidata` database and clear the `covidcast` table - cnx = mysql.connector.connect(user="user", password="pass", host="delphi_database_epidata", database="epidata") - cur = cnx.cursor() - cur.execute("truncate table covidcast") - cur.execute('update covidcast_meta_cache set timestamp = 0, epidata = ""') - cnx.commit() - cur.close() + engine = create_engine('mysql://user:pass@delphi_database_epidata/epidata') + cnx = engine.connect() + cnx.execute("truncate table covidcast") + cnx.execute('update covidcast_meta_cache set timestamp = 0, epidata = ""') # make connection and cursor available to test cases self.cnx = cnx - self.cur = cnx.cursor() def tearDown(self): """Perform per-test teardown.""" - self.cur.close() self.cnx.close() def _insert_rows(self, rows: Iterable[CovidcastRow]): sql = ",\n".join((str(r) for r in rows)) - self.cur.execute( + self.cnx.execute( f""" INSERT INTO `covidcast` (`id`, `source`, `signal`, `time_type`, `geo_type`, @@ -143,7 +141,6 @@ def _insert_rows(self, rows: Iterable[CovidcastRow]): {sql} """ ) - self.cnx.commit() return rows def _rows_to_df(self, rows: Iterable[CovidcastRow]) -> pd.DataFrame: @@ -160,17 +157,12 @@ def test_as_pandas(self): rows = [CovidcastRow(time_value=20200401 + i, value=float(i)) for i in range(10)] self._insert_rows(rows) - with self.subTest("simple"): - query = """select * from `covidcast`""" - params = {} - parse_dates = None - engine = self.cnx - df = pd.read_sql_query(str(query), engine, params=params, parse_dates=parse_dates) - df = df.astype({"is_latest_issue": bool, "is_wip": bool}) + with app.test_request_context('/correlation'): + q = QueryBuilder("covidcast", "t") + + df = as_pandas(str(q), params={}, db_engine=self.cnx, parse_dates=None).astype({"is_latest_issue": bool, "is_wip": bool}) expected_df = self._rows_to_df(rows) pd.testing.assert_frame_equal(df, expected_df) - query = limit_query(query, 5) - df = pd.read_sql_query(str(query), engine, params=params, parse_dates=parse_dates) - df = df.astype({"is_latest_issue": bool, "is_wip": bool}) + df = as_pandas(str(q), params={}, db_engine=self.cnx, parse_dates=None, limit_rows=5).astype({"is_latest_issue": bool, "is_wip": bool}) expected_df = self._rows_to_df(rows[:5]) pd.testing.assert_frame_equal(df, expected_df)