From 40b25c53c50e662f705f973d5ae8e8d4db80f7ef Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Fri, 7 Oct 2022 13:57:52 -0700 Subject: [PATCH 01/29] Server: add CovidcastRow helper class for testing --- src/acquisition/covidcast/covidcast_row.py | 269 ++++++++++++++++++ .../covidcast/test_covidcast_row.py | 90 ++++++ 2 files changed, 359 insertions(+) create mode 100644 src/acquisition/covidcast/covidcast_row.py create mode 100644 tests/acquisition/covidcast/test_covidcast_row.py diff --git a/src/acquisition/covidcast/covidcast_row.py b/src/acquisition/covidcast/covidcast_row.py new file mode 100644 index 000000000..af57b0b28 --- /dev/null +++ b/src/acquisition/covidcast/covidcast_row.py @@ -0,0 +1,269 @@ +from dataclasses import asdict, dataclass, field, fields +from datetime import date +from typing import Any, Dict, Iterable, List, Optional, Union, get_args, get_origin + +from delphi_utils import Nans +from numpy import isnan +from pandas import DataFrame, concat + +from .csv_importer import CsvImporter +from ...server.utils.dates import date_to_time_value, time_value_to_date + + +def _is_none(v: Optional[float]) -> bool: + return True if v is None or (v is not None and isnan(v)) else False + +@dataclass +class CovidcastRow: + """A container for (most of) the values of a single covidcast database row. + + Used for: + - inserting rows into the database + - creating test rows with default fields for testing + - created from many formats (dict, csv, df, kwargs) + - can be viewed in many formats (dict, csv, df) + + The rows are specified in 'v4_schema.sql'. + """ + + source: str = "src" + signal: str = "sig" + time_type: str = "day" + geo_type: str = "county" + time_value: int = 20200202 # Can be initialized with datetime.date + geo_value: str = "01234" + value: float = 10.0 + stderr: float = 10.0 + sample_size: float = 10.0 + missing_value: int = Nans.NOT_MISSING.value + missing_stderr: int = Nans.NOT_MISSING.value + missing_sample_size: int = Nans.NOT_MISSING.value + issue: Optional[int] = 20200202 # Can be initialized with datetime.date + lag: Optional[int] = 0 + id: Optional[int] = None + direction: Optional[int] = None + direction_updated_timestamp: int = 0 + value_updated_timestamp: int = 20200202 # Can be initialized with datetime.date + + def __post_init__(self): + # Convert time values to ints by default. + self.time_value = date_to_time_value(self.time_value) if isinstance(self.time_value, date) else self.time_value + self.issue = date_to_time_value(self.issue) if isinstance(self.issue, date) else self.issue + self.value_updated_timestamp = date_to_time_value(self.value_updated_timestamp) if isinstance(self.value_updated_timestamp, date) else self.value_updated_timestamp + + # These specify common views into this object: + # - 1. If this row was returned by an API request + self._api_row_ignore_fields = ["id", "direction", "direction_updated_timestamp", "value_updated_timestamp"] + # - 2. If this row was returned by an old API request (PHP server) + self._api_row_compatibility_ignore_fields = ["id", "direction", "direction_updated_timestamp", "value_updated_timestamp", "source"] + # - 3. If this row was returned by the database. + self._db_row_ignore_fields = [] + + def _sanity_check_fields(self, extra_checks: bool = True): + if self.issue and self.issue < self.time_value: + self.issue = self.time_value + + if self.issue: + self.lag = (time_value_to_date(self.issue) - time_value_to_date(self.time_value)).days + else: + self.lag = None + + # This sanity checking is already done in CsvImporter, but it's here so the testing class gets it too. + if _is_none(self.value) and self.missing_value == Nans.NOT_MISSING: + self.missing_value = Nans.NOT_APPLICABLE.value if extra_checks else Nans.OTHER.value + + if _is_none(self.stderr) and self.missing_stderr == Nans.NOT_MISSING: + self.missing_stderr = Nans.NOT_APPLICABLE.value if extra_checks else Nans.OTHER.value + + if _is_none(self.sample_size) and self.missing_sample_size == Nans.NOT_MISSING: + self.missing_sample_size = Nans.NOT_APPLICABLE.value if extra_checks else Nans.OTHER.value + + return self + + @staticmethod + def fromCsvRowValue(row_value: Optional[CsvImporter.RowValues], source: str, signal: str, time_type: str, geo_type: str, time_value: int, issue: int, lag: int): + if row_value is None: + return None + return CovidcastRow( + source, + signal, + time_type, + geo_type, + time_value, + row_value.geo_value, + row_value.value, + row_value.stderr, + row_value.sample_size, + row_value.missing_value, + row_value.missing_stderr, + row_value.missing_sample_size, + issue, + lag, + ) + + @staticmethod + def fromCsvRows(row_values: Iterable[Optional[CsvImporter.RowValues]], source: str, signal: str, time_type: str, geo_type: str, time_value: int, issue: int, lag: int): + # NOTE: returns a generator, as row_values is expected to be a generator + return (CovidcastRow.fromCsvRowValue(row_value, source, signal, time_type, geo_type, time_value, issue, lag) for row_value in row_values) + + @staticmethod + def from_json(json: Dict[str, Any]) -> "CovidcastRow": + return CovidcastRow( + source=json["source"], + signal=json["signal"], + time_type=json["time_type"], + geo_type=json["geo_type"], + geo_value=json["geo_value"], + issue=json["issue"], + lag=json["lag"], + value=json["value"], + stderr=json["stderr"], + sample_size=json["sample_size"], + missing_value=json["missing_value"], + missing_stderr=json["missing_stderr"], + missing_sample_size=json["missing_sample_size"], + ) + + def as_dict(self, ignore_fields: Optional[List[str]] = None) -> dict: + d = asdict(self) + if ignore_fields: + for key in ignore_fields: + del d[key] + return d + + def as_dataframe(self, ignore_fields: Optional[List[str]] = None) -> DataFrame: + return DataFrame.from_records([self.as_dict(ignore_fields=ignore_fields)]) + + @property + def api_row_df(self) -> DataFrame: + """Returns a dataframe view into the row with the fields returned by the API server.""" + return self.as_dataframe(ignore_fields=self._api_row_ignore_fields) + + @property + def api_compatibility_row_df(self) -> DataFrame: + """Returns a dataframe view into the row with the fields returned by the old API server (the PHP server).""" + return self.as_dataframe(ignore_fields=self._api_row_compatibility_ignore_fields) + + @property + def db_row_df(self) -> DataFrame: + """Returns a dataframe view into the row with the fields returned by an all-field database query.""" + return self.as_dataframe(ignore_fields=self._db_row_ignore_fields) + + @property + def signal_pair(self): + return f"{self.source}:{self.signal}" + + @property + def geo_pair(self): + return f"{self.geo_type}:{self.geo_value}" + + @property + def time_pair(self): + return f"{self.time_type}:{self.time_value}" + +# TODO: Deprecate this class in favor of functions over the List[CovidcastRow] datatype. +# All the inner variables of this class are derived from the CovidcastRow class. +@dataclass +class CovidcastRows: + rows: List[CovidcastRow] = field(default_factory=list) + + def __post_init__(self): + # These specify common views into this object: + # - 1. If this row was returned by an API request + self._api_row_ignore_fields = CovidcastRow()._api_row_ignore_fields + # - 2. If this row was returned by an old API request (PHP server) + self._api_row_compatibility_ignore_fields = CovidcastRow()._api_row_compatibility_ignore_fields + # - 3. If this row was returned by the database. + self._db_row_ignore_fields = CovidcastRow()._db_row_ignore_fields + + # Used to create a consistent DataFrame for tests. + dtypes = {field.name: field.type if get_origin(field.type) is not Union else get_args(field.type)[0] for field in fields(CovidcastRow)} + # Sometimes the int fields have None values, so we expand their scope using pandas.Int64DType. + self._DTYPES = {key: value if value is not int else "Int64" for key, value in dtypes.items()} + + @staticmethod + def from_args(sanity_check: bool = True, test_mode: bool = True, **kwargs: Dict[str, Iterable]): + """A convenience constructor. + + Handy for constructing batches of test cases. + + Example: + CovidcastRows.from_args(value=[1, 2, 3], time_value=[1, 2, 3]) will yield + CovidcastRows(rows=[CovidcastRow(value=1, time_value=1), CovidcastRow(value=2, time_value=2), CovidcastRow(value=3, time_value=3)]) + with all the defaults from CovidcastRow. + """ + # All the args must be fields of CovidcastRow. + assert set(kwargs.keys()) <= set(field.name for field in fields(CovidcastRow)) + + # If any iterables were passed instead of lists, convert them to lists. + kwargs = {key: list(value) for key, value in kwargs.items()} + + # All the arg values must be lists of the same length. + assert len(set(len(lst) for lst in kwargs.values())) == 1 + + return CovidcastRows(rows=[CovidcastRow(**_kwargs)._sanity_check_fields(extra_checks=test_mode) if sanity_check else CovidcastRow(**_kwargs) for _kwargs in transpose_dict(kwargs)]) + + @staticmethod + def from_records(records: Iterable[dict], sanity_check: bool = False): + """A convenience constructor. + + Default is different from from_args, because from_records is usually called on faux-API returns in tests, + where we don't want any values getting default filled in. + """ + records = list(records) + assert set().union(*[record.keys() for record in records]) <= set(field.name for field in fields(CovidcastRow)) + + return CovidcastRows(rows=[CovidcastRow(**record) if not sanity_check else CovidcastRow(**record)._sanity_check_fields() for record in records]) + + def as_dicts(self, ignore_fields: Optional[List[str]] = None) -> List[dict]: + return [row.as_dict(ignore_fields=ignore_fields) for row in self.rows] + + def as_dataframe(self, ignore_fields: Optional[List[str]] = None) -> DataFrame: + if ignore_fields is None: + ignore_fields = [] + columns = [field.name for field in fields(CovidcastRow) if field.name not in ignore_fields] + if self.rows: + df = concat([row.as_dataframe(ignore_fields=ignore_fields) for row in self.rows], ignore_index=True) + df = set_df_dtypes(df, self._DTYPES) + return df[columns] + else: + return DataFrame(columns=columns) + + @property + def api_row_df(self) -> DataFrame: + return self.as_dataframe(ignore_fields=self._api_row_ignore_fields) + + @property + def api_compatibility_row_df(self) -> DataFrame: + return self.as_dataframe(ignore_fields=self._api_row_compatibility_ignore_fields) + + @property + def db_row_df(self) -> DataFrame: + return self.as_dataframe(ignore_fields=self._db_row_ignore_fields) + + +def transpose_dict(d: Dict[Any, List[Any]]) -> List[Dict[Any, Any]]: + """Given a dictionary whose values are lists of the same length, turn it into a list of dictionaries whose values are the individual list entries. + + Example: + >>> transpose_dict(dict([["a", [2, 4, 6]], ["b", [3, 5, 7]], ["c", [10, 20, 30]]])) + [{"a": 2, "b": 3, "c": 10}, {"a": 4, "b": 5, "c": 20}, {"a": 6, "b": 7, "c": 30}] + """ + return [dict(zip(d.keys(), values)) for values in zip(*d.values())] + + +def set_df_dtypes(df: DataFrame, dtypes: Dict[str, Any]) -> DataFrame: + """Set the dataframe column datatypes. + df: pd.DataFrame + The dataframe to change. + dtypes: Dict[str, Any] + The keys of the dict are columns and the values are either types or Pandas + string aliases for types. Not all columns are required. + """ + assert all(isinstance(e, type) or isinstance(e, str) for e in dtypes.values()), "Values must be types or Pandas string aliases for types." + + df = df.copy() + for k, v in dtypes.items(): + if k in df.columns: + df[k] = df[k].astype(v) + return df diff --git a/tests/acquisition/covidcast/test_covidcast_row.py b/tests/acquisition/covidcast/test_covidcast_row.py new file mode 100644 index 000000000..969b521b9 --- /dev/null +++ b/tests/acquisition/covidcast/test_covidcast_row.py @@ -0,0 +1,90 @@ +import unittest + +from pandas import DataFrame, date_range +from pandas.testing import assert_frame_equal + +from delphi_utils.nancodes import Nans +from delphi.epidata.server.utils.dates import date_to_time_value +from delphi.epidata.acquisition.covidcast.covidcast_row import set_df_dtypes, transpose_dict, CovidcastRow, CovidcastRows + +class TestCovidcastRows(unittest.TestCase): + def test_transpose_dict(self): + assert transpose_dict(dict([["a", [2, 4, 6]], ["b", [3, 5, 7]], ["c", [10, 20, 30]]])) == [{"a": 2, "b": 3, "c": 10}, {"a": 4, "b": 5, "c": 20}, {"a": 6, "b": 7, "c": 30}] + + def test_CovidcastRow(self): + df = CovidcastRow(value=5.0).api_row_df + expected_df = DataFrame.from_records([{ + "source": "src", + "signal": "sig", + "time_type": "day", + "geo_type": "county", + "time_value": 20200202, + "geo_value": "01234", + "value": 5.0, + "stderr": 10.0, + "sample_size": 10.0, + "missing_value": Nans.NOT_MISSING, + "missing_stderr": Nans.NOT_MISSING, + "missing_sample_size": Nans.NOT_MISSING, + "issue": 20200202, + "lag": 0, + }]) + assert_frame_equal(df, expected_df) + + df = CovidcastRow(value=5.0).api_compatibility_row_df + expected_df = DataFrame.from_records([{ + "signal": "sig", + "time_type": "day", + "geo_type": "county", + "time_value": 20200202, + "geo_value": "01234", + "value": 5.0, + "stderr": 10.0, + "sample_size": 10.0, + "missing_value": Nans.NOT_MISSING, + "missing_stderr": Nans.NOT_MISSING, + "missing_sample_size": Nans.NOT_MISSING, + "issue": 20200202, + "lag": 0, + }]) + assert_frame_equal(df, expected_df) + + def test_CovidcastRows(self): + df = CovidcastRows.from_args(signal=["sig_base"] * 5 + ["sig_other"] * 5, time_value=date_range("2021-05-01", "2021-05-05").to_list() * 2, value=list(range(10))).api_row_df + expected_df = set_df_dtypes(DataFrame({ + "source": ["src"] * 10, + "signal": ["sig_base"] * 5 + ["sig_other"] * 5, + "time_type": ["day"] * 10, + "geo_type": ["county"] * 10, + "time_value": map(date_to_time_value, date_range("2021-05-01", "2021-05-5").to_list() * 2), + "geo_value": ["01234"] * 10, + "value": range(10), + "stderr": [10.0] * 10, + "sample_size": [10.0] * 10, + "missing_value": [Nans.NOT_MISSING] * 10, + "missing_stderr": [Nans.NOT_MISSING] * 10, + "missing_sample_size": [Nans.NOT_MISSING] * 10, + "issue": map(date_to_time_value, date_range("2021-05-01", "2021-05-5").to_list() * 2), + "lag": [0] * 10, + }), CovidcastRows()._DTYPES) + assert_frame_equal(df, expected_df) + + df = CovidcastRows.from_args( + signal=["sig_base"] * 5 + ["sig_other"] * 5, time_value=date_range("2021-05-01", "2021-05-05").to_list() * 2, value=list(range(10)) + ).api_compatibility_row_df + expected_df = set_df_dtypes(DataFrame({ + "signal": ["sig_base"] * 5 + ["sig_other"] * 5, + "time_type": ["day"] * 10, + "geo_type": ["county"] * 10, + "time_value": map(date_to_time_value, date_range("2021-05-01", "2021-05-5").to_list() * 2), + "geo_value": ["01234"] * 10, + "value": range(10), + "stderr": [10.0] * 10, + "sample_size": [10.0] * 10, + "missing_value": [Nans.NOT_MISSING] * 10, + "missing_stderr": [Nans.NOT_MISSING] * 10, + "missing_sample_size": [Nans.NOT_MISSING] * 10, + "issue": map(date_to_time_value, date_range("2021-05-01", "2021-05-5").to_list() * 2), + "lag": [0] * 10, + }), CovidcastRows()._DTYPES) + assert_frame_equal(df, expected_df) From 4f8c34655d2185cfa477e368d32a1ce2a77ed572 Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Fri, 7 Oct 2022 15:33:05 -0700 Subject: [PATCH 02/29] Server: update csv_to_database to use CovidcastRow --- src/acquisition/covidcast/csv_to_database.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/acquisition/covidcast/csv_to_database.py b/src/acquisition/covidcast/csv_to_database.py index 34cbad663..0abe53f1f 100644 --- a/src/acquisition/covidcast/csv_to_database.py +++ b/src/acquisition/covidcast/csv_to_database.py @@ -7,7 +7,8 @@ # first party from delphi.epidata.acquisition.covidcast.csv_importer import CsvImporter -from delphi.epidata.acquisition.covidcast.database import Database, CovidcastRow, DBLoadStateException +from delphi.epidata.acquisition.covidcast.covidcast_row import CovidcastRow +from delphi.epidata.acquisition.covidcast.database import Database, DBLoadStateException from delphi.epidata.acquisition.covidcast.file_archiver import FileArchiver from delphi.epidata.acquisition.covidcast.logger import get_structured_logger From c1191529917904cfffd2b3ca2e7f6a9b8a94dfd5 Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Fri, 7 Oct 2022 15:31:46 -0700 Subject: [PATCH 03/29] Server: update test_db to use CovidcastRow --- integrations/acquisition/covidcast/test_db.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/integrations/acquisition/covidcast/test_db.py b/integrations/acquisition/covidcast/test_db.py index 3cd7e91a7..5daf8d272 100644 --- a/integrations/acquisition/covidcast/test_db.py +++ b/integrations/acquisition/covidcast/test_db.py @@ -1,10 +1,11 @@ -import unittest - from delphi_utils import Nans -from delphi.epidata.acquisition.covidcast.database import Database, CovidcastRow, DBLoadStateException + +from delphi.epidata.acquisition.covidcast.database import DBLoadStateException +from delphi.epidata.acquisition.covidcast.covidcast_row import CovidcastRow from delphi.epidata.acquisition.covidcast.test_utils import CovidcastBase import delphi.operations.secrets as secrets + # all the Nans we use here are just one value, so this is a shortcut to it: nmv = Nans.NOT_MISSING.value @@ -31,8 +32,8 @@ def _find_matches_for_row(self, row): def test_insert_or_update_with_nonempty_load_table(self): # make rows - a_row = self._make_placeholder_row()[0] - another_row = self._make_placeholder_row(time_value=self.DEFAULT_TIME_VALUE+1, issue=self.DEFAULT_ISSUE+1)[0] + a_row = CovidcastRow(time_value=20200202) + another_row = CovidcastRow(time_value=20200203, issue=20200203) # insert one self._db.insert_or_update_bulk([a_row]) # put something into the load table @@ -61,7 +62,7 @@ def test_id_sync(self): latest_view = 'epimetric_latest_v' # add a data point - base_row, _ = self._make_placeholder_row() + base_row = CovidcastRow() self._insert_rows([base_row]) # ensure the primary keys match in the latest and history tables matches = self._find_matches_for_row(base_row) @@ -71,7 +72,7 @@ def test_id_sync(self): old_pk_id = matches[latest_view][pk_column] # add a reissue for said data point - next_row, _ = self._make_placeholder_row() + next_row = CovidcastRow() next_row.issue += 1 self._insert_rows([next_row]) # ensure the new keys also match From 565aad4920cd62c9e983a8a6cbaf74d4400a5184 Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Fri, 7 Oct 2022 15:32:06 -0700 Subject: [PATCH 04/29] Server: update test_delete_batch to use CovidcastRow --- integrations/acquisition/covidcast/test_delete_batch.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/integrations/acquisition/covidcast/test_delete_batch.py b/integrations/acquisition/covidcast/test_delete_batch.py index 915c9341b..15ae7e2e2 100644 --- a/integrations/acquisition/covidcast/test_delete_batch.py +++ b/integrations/acquisition/covidcast/test_delete_batch.py @@ -5,13 +5,10 @@ import unittest from os import path -# third party -import mysql.connector - # first party -from delphi_utils import Nans -from delphi.epidata.acquisition.covidcast.database import Database, CovidcastRow import delphi.operations.secrets as secrets +from delphi.epidata.acquisition.covidcast.database import Database +from delphi.epidata.acquisition.covidcast.covidcast_row import CovidcastRow # py3tester coverage target (equivalent to `import *`) __test_target__ = 'delphi.epidata.acquisition.covidcast.database' From c4e5675859db0e25d45623a75966e6b0bd16e9a3 Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Fri, 7 Oct 2022 15:32:20 -0700 Subject: [PATCH 05/29] Server: update test_delphi_epidata to use CovidcastRow --- integrations/client/test_delphi_epidata.py | 110 ++++++++++----------- 1 file changed, 55 insertions(+), 55 deletions(-) diff --git a/integrations/client/test_delphi_epidata.py b/integrations/client/test_delphi_epidata.py index 625d2859d..cfeb83bd4 100644 --- a/integrations/client/test_delphi_epidata.py +++ b/integrations/client/test_delphi_epidata.py @@ -1,26 +1,28 @@ """Integration tests for delphi_epidata.py.""" # standard library -import unittest import time -from unittest.mock import patch, MagicMock from json import JSONDecodeError +from unittest.mock import MagicMock, patch -# third party -from aiohttp.client_exceptions import ClientResponseError -import mysql.connector +# first party import pytest +from aiohttp.client_exceptions import ClientResponseError -# first party -from delphi_utils import Nans -from delphi.epidata.client.delphi_epidata import Epidata -from delphi.epidata.acquisition.covidcast.database import Database, CovidcastRow +# third party +import delphi.operations.secrets as secrets from delphi.epidata.acquisition.covidcast.covidcast_meta_cache_updater import main as update_covidcast_meta_cache +from delphi.epidata.acquisition.covidcast.covidcast_row import CovidcastRow from delphi.epidata.acquisition.covidcast.test_utils import CovidcastBase -import delphi.operations.secrets as secrets +from delphi.epidata.client.delphi_epidata import Epidata +from delphi_utils import Nans + # py3tester coverage target __test_target__ = 'delphi.epidata.client.delphi_epidata' +# all the Nans we use here are just one value, so this is a shortcut to it: +nmv = Nans.NOT_MISSING.value +IGNORE_FIELDS = ["id", "direction_updated_timestamp", "value_updated_timestamp", "source", "time_type", "geo_type"] def fake_epidata_endpoint(func): """This can be used as a decorator to enable a bogus Epidata endpoint to return 404 responses.""" @@ -30,9 +32,6 @@ def wrapper(*args): Epidata.BASE_URL = 'http://delphi_web_epidata/epidata/api.php' return wrapper -# all the Nans we use here are just one value, so this is a shortcut to it: -nmv = Nans.NOT_MISSING.value - class DelphiEpidataPythonClientTests(CovidcastBase): """Tests the Python client.""" @@ -54,12 +53,12 @@ def test_covidcast(self): # insert placeholder data: three issues of one signal, one issue of another rows = [ - self._make_placeholder_row(issue=self.DEFAULT_ISSUE + i, value=i, lag=i)[0] + CovidcastRow(issue=20200202 + i, value=i, lag=i) for i in range(3) ] row_latest_issue = rows[-1] rows.append( - self._make_placeholder_row(signal="sig2")[0] + CovidcastRow(signal="sig2") ) self._insert_rows(rows) @@ -70,10 +69,11 @@ def test_covidcast(self): ) expected = [ - self.expected_from_row(row_latest_issue), - self.expected_from_row(rows[-1]) + row_latest_issue.as_dict(ignore_fields=IGNORE_FIELDS), + rows[-1].as_dict(ignore_fields=IGNORE_FIELDS) ] + self.assertEqual(response['epidata'], expected) # check result self.assertEqual(response, { 'result': 1, @@ -89,10 +89,10 @@ def test_covidcast(self): expected = [{ rows[0].signal: [ - self.expected_from_row(row_latest_issue, self.DEFAULT_MINUS + ['signal']), + row_latest_issue.as_dict(ignore_fields=IGNORE_FIELDS + ['signal']), ], rows[-1].signal: [ - self.expected_from_row(rows[-1], self.DEFAULT_MINUS + ['signal']), + rows[-1].as_dict(ignore_fields=IGNORE_FIELDS + ['signal']), ], }] @@ -109,12 +109,12 @@ def test_covidcast(self): **self.params_from_row(rows[0]) ) - expected = self.expected_from_row(row_latest_issue) + expected = [row_latest_issue.as_dict(ignore_fields=IGNORE_FIELDS)] # check result self.assertEqual(response_1, { 'result': 1, - 'epidata': [expected], + 'epidata': expected, 'message': 'success', }) @@ -124,13 +124,13 @@ def test_covidcast(self): **self.params_from_row(rows[0], as_of=rows[1].issue) ) - expected = self.expected_from_row(rows[1]) + expected = [rows[1].as_dict(ignore_fields=IGNORE_FIELDS)] # check result self.maxDiff=None self.assertEqual(response_1a, { 'result': 1, - 'epidata': [expected], + 'epidata': expected, 'message': 'success', }) @@ -141,8 +141,8 @@ def test_covidcast(self): ) expected = [ - self.expected_from_row(rows[0]), - self.expected_from_row(rows[1]) + rows[0].as_dict(ignore_fields=IGNORE_FIELDS), + rows[1].as_dict(ignore_fields=IGNORE_FIELDS) ] # check result @@ -158,12 +158,12 @@ def test_covidcast(self): **self.params_from_row(rows[0], lag=2) ) - expected = self.expected_from_row(row_latest_issue) + expected = [row_latest_issue.as_dict(ignore_fields=IGNORE_FIELDS)] # check result self.assertDictEqual(response_3, { 'result': 1, - 'epidata': [expected], + 'epidata': expected, 'message': 'success', }) with self.subTest(name='long request'): @@ -223,16 +223,16 @@ def test_geo_value(self): # insert placeholder data: three counties, three MSAs N = 3 rows = [ - self._make_placeholder_row(geo_type="county", geo_value=str(i)*5, value=i)[0] + CovidcastRow(geo_type="county", geo_value=str(i)*5, value=i) for i in range(N) ] + [ - self._make_placeholder_row(geo_type="msa", geo_value=str(i)*5, value=i*10)[0] + CovidcastRow(geo_type="msa", geo_value=str(i)*5, value=i*10) for i in range(N) ] self._insert_rows(rows) counties = [ - self.expected_from_row(rows[i]) for i in range(N) + rows[i].as_dict(ignore_fields=IGNORE_FIELDS) for i in range(N) ] def fetch(geo): @@ -241,31 +241,31 @@ def fetch(geo): ) # test fetch all - r = fetch('*') - self.assertEqual(r['message'], 'success') - self.assertEqual(r['epidata'], counties) + request = fetch('*') + self.assertEqual(request['message'], 'success') + self.assertEqual(request['epidata'], counties) # test fetch a specific region - r = fetch('11111') - self.assertEqual(r['message'], 'success') - self.assertEqual(r['epidata'], [counties[1]]) + request = fetch('11111') + self.assertEqual(request['message'], 'success') + self.assertEqual(request['epidata'], [counties[1]]) # test fetch a specific yet not existing region - r = fetch('55555') - self.assertEqual(r['message'], 'no results') + request = fetch('55555') + self.assertEqual(request['message'], 'no results') # test fetch a multiple regions - r = fetch(['11111', '22222']) - self.assertEqual(r['message'], 'success') - self.assertEqual(r['epidata'], [counties[1], counties[2]]) + request = fetch(['11111', '22222']) + self.assertEqual(request['message'], 'success') + self.assertEqual(request['epidata'], [counties[1], counties[2]]) # test fetch a multiple regions in another variant - r = fetch(['00000', '22222']) - self.assertEqual(r['message'], 'success') - self.assertEqual(r['epidata'], [counties[0], counties[2]]) + request = fetch(['00000', '22222']) + self.assertEqual(request['message'], 'success') + self.assertEqual(request['epidata'], [counties[0], counties[2]]) # test fetch a multiple regions but one is not existing - r = fetch(['11111', '55555']) - self.assertEqual(r['message'], 'success') - self.assertEqual(r['epidata'], [counties[1]]) + request = fetch(['11111', '55555']) + self.assertEqual(request['message'], 'success') + self.assertEqual(request['epidata'], [counties[1]]) # test fetch a multiple regions but specify no region - r = fetch([]) - self.assertEqual(r['message'], 'no results') + request = fetch([]) + self.assertEqual(request['message'], 'no results') def test_covidcast_meta(self): """Test that the covidcast_meta endpoint returns expected data.""" @@ -275,7 +275,7 @@ def test_covidcast_meta(self): # 2nd issue: 1 11 21 # 3rd issue: 2 12 22 rows = [ - self._make_placeholder_row(time_value=self.DEFAULT_TIME_VALUE + t, issue=self.DEFAULT_ISSUE + i, value=t*10 + i)[0] + CovidcastRow(time_value=2020_02_02 + t, issue=2020_02_02 + i, value=t*10 + i) for i in range(3) for t in range(3) ] self._insert_rows(rows) @@ -299,14 +299,14 @@ def test_covidcast_meta(self): signal=rows[0].signal, time_type=rows[0].time_type, geo_type=rows[0].geo_type, - min_time=self.DEFAULT_TIME_VALUE, - max_time=self.DEFAULT_TIME_VALUE + 2, + min_time=2020_02_02, + max_time=2020_02_02 + 2, num_locations=1, min_value=2., mean_value=12., max_value=22., stdev_value=8.1649658, # population stdev, not sample, which is 10. - max_issue=self.DEFAULT_ISSUE + 2, + max_issue=2020_02_02 + 2, min_lag=0, max_lag=0, # we didn't set lag when inputting data ) @@ -322,10 +322,10 @@ def test_async_epidata(self): # insert placeholder data: three counties, three MSAs N = 3 rows = [ - self._make_placeholder_row(geo_type="county", geo_value=str(i)*5, value=i)[0] + CovidcastRow(geo_type="county", geo_value=str(i)*5, value=i) for i in range(N) ] + [ - self._make_placeholder_row(geo_type="msa", geo_value=str(i)*5, value=i*10)[0] + CovidcastRow(geo_type="msa", geo_value=str(i)*5, value=i*10) for i in range(N) ] self._insert_rows(rows) From 166802a8cfa3659c49ec6efad943c64c9caa6cc9 Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Fri, 7 Oct 2022 15:32:32 -0700 Subject: [PATCH 06/29] Server: update test_covidcast_endpoints to use CovidcastRow --- .../server/test_covidcast_endpoints.py | 39 ++++++++++++++++--- 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/integrations/server/test_covidcast_endpoints.py b/integrations/server/test_covidcast_endpoints.py index 54974a874..2d342db6f 100644 --- a/integrations/server/test_covidcast_endpoints.py +++ b/integrations/server/test_covidcast_endpoints.py @@ -1,7 +1,9 @@ """Integration tests for the custom `covidcast/*` endpoints.""" # standard library -from typing import Iterable, Dict, Any +from copy import copy +from itertools import accumulate, chain +from typing import Iterable, Dict, Any, List, Sequence import unittest from io import StringIO @@ -10,17 +12,21 @@ # third party import mysql.connector +from more_itertools import interleave_longest, windowed import requests import pandas as pd +import numpy as np from delphi_utils import Nans from delphi.epidata.acquisition.covidcast.covidcast_meta_cache_updater import main as update_cache +from delphi.epidata.server.endpoints.covidcast_utils.model import DataSignal, DataSource from delphi.epidata.acquisition.covidcast.database import Database from delphi.epidata.acquisition.covidcast.test_utils import CovidcastBase # use the local instance of the Epidata API BASE_URL = "http://delphi_web_epidata/epidata/covidcast" +BASE_URL_OLD = "http://delphi_web_epidata/epidata/api.php" class CovidcastEndpointTests(CovidcastBase): @@ -32,15 +38,25 @@ def localSetUp(self): # reset the `covidcast_meta_cache` table (it should always have one row) self._db._cursor.execute('update covidcast_meta_cache set timestamp = 0, epidata = "[]"') - def _fetch(self, endpoint="/", **params): + def _fetch(self, endpoint="/", is_compatibility=False, **params): # make the request - response = requests.get( - f"{BASE_URL}{endpoint}", - params=params, - ) + if is_compatibility: + url = BASE_URL_OLD + params.setdefault("endpoint", "covidcast") + if params.get("source"): + params.setdefault("data_source", params.get("source")) + else: + url = f"{BASE_URL}{endpoint}" + response = requests.get(url, params=params) response.raise_for_status() return response.json() + def _diff_rows(self, rows: Sequence[float]): + return [float(x - y) if x is not None and y is not None else None for x, y in zip(rows[1:], rows[:-1])] + + def _smooth_rows(self, rows: Sequence[float]): + return [sum(e)/len(e) if None not in e else None for e in windowed(rows, 7)] + def test_basic(self): """Request a signal from the / endpoint.""" @@ -56,6 +72,17 @@ def test_basic(self): out = self._fetch("/", signal=first.signal_pair(), geo=first.geo_pair(), time="day:*") self.assertEqual(len(out["epidata"]), len(rows)) + with self.subTest("unknown signal"): + rows = [CovidcastRow(source="jhu-csse", signal="confirmed_unknown", time_value=20200401 + i, value=i) for i in range(10)] + first = rows[0] + self._insert_rows(rows) + + out = self._fetch("/", signal="jhu-csse:confirmed_unknown", geo=first.geo_pair, time="day:*") + self.assertEqual(len(out["epidata"]), len(rows)) + out_values = [row["value"] for row in out["epidata"]] + expected_values = [float(row.value) for row in rows] + self.assertEqual(out_values, expected_values) + def test_trend(self): """Request a signal from the /trend endpoint.""" From 62dd20f0d8f76a7bde2d70556cd6277fdfd8fa8a Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Fri, 7 Oct 2022 15:32:40 -0700 Subject: [PATCH 07/29] Server: update test_covidcast to use CovidcastRow --- integrations/server/test_covidcast.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/integrations/server/test_covidcast.py b/integrations/server/test_covidcast.py index 86ce0c53d..aaebe04d6 100644 --- a/integrations/server/test_covidcast.py +++ b/integrations/server/test_covidcast.py @@ -358,22 +358,22 @@ def test_date_formats(self): response, expected = self.request_based_on_row(rows[0], time_values="20000102,20000103", geo_value="*") # assert that the right data came back - self.assertEqual(len(response['epidata']), 4) + self.assertEqual(len(response['epidata']), 2 * 2) # make the request response, expected = self.request_based_on_row(rows[0], time_values="2000-01-02,2000-01-03", geo_value="*") # assert that the right data came back - self.assertEqual(len(response['epidata']), 4) + self.assertEqual(len(response['epidata']), 2 * 2) # make the request response, expected = self.request_based_on_row(rows[0], time_values="20000102-20000104", geo_value="*") # assert that the right data came back - self.assertEqual(len(response['epidata']), 6) + self.assertEqual(len(response['epidata']), 2 * 3) # make the request response, expected = self.request_based_on_row(rows[0], time_values="2000-01-02:2000-01-04", geo_value="*") # assert that the right data came back - self.assertEqual(len(response['epidata']), 6) + self.assertEqual(len(response['epidata']), 2 * 3) From 2d68be7b420a08fc711d1df66e249ac301538582 Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Fri, 7 Oct 2022 15:33:14 -0700 Subject: [PATCH 08/29] Server: update test_utils to use CovidcastRow --- src/acquisition/covidcast/test_utils.py | 44 ++++++------------------- 1 file changed, 10 insertions(+), 34 deletions(-) diff --git a/src/acquisition/covidcast/test_utils.py b/src/acquisition/covidcast/test_utils.py index 181dfac68..45f9fbfd0 100644 --- a/src/acquisition/covidcast/test_utils.py +++ b/src/acquisition/covidcast/test_utils.py @@ -1,7 +1,9 @@ +from typing import Sequence import unittest from delphi_utils import Nans -from delphi.epidata.acquisition.covidcast.database import Database, CovidcastRow +from delphi.epidata.acquisition.covidcast.covidcast_row import CovidcastRow +from delphi.epidata.acquisition.covidcast.database import Database import delphi.operations.secrets as secrets # all the Nans we use here are just one value, so this is a shortcut to it: @@ -31,36 +33,20 @@ def tearDown(self): # close and destroy conenction to the database self._db.disconnect(False) del self._db + self.localTearDown() - DEFAULT_TIME_VALUE=2000_01_01 - DEFAULT_ISSUE=2000_01_01 - def _make_placeholder_row(self, **kwargs): - settings = { - 'source': 'src', - 'signal': 'sig', - 'geo_type': 'state', - 'geo_value': 'pa', - 'time_type': 'day', - 'time_value': self.DEFAULT_TIME_VALUE, - 'value': 0.0, - 'stderr': 1.0, - 'sample_size': 2.0, - 'missing_value': nmv, - 'missing_stderr': nmv, - 'missing_sample_size': nmv, - 'issue': self.DEFAULT_ISSUE, - 'lag': 0 - } - settings.update(kwargs) - return (CovidcastRow(**settings), settings) + def localTearDown(self): + # stub; override in subclasses to perform custom teardown. + # runs after database changes have been committed + pass - def _insert_rows(self, rows): + def _insert_rows(self, rows: Sequence[CovidcastRow]): # inserts rows into the database using the full acquisition process, including 'dbjobs' load into history & latest tables n = self._db.insert_or_update_bulk(rows) print(f"{n} rows added to load table & dispatched to v4 schema") self._db._connection.commit() # NOTE: this isnt expressly needed for our test cases, but would be if using external access (like through client lib) to ensure changes are visible outside of this db session - def params_from_row(self, row, **kwargs): + def params_from_row(self, row: CovidcastRow, **kwargs): ret = { 'data_source': row.source, 'signals': row.signal, @@ -71,13 +57,3 @@ def params_from_row(self, row, **kwargs): } ret.update(kwargs) return ret - - DEFAULT_MINUS=['time_type', 'geo_type', 'source'] - def expected_from_row(self, row, minus=DEFAULT_MINUS): - expected = dict(vars(row)) - # remove columns commonly excluded from output - # nb may need to add source or *_type back in for multiplexed queries - for key in ['id', 'direction_updated_timestamp'] + minus: - del expected[key] - return expected - From 9c53fa038ce49f015665902182ca83a33fea5497 Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Fri, 7 Oct 2022 15:49:10 -0700 Subject: [PATCH 09/29] Server: update test_covidcast to use CovidcastRow --- integrations/server/test_covidcast.py | 131 ++++++++++++-------------- 1 file changed, 62 insertions(+), 69 deletions(-) diff --git a/integrations/server/test_covidcast.py b/integrations/server/test_covidcast.py index aaebe04d6..c003f8d07 100644 --- a/integrations/server/test_covidcast.py +++ b/integrations/server/test_covidcast.py @@ -1,7 +1,7 @@ """Integration tests for the `covidcast` endpoint.""" # standard library -import json +from typing import Callable import unittest # third party @@ -10,12 +10,13 @@ # first party from delphi_utils import Nans +from delphi.epidata.acquisition.covidcast.covidcast_row import CovidcastRow from delphi.epidata.acquisition.covidcast.test_utils import CovidcastBase # use the local instance of the Epidata API +# TODO: should we still be using this? BASE_URL = 'http://delphi_web_epidata/epidata/api.php' - - +IGNORE_FIELDS = ["id", "direction_updated_timestamp", "value_updated_timestamp", "source", "time_type", "geo_type"] class CovidcastTests(CovidcastBase): """Tests the `covidcast` endpoint.""" @@ -24,28 +25,26 @@ def localSetUp(self): """Perform per-test setup.""" self._db._cursor.execute('update covidcast_meta_cache set timestamp = 0, epidata = "[]"') - def request_based_on_row(self, row, extract_response=lambda x: x.json(), **kwargs): + def request_based_on_row(self, row: CovidcastRow, extract_response: Callable = lambda x: x.json(), **kwargs): params = self.params_from_row(row, endpoint='covidcast', **kwargs) response = requests.get(BASE_URL, params=params) response.raise_for_status() response = extract_response(response) - expected = self.expected_from_row(row) - - return response, expected + return response def _insert_placeholder_set_one(self): - row, settings = self._make_placeholder_row() + row = CovidcastRow() self._insert_rows([row]) return row def _insert_placeholder_set_two(self): rows = [ - self._make_placeholder_row(geo_type='county', geo_value=str(i)*5, value=i*1., stderr=i*10., sample_size=i*100.)[0] + CovidcastRow(geo_type='county', geo_value=str(i)*5, value=i*1., stderr=i*10., sample_size=i*100.) for i in [1, 2, 3] ] + [ # geo value intended to overlap with counties above - self._make_placeholder_row(geo_type='msa', geo_value=str(i-3)*5, value=i*1., stderr=i*10., sample_size=i*100.)[0] + CovidcastRow(geo_type='msa', geo_value=str(i-3)*5, value=i*1., stderr=i*10., sample_size=i*100.) for i in [4, 5, 6] ] self._insert_rows(rows) @@ -53,11 +52,11 @@ def _insert_placeholder_set_two(self): def _insert_placeholder_set_three(self): rows = [ - self._make_placeholder_row(geo_type='county', geo_value='11111', time_value=2000_01_01+i, value=i*1., stderr=i*10., sample_size=i*100., issue=2000_01_03, lag=2-i)[0] + CovidcastRow(geo_type='county', geo_value='11111', time_value=2000_01_01+i, value=i*1., stderr=i*10., sample_size=i*100., issue=2000_01_03, lag=2-i) for i in [1, 2, 3] ] + [ # time value intended to overlap with 11111 above, with disjoint geo values - self._make_placeholder_row(geo_type='county', geo_value=str(i)*5, time_value=2000_01_01+i-3, value=i*1., stderr=i*10., sample_size=i*100., issue=2000_01_03, lag=5-i)[0] + CovidcastRow(geo_type='county', geo_value=str(i)*5, time_value=2000_01_01+i-3, value=i*1., stderr=i*10., sample_size=i*100., issue=2000_01_03, lag=5-i) for i in [4, 5, 6] ] self._insert_rows(rows) @@ -70,10 +69,13 @@ def test_round_trip(self): row = self._insert_placeholder_set_one() # make the request - response, expected = self.request_based_on_row(row) + response = self.request_based_on_row(row) + + expected = [row.as_dict(ignore_fields=IGNORE_FIELDS)] + self.assertEqual(response, { 'result': 1, - 'epidata': [expected], + 'epidata': expected, 'message': 'success', }) @@ -130,32 +132,25 @@ def test_csv_format(self): # make the request # NB 'format' is a Python reserved word - response, _ = self.request_based_on_row( + response = self.request_based_on_row( row, extract_response=lambda resp: resp.text, **{'format':'csv'} ) - expected_response = ( - "geo_value,signal,time_value,direction,issue,lag,missing_value," + - "missing_stderr,missing_sample_size,value,stderr,sample_size\n" + - ",".join("" if x is None else str(x) for x in [ - row.geo_value, - row.signal, - row.time_value, - row.direction, - row.issue, - row.lag, - row.missing_value, - row.missing_stderr, - row.missing_sample_size, - row.value, - row.stderr, - row.sample_size - ]) + "\n" + + # TODO: This is a mess because of api.php. + column_order = [ + "geo_value", "signal", "time_value", "direction", "issue", "lag", "missing_value", + "missing_stderr", "missing_sample_size", "value", "stderr", "sample_size" + ] + expected = ( + row.api_compatibility_row_df + .assign(direction = None) + .to_csv(columns=column_order, index=False) ) # assert that the right data came back - self.assertEqual(response, expected_response) + self.assertEqual(response, expected) def test_raw_json_format(self): """Test generate raw json data.""" @@ -164,10 +159,12 @@ def test_raw_json_format(self): row = self._insert_placeholder_set_one() # make the request - response, expected = self.request_based_on_row(row, **{'format':'json'}) + response = self.request_based_on_row(row, **{'format':'json'}) + + expected = [row.as_dict(ignore_fields=IGNORE_FIELDS)] # assert that the right data came back - self.assertEqual(response, [expected]) + self.assertEqual(response, expected) def test_fields(self): """Test fields parameter""" @@ -176,7 +173,9 @@ def test_fields(self): row = self._insert_placeholder_set_one() # limit fields - response, expected = self.request_based_on_row(row, fields='time_value,geo_value') + response = self.request_based_on_row(row, fields='time_value,geo_value') + + expected = row.as_dict(ignore_fields=IGNORE_FIELDS) expected_all = { 'result': 1, 'epidata': [{ @@ -189,15 +188,14 @@ def test_fields(self): self.assertEqual(response, expected_all) # limit using invalid fields - response, _ = self.request_based_on_row(row, fields='time_value,geo_value,doesnt_exist') + response = self.request_based_on_row(row, fields='time_value,geo_value,doesnt_exist') # assert that the right data came back (only valid fields) self.assertEqual(response, expected_all) # limit exclude fields: exclude all except time_value and geo_value - - response, _ = self.request_based_on_row(row, fields=( + response = self.request_based_on_row(row, fields=( '-value,-stderr,-sample_size,-direction,-issue,-lag,-signal,' + '-missing_value,-missing_stderr,-missing_sample_size' )) @@ -210,18 +208,15 @@ def test_location_wildcard(self): # insert placeholder data rows = self._insert_placeholder_set_two() - expected_counties = [ - self.expected_from_row(r) for r in rows[:3] - ] - + expected = [row.as_dict(ignore_fields=IGNORE_FIELDS) for row in rows[:3]] # make the request - response, _ = self.request_based_on_row(rows[0], geo_value="*") + response = self.request_based_on_row(rows[0], geo_value="*") self.maxDiff = None # assert that the right data came back self.assertEqual(response, { 'result': 1, - 'epidata': expected_counties, + 'epidata': expected, 'message': 'success', }) @@ -230,35 +225,33 @@ def test_geo_value(self): # insert placeholder data rows = self._insert_placeholder_set_two() - expected_counties = [ - self.expected_from_row(r) for r in rows[:3] - ] + expected = [row.as_dict(ignore_fields=IGNORE_FIELDS) for row in rows[:3]] def fetch(geo_value): # make the request - response, _ = self.request_based_on_row(rows[0], geo_value=geo_value) + response = self.request_based_on_row(rows[0], geo_value=geo_value) return response # test fetch a specific region r = fetch('11111') self.assertEqual(r['message'], 'success') - self.assertEqual(r['epidata'], [expected_counties[0]]) + self.assertEqual(r['epidata'], expected[0:1]) # test fetch a specific yet not existing region r = fetch('55555') self.assertEqual(r['message'], 'no results') # test fetch multiple regions r = fetch('11111,22222') self.assertEqual(r['message'], 'success') - self.assertEqual(r['epidata'], [expected_counties[0], expected_counties[1]]) + self.assertEqual(r['epidata'], expected[0:2]) # test fetch multiple noncontiguous regions r = fetch('11111,33333') self.assertEqual(r['message'], 'success') - self.assertEqual(r['epidata'], [expected_counties[0], expected_counties[2]]) + self.assertEqual(r['epidata'], [expected[0], expected[2]]) # test fetch multiple regions but one is not existing r = fetch('11111,55555') self.assertEqual(r['message'], 'success') - self.assertEqual(r['epidata'], [expected_counties[0]]) + self.assertEqual(r['epidata'], expected[0:1]) # test fetch empty region r = fetch('') self.assertEqual(r['message'], 'no results') @@ -268,12 +261,10 @@ def test_location_timeline(self): # insert placeholder data rows = self._insert_placeholder_set_three() - expected_timeseries = [ - self.expected_from_row(r) for r in rows[:3] - ] + expected_timeseries = [row.as_dict(ignore_fields=IGNORE_FIELDS) for row in rows[:3]] # make the request - response, _ = self.request_based_on_row(rows[0], time_values='20000101-20000105') + response = self.request_based_on_row(rows[0], time_values='20000101-20000105') # assert that the right data came back self.assertEqual(response, { @@ -299,15 +290,16 @@ def test_unique_key_constraint(self): def test_nullable_columns(self): """Missing values should be surfaced as null.""" - row, _ = self._make_placeholder_row( + row = CovidcastRow( stderr=None, sample_size=None, missing_stderr=Nans.OTHER.value, missing_sample_size=Nans.OTHER.value ) self._insert_rows([row]) # make the request - response, expected = self.request_based_on_row(row) - expected.update(stderr=None, sample_size=None) + response = self.request_based_on_row(row) + expected = row.as_dict(ignore_fields=IGNORE_FIELDS) + # expected.update(stderr=None, sample_size=None) # assert that the right data came back self.assertEqual(response, { @@ -321,18 +313,19 @@ def test_temporal_partitioning(self): # insert placeholder data rows = [ - self._make_placeholder_row(time_type=tt)[0] + CovidcastRow(time_type=tt) for tt in "hour day week month year".split() ] self._insert_rows(rows) # make the request - response, expected = self.request_based_on_row(rows[1], time_values="0-99999999") + response = self.request_based_on_row(rows[1], time_values="20000101-30010201") + expected = [rows[1].as_dict(ignore_fields=IGNORE_FIELDS)] # assert that the right data came back self.assertEqual(response, { 'result': 1, - 'epidata': [expected], + 'epidata': expected, 'message': 'success', }) @@ -343,37 +336,37 @@ def test_date_formats(self): rows = self._insert_placeholder_set_three() # make the request - response, expected = self.request_based_on_row(rows[0], time_values="20000102", geo_value="*") + response = self.request_based_on_row(rows[0], time_values="20000102", geo_value="*") # assert that the right data came back self.assertEqual(len(response['epidata']), 2) # make the request - response, expected = self.request_based_on_row(rows[0], time_values="2000-01-02", geo_value="*") + response = self.request_based_on_row(rows[0], time_values="2000-01-02", geo_value="*") # assert that the right data came back self.assertEqual(len(response['epidata']), 2) # make the request - response, expected = self.request_based_on_row(rows[0], time_values="20000102,20000103", geo_value="*") + response = self.request_based_on_row(rows[0], time_values="20000102,20000103", geo_value="*") # assert that the right data came back self.assertEqual(len(response['epidata']), 2 * 2) # make the request - response, expected = self.request_based_on_row(rows[0], time_values="2000-01-02,2000-01-03", geo_value="*") + response = self.request_based_on_row(rows[0], time_values="2000-01-02,2000-01-03", geo_value="*") # assert that the right data came back self.assertEqual(len(response['epidata']), 2 * 2) # make the request - response, expected = self.request_based_on_row(rows[0], time_values="20000102-20000104", geo_value="*") + response = self.request_based_on_row(rows[0], time_values="20000102-20000104", geo_value="*") # assert that the right data came back self.assertEqual(len(response['epidata']), 2 * 3) # make the request - response, expected = self.request_based_on_row(rows[0], time_values="2000-01-02:2000-01-04", geo_value="*") + response = self.request_based_on_row(rows[0], time_values="2000-01-02:2000-01-04", geo_value="*") # assert that the right data came back self.assertEqual(len(response['epidata']), 2 * 3) From 9850fba1cf9efa2525c303b271759989f77f6e15 Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Fri, 7 Oct 2022 15:50:17 -0700 Subject: [PATCH 10/29] Server: update TimePair to auto-sort tuples --- src/server/_params.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/server/_params.py b/src/server/_params.py index fa4f63483..0e206e548 100644 --- a/src/server/_params.py +++ b/src/server/_params.py @@ -110,6 +110,10 @@ class TimePair: time_type: str time_values: Union[bool, Sequence[Union[int, Tuple[int, int]]]] + def __post_init__(self): + if isinstance(self.time_values, list): + self.time_values = [(min(time_value), max(time_value)) if isinstance(time_value, tuple) else time_value for time_value in self.time_values] + @property def is_week(self) -> bool: return self.time_type == 'week' From 0c7466c5ef5ed35dedeb5f92c860ea12a0823c12 Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Fri, 7 Oct 2022 16:02:09 -0700 Subject: [PATCH 11/29] Server: minor model.py data_source_by_id name update --- src/server/endpoints/covidcast_utils/model.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/server/endpoints/covidcast_utils/model.py b/src/server/endpoints/covidcast_utils/model.py index 28b398580..520cb9c37 100644 --- a/src/server/endpoints/covidcast_utils/model.py +++ b/src/server/endpoints/covidcast_utils/model.py @@ -202,7 +202,7 @@ def _load_data_sources(): data_sources, data_sources_df = _load_data_sources() -data_source_by_id = {d.source: d for d in data_sources} +data_sources_by_id = {d.source: d for d in data_sources} def _load_data_signals(sources: List[DataSource]): @@ -231,12 +231,11 @@ def _load_data_signals(sources: List[DataSource]): data_signals_by_key = {d.key: d for d in data_signals} # also add the resolved signal version to the signal lookup for d in data_signals: - source = data_source_by_id.get(d.source) + source = data_sources_by_id.get(d.source) if source and source.uses_db_alias: data_signals_by_key[(source.db_source, d.signal)] = d - def get_related_signals(signal: DataSignal) -> List[DataSignal]: return [s for s in data_signals if s != signal and s.signal_basename == signal.signal_basename] @@ -266,7 +265,7 @@ def create_source_signal_alias_mapper(source_signals: List[SourceSignalPair]) -> alias_to_data_sources: Dict[str, List[DataSource]] = {} transformed_pairs: List[SourceSignalPair] = [] for pair in source_signals: - source = data_source_by_id.get(pair.source) + source = data_sources_by_id.get(pair.source) if not source or not source.uses_db_alias: transformed_pairs.append(pair) continue From b5818ecb3eb0166b466667c25413142f1371912b Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Fri, 7 Oct 2022 16:03:03 -0700 Subject: [PATCH 12/29] Server: minor dates.py spacing fix --- src/server/utils/dates.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/server/utils/dates.py b/src/server/utils/dates.py index ef34a50b9..f2b21f87b 100644 --- a/src/server/utils/dates.py +++ b/src/server/utils/dates.py @@ -36,14 +36,12 @@ def guess_time_value_is_week(value: int) -> bool: def date_to_time_value(d: date) -> int: return int(d.strftime("%Y%m%d")) - def week_to_time_value(w: Week) -> int: return w.year * 100 + w.week def time_value_to_iso(value: int) -> str: return time_value_to_date(value).strftime("%Y-%m-%d") - def shift_time_value(time_value: int, days: int) -> int: if days == 0: return time_value From 01baf4fa1469f349fe2ed286127f264f2621e478 Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Fri, 7 Oct 2022 16:17:48 -0700 Subject: [PATCH 13/29] Server: remove unused imports test_covidcast --- tests/server/endpoints/test_covidcast.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/server/endpoints/test_covidcast.py b/tests/server/endpoints/test_covidcast.py index b7ecdc263..823f9126a 100644 --- a/tests/server/endpoints/test_covidcast.py +++ b/tests/server/endpoints/test_covidcast.py @@ -5,10 +5,6 @@ from flask import Response from delphi.epidata.server.main import app -from delphi.epidata.server._params import ( - GeoPair, - TimePair, -) # py3tester coverage target __test_target__ = "delphi.epidata.server.endpoints.covidcast" From 24dbd2b0c83bf13cd04e559b7e532c7242a04a4c Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Fri, 7 Oct 2022 16:31:15 -0700 Subject: [PATCH 14/29] JIT: major feature commit * add smooth_diff * add model updates * add /trend endpoint * add /trendseries endpoint * add /csv endpoint * params with utility functions * update date utility functions --- deploy.json | 8 +- .../server/test_covidcast_endpoints.py | 384 +++++++++--- requirements.txt | 2 + src/acquisition/covidcast/covidcast_row.py | 4 +- src/server/_config.py | 1 + src/server/_params.py | 33 +- src/server/_validate.py | 10 + src/server/endpoints/covidcast.py | 260 ++++++-- src/server/endpoints/covidcast_utils/model.py | 316 +++++++++- .../endpoints/covidcast_utils/smooth_diff.py | 179 ++++++ src/server/endpoints/covidcast_utils/trend.py | 4 + src/server/utils/__init__.py | 18 +- src/server/utils/dates.py | 45 ++ .../covidcast/test_covidcast_row.py | 3 + .../endpoints/covidcast_utils/test_model.py | 557 ++++++++++++++++++ .../covidcast_utils/test_smooth_diff.py | 163 +++++ tests/server/test_params.py | 36 +- tests/server/utils/test_dates.py | 18 +- 18 files changed, 1905 insertions(+), 136 deletions(-) create mode 100644 src/server/endpoints/covidcast_utils/smooth_diff.py create mode 100644 tests/server/endpoints/covidcast_utils/test_model.py create mode 100644 tests/server/endpoints/covidcast_utils/test_smooth_diff.py diff --git a/deploy.json b/deploy.json index 45b45883e..b50bec4aa 100644 --- a/deploy.json +++ b/deploy.json @@ -32,6 +32,13 @@ "match": "^.*\\.(py)$", "add-header-comment": true }, + { + "type": "move", + "src": "src/server/utils", + "dst": "[[package]]/server/utils/", + "match": "^.*\\.(py)$", + "add-header-comment": true + }, { "type": "move", "src": "src/server/endpoints/covidcast_utils", @@ -39,7 +46,6 @@ "match": "^.*\\.(py)$", "add-header-comment": true }, - "// acquisition - fluview", { "type": "move", diff --git a/integrations/server/test_covidcast_endpoints.py b/integrations/server/test_covidcast_endpoints.py index 2d342db6f..aa96aae3d 100644 --- a/integrations/server/test_covidcast_endpoints.py +++ b/integrations/server/test_covidcast_endpoints.py @@ -3,34 +3,39 @@ # standard library from copy import copy from itertools import accumulate, chain -from typing import Iterable, Dict, Any, List, Sequence -import unittest +from typing import List, Sequence from io import StringIO -# from typing import Optional -from dataclasses import dataclass - # third party -import mysql.connector from more_itertools import interleave_longest, windowed import requests import pandas as pd -import numpy as np -from delphi_utils import Nans from delphi.epidata.acquisition.covidcast.covidcast_meta_cache_updater import main as update_cache -from delphi.epidata.server.endpoints.covidcast_utils.model import DataSignal, DataSource - -from delphi.epidata.acquisition.covidcast.database import Database from delphi.epidata.acquisition.covidcast.test_utils import CovidcastBase +from delphi.epidata.acquisition.covidcast.covidcast_row import CovidcastRow, CovidcastRows, set_df_dtypes # use the local instance of the Epidata API BASE_URL = "http://delphi_web_epidata/epidata/covidcast" BASE_URL_OLD = "http://delphi_web_epidata/epidata/api.php" -class CovidcastEndpointTests(CovidcastBase): +def _read_csv(txt: str) -> pd.DataFrame: + df = pd.read_csv(StringIO(txt), index_col=0).rename(columns={"data_source": "source"}) + df.time_value = pd.to_datetime(df.time_value).dt.strftime("%Y%m%d").astype(int) + df.issue = pd.to_datetime(df.issue).dt.strftime("%Y%m%d").astype(int) + df = set_df_dtypes(df, CovidcastRows()._DTYPES) + df.geo_value = df.geo_value.str.zfill(5) + return df + +def _diff_rows(rows: Sequence[float]): + return [float(x - y) if x is not None and y is not None else None for x, y in zip(rows[1:], rows[:-1])] +def _smooth_rows(rows: Sequence[float]): + return [sum(e)/len(e) if None not in e else None for e in windowed(rows, 7)] + + +class CovidcastEndpointTests(CovidcastBase): """Tests the `covidcast/*` endpoint.""" def localSetUp(self): @@ -51,16 +56,10 @@ def _fetch(self, endpoint="/", is_compatibility=False, **params): response.raise_for_status() return response.json() - def _diff_rows(self, rows: Sequence[float]): - return [float(x - y) if x is not None and y is not None else None for x, y in zip(rows[1:], rows[:-1])] - - def _smooth_rows(self, rows: Sequence[float]): - return [sum(e)/len(e) if None not in e else None for e in windowed(rows, 7)] - def test_basic(self): """Request a signal from the / endpoint.""" - rows = [self._make_placeholder_row(time_value=20200401 + i, value=i)[0] for i in range(10)] + rows = [CovidcastRow(time_value=20200401 + i, value=i) for i in range(10)] first = rows[0] self._insert_rows(rows) @@ -69,7 +68,7 @@ def test_basic(self): self.assertEqual(out["result"], -1) with self.subTest("simple"): - out = self._fetch("/", signal=first.signal_pair(), geo=first.geo_pair(), time="day:*") + out = self._fetch("/", signal=first.signal_pair, geo=first.geo_pair, time="day:*") self.assertEqual(len(out["epidata"]), len(rows)) with self.subTest("unknown signal"): @@ -78,55 +77,198 @@ def test_basic(self): self._insert_rows(rows) out = self._fetch("/", signal="jhu-csse:confirmed_unknown", geo=first.geo_pair, time="day:*") + out_values = [row["value"] for row in out["epidata"]] + expected_values = [float(row.value) for row in rows] + self.assertEqual(out_values, expected_values) + + def test_compatibility(self): + """Request at the /api.php endpoint.""" + rows = [CovidcastRow(source="src", signal="sig", time_value=20200401 + i, value=i) for i in range(10)] + first = rows[0] + self._insert_rows(rows) + + with self.subTest("simple"): + out = self._fetch("/", signal=first.signal_pair, geo=first.geo_pair, time="day:*") self.assertEqual(len(out["epidata"]), len(rows)) + + with self.subTest("unknown signal"): + rows = [CovidcastRow(source="jhu-csse", signal="confirmed_unknown", time_value=20200401 + i, value=i) for i in range(10)] + first = rows[0] + self._insert_rows(rows) + + out = self._fetch("/", signal="jhu-csse:confirmed_unknown", geo=first.geo_pair, time="day:*") out_values = [row["value"] for row in out["epidata"]] expected_values = [float(row.value) for row in rows] self.assertEqual(out_values, expected_values) + # JIT tests + def test_derived_signals(self): + time_value_pairs = [(20200401 + i, i ** 2) for i in range(10)] + rows01 = [CovidcastRow(source="jhu-csse", signal="confirmed_cumulative_num", time_value=time_value, value=value, geo_value="01") for time_value, value in time_value_pairs] + rows02 = [CovidcastRow(source="jhu-csse", signal="confirmed_cumulative_num", time_value=time_value, value=2 * value, geo_value="02") for time_value, value in time_value_pairs] + first = rows01[0] + self._insert_rows(rows01 + rows02) + + with self.subTest("diffed signal"): + out = self._fetch("/", signal="jhu-csse:confirmed_incidence_num", geo=first.geo_pair, time="day:*") + assert out['result'] == -2 + out = self._fetch("/", signal="jhu-csse:confirmed_incidence_num", geo=first.geo_pair, time="day:20200401-20200410") + out_values = [row["value"] for row in out["epidata"]] + values = [value for _, value in time_value_pairs] + expected_values = _diff_rows(values) + self.assertAlmostEqual(out_values, expected_values) + + with self.subTest("diffed signal, multiple geos"): + out = self._fetch("/", signal="jhu-csse:confirmed_incidence_num", geo="county:01,02", time="day:20200401-20200410") + out_values = [row["value"] for row in out["epidata"]] + values1 = [value for _, value in time_value_pairs] + values2 = [2 * value for _, value in time_value_pairs] + expected_values = _diff_rows(values1) + _diff_rows(values2) + self.assertAlmostEqual(out_values, expected_values) + + with self.subTest("diffed signal, multiple geos using geo:*"): + out = self._fetch("/", signal="jhu-csse:confirmed_incidence_num", geo="county:*", time="day:20200401-20200410") + values1 = [value for _, value in time_value_pairs] + values2 = [2 * value for _, value in time_value_pairs] + expected_values = _diff_rows(values1) + _diff_rows(values2) + self.assertAlmostEqual(out_values, expected_values) + + with self.subTest("smooth diffed signal"): + out = self._fetch("/", signal="jhu-csse:confirmed_7dav_incidence_num", geo=first.geo_pair, time="day:20200401-20200410") + out_values = [row["value"] for row in out["epidata"]] + values = [value for _, value in time_value_pairs] + expected_values = _smooth_rows(_diff_rows(values)) + self.assertAlmostEqual(out_values, expected_values) + + with self.subTest("diffed signal and smoothed signal in one request"): + out = self._fetch("/", signal="jhu-csse:confirmed_incidence_num;jhu-csse:confirmed_7dav_incidence_num", geo=first.geo_pair, time="day:20200401-20200410") + out_values = [row["value"] for row in out["epidata"]] + values = [value for _, value in time_value_pairs] + expected_diff = _diff_rows(values) + expected_smoothed = _smooth_rows(expected_diff) + expected_values = list(interleave_longest(expected_smoothed, expected_diff)) + self.assertAlmostEqual(out_values, expected_values) + + time_value_pairs = [(20200401 + i, i ** 2) for i in chain(range(10), range(15, 20))] + rows = [CovidcastRow(source="jhu-csse", signal="confirmed_cumulative_num", geo_value="03", time_value=time_value, value=value) for time_value, value in time_value_pairs] + first = rows[0] + self._insert_rows(rows) + + with self.subTest("diffing with a time gap"): + # should fetch 1 extra day + out = self._fetch("/", signal="jhu-csse:confirmed_incidence_num", geo=first.geo_pair, time="day:20200401-20200420") + out_values = [row["value"] for row in out["epidata"]] + values = [value for _, value in time_value_pairs][:10] + [None] * 5 + [value for _, value in time_value_pairs][10:] + expected_values = _diff_rows(values) + self.assertAlmostEqual(out_values, expected_values) + + with self.subTest("smoothing and diffing with a time gap"): + # should fetch 1 extra day + out = self._fetch("/", signal="jhu-csse:confirmed_7dav_incidence_num", geo=first.geo_pair, time="day:20200401-20200420") + out_values = [row["value"] for row in out["epidata"]] + values = [value for _, value in time_value_pairs][:10] + [None] * 5 + [value for _, value in time_value_pairs][10:] + expected_values = _smooth_rows(_diff_rows(values)) + self.assertAlmostEqual(out_values, expected_values) + + def test_compatibility(self): + """Request at the /api.php endpoint.""" + rows = [CovidcastRow(source="src", signal="sig", time_value=20200401 + i, value=i) for i in range(10)] + first = rows[0] + self._insert_rows(rows) + + with self.subTest("simple"): + out = self._fetch(is_compatibility=True, source=first.source, signal=first.signal, geo=first.geo_pair, time="day:*") + self.assertEqual(len(out["epidata"]), len(rows)) + + def _diff_covidcast_rows(self, rows: List[CovidcastRow]) -> List[CovidcastRow]: + new_rows = list() + for x, y in zip(rows[1:], rows[:-1]): + new_row = copy(x) + new_row.value = x.value - y.value + new_rows.append(new_row) + return new_rows + def test_trend(self): """Request a signal from the /trend endpoint.""" num_rows = 30 - rows = [self._make_placeholder_row(time_value=20200401 + i, value=i)[0] for i in range(num_rows)] + rows = [CovidcastRow(time_value=20200401 + i, value=i) for i in range(num_rows)] first = rows[0] last = rows[-1] ref = rows[num_rows // 2] self._insert_rows(rows) - out = self._fetch("/trend", signal=first.signal_pair(), geo=first.geo_pair(), date=last.time_value, window="20200401-20201212", basis=ref.time_value) + with self.subTest("no JIT"): + out = self._fetch("/trend", signal=first.signal_pair, geo=first.geo_pair, date=last.time_value, window="20200401-20201212", basis=ref.time_value) + + self.assertEqual(out["result"], 1) + self.assertEqual(len(out["epidata"]), 1) + trend = out["epidata"][0] + self.assertEqual(trend["geo_type"], last.geo_type) + self.assertEqual(trend["geo_value"], last.geo_value) + self.assertEqual(trend["signal_source"], last.source) + self.assertEqual(trend["signal_signal"], last.signal) + + self.assertEqual(trend["date"], last.time_value) + self.assertEqual(trend["value"], last.value) + + self.assertEqual(trend["basis_date"], ref.time_value) + self.assertEqual(trend["basis_value"], ref.value) + self.assertEqual(trend["basis_trend"], "increasing") + + self.assertEqual(trend["min_date"], first.time_value) + self.assertEqual(trend["min_value"], first.value) + self.assertEqual(trend["min_trend"], "increasing") + self.assertEqual(trend["max_date"], last.time_value) + self.assertEqual(trend["max_value"], last.value) + self.assertEqual(trend["max_trend"], "steady") + + num_rows = 30 + time_value_pairs = [(20200331, 0)] + [(20200401 + i, v) for i, v in enumerate(accumulate(range(num_rows)))] + rows = [CovidcastRow(source="jhu-csse", signal="confirmed_cumulative_num", time_value=t, value=v) for t, v in time_value_pairs] + self._insert_rows(rows) + diffed_rows = self._diff_covidcast_rows(rows) + for row in diffed_rows: + row.signal = "confirmed_incidence_num" + first = diffed_rows[0] + last = diffed_rows[-1] + ref = diffed_rows[num_rows // 2] + with self.subTest("use JIT"): + out = self._fetch("/trend", signal="jhu-csse:confirmed_incidence_num", geo=first.geo_pair, date=last.time_value, window="20200401-20201212", basis=ref.time_value) + + self.assertEqual(out["result"], 1) + self.assertEqual(len(out["epidata"]), 1) + trend = out["epidata"][0] + self.assertEqual(trend["geo_type"], last.geo_type) + self.assertEqual(trend["geo_value"], last.geo_value) + self.assertEqual(trend["signal_source"], last.source) + self.assertEqual(trend["signal_signal"], last.signal) + + self.assertEqual(trend["date"], last.time_value) + self.assertEqual(trend["value"], last.value) + + self.assertEqual(trend["basis_date"], ref.time_value) + self.assertEqual(trend["basis_value"], ref.value) + self.assertEqual(trend["basis_trend"], "increasing") + + self.assertEqual(trend["min_date"], first.time_value) + self.assertEqual(trend["min_value"], first.value) + self.assertEqual(trend["min_trend"], "increasing") + self.assertEqual(trend["max_date"], last.time_value) + self.assertEqual(trend["max_value"], last.value) + self.assertEqual(trend["max_trend"], "steady") - self.assertEqual(out["result"], 1) - self.assertEqual(len(out["epidata"]), 1) - trend = out["epidata"][0] - self.assertEqual(trend["geo_type"], last.geo_type) - self.assertEqual(trend["geo_value"], last.geo_value) - self.assertEqual(trend["signal_source"], last.source) - self.assertEqual(trend["signal_signal"], last.signal) - - self.assertEqual(trend["date"], last.time_value) - self.assertEqual(trend["value"], last.value) - - self.assertEqual(trend["basis_date"], ref.time_value) - self.assertEqual(trend["basis_value"], ref.value) - self.assertEqual(trend["basis_trend"], "increasing") - - self.assertEqual(trend["min_date"], first.time_value) - self.assertEqual(trend["min_value"], first.value) - self.assertEqual(trend["min_trend"], "increasing") - self.assertEqual(trend["max_date"], last.time_value) - self.assertEqual(trend["max_value"], last.value) - self.assertEqual(trend["max_trend"], "steady") def test_trendseries(self): """Request a signal from the /trendseries endpoint.""" num_rows = 3 - rows = [self._make_placeholder_row(time_value=20200401 + i, value=num_rows - i)[0] for i in range(num_rows)] + rows = [CovidcastRow(time_value=20200401 + i, value=num_rows - i) for i in range(num_rows)] first = rows[0] last = rows[-1] self._insert_rows(rows) - out = self._fetch("/trendseries", signal=first.signal_pair(), geo=first.geo_pair(), date=last.time_value, window="20200401-20200410", basis=1) + out = self._fetch("/trendseries", signal=first.signal_pair, geo=first.geo_pair, date=last.time_value, window="20200401-20200410", basis=1) self.assertEqual(out["result"], 1) self.assertEqual(len(out["epidata"]), 3) @@ -154,6 +296,7 @@ def match_row(trend, row): self.assertEqual(trend["max_date"], first.time_value) self.assertEqual(trend["max_value"], first.value) self.assertEqual(trend["max_trend"], "steady") + with self.subTest("trend1"): trend = trends[1] match_row(trend, rows[1]) @@ -182,19 +325,78 @@ def match_row(trend, row): self.assertEqual(trend["max_value"], first.value) self.assertEqual(trend["max_trend"], "decreasing") + num_rows = 3 + time_value_pairs = [(20200331, 0)] + [(20200401 + i, v) for i, v in enumerate(accumulate([num_rows - i for i in range(num_rows)]))] + rows = [CovidcastRow(source="jhu-csse", signal="confirmed_cumulative_num", time_value=t, value=v) for t, v in time_value_pairs] + self._insert_rows(rows) + diffed_rows = self._diff_covidcast_rows(rows) + for row in diffed_rows: + row.signal = "confirmed_incidence_num" + first = diffed_rows[0] + last = diffed_rows[-1] + + out = self._fetch("/trendseries", signal="jhu-csse:confirmed_incidence_num", geo=first.geo_pair, date=last.time_value, window="20200401-20200410", basis=1) + + self.assertEqual(out["result"], 1) + self.assertEqual(len(out["epidata"]), 3) + trends = out["epidata"] + + with self.subTest("trend0, JIT"): + trend = trends[0] + match_row(trend, first) + self.assertEqual(trend["basis_date"], None) + self.assertEqual(trend["basis_value"], None) + self.assertEqual(trend["basis_trend"], "unknown") + + self.assertEqual(trend["min_date"], last.time_value) + self.assertEqual(trend["min_value"], last.value) + self.assertEqual(trend["min_trend"], "increasing") + self.assertEqual(trend["max_date"], first.time_value) + self.assertEqual(trend["max_value"], first.value) + self.assertEqual(trend["max_trend"], "steady") + + with self.subTest("trend1"): + trend = trends[1] + match_row(trend, diffed_rows[1]) + self.assertEqual(trend["basis_date"], first.time_value) + self.assertEqual(trend["basis_value"], first.value) + self.assertEqual(trend["basis_trend"], "decreasing") + + self.assertEqual(trend["min_date"], last.time_value) + self.assertEqual(trend["min_value"], last.value) + self.assertEqual(trend["min_trend"], "increasing") + self.assertEqual(trend["max_date"], first.time_value) + self.assertEqual(trend["max_value"], first.value) + self.assertEqual(trend["max_trend"], "decreasing") + + with self.subTest("trend2"): + trend = trends[2] + match_row(trend, last) + self.assertEqual(trend["basis_date"], diffed_rows[1].time_value) + self.assertEqual(trend["basis_value"], diffed_rows[1].value) + self.assertEqual(trend["basis_trend"], "decreasing") + + self.assertEqual(trend["min_date"], last.time_value) + self.assertEqual(trend["min_value"], last.value) + self.assertEqual(trend["min_trend"], "steady") + self.assertEqual(trend["max_date"], first.time_value) + self.assertEqual(trend["max_value"], first.value) + self.assertEqual(trend["max_trend"], "decreasing") + + def test_correlation(self): """Request a signal from the /correlation endpoint.""" num_rows = 30 - reference_rows = [self._make_placeholder_row(signal="ref", time_value=20200401 + i, value=i)[0] for i in range(num_rows)] + reference_rows = [CovidcastRow(signal="ref", time_value=20200401 + i, value=i) for i in range(num_rows)] first = reference_rows[0] self._insert_rows(reference_rows) - other_rows = [self._make_placeholder_row(signal="other", time_value=20200401 + i, value=i)[0] for i in range(num_rows)] + other_rows = [CovidcastRow(signal="other", time_value=20200401 + i, value=i) for i in range(num_rows)] other = other_rows[0] self._insert_rows(other_rows) max_lag = 3 - out = self._fetch("/correlation", reference=first.signal_pair(), others=other.signal_pair(), geo=first.geo_pair(), window="20200401-20201212", lag=max_lag) + out = self._fetch("/correlation", reference=first.signal_pair, others=other.signal_pair, geo=first.geo_pair, window="20200401-20201212", lag=max_lag) self.assertEqual(out["result"], 1) df = pd.DataFrame(out["epidata"]) self.assertEqual(len(df), max_lag * 2 + 1) # -...0...+ @@ -212,31 +414,75 @@ def test_correlation(self): def test_csv(self): """Request a signal from the /csv endpoint.""" - rows = [self._make_placeholder_row(time_value=20200401 + i, value=i)[0] for i in range(10)] - first = rows[0] - self._insert_rows(rows) - - response = requests.get( - f"{BASE_URL}/csv", - params=dict(signal=first.signal_pair(), start_day="2020-04-01", end_day="2020-12-12", geo_type=first.geo_type), + expected_columns = ["geo_value", "signal", "time_value", "issue", "lag", "value", "stderr", "sample_size", "geo_type", "data_source"] + data = CovidcastRows.from_args( + time_value=pd.date_range("2020-04-01", "2020-04-10"), + value=range(10) ) - response.raise_for_status() - out = response.text - df = pd.read_csv(StringIO(out), index_col=0) - self.assertEqual(df.shape, (len(rows), 10)) - self.assertEqual(list(df.columns), ["geo_value", "signal", "time_value", "issue", "lag", "value", "stderr", "sample_size", "geo_type", "data_source"]) + self._insert_rows(data.rows) + first = data.rows[0] + with self.subTest("no JIT"): + response = requests.get( + f"{BASE_URL}/csv", + params=dict(signal=first.signal_pair, start_day="2020-04-01", end_day="2020-04-10", geo_type=first.geo_type), + ) + response.raise_for_status() + out = response.text + df = pd.read_csv(StringIO(out), index_col=0) + + self.assertEqual(df.shape, (len(data.rows), 10)) + self.assertEqual(list(df.columns), expected_columns) + + data = CovidcastRows.from_args( + source=["jhu-csse"] * 11, + signal=["confirmed_cumulative_num"] * 11, + time_value=pd.date_range("2020-03-31", "2020-04-10"), + value=accumulate(range(11)), + ) + self._insert_rows(data.rows) + first = data.rows[0] + with self.subTest("use JIT"): + response = requests.get( + f"{BASE_URL}/csv", + params=dict(signal="jhu-csse:confirmed_cumulative_num", start_day="2020-04-01", end_day="2020-04-10", geo_type=first.geo_type), + ) + response.raise_for_status() + df = _read_csv(response.text) + expected_df = CovidcastRows.from_args( + source=["jhu-csse"] * 10, + signal=["confirmed_cumulative_num"] * 10, + time_value=pd.date_range("2020-04-01", "2020-04-10"), + value=list(accumulate(range(11)))[1:], + ).api_row_df[df.columns] + pd.testing.assert_frame_equal(df, expected_df) + + response = requests.get( + f"{BASE_URL}/csv", + params=dict(signal="jhu-csse:confirmed_incidence_num", start_day="2020-04-01", end_day="2020-04-10", geo_type=first.geo_type), + ) + response.raise_for_status() + df_diffed = _read_csv(response.text) + expected_df = CovidcastRows.from_args( + source=["jhu-csse"] * 10, + signal=["confirmed_incidence_num"] * 10, + time_value=pd.date_range("2020-04-01", "2020-04-10"), + value=range(1, 11), + stderr=[None] * 10, + sample_size=[None] * 10 + ).api_row_df[df_diffed.columns] + pd.testing.assert_frame_equal(df_diffed, expected_df) def test_backfill(self): """Request a signal from the /backfill endpoint.""" num_rows = 10 - issue_0 = [self._make_placeholder_row(time_value=20200401 + i, value=i, sample_size=1, lag=0, issue=20200401 + i)[0] for i in range(num_rows)] - issue_1 = [self._make_placeholder_row(time_value=20200401 + i, value=i + 1, sample_size=2, lag=1, issue=20200401 + i + 1)[0] for i in range(num_rows)] - last_issue = [self._make_placeholder_row(time_value=20200401 + i, value=i + 2, sample_size=3, lag=2, issue=20200401 + i + 2)[0] for i in range(num_rows)] # <-- the latest issues + issue_0 = [CovidcastRow(time_value=20200401 + i, value=i, sample_size=1, lag=0, issue=20200401 + i) for i in range(num_rows)] + issue_1 = [CovidcastRow(time_value=20200401 + i, value=i + 1, sample_size=2, lag=1, issue=20200401 + i + 1) for i in range(num_rows)] + last_issue = [CovidcastRow(time_value=20200401 + i, value=i + 2, sample_size=3, lag=2, issue=20200401 + i + 2) for i in range(num_rows)] # <-- the latest issues self._insert_rows([*issue_0, *issue_1, *last_issue]) first = issue_0[0] - out = self._fetch("/backfill", signal=first.signal_pair(), geo=first.geo_pair(), time="day:20200401-20201212", anchor_lag=3) + out = self._fetch("/backfill", signal=first.signal_pair, geo=first.geo_pair, time="day:20200401-20201212", anchor_lag=3) self.assertEqual(out["result"], 1) df = pd.DataFrame(out["epidata"]) self.assertEqual(len(df), 3 * num_rows) # num issues @@ -258,7 +504,7 @@ def test_meta(self): """Request a signal from the /meta endpoint.""" num_rows = 10 - rows = [self._make_placeholder_row(time_value=20200401 + i, value=i, source="fb-survey", signal="smoothed_cli")[0] for i in range(num_rows)] + rows = [CovidcastRow(time_value=20200401 + i, value=i, source="fb-survey", signal="smoothed_cli") for i in range(num_rows)] self._insert_rows(rows) first = rows[0] last = rows[-1] @@ -299,22 +545,22 @@ def test_coverage(self): num_geos_per_date = [10, 20, 30, 40, 44] dates = [20200401 + i for i in range(len(num_geos_per_date))] - rows = [self._make_placeholder_row(time_value=dates[i], value=i, geo_value=str(geo_value))[0] for i, num_geo in enumerate(num_geos_per_date) for geo_value in range(num_geo)] + rows = [CovidcastRow(time_value=dates[i], value=i, geo_value=str(geo_value)) for i, num_geo in enumerate(num_geos_per_date) for geo_value in range(num_geo)] self._insert_rows(rows) first = rows[0] with self.subTest("default"): - out = self._fetch("/coverage", signal=first.signal_pair(), geo_type=first.geo_type, latest=dates[-1], format="json") + out = self._fetch("/coverage", signal=first.signal_pair, geo_type=first.geo_type, latest=dates[-1], format="json") self.assertEqual(len(out), len(num_geos_per_date)) self.assertEqual([o["time_value"] for o in out], dates) self.assertEqual([o["count"] for o in out], num_geos_per_date) with self.subTest("specify window"): - out = self._fetch("/coverage", signal=first.signal_pair(), geo_type=first.geo_type, window=f"{dates[0]}-{dates[1]}", format="json") + out = self._fetch("/coverage", signal=first.signal_pair, geo_type=first.geo_type, window=f"{dates[0]}-{dates[1]}", format="json") self.assertEqual(len(out), 2) self.assertEqual([o["time_value"] for o in out], dates[:2]) self.assertEqual([o["count"] for o in out], num_geos_per_date[:2]) with self.subTest("invalid geo_type"): - out = self._fetch("/coverage", signal=first.signal_pair(), geo_type="doesnt_exist", format="json") + out = self._fetch("/coverage", signal=first.signal_pair, geo_type="doesnt_exist", format="json") self.assertEqual(len(out), 0) diff --git a/requirements.txt b/requirements.txt index 945ac11ea..51abf274d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,5 @@ scipy==1.6.2 tenacity==7.0.0 newrelic epiweeks==2.1.2 +delphi_utils +more_itertools==8.4.0 \ No newline at end of file diff --git a/src/acquisition/covidcast/covidcast_row.py b/src/acquisition/covidcast/covidcast_row.py index af57b0b28..e982a0250 100644 --- a/src/acquisition/covidcast/covidcast_row.py +++ b/src/acquisition/covidcast/covidcast_row.py @@ -6,8 +6,8 @@ from numpy import isnan from pandas import DataFrame, concat -from .csv_importer import CsvImporter -from ...server.utils.dates import date_to_time_value, time_value_to_date +from delphi.epidata.acquisition.covidcast.csv_importer import CsvImporter +from delphi.epidata.server.utils.dates import date_to_time_value, time_value_to_date def _is_none(v: Optional[float]) -> bool: diff --git a/src/server/_config.py b/src/server/_config.py index 47688a8ef..7973f9872 100644 --- a/src/server/_config.py +++ b/src/server/_config.py @@ -9,6 +9,7 @@ MAX_RESULTS = int(10e6) MAX_COMPATIBILITY_RESULTS = int(3650) +MAX_SMOOTHER_WINDOW = 30 SQLALCHEMY_DATABASE_URI = os.environ.get("SQLALCHEMY_DATABASE_URI", "sqlite:///test.db") diff --git a/src/server/_params.py b/src/server/_params.py index 0e206e548..74ee8540d 100644 --- a/src/server/_params.py +++ b/src/server/_params.py @@ -1,10 +1,11 @@ -from math import inf import re from dataclasses import dataclass +from itertools import groupby +from math import inf from typing import List, Optional, Sequence, Tuple, Union from flask import request - +from more_itertools import flatten from ._exceptions import ValidationFailedException from .utils import days_in_range, weeks_in_range, guess_time_value_is_day @@ -92,9 +93,35 @@ def count(self) -> float: return inf if self.signal else 0 return len(self.signal) + def add_signal(self, signal: str) -> None: + if not isinstance(self.signal, bool): + self.signal.append(signal) + + def __hash__(self) -> int: + return hash((self.source, self.signal if self.signal is isinstance(self.signal, bool) else tuple(self.signal))) + + +def _combine_source_signal_pairs(source_signal_pairs: List[SourceSignalPair]) -> List[SourceSignalPair]: + """Combine SourceSignalPairs with the same source into a single SourceSignalPair object. + + Example: + [SourceSignalPair("src", ["sig1", "sig2"]), SourceSignalPair("src", ["sig2", "sig3"])] will be merged + into [SourceSignalPair("src", ["sig1", "sig2", "sig3])]. + """ + source_signal_pairs_grouped = groupby(sorted(source_signal_pairs, key=lambda x: x.source), lambda x: x.source) + source_signal_pairs_combined = [] + for source, group in source_signal_pairs_grouped: + group = list(group) + if any(x.signal == True for x in group): + combined_signals = True + else: + combined_signals = sorted(set(flatten(x.signal for x in group))) + source_signal_pairs_combined.append(SourceSignalPair(source, combined_signals)) + return source_signal_pairs_combined + def parse_source_signal_arg(key: str = "signal") -> List[SourceSignalPair]: - return [SourceSignalPair(source, signals) for [source, signals] in _parse_common_multi_arg(key)] + return _combine_source_signal_pairs([SourceSignalPair(source, signals) for [source, signals] in _parse_common_multi_arg(key)]) def parse_single_source_signal_arg(key: str) -> SourceSignalPair: diff --git a/src/server/_validate.py b/src/server/_validate.py index 3b91e5570..af4f3e4d0 100644 --- a/src/server/_validate.py +++ b/src/server/_validate.py @@ -189,3 +189,13 @@ def push_range(first: str, last: str): values.append(parse_date(part)) # success, return the list return values + +def extract_bool(key: Union[str, Sequence[str]]) -> Optional[bool]: + s = _extract_value(key) + if not s: + return None + if s.lower() == "true": + return True + if s.lower() == "false": + return False + raise ValidationFailedException(f"{key}: not a boolean: {s}") diff --git a/src/server/endpoints/covidcast.py b/src/server/endpoints/covidcast.py index 4a636d891..4d9738286 100644 --- a/src/server/endpoints/covidcast.py +++ b/src/server/endpoints/covidcast.py @@ -1,14 +1,18 @@ +from numbers import Number from typing import List, Optional, Union, Tuple, Dict, Any from itertools import groupby from datetime import date, timedelta +from bisect import bisect_right from epiweeks import Week from flask import Blueprint, request from flask.json import loads, jsonify -from bisect import bisect_right +from more_itertools import peekable +from numpy import nan from sqlalchemy import text from pandas import read_csv, to_datetime -from .._common import is_compatibility_mode, db +from .._common import is_compatibility_mode, app, db +from .._config import MAX_SMOOTHER_WINDOW from .._exceptions import ValidationFailedException, DatabaseErrorException from .._params import ( GeoPair, @@ -26,6 +30,7 @@ from .._query import QueryBuilder, execute_query, run_query, parse_row, filter_fields from .._printer import create_printer, CSVPrinter from .._validate import ( + extract_bool, extract_date, extract_dates, extract_integer, @@ -36,11 +41,13 @@ from .._pandas import as_pandas, print_pandas from .covidcast_utils import compute_trend, compute_trends, compute_correlations, compute_trend_value, CovidcastMetaEntry from ..utils import shift_time_value, date_to_time_value, time_value_to_iso, time_value_to_date, shift_week_value, week_value_to_week, guess_time_value_is_day, week_to_time_value -from .covidcast_utils.model import TimeType, count_signal_time_types, data_sources, create_source_signal_alias_mapper +from .covidcast_utils.model import TimeType, count_signal_time_types, data_sources, create_source_signal_alias_mapper, get_basename_signal_and_jit_generator, get_pad_length, pad_time_pairs, pad_time_window +from .covidcast_utils.smooth_diff import SmootherKernelValue # first argument is the endpoint name bp = Blueprint("covidcast", __name__) alias = None +JIT_COMPUTE_ON = True latest_table = "epimetric_latest_v" history_table = "epimetric_full_v" @@ -81,12 +88,17 @@ def parse_time_pairs() -> List[TimePair]: # old version require_all("time_type", "time_values") time_values = extract_dates("time_values") + # TODO: Put a bound on the number of time_values? + # if time_values and len(time_values) > 30: + # raise ValidationFailedException("parameter value exceed: too many time pairs requested, consider using a timerange instead YYYYMMDD-YYYYMMDD") return [TimePair(time_type, time_values)] if ":" not in request.values.get("time", ""): raise ValidationFailedException("missing parameter: time or (time_type and time_values)") - return parse_time_arg() + time_pairs = parse_time_arg() + # TODO: Put a bound on the number of time_values? (see above) + return time_pairs def _handle_lag_issues_as_of(q: QueryBuilder, issues: Optional[List[Union[Tuple[int, int], int]]] = None, lag: Optional[int] = None, as_of: Optional[int] = None): @@ -111,50 +123,117 @@ def _handle_lag_issues_as_of(q: QueryBuilder, issues: Optional[List[Union[Tuple[ pass +def parse_transform_args(): + # The length of the window to smooth over. + smoother_window_length = extract_integer("smoother_window_length") + if smoother_window_length is None: + smoother_window_length = 7 + + # TODO: Add support for floats inputs here. + # The value to fill for missing date values. + pad_fill_value = extract_integer("pad_fill_value") + if pad_fill_value is None: + pad_fill_value = nan + + # The value to fill for None or nan values. + nan_fill_value = extract_integer("nans_fill_value") + if nan_fill_value is None: + nan_fill_value = nan + + smoother_args = { + "smoother_kernel": SmootherKernelValue.average, + "smoother_window_length": smoother_window_length if isinstance(smoother_window_length, Number) and smoother_window_length <= MAX_SMOOTHER_WINDOW else MAX_SMOOTHER_WINDOW, + "pad_fill_value": pad_fill_value if isinstance(pad_fill_value, Number) else nan, + "nans_fill_value": nan_fill_value if isinstance(nan_fill_value, Number) else nan + } + return smoother_args + + +def parse_jit_bypass(): + jit_bypass = extract_bool("jit_bypass") + if jit_bypass is None: + return False + else: + return jit_bypass + + @bp.route("/", methods=("GET", "POST")) def handle(): source_signal_pairs = parse_source_signal_pairs() source_signal_pairs, alias_mapper = create_source_signal_alias_mapper(source_signal_pairs) time_pairs = parse_time_pairs() geo_pairs = parse_geo_pairs() + jit_bypass = parse_jit_bypass() as_of = extract_date("as_of") issues = extract_dates("issues") lag = extract_integer("lag") + is_time_type_week = any(time_pair.time_type == "week" for time_pair in time_pairs) + is_time_value_true = any(isinstance(time_pair.time_values, bool) for time_pair in time_pairs) + + is_compatibility = is_compatibility_mode() + def alias_row(row): + if is_compatibility: + # old api returned fewer fields + remove_fields = ["geo_type", "source", "time_type"] + for field in remove_fields: + if field in row: + del row[field] + if is_compatibility or not alias_mapper or "source" not in row: + return row + row["source"] = alias_mapper(row["source"], row["signal"]) + return row # build query q = QueryBuilder(latest_table, "t") - fields_string = ["geo_value", "signal"] + fields_string = ["geo_type", "geo_value", "source", "signal", "time_type"] fields_int = ["time_value", "direction", "issue", "lag", "missing_value", "missing_stderr", "missing_sample_size"] fields_float = ["value", "stderr", "sample_size"] - is_compatibility = is_compatibility_mode() - if is_compatibility: - q.set_order("signal", "time_value", "geo_value", "issue") + + # TODO: JIT computations don't support time_value = *; there may be a clever way to implement this. + use_jit_compute = not any((issues, lag, is_time_type_week, is_time_value_true)) and JIT_COMPUTE_ON and not jit_bypass + if use_jit_compute: + transform_args = parse_transform_args() + pad_length = get_pad_length(source_signal_pairs, transform_args.get("smoother_window_length")) + time_pairs = pad_time_pairs(time_pairs, pad_length) + app.logger.info(f"JIT compute enabled for route '/': {source_signal_pairs}") + source_signal_pairs, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs, transform_args=transform_args) + app.logger.info(f"JIT base signals: {source_signal_pairs}") + + def gen_transform(rows): + parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows) + transformed_rows = row_transform_generator(parsed_rows=parsed_rows, time_pairs=time_pairs, transform_args=transform_args) + for row in transformed_rows: + yield alias_row(row) else: - # transfer also the new detail columns - fields_string.extend(["source", "geo_type", "time_type"]) - q.set_order("source", "signal", "time_type", "time_value", "geo_type", "geo_value", "issue") + def gen_transform(rows): + parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows) + for row in parsed_rows: + yield alias_row(row) + + q.set_order("source", "signal", "geo_type", "geo_value", "time_type", "time_value", "issue") q.set_fields(fields_string, fields_int, fields_float) # basic query info # data type of each field # build the source, signal, time, and location (type and id) filters - q.where_source_signal_pairs("source", "signal", source_signal_pairs) q.where_geo_pairs("geo_type", "geo_value", geo_pairs) q.where_time_pairs("time_type", "time_value", time_pairs) _handle_lag_issues_as_of(q, issues, lag, as_of) - def transform_row(row, proxy): - if is_compatibility or not alias_mapper or "source" not in row: - return row - row["source"] = alias_mapper(row["source"], proxy["signal"]) - return row + p = create_printer() + + # execute first query + try: + r = run_query(p, (str(q), q.params)) + except Exception as e: + raise DatabaseErrorException(str(e)) - # send query - return execute_query(str(q), q.params, fields_string, fields_int, fields_float, transform=transform_row) + # now use a generator for sending the rows and execute all the other queries + return p(filter_fields(gen_transform(r))) def _verify_argument_time_type_matches(is_day_argument: bool, count_daily_signal: int, count_weekly_signal: int) -> None: @@ -171,12 +250,17 @@ def handle_trend(): daily_signals, weekly_signals = count_signal_time_types(source_signal_pairs) source_signal_pairs, alias_mapper = create_source_signal_alias_mapper(source_signal_pairs) geo_pairs = parse_geo_pairs() + transform_args = parse_transform_args() + jit_bypass = parse_jit_bypass() time_window, is_day = parse_day_or_week_range_arg("window") time_value, is_also_day = parse_day_or_week_arg("date") + if is_day != is_also_day: raise ValidationFailedException("mixing weeks with day arguments") + _verify_argument_time_type_matches(is_day, daily_signals, weekly_signals) + basis_time_value = extract_date("basis") if basis_time_value is None: base_shift = extract_integer("basis_shift") @@ -184,14 +268,42 @@ def handle_trend(): base_shift = 7 basis_time_value = shift_time_value(time_value, -1 * base_shift) if is_day else shift_week_value(time_value, -1 * base_shift) + def gen_trend(rows): + for key, group in groupby(rows, lambda row: (row["source"], row["signal"], row["geo_type"], row["geo_value"])): + source, signal, geo_type, geo_value = key + if alias_mapper: + source = alias_mapper(source, signal) + trend = compute_trend(geo_type, geo_value, source, signal, time_value, basis_time_value, ((row["time_value"], row["value"]) for row in group)) + yield trend.asdict() + # build query q = QueryBuilder(latest_table, "t") fields_string = ["geo_type", "geo_value", "source", "signal"] fields_int = ["time_value"] fields_float = ["value"] + + use_jit_compute = all((is_day, is_also_day)) and JIT_COMPUTE_ON and not jit_bypass + if use_jit_compute: + pad_length = get_pad_length(source_signal_pairs, transform_args.get("smoother_window_length")) + app.logger.info(f"JIT compute enabled for route '/trend': {source_signal_pairs}") + source_signal_pairs, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs) + app.logger.info(f"JIT base signals: {source_signal_pairs}") + time_window = pad_time_window(time_window, pad_length) + + def gen_transform(rows): + parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows) + transformed_rows = row_transform_generator(parsed_rows=parsed_rows, time_pairs=[TimePair("day", [time_window])], transform_args=transform_args) + for row in transformed_rows: + yield row + else: + def gen_transform(rows): + parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows) + for row in parsed_rows: + yield row + q.set_fields(fields_string, fields_int, fields_float) - q.set_order("geo_type", "geo_value", "source", "signal", "time_value") + q.set_order("source", "signal", "geo_type", "geo_value", "time_value") q.where_source_signal_pairs("source", "signal", source_signal_pairs) q.where_geo_pairs("geo_type", "geo_value", geo_pairs) @@ -202,13 +314,6 @@ def handle_trend(): p = create_printer() - def gen(rows): - for key, group in groupby((parse_row(row, fields_string, fields_int, fields_float) for row in rows), lambda row: (row["geo_type"], row["geo_value"], row["source"], row["signal"])): - geo_type, geo_value, source, signal = key - if alias_mapper: - source = alias_mapper(source, signal) - trend = compute_trend(geo_type, geo_value, source, signal, time_value, basis_time_value, ((row["time_value"], row["value"]) for row in group)) - yield trend.asdict() # execute first query try: @@ -217,7 +322,7 @@ def gen(rows): raise DatabaseErrorException(str(e)) # now use a generator for sending the rows and execute all the other queries - return p(filter_fields(gen(r))) + return p(filter_fields(gen_trend(gen_transform(r)))) @bp.route("/trendseries", methods=("GET", "POST")) @@ -227,21 +332,58 @@ def handle_trendseries(): daily_signals, weekly_signals = count_signal_time_types(source_signal_pairs) source_signal_pairs, alias_mapper = create_source_signal_alias_mapper(source_signal_pairs) geo_pairs = parse_geo_pairs() + transform_args = parse_transform_args() + jit_bypass = parse_jit_bypass() time_window, is_day = parse_day_or_week_range_arg("window") + _verify_argument_time_type_matches(is_day, daily_signals, weekly_signals) + basis_shift = extract_integer(("basis", "basis_shift")) if basis_shift is None: basis_shift = 7 + shifter = lambda x: shift_time_value(x, -basis_shift) + if not is_day: + shifter = lambda x: shift_week_value(x, -basis_shift) + + def gen_trend(rows): + for key, group in groupby(rows, lambda row: (row["source"], row["signal"], row["geo_type"], row["geo_value"])): + source, signal, geo_type, geo_value = key + if alias_mapper: + source = alias_mapper(source, signal) + trends = compute_trends(geo_type, geo_value, source, signal, shifter, ((row["time_value"], row["value"]) for row in group)) + for t in trends: + yield t.asdict() + # build query q = QueryBuilder(latest_table, "t") fields_string = ["geo_type", "geo_value", "source", "signal"] fields_int = ["time_value"] fields_float = ["value"] + + use_jit_compute = is_day and JIT_COMPUTE_ON and not jit_bypass + if use_jit_compute: + pad_length = get_pad_length(source_signal_pairs, transform_args.get("smoother_window_length")) + app.logger.info(f"JIT compute enabled for route '/trendseries': {source_signal_pairs}") + source_signal_pairs, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs) + app.logger.info(f"JIT base signals: {source_signal_pairs}") + time_window = pad_time_window(time_window, pad_length) + + def gen_transform(rows): + parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows) + transformed_rows = row_transform_generator(parsed_rows=parsed_rows, time_pairs=[TimePair("day", [time_window])], transform_args=transform_args) + for row in transformed_rows: + yield row + else: + def gen_transform(rows): + parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows) + for row in parsed_rows: + yield row + q.set_fields(fields_string, fields_int, fields_float) - q.set_order("geo_type", "geo_value", "source", "signal", "time_value") + q.set_order("source", "signal", "geo_type", "geo_value", "time_value") q.where_source_signal_pairs("source", "signal", source_signal_pairs) q.where_geo_pairs("geo_type", "geo_value", geo_pairs) @@ -252,19 +394,6 @@ def handle_trendseries(): p = create_printer() - shifter = lambda x: shift_time_value(x, -basis_shift) - if not is_day: - shifter = lambda x: shift_week_value(x, -basis_shift) - - def gen(rows): - for key, group in groupby((parse_row(row, fields_string, fields_int, fields_float) for row in rows), lambda row: (row["geo_type"], row["geo_value"], row["source"], row["signal"])): - geo_type, geo_value, source, signal = key - if alias_mapper: - source = alias_mapper(source, signal) - trends = compute_trends(geo_type, geo_value, source, signal, shifter, ((row["time_value"], row["value"]) for row in group)) - for t in trends: - yield t.asdict() - # execute first query try: r = run_query(p, (str(q), q.params)) @@ -272,7 +401,7 @@ def gen(rows): raise DatabaseErrorException(str(e)) # now use a generator for sending the rows and execute all the other queries - return p(filter_fields(gen(r))) + return p(filter_fields(gen_trend(gen_transform(r)))) @bp.route("/correlation", methods=("GET", "POST")) @@ -355,10 +484,15 @@ def handle_export(): source_signal_pairs, alias_mapper = create_source_signal_alias_mapper(source_signal_pairs) start_day, is_day = parse_day_or_week_arg("start_day", 202001 if weekly_signals > 0 else 20200401) end_day, is_end_day = parse_day_or_week_arg("end_day", 202020 if weekly_signals > 0 else 20200901) + time_window = (start_day, end_day) if is_day != is_end_day: raise ValidationFailedException("mixing weeks with day arguments") + _verify_argument_time_type_matches(is_day, daily_signals, weekly_signals) + transform_args = parse_transform_args() + jit_bypass = parse_jit_bypass() + geo_type = request.args.get("geo_type", "county") geo_values = request.args.get("geo_values", "*") @@ -372,10 +506,32 @@ def handle_export(): # build query q = QueryBuilder(latest_table, "t") - q.set_fields(["geo_value", "signal", "time_value", "issue", "lag", "value", "stderr", "sample_size", "geo_type", "source"], [], []) - q.set_order("time_value", "geo_value") + fields_string = ["geo_value", "signal", "geo_type", "source"] + fields_int = ["time_value", "issue", "lag"] + fields_float = ["value", "stderr", "sample_size"] + + use_jit_compute = all([is_day, is_end_day]) and JIT_COMPUTE_ON and not jit_bypass + if use_jit_compute: + pad_length = get_pad_length(source_signal_pairs, transform_args.get("smoother_window_length")) + app.logger.info(f"JIT compute enabled for route '/csv': {source_signal_pairs}") + source_signal_pairs, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs) + app.logger.info(f"JIT base signals: {source_signal_pairs}") + time_window = pad_time_window(time_window, pad_length) + + def gen_transform(rows): + parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows) + transformed_rows = row_transform_generator(parsed_rows=parsed_rows, time_pairs=[TimePair("day", [time_window])], transform_args=transform_args) + for row in transformed_rows: + yield row + else: + def gen_transform(rows): + for row in rows: + yield row + + q.set_fields(fields_string, fields_int, fields_float) + q.set_order("geo_value", "time_value") q.where_source_signal_pairs("source", "signal", source_signal_pairs) - q.where_time_pairs("time_type", "time_value", [TimePair("day" if is_day else "week", [(start_day, end_day)])]) + q.where_time_pairs("time_type", "time_value", [TimePair("day" if is_day else "week", [time_window])]) q.where_geo_pairs("geo_type", "geo_value", [GeoPair(geo_type, True if geo_values == "*" else geo_values)]) _handle_lag_issues_as_of(q, None, None, as_of) @@ -386,7 +542,7 @@ def handle_export(): filename = "covidcast-{source}-{signal}-{start_day}-to-{end_day}{as_of}".format(source=source, signal=signal, start_day=format_date(start_day), end_day=format_date(end_day), as_of=as_of_str) p = CSVPrinter(filename) - def parse_row(i, row): + def parse_csv_row(i, row): # '',geo_value,signal,{time_value,issue},lag,value,stderr,sample_size,geo_type,data_source return { "": i, @@ -402,10 +558,9 @@ def parse_row(i, row): "data_source": alias_mapper(row["source"], row["signal"]) if alias_mapper else row["source"], } - def gen(first_row, rows): - yield parse_row(0, first_row) + def gen_parse(rows): for i, row in enumerate(rows): - yield parse_row(i + 1, row) + yield parse_csv_row(i, row) # execute query try: @@ -414,14 +569,15 @@ def gen(first_row, rows): raise DatabaseErrorException(str(e)) # special case for no data to be compatible with the CSV server - first_row = next(r, None) + transformed_query = peekable(gen_transform(r)) + first_row = transformed_query.peek(None) if not first_row: return "No matching data found for signal {source}:{signal} " "at {geo} level from {start_day} to {end_day}, as of {as_of}.".format( source=source, signal=signal, geo=geo_type, start_day=format_date(start_day), end_day=format_date(end_day), as_of=(date.today().isoformat() if as_of is None else format_date(as_of)) ) # now use a generator for sending the rows and execute all the other queries - return p(gen(first_row, r)) + return p(gen_parse(transformed_query)) @bp.route("/backfill", methods=("GET", "POST")) diff --git a/src/server/endpoints/covidcast_utils/model.py b/src/server/endpoints/covidcast_utils/model.py index 520cb9c37..fea526f12 100644 --- a/src/server/endpoints/covidcast_utils/model.py +++ b/src/server/endpoints/covidcast_utils/model.py @@ -1,12 +1,28 @@ from dataclasses import asdict, dataclass, field -from typing import Callable, Optional, Dict, List, Set, Tuple from enum import Enum +from functools import partial +from itertools import groupby, repeat, tee +from numbers import Number +from typing import Callable, Generator, Iterator, Optional, Dict, List, Set, Tuple, Union + from pathlib import Path import re +from more_itertools import flatten, interleave_longest, peekable import pandas as pd import numpy as np -from ..._params import SourceSignalPair +from delphi_utils.nancodes import Nans +from ..._params import SourceSignalPair, TimePair +from .smooth_diff import generate_smoothed_rows, generate_diffed_rows +from ...utils import shift_time_value, iterate_over_ints_and_ranges + + +IDENTITY: Callable = lambda rows, **kwargs: rows +DIFF: Callable = lambda rows, **kwargs: generate_diffed_rows(rows, **kwargs) +SMOOTH: Callable = lambda rows, **kwargs: generate_smoothed_rows(rows, **kwargs) +DIFF_SMOOTH: Callable = lambda rows, **kwargs: generate_smoothed_rows(generate_diffed_rows(rows, **kwargs), **kwargs) + +SignalTransforms = Dict[SourceSignalPair, SourceSignalPair] class HighValuesAre(str, Enum): @@ -21,6 +37,7 @@ class SignalFormat(str, Enum): fraction = "fraction" raw_count = "raw_count" raw = "raw" + count = "count" class SignalCategory(str, Enum): @@ -299,3 +316,298 @@ def map_row(source: str, signal: str) -> str: return signal_source.source return transformed_pairs, map_row + + +def _resolve_bool_source_signals(source_signals: Union[SourceSignalPair, List[SourceSignalPair]], data_sources_by_id: Dict[str, DataSource]) -> Union[SourceSignalPair, List[SourceSignalPair]]: + """Expand a request for all signals to an explicit list of signal names. + + Example: SourceSignalPair("jhu-csse", signal=True) would return SourceSignalPair("jhu-csse", []). + """ + if isinstance(source_signals, SourceSignalPair): + if source_signals.signal == True: + source = data_sources_by_id.get(source_signals.source) + if source: + return SourceSignalPair(source.source, [s.signal for s in source.signals]) + return source_signals + + if isinstance(source_signals, list): + return [_resolve_bool_source_signals(pair, data_sources_by_id) for pair in source_signals] + + raise TypeError("source_signals is not Union[SourceSignalPair, List[SourceSignalPair]].") + + +def _reindex_iterable(iterator: Iterator[Dict], time_pairs: Optional[List[TimePair]], fill_value: Optional[Number] = None) -> Iterator[Dict]: + """Produces an iterator that fills in gaps in the time window of another iterator. + + Used to produce an iterator with a contiguous time index for time series operations. + + We iterate over contiguous range of days made from time_pairs. If `iterator`, which is assumed to be sorted by its "time_value" key, + is missing a time_value in the range, a row is returned with the missing date and dummy fields. + """ + # Iterate as normal if time_pairs is empty or None. + if not time_pairs: + yield from iterator + return + + _iterator = peekable(iterator) + + # If the iterator is empty, we halt immediately. + try: + first_item = _iterator.peek() + except StopIteration: + return + + _default_item = first_item.copy() + _default_item.update({"stderr": None, "sample_size": None, "issue": None, "lag": None, "missing_stderr": Nans.NOT_APPLICABLE, "missing_sample_size": Nans.NOT_APPLICABLE}) + + # Non-trivial operations otherwise. + min_time_value = first_item.get("time_value") + for expected_time_value in get_day_range(time_pairs): + if expected_time_value < min_time_value: + continue + + try: + # This will stay the same until the peeked element is consumed. + new_item = _iterator.peek() + except StopIteration: + return + + if expected_time_value == new_item.get("time_value"): + # Get the value we just peeked. + yield next(_iterator) + else: + # Return a default row instead. + # Copy to avoid Python by-reference memory issues. + default_item = _default_item.copy() + default_item.update({"time_value": expected_time_value, "value": fill_value, "missing_value": Nans.NOT_MISSING if fill_value and not np.isnan(fill_value) else Nans.NOT_APPLICABLE}) + yield default_item + + +def _get_base_signal_transform(signal: Union[DataSignal, Tuple[str, str]], data_signals_by_key: Dict[Tuple[str, str], DataSignal] = data_signals_by_key) -> Callable: + """Given a DataSignal, return the transformation that needs to be applied to its base signal to derive the signal.""" + if isinstance(signal, DataSignal): + base_signal = data_signals_by_key.get((signal.source, signal.signal_basename)) + if signal.format not in [SignalFormat.raw, SignalFormat.raw_count, SignalFormat.count] or not signal.compute_from_base or not base_signal: + return IDENTITY + if signal.is_cumulative and signal.is_smoothed: + return SMOOTH + if not signal.is_cumulative and not signal.is_smoothed: + return DIFF if base_signal.is_cumulative else IDENTITY + if not signal.is_cumulative and signal.is_smoothed: + return DIFF_SMOOTH if base_signal.is_cumulative else SMOOTH + return IDENTITY + + if isinstance(signal, tuple): + if signal := data_signals_by_key.get(signal): + return _get_base_signal_transform(signal, data_signals_by_key) + return IDENTITY + + raise TypeError("signal must be either Tuple[str, str] or DataSignal.") + + +def get_transform_types( + source_signal_pairs: List[SourceSignalPair], + data_sources_by_id: Dict[str, DataSource] = data_sources_by_id, + data_signals_by_key: Dict[Tuple[str, str], DataSignal] = data_signals_by_key +) -> Set[Callable]: + """Return a collection of the unique transforms required for transforming a given source-signal pair list. + + Example: + SourceSignalPair("src", ["sig", "sig_smoothed", "sig_diff"]) would return {IDENTITY, SMOOTH, DIFF}. + + Used to pad the user DB query with extra days. + """ + source_signal_pairs = _resolve_bool_source_signals(source_signal_pairs, data_sources_by_id) + + transform_types = set() + for source_signal_pair in source_signal_pairs: + source_name = source_signal_pair.source + signal_names = source_signal_pair.signal + + if isinstance(signal_names, bool): + continue + + transform_types |= {_get_base_signal_transform((source_name, signal_name), data_signals_by_key=data_signals_by_key) for signal_name in signal_names} + + return transform_types + + +def get_pad_length( + source_signal_pairs: List[SourceSignalPair], + smoother_window_length: int, + data_sources_by_id: Dict[str, DataSource] = data_sources_by_id, + data_signals_by_key: Dict[Tuple[str, str], DataSignal] = data_signals_by_key, +): + """Returns the size of the extra date padding needed, depending on the transformations the source-signal pair list requires. + + If smoothing is required, we fetch an extra smoother_window_length - 1 days (6 by default). If both diffing and smoothing is required on the same signal, + then we fetch extra smoother_window_length days (7 by default). + + Used to pad the user DB query with extra days. + """ + transform_types = get_transform_types(source_signal_pairs, data_sources_by_id=data_sources_by_id, data_signals_by_key=data_signals_by_key) + pad_length = [0] + if DIFF_SMOOTH in transform_types: + pad_length.append(smoother_window_length) + if SMOOTH in transform_types: + pad_length.append(smoother_window_length - 1) + if DIFF in transform_types: + pad_length.append(1) + return max(pad_length) + + +def pad_time_pairs(time_pairs: List[TimePair], pad_length: int) -> List[TimePair]: + """Pads a list of TimePairs with another TimePair that extends the smallest time value by the pad_length, if needed. + + Assumes day time_type, since this function is only called for JIT computations which share the same assumption. + + Example: + [TimePair("day", [20210407])] with pad_length 6 would return [TimePair("day", [20210407]), TimePair("day", [(20210401, 20210407)])]. + """ + if pad_length < 0: + raise ValueError("pad_length should non-negative.") + if pad_length == 0: + return time_pairs.copy() + + extracted_non_bool_time_values = flatten(time_pair.time_values for time_pair in time_pairs if not isinstance(time_pair.time_values, bool)) + min_time = min(time_value if isinstance(time_value, int) else time_value[0] for time_value in extracted_non_bool_time_values) + + padded_time = TimePair("day", [(shift_time_value(min_time, -1 * pad_length), min_time)]) + + return time_pairs + [padded_time] + + +def pad_time_window(time_window: Tuple[int, int], pad_length: int) -> Tuple[int, int]: + """Extend a time window on the left by pad_length. + + Example: + (20210407, 20210413) with pad_length 6 would return (20210401, 20210413). + + Used to pad the user DB query with extra days. + """ + if pad_length < 0: + raise ValueError("pad_length should non-negative.") + if pad_length == 0: + return time_window + min_time, max_time = time_window + return (shift_time_value(min_time, -1 * pad_length), max_time) + + +def get_day_range(time_pairs: List[TimePair]) -> Iterator[int]: + """Iterate over a list of TimePair time_values, including the values contained in the ranges. + + Example: + [TimePair("day", [20210407, 20210408]), TimePair("day", [20210405, (20210408, 20210411)])] would iterate over + [20210405, 20210407, 20210408, 20210409, 20210410, 20210411]. + """ + time_values_flattened = [] + + for time_pair in time_pairs: + if time_pair.time_type != "day": + raise ValueError("get_day_range only supports 'day' time_type.") + + if isinstance(time_pair.time_values, int): + time_values_flattened.append(time_pair.time_values) + elif isinstance(time_pair.time_values, list): + time_values_flattened.extend(time_pair.time_values) + else: + raise ValueError("get_day_range only supports int or list time_values.") + + return iterate_over_ints_and_ranges(time_values_flattened) + + +def _generate_transformed_rows( + parsed_rows: Iterator[Dict], + time_pairs: Optional[List[TimePair]] = None, + transform_dict: Optional[SignalTransforms] = None, + transform_args: Optional[Dict] = None, + group_keyfunc: Optional[Callable] = None, + data_signals_by_key: Dict[Tuple[str, str], DataSignal] = data_signals_by_key, +) -> Iterator[Dict]: + """Applies time-series transformations to streamed rows from a database. + + Parameters: + parsed_rows: Iterator[Dict] + An iterator streaming rows from a database query. Assumed to be sorted by source, signal, geo_type, geo_value, time_type, and time_value. + time_pairs: Optional[List[TimePair]], default None + A list of TimePairs, which can be used to create a continguous time index for time-series operations. + The min and max dates in the TimePairs list is used. + transform_dict: Optional[SignalTransforms], default None + A dictionary mapping base sources to a list of their derived signals that the user wishes to query. + For example, transform_dict may be {("jhu-csse", "confirmed_cumulative_num): [("jhu-csse", "confirmed_incidence_num"), ("jhu-csse", "confirmed_7dav_incidence_num")]}. + transform_args: Optional[Dict], default None + A dictionary of keyword arguments for the transformer functions. + group_keyfunc: Optional[Callable], default None + The groupby function to use to order the streamed rows. Note that Python groupby does not do any sorting, so + parsed_rows are assumed to be sorted in accord with this groupby. + data_signals_by_key: Dict[Tuple[str, str], DataSignal], default data_signals_by_key + The dictionary of DataSignals which is used to find the base signal transforms. + + Yields: + transformed rows: Dict + The transformed rows returned in an interleaved fashion. Non-transformed rows have the IDENTITY operation applied. + """ + if not transform_args: + transform_args = dict() + if not transform_dict: + transform_dict = dict() + if not group_keyfunc: + group_keyfunc = lambda row: (row["source"], row["signal"], row["geo_type"], row["geo_value"]) + + for key, source_signal_geo_rows in groupby(parsed_rows, group_keyfunc): + base_source_name, base_signal_name, _, _ = key + # Extract the list of derived signals; if a signal is not in the dictionary, then use the identity map. + derived_signal_transform_map: SourceSignalPair = transform_dict.get(SourceSignalPair(base_source_name, [base_signal_name]), SourceSignalPair(base_source_name, [base_signal_name])) + # Create a list of source-signal pairs along with the transformation required for the signal. + signal_names_and_transforms: List[Tuple[Tuple[str, str], Callable]] = [(derived_signal, _get_base_signal_transform((base_source_name, derived_signal), data_signals_by_key)) for derived_signal in derived_signal_transform_map.signal] + # Put the current time series on a contiguous time index. + source_signal_geo_rows = _reindex_iterable(source_signal_geo_rows, time_pairs, fill_value=transform_args.get("pad_fill_value")) + # Create copies of the iterable, with smart memory usage. + source_signal_geo_rows_copies: Iterator[Iterator[Dict]] = tee(source_signal_geo_rows, len(signal_names_and_transforms)) + # Create a list of transformed group iterables, remembering their derived name as needed. + transformed_signals_iterator: Iterator[Tuple[str, Iterator[Dict]]] = (zip(repeat(derived_signal), transform(rows, **transform_args)) for (derived_signal, transform), rows in zip(signal_names_and_transforms, source_signal_geo_rows_copies)) + # Traverse through the transformed iterables in an interleaved fashion, which makes sure that only a small window + # of the original iterable (group) is stored in memory. + for derived_signal_name, row in interleave_longest(*transformed_signals_iterator): + row["signal"] = derived_signal_name + yield row + + +def get_basename_signal_and_jit_generator( + source_signal_pairs: List[SourceSignalPair], + transform_args: Optional[Dict[str, Union[str, int]]] = None, + data_sources_by_id: Dict[str, DataSource] = data_sources_by_id, + data_signals_by_key: Dict[Tuple[str, str], DataSignal] = data_signals_by_key, +) -> Tuple[List[SourceSignalPair], Generator]: + """From a list of SourceSignalPairs, return the base signals required to derive them and a transformation function to take a stream + of the base signals and return the transformed signals. + + Example: + SourceSignalPair("src", signal=["sig_base", "sig_smoothed"]) would return SourceSignalPair("src", signal=["sig_base"]) and a transformation function + that will take the returned database query for "sig_base" and return both the base time series and the smoothed time series. transform_dict in this case + would be {("src", "sig_base"): [("src", "sig_base"), ("src", "sig_smooth")]}. + """ + source_signal_pairs = _resolve_bool_source_signals(source_signal_pairs, data_sources_by_id) + base_signal_pairs: List[SourceSignalPair] = [] + transform_dict: SignalTransforms = dict() + + for pair in source_signal_pairs: + # Should only occur when the SourceSignalPair was unrecognized by _resolve_bool_source_signals. Useful for testing with fake signal names. + if isinstance(pair.signal, bool): + base_signal_pairs.append(pair) + continue + + signals = [] + for signal_name in pair.signal: + signal = data_signals_by_key.get((pair.source, signal_name)) + if not signal or not signal.compute_from_base: + transform_dict.setdefault(SourceSignalPair(source=pair.source, signal=[signal_name]), SourceSignalPair(source=pair.source, signal=[])).add_signal(signal_name) + signals.append(signal_name) + else: + transform_dict.setdefault(SourceSignalPair(source=pair.source, signal=[signal.signal_basename]), SourceSignalPair(source=pair.source, signal=[])).add_signal(signal_name) + signals.append(signal.signal_basename) + base_signal_pairs.append(SourceSignalPair(pair.source, signals)) + + row_transform_generator = partial(_generate_transformed_rows, transform_dict=transform_dict, transform_args=transform_args, data_signals_by_key=data_signals_by_key) + + return base_signal_pairs, row_transform_generator diff --git a/src/server/endpoints/covidcast_utils/smooth_diff.py b/src/server/endpoints/covidcast_utils/smooth_diff.py new file mode 100644 index 000000000..d4a986c97 --- /dev/null +++ b/src/server/endpoints/covidcast_utils/smooth_diff.py @@ -0,0 +1,179 @@ +from enum import Enum +from logging import getLogger +from numbers import Number +from typing import Dict, Iterable, List, Union + +from delphi_utils.nancodes import Nans +from more_itertools import windowed +from numpy import array, dot, isnan, nan, nan_to_num, ndarray + +from ...utils.dates import time_value_to_date + + +class SmootherKernelValue(str, Enum): + average = "average" + + +def generate_smoothed_rows( + rows: Iterable[Dict], + smoother_kernel: Union[List[Number], SmootherKernelValue] = SmootherKernelValue.average, + smoother_window_length: int = 7, + nan_fill_value: Number = nan, + **kwargs +) -> Iterable[Dict]: + """Generate smoothed row entries. + + There are roughly two modes of boundary handling: + * no padding, the windows start at length 1 on the left boundary and grow to size + smoother_window_length (achieved with pad_fill_value = None) + * value padding, smoother_window_length - 1 many fill_values are appended at the start of the + given date (achieved with any other pad_fill_value) + + Note that this function crucially relies on the assumption that the iterable rows + have been sorted by time_value. If this assumption is violated, the results will likely be + incoherent. + + Parameters + ---------- + rows: Iterable[Dict] + An iterable over the rows a database query returns. The rows are assumed to be + dicts containing the "geo_type", "geo_value", and "time_value" keys. Assumes the + rows have been sorted by geo and time_value beforehand. + smooth_kernel: Union[List[Number], SmootherKernelValue], default SmootherValue.average + Either a SmootherKernelValue or a custom list of numbers for weighted averaging. + smoother_window_length: int, default 7 + The length of the averaging window for the smoother. + nan_fill_value: Number, default nan + The value to use when encountering nans (e.g. None and numpy.nan types); uses nan by default. + **kwargs: + Container for non-shared parameters with other computation functions. + """ + # Validate params. + if not isinstance(smoother_window_length, int) or smoother_window_length < 1: + smoother_window_length = 7 + if isinstance(smoother_kernel, list): + smoother_window_length = len(smoother_kernel) + if not isinstance(nan_fill_value, Number): + nan_fill_value = nan + if not isinstance(smoother_kernel, (list, SmootherKernelValue)): + smoother_kernel = SmootherKernelValue.average + + for window in windowed(rows, smoother_window_length): # Iterable[List[Dict]] + # This occurs only if len(rows) < smoother_window_length. + if None in window: + continue + + new_value = _smoother(_get_validated_window_values(window, nan_fill_value), kernel=smoother_kernel) + # The database returns NULL values as None, so we stay consistent with that. + new_value = float(round(new_value, 7)) if not isnan(new_value) else None + if new_value and isnan(new_value): + breakpoint() + + new_item = _fill_remaining_row_values(window) + new_item.update({"value": new_value, "missing_value": Nans.NOT_MISSING if new_value is not None else Nans.NOT_APPLICABLE}) + + yield new_item + + +def generate_diffed_rows(rows: Iterable[Dict], nan_fill_value: Number = nan, **kwargs) -> Iterable[Dict]: + """Generate differences between row values. + + Note that this function crucially relies on the assumption that the iterable rows have been + sorted by time_value. If this assumption is violated, the results will likely be incoherent. + + rows: Iterable[Dict] + An iterable over the rows a database query returns. The rows are assumed to be dicts + containing the "geo_type", "geo_value", and "time_value" keys. Assumes the rows have been + sorted by geo and time_value beforehand. + nan_fill_value: Number, default nan + The value to use when encountering nans (e.g. None and numpy.nan types); uses nan by default. + **kwargs: + Container for non-shared parameters with other computation functions. + """ + if not isinstance(nan_fill_value, Number): + nan_fill_value = nan + + for window in windowed(rows, 2): + # This occurs only if len(rows) < 2. + if None in window: + continue + + first_value, second_value = _get_validated_window_values(window, nan_fill_value) + new_value = round(second_value - first_value, 7) + # The database returns NULL values as None, so we stay consistent with that. + new_value = float(new_value) if not isnan(new_value) else None + + new_item = _fill_remaining_row_values(window) + new_item.update({"value": new_value, "missing_value": Nans.NOT_MISSING if new_value is not None else Nans.NOT_APPLICABLE}) + + yield new_item + + +def _smoother(values: List[Number], kernel: Union[List[Number], SmootherKernelValue] = SmootherKernelValue.average) -> Number: + """Basic smoother. + + If kernel passed, uses the kernel as summation weights. If something is wrong, + defaults to the mean. + """ + + if kernel and isinstance(kernel, list): + kernel = array(kernel, copy=False) + values = array(values, copy=False) + smoothed_value = dot(values, kernel) + elif kernel and isinstance(kernel, SmootherKernelValue): + if kernel == SmootherKernelValue.average: + smoothed_value = array(values, copy=False).mean() + else: + raise ValueError("Unimplemented SmootherKernelValue.") + else: + raise ValueError("Kernel must be specified in _smoother.") + + return smoothed_value + + +def _get_validated_window_values(window: List[dict], nan_fill_value: Number) -> ndarray: + """Extracts and validates the values in the window, returning a list of floats. + + The main objective is to create a consistent nan type values from None or np.nan. We replace None with np.nan, so they can be filled. + + Assumes any None values were filtered out of window, so it is a list of Dict only. + """ + return nan_to_num([e.get("value") if e.get("value") is not None else nan for e in window], nan=nan_fill_value) + + +def _fill_remaining_row_values(window: Iterable[dict]) -> dict: + """Set a few default fields for the covidcast row.""" + logger = getLogger("gunicorn.error") + + # Start by defaulting to the field values of the last window member. + new_item = window[-1].copy() + + try: + issues = [e.get("issue") for e in window] + if None in issues: + new_issue = None + else: + new_issue = max(issues) + except (TypeError, ValueError): + logger.warn(f"There was an error computing an issue field for {new_item.get('source')}:{new_item.get('signal')}.") + new_issue = None + + try: + if new_issue is None: + new_lag = None + else: + new_lag = (time_value_to_date(new_issue) - time_value_to_date(new_item["time_value"])).days + except (TypeError, ValueError): + logger.warn(f"There was an error computing a lag field for {new_item.get('source')}:{new_item.get('signal')}.") + new_lag = None + + new_item.update({ + "issue": new_issue, + "lag": new_lag, + "stderr": None, + "sample_size": None, + "missing_stderr": Nans.NOT_APPLICABLE, + "missing_sample_size": Nans.NOT_APPLICABLE + }) + + return new_item diff --git a/src/server/endpoints/covidcast_utils/trend.py b/src/server/endpoints/covidcast_utils/trend.py index 43c4ac21b..9a2825208 100644 --- a/src/server/endpoints/covidcast_utils/trend.py +++ b/src/server/endpoints/covidcast_utils/trend.py @@ -42,6 +42,8 @@ def compute_trend(geo_type: str, geo_value: str, signal_source: str, signal_sign # find all needed rows for time, value in rows: + if value is None: + continue if time == current_time: t.value = value if time == basis_time: @@ -73,6 +75,8 @@ def compute_trends(geo_type: str, geo_value: str, signal_source: str, signal_sig lookup: Dict[int, float] = OrderedDict() # find all needed rows for time, value in rows: + if value is None: + continue lookup[time] = value if min_value is None or min_value > value: min_date = time diff --git a/src/server/utils/__init__.py b/src/server/utils/__init__.py index 3198779d0..bddb00c43 100644 --- a/src/server/utils/__init__.py +++ b/src/server/utils/__init__.py @@ -1 +1,17 @@ -from .dates import shift_time_value, date_to_time_value, time_value_to_iso, time_value_to_date, days_in_range, weeks_in_range, shift_week_value, week_to_time_value, week_value_to_week, guess_time_value_is_day, time_values_to_ranges, days_to_ranges, weeks_to_ranges +from .dates import ( + shift_time_value, + date_to_time_value, + time_value_to_iso, + time_value_to_date, + days_in_range, + weeks_in_range, + shift_week_value, + week_to_time_value, + week_value_to_week, + guess_time_value_is_day, + time_values_to_ranges, + days_to_ranges, + weeks_to_ranges, + iterate_over_range, + iterate_over_ints_and_ranges, +) diff --git a/src/server/utils/dates.py b/src/server/utils/dates.py index f2b21f87b..5a2fb0205 100644 --- a/src/server/utils/dates.py +++ b/src/server/utils/dates.py @@ -1,5 +1,6 @@ from typing import ( Callable, + Iterator, Optional, Sequence, Tuple, @@ -138,3 +139,47 @@ def _to_ranges(values: Sequence[Union[Tuple[int, int], int]], value_to_date: Cal except Exception as e: logging.info('bad input to date ranges', input=values, exception=e) return values + +def iterate_over_range(start: int, end: int) -> Iterator[int]: + """Iterate over ints corresponding to dates in a time range. + + Left inclusive, right exclusive to mimic behavior of Python's built-in range. + """ + if start > end: + return + + current_date, final_date = time_value_to_date(start), time_value_to_date(end) + while current_date < final_date: + yield date_to_time_value(current_date) + current_date = current_date + timedelta(days=1) + +def iterate_over_ints_and_ranges(lst: Iterator[Union[int, Tuple[int, int]]], use_dates: bool = True) -> Iterator[int]: + """A generator that iterates over the unique values in a list of integers and ranges in ascending order. + + The tuples are assumed to be left- and right-inclusive. If use_dates is True, then the integers are interpreted as + YYYYMMDD dates. + + Examples: + >>> list(iterate_over_ints_and_ranges([(5, 8), 0], False)) + [0, 5, 6, 7, 8] + >>> list(iterate_over_ints_and_ranges([(5, 8), (4, 6), (3, 5)], False)) + [3, 4, 5, 6, 7, 8] + >>> list(iterate_over_ints_and_ranges([(7, 8), (5, 7), (3, 8), 8], False)) + [3, 4, 5, 6, 7, 8] + """ + lst = sorted((x, x) if isinstance(x, int) else x for x in lst) + if not lst: + return + + if use_dates: + increment = lambda x, y: date_to_time_value(time_value_to_date(x) + timedelta(days=y)) + range_handler = iterate_over_range + else: + increment = lambda x, y: x + y + range_handler = range + + biggest_seen = increment(lst[0][0], -1) + for a, b in lst: + for y in range_handler(max(a, increment(biggest_seen, 1)), increment(b, 1)): + yield y + biggest_seen = max(biggest_seen, b) diff --git a/tests/acquisition/covidcast/test_covidcast_row.py b/tests/acquisition/covidcast/test_covidcast_row.py index 969b521b9..ef3793226 100644 --- a/tests/acquisition/covidcast/test_covidcast_row.py +++ b/tests/acquisition/covidcast/test_covidcast_row.py @@ -7,6 +7,9 @@ from delphi.epidata.server.utils.dates import date_to_time_value from delphi.epidata.acquisition.covidcast.covidcast_row import set_df_dtypes, transpose_dict, CovidcastRow, CovidcastRows +# py3tester coverage target (equivalent to `import *`) +__test_target__ = 'delphi.epidata.acquisition.covidcast.covidcast_row' + class TestCovidcastRows(unittest.TestCase): def test_transpose_dict(self): assert transpose_dict(dict([["a", [2, 4, 6]], ["b", [3, 5, 7]], ["c", [10, 20, 30]]])) == [{"a": 2, "b": 3, "c": 10}, {"a": 4, "b": 5, "c": 20}, {"a": 6, "b": 7, "c": 30}] diff --git a/tests/server/endpoints/covidcast_utils/test_model.py b/tests/server/endpoints/covidcast_utils/test_model.py new file mode 100644 index 000000000..246f6702a --- /dev/null +++ b/tests/server/endpoints/covidcast_utils/test_model.py @@ -0,0 +1,557 @@ +import unittest +from itertools import chain +from numbers import Number +from typing import Iterable, List, Optional + +import pandas as pd +from more_itertools import interleave_longest, windowed +from pandas.testing import assert_frame_equal + +from delphi.epidata.acquisition.covidcast.covidcast_row import CovidcastRows +from delphi.epidata.server._params import SourceSignalPair, TimePair +from delphi.epidata.server.endpoints.covidcast_utils.model import ( + _generate_transformed_rows, + _get_base_signal_transform, + _reindex_iterable, + _resolve_bool_source_signals, + DataSignal, + DataSource, + DIFF_SMOOTH, + DIFF, + get_basename_signal_and_jit_generator, + get_day_range, + get_pad_length, + get_transform_types, + IDENTITY, + pad_time_pairs, + SMOOTH, +) +from delphi_utils.nancodes import Nans + +# fmt: off +DATA_SIGNALS_BY_KEY = { + ("src", "sig_diff"): DataSignal( + source="src", + signal="sig_diff", + signal_basename="sig_base", + name="src", + active=True, + short_description="", + description="", + time_label="", + value_label="", + is_cumulative=False, + compute_from_base=True, + ), + ("src", "sig_smooth"): DataSignal( + source="src", + signal="sig_smooth", + signal_basename="sig_base", + name="src", + active=True, + short_description="", + description="", + time_label="", + value_label="", + is_cumulative=True, + is_smoothed=True, + compute_from_base=True, + ), + ("src", "sig_diff_smooth"): DataSignal( + source="src", + signal="sig_diff_smooth", + signal_basename="sig_base", + name="src", + active=True, + short_description="", + description="", + time_label="", + value_label="", + is_cumulative=False, + is_smoothed=True, + compute_from_base=True, + ), + ("src", "sig_base"): DataSignal( + source="src", + signal="sig_base", + signal_basename="sig_base", + name="src", + active=True, + short_description="", + description="", + time_label="", + value_label="", + is_cumulative=True, + ), + ("src2", "sig_base"): DataSignal( + source="src2", + signal="sig_base", + signal_basename="sig_base", + name="sig_base", + active=True, + short_description="", + description="", + time_label="", + value_label="", + is_cumulative=True, + ), + ("src2", "sig_diff_smooth"): DataSignal( + source="src2", + signal="sig_diff_smooth", + signal_basename="sig_base", + name="sig_smooth", + active=True, + short_description="", + description="", + time_label="", + value_label="", + is_cumulative=False, + is_smoothed=True, + compute_from_base=True, + ), +} + +DATA_SOURCES_BY_ID = { + "src": DataSource( + source="src", + db_source="src", + name="src", + description="", + reference_signal="sig_base", + signals=[DATA_SIGNALS_BY_KEY[key] for key in DATA_SIGNALS_BY_KEY if key[0] == "src"], + ), + "src2": DataSource( + source="src2", + db_source="src2", + name="src2", + description="", + reference_signal="sig_base", + signals=[DATA_SIGNALS_BY_KEY[key] for key in DATA_SIGNALS_BY_KEY if key[0] == "src2"], + ), +} +# fmt: on + + +def _diff_rows(rows: Iterable[Number]) -> List[Number]: + return [round(float(y - x), 8) if not (x is None or y is None) else None for x, y in windowed(rows, 2)] + + +def _smooth_rows(rows: Iterable[Number], window_length: int = 7, kernel: Optional[List[Number]] = None): + if not kernel: + kernel = [1.0 / window_length] * window_length + return [round(sum(x * y for x, y in zip(window, kernel)), 8) if None not in window else None for window in windowed(rows, len(kernel))] + + +def _reindex_windowed(lst: list, window_length: int) -> list: + return [max(window) if None not in window else None for window in windowed(lst, window_length)] + + +class TestModel(unittest.TestCase): + def test__resolve_bool_source_signals(self): + source_signal_pair = [SourceSignalPair(source="src", signal=True), SourceSignalPair(source="src", signal=["sig_unknown"])] + resolved_source_signal_pair = _resolve_bool_source_signals(source_signal_pair, DATA_SOURCES_BY_ID) + expected_source_signal_pair = [ + SourceSignalPair(source="src", signal=["sig_diff", "sig_smooth", "sig_diff_smooth", "sig_base"]), + SourceSignalPair(source="src", signal=["sig_unknown"]), + ] + assert resolved_source_signal_pair == expected_source_signal_pair + + def test__reindex_iterable(self): + # Trivial test. + time_pairs = [(20210503, 20210508)] + assert list(_reindex_iterable([], time_pairs)) == [] + + data = CovidcastRows.from_args(time_value=pd.date_range("2021-05-03", "2021-05-08").to_list()).api_row_df + for time_pairs in [[TimePair("day", [(20210503, 20210508)])], [], None]: + with self.subTest(f"Identity operations: {time_pairs}"): + df = CovidcastRows.from_records(_reindex_iterable(data.to_dict(orient="records"), time_pairs)).api_row_df + assert_frame_equal(df, data) + + data = CovidcastRows.from_args(time_value=pd.date_range("2021-05-03", "2021-05-08").to_list() + pd.date_range("2021-05-11", "2021-05-14").to_list()).api_row_df + with self.subTest("Non-trivial operations"): + time_pairs = [TimePair("day", [(20210501, 20210513)])] + + df = CovidcastRows.from_records(_reindex_iterable(data.to_dict(orient="records"), time_pairs)).api_row_df + expected_df = CovidcastRows.from_args( + time_value=pd.date_range("2021-05-03", "2021-05-13"), + issue=pd.date_range("2021-05-03", "2021-05-08").to_list() + [None] * 2 + pd.date_range("2021-05-11", "2021-05-13").to_list(), + lag=[0] * 6 + [None] * 2 + [0] * 3, + value=chain([10.0] * 6, [None] * 2, [10.0] * 3), + stderr=chain([10.0] * 6, [None] * 2, [10.0] * 3), + sample_size=chain([10.0] * 6, [None] * 2, [10.0] * 3), + ).api_row_df + assert_frame_equal(df, expected_df) + + df = CovidcastRows.from_records(_reindex_iterable(data.to_dict(orient="records"), time_pairs, fill_value=2.0)).api_row_df + expected_df = CovidcastRows.from_args( + time_value=pd.date_range("2021-05-03", "2021-05-13"), + issue=pd.date_range("2021-05-03", "2021-05-08").to_list() + [None] * 2 + pd.date_range("2021-05-11", "2021-05-13").to_list(), + lag=[0] * 6 + [None] * 2 + [0] * 3, + value=chain([10.0] * 6, [2.0] * 2, [10.0] * 3), + stderr=chain([10.0] * 6, [None] * 2, [10.0] * 3), + sample_size=chain([10.0] * 6, [None] * 2, [10.0] * 3), + ).api_row_df + assert_frame_equal(df, expected_df) + + def test__get_base_signal_transform(self): + assert _get_base_signal_transform(DATA_SIGNALS_BY_KEY[("src", "sig_smooth")], DATA_SIGNALS_BY_KEY) == SMOOTH + assert _get_base_signal_transform(DATA_SIGNALS_BY_KEY[("src", "sig_diff_smooth")], DATA_SIGNALS_BY_KEY) == DIFF_SMOOTH + assert _get_base_signal_transform(DATA_SIGNALS_BY_KEY[("src", "sig_diff")], DATA_SIGNALS_BY_KEY) == DIFF + assert _get_base_signal_transform(("src", "sig_diff"), DATA_SIGNALS_BY_KEY) == DIFF + assert _get_base_signal_transform(DATA_SIGNALS_BY_KEY[("src", "sig_base")], DATA_SIGNALS_BY_KEY) == IDENTITY + assert _get_base_signal_transform(("src", "sig_unknown"), DATA_SIGNALS_BY_KEY) == IDENTITY + + def test_get_transform_types(self): + source_signal_pairs = [SourceSignalPair(source="src", signal=True)] + transform_types = get_transform_types(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + expected_transform_types = {IDENTITY, DIFF, SMOOTH, DIFF_SMOOTH} + assert transform_types == expected_transform_types + + source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_diff"])] + transform_types = get_transform_types(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + expected_transform_types = {DIFF} + assert transform_types == expected_transform_types + + source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_smooth"])] + transform_types = get_transform_types(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + expected_transform_types = {SMOOTH} + assert transform_types == expected_transform_types + + source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_diff_smooth"])] + transform_types = get_transform_types(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + expected_transform_types = {DIFF_SMOOTH} + assert transform_types == expected_transform_types + + def test_get_pad_length(self): + source_signal_pairs = [SourceSignalPair(source="src", signal=True)] + pad_length = get_pad_length(source_signal_pairs, smoother_window_length=7, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + assert pad_length == 7 + + source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_diff"])] + pad_length = get_pad_length(source_signal_pairs, smoother_window_length=7, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + assert pad_length == 1 + + source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_smooth"])] + pad_length = get_pad_length(source_signal_pairs, smoother_window_length=5, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + assert pad_length == 4 + + source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_diff_smooth"])] + pad_length = get_pad_length(source_signal_pairs, smoother_window_length=10, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + assert pad_length == 10 + + def test_pad_time_pairs(self): + # fmt: off + time_pairs = [ + TimePair("day", [20210817, (20210810, 20210815)]), + TimePair("day", True), + TimePair("day", [20210816]) + ] + expected_padded_time_pairs = [ + TimePair("day", [20210817, (20210810, 20210815)]), + TimePair("day", True), + TimePair("day", [20210816]), + TimePair("day", [(20210803, 20210810)]) + ] + assert pad_time_pairs(time_pairs, pad_length=7) == expected_padded_time_pairs + + time_pairs = [ + TimePair("day", [20210817, (20210810, 20210815)]), + TimePair("day", True), + TimePair("day", [20210816]), + TimePair("day", [20210809]) + ] + expected_padded_time_pairs = [ + TimePair("day", [20210817, (20210810, 20210815)]), + TimePair("day", True), + TimePair("day", [20210816]), + TimePair("day", [20210809]), + TimePair("day", [(20210801, 20210809)]), + ] + assert pad_time_pairs(time_pairs, pad_length=8) == expected_padded_time_pairs + + time_pairs = [ + TimePair("day", [20210817, (20210810, 20210815)]) + ] + expected_padded_time_pairs = [ + TimePair("day", [20210817, (20210810, 20210815)]) + ] + assert pad_time_pairs(time_pairs, pad_length=0) == expected_padded_time_pairs + # fmt: on + + def test__generate_transformed_rows(self): + # fmt: off + with self.subTest("diffed signal test"): + data = CovidcastRows.from_args( + signal=["sig_base"] * 5, + time_value=pd.date_range("2021-05-01", "2021-05-05"), + value=range(5) + ).api_row_df + transform_dict = {SourceSignalPair("src", ["sig_base"]): SourceSignalPair("src", ["sig_diff"])} + df = CovidcastRows.from_records(_generate_transformed_rows(data.to_dict(orient="records"), transform_dict=transform_dict, data_signals_by_key=DATA_SIGNALS_BY_KEY)).api_row_df + + expected_df = CovidcastRows.from_args( + signal=["sig_diff"] * 4, + time_value=pd.date_range("2021-05-02", "2021-05-05"), + value=[1.0] * 4, + stderr=[None] * 4, + sample_size=[None] * 4, + missing_stderr=[Nans.NOT_APPLICABLE] * 4, + missing_sample_size=[Nans.NOT_APPLICABLE] * 4, + ).api_row_df + + assert_frame_equal(df, expected_df) + + with self.subTest("smoothed and diffed signals on one base test"): + data = CovidcastRows.from_args( + signal=["sig_base"] * 10, + time_value=pd.date_range("2021-05-01", "2021-05-10"), + value=range(10), + stderr=range(10), + sample_size=range(10) + ).api_row_df + transform_dict = {SourceSignalPair("src", ["sig_base"]): SourceSignalPair("src", ["sig_diff", "sig_smooth"])} + df = CovidcastRows.from_records(_generate_transformed_rows(data.to_dict(orient="records"), transform_dict=transform_dict, data_signals_by_key=DATA_SIGNALS_BY_KEY)).api_row_df + + expected_df = CovidcastRows.from_args( + signal=interleave_longest(["sig_diff"] * 9, ["sig_smooth"] * 4), + time_value=interleave_longest(pd.date_range("2021-05-02", "2021-05-10"), pd.date_range("2021-05-07", "2021-05-10")), + value=interleave_longest(_diff_rows(data.value.to_list()), _smooth_rows(data.value.to_list())), + stderr=[None] * 13, + sample_size=[None] * 13, + ).api_row_df + + # Test no order. + idx = ["source", "signal", "time_value"] + assert_frame_equal(df.set_index(idx).sort_index(), expected_df.set_index(idx).sort_index()) + # Test order. + assert_frame_equal(df, expected_df) + + with self.subTest("smoothed and diffed signal on two non-continguous regions"): + data = CovidcastRows.from_args( + signal=["sig_base"] * 15, + time_value=chain(pd.date_range("2021-05-01", "2021-05-10"), pd.date_range("2021-05-16", "2021-05-20")), + value=range(15), + stderr=range(15), + sample_size=range(15), + ).api_row_df + transform_dict = {SourceSignalPair("src", ["sig_base"]): SourceSignalPair("src", ["sig_diff", "sig_smooth"])} + time_pairs = [TimePair("day", [(20210501, 20210520)])] + df = CovidcastRows.from_records( + _generate_transformed_rows(data.to_dict(orient="records"), time_pairs=time_pairs, transform_dict=transform_dict, data_signals_by_key=DATA_SIGNALS_BY_KEY) + ).api_row_df + + filled_values = data.value.to_list()[:10] + [None] * 5 + data.value.to_list()[10:] + filled_time_values = list(chain(pd.date_range("2021-05-01", "2021-05-10"), [None] * 5, pd.date_range("2021-05-16", "2021-05-20"))) + + expected_df = CovidcastRows.from_args( + signal=interleave_longest(["sig_diff"] * 19, ["sig_smooth"] * 14), + time_value=interleave_longest(pd.date_range("2021-05-02", "2021-05-20"), pd.date_range("2021-05-07", "2021-05-20")), + value=interleave_longest(_diff_rows(filled_values), _smooth_rows(filled_values)), + stderr=[None] * 33, + sample_size=[None] * 33, + issue=interleave_longest(_reindex_windowed(filled_time_values, 2), _reindex_windowed(filled_time_values, 7)), + ).api_row_df + # Test no order. + idx = ["source", "signal", "time_value"] + assert_frame_equal(df.set_index(idx).sort_index(), expected_df.set_index(idx).sort_index()) + # Test order. + assert_frame_equal(df, expected_df) + # fmt: on + + def test_get_basename_signals(self): + with self.subTest("none to transform"): + source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_base"])] + basename_pairs, _ = get_basename_signal_and_jit_generator(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + expected_basename_pairs = [SourceSignalPair(source="src", signal=["sig_base"])] + assert basename_pairs == expected_basename_pairs + + with self.subTest("unrecognized signal"): + source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_unknown"])] + basename_pairs, _ = get_basename_signal_and_jit_generator(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + expected_basename_pairs = [SourceSignalPair(source="src", signal=["sig_unknown"])] + assert basename_pairs == expected_basename_pairs + + with self.subTest("plain"): + source_signal_pairs = [ + SourceSignalPair(source="src", signal=["sig_diff", "sig_smooth", "sig_diff_smooth", "sig_base"]), + SourceSignalPair(source="src2", signal=["sig"]), + ] + basename_pairs, _ = get_basename_signal_and_jit_generator(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + expected_basename_pairs = [ + SourceSignalPair(source="src", signal=["sig_base", "sig_base", "sig_base", "sig_base"]), + SourceSignalPair(source="src2", signal=["sig"]), + ] + assert basename_pairs == expected_basename_pairs + + with self.subTest("resolve"): + source_signal_pairs = [SourceSignalPair(source="src", signal=True)] + basename_pairs, _ = get_basename_signal_and_jit_generator(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + expected_basename_pairs = [SourceSignalPair("src", ["sig_base"] * 4)] + assert basename_pairs == expected_basename_pairs + + with self.subTest("test base, diff, smooth"): + # fmt: off + data = CovidcastRows.from_args( + signal=["sig_base"] * 20 + ["sig_other"] * 5, + time_value=chain(pd.date_range("2021-05-01", "2021-05-10"), pd.date_range("2021-05-21", "2021-05-30"), pd.date_range("2021-05-01", "2021-05-05")), + value=chain(range(20), range(5)), + stderr=chain(range(20), range(5)), + sample_size=chain(range(20), range(5)), + ).api_row_df + source_signal_pairs = [SourceSignalPair("src", ["sig_base", "sig_diff", "sig_other", "sig_smooth"])] + _, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + time_pairs = [TimePair("day", [(20210501, 20210530)])] + df = CovidcastRows.from_records(row_transform_generator(data.to_dict(orient="records"), time_pairs=time_pairs)).api_row_df + + filled_values = list(chain(range(10), [None] * 10, range(10, 20))) + filled_time_values = list(chain(pd.date_range("2021-05-01", "2021-05-10"), [None] * 10, pd.date_range("2021-05-21", "2021-05-30"))) + + expected_df = CovidcastRows.from_args( + signal=["sig_base"] * 30 + ["sig_diff"] * 29 + ["sig_other"] * 5 + ["sig_smooth"] * 24, + time_value=chain( + pd.date_range("2021-05-01", "2021-05-30"), + pd.date_range("2021-05-02", "2021-05-30"), + pd.date_range("2021-05-01", "2021-05-05"), + pd.date_range("2021-05-07", "2021-05-30") + ), + value=chain( + filled_values, + _diff_rows(filled_values), + range(5), + _smooth_rows(filled_values) + ), + stderr=chain( + chain(range(10), [None] * 10, range(10, 20)), + chain([None] * 29), + range(5), + chain([None] * 24), + ), + sample_size=chain( + chain(range(10), [None] * 10, range(10, 20)), + chain([None] * 29), + range(5), + chain([None] * 24), + ), + issue=chain(filled_time_values, _reindex_windowed(filled_time_values, 2), pd.date_range("2021-05-01", "2021-05-05"), _reindex_windowed(filled_time_values, 7)), + ).api_row_df + # fmt: on + # Test no order. + idx = ["source", "signal", "time_value"] + assert_frame_equal(df.set_index(idx).sort_index(), expected_df.set_index(idx).sort_index()) + + with self.subTest("test base, diff, smooth; multiple geos"): + # fmt: off + data = CovidcastRows.from_args( + signal=["sig_base"] * 40, + geo_value=["ak"] * 20 + ["ca"] * 20, + time_value=chain(pd.date_range("2021-05-01", "2021-05-20"), pd.date_range("2021-05-01", "2021-05-20")), + value=chain(range(20), range(0, 40, 2)), + stderr=chain(range(20), range(0, 40, 2)), + sample_size=chain(range(20), range(0, 40, 2)), + ).api_row_df + source_signal_pairs = [SourceSignalPair("src", ["sig_base", "sig_diff", "sig_other", "sig_smooth"])] + _, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + df = CovidcastRows.from_records(row_transform_generator(data.to_dict(orient="records"))).api_row_df + + expected_df = CovidcastRows.from_args( + signal=["sig_base"] * 40 + ["sig_diff"] * 38 + ["sig_smooth"] * 28, + geo_value=["ak"] * 20 + ["ca"] * 20 + ["ak"] * 19 + ["ca"] * 19 + ["ak"] * 14 + ["ca"] * 14, + time_value=chain( + pd.date_range("2021-05-01", "2021-05-20"), + pd.date_range("2021-05-01", "2021-05-20"), + pd.date_range("2021-05-02", "2021-05-20"), + pd.date_range("2021-05-02", "2021-05-20"), + pd.date_range("2021-05-07", "2021-05-20"), + pd.date_range("2021-05-07", "2021-05-20"), + ), + value=chain( + chain(range(20), range(0, 40, 2)), + chain([1] * 19, [2] * 19), + chain([sum(x) / len(x) for x in windowed(range(20), 7)], + [sum(x) / len(x) for x in windowed(range(0, 40, 2), 7)]) + ), + stderr=chain( + chain(range(20), range(0, 40, 2)), + chain([None] * 38), + chain([None] * 28), + ), + sample_size=chain( + chain(range(20), range(0, 40, 2)), + chain([None] * 38), + chain([None] * 28), + ), + ).api_row_df + # fmt: on + # Test no order. + idx = ["source", "signal", "time_value"] + assert_frame_equal(df.set_index(idx).sort_index(), expected_df.set_index(idx).sort_index()) + + with self.subTest("resolve signals called"): + data = CovidcastRows.from_args( + signal=["sig_base"] * 20 + ["sig_other"] * 5, + time_value=chain(pd.date_range("2021-05-01", "2021-05-10"), pd.date_range("2021-05-21", "2021-05-30"), pd.date_range("2021-05-01", "2021-05-05")), + value=chain(range(20), range(5)), + stderr=chain(range(20), range(5)), + sample_size=chain(range(20), range(5)), + ).api_row_df + source_signal_pairs = [SourceSignalPair("src", True)] + _, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + time_pairs = [TimePair("day", [(20210501, 20210530)])] + df = CovidcastRows.from_records(row_transform_generator(data.to_dict(orient="records"), time_pairs=time_pairs)).api_row_df + + filled_values = list(chain(range(10), [None] * 10, range(10, 20))) + filled_time_values = list(chain(pd.date_range("2021-05-01", "2021-05-10"), [None] * 10, pd.date_range("2021-05-21", "2021-05-30"))) + + # fmt: off + expected_df = CovidcastRows.from_args( + signal=["sig_base"] * 30 + ["sig_diff"] * 29 + ["sig_diff_smooth"] * 23 + ["sig_other"] * 5 + ["sig_smooth"] * 24, + time_value=chain( + pd.date_range("2021-05-01", "2021-05-30"), + pd.date_range("2021-05-02", "2021-05-30"), + pd.date_range("2021-05-08", "2021-05-30"), + pd.date_range("2021-05-01", "2021-05-05"), + pd.date_range("2021-05-07", "2021-05-30"), + ), + value=chain( + filled_values, + _diff_rows(filled_values), + _smooth_rows(_diff_rows(filled_values)), + range(5), + _smooth_rows(filled_values) + ), + stderr=chain( + chain(range(10), [None] * 10, range(10, 20)), + chain([None] * 29), + chain([None] * 23), + range(5), + chain([None] * 24), + ), + sample_size=chain( + chain(range(10), [None] * 10, range(10, 20)), + chain([None] * 29), + chain([None] * 23), + range(5), + chain([None] * 24), + ), + issue=chain( + filled_time_values, + _reindex_windowed(filled_time_values, 2), + _reindex_windowed(filled_time_values, 8), + pd.date_range("2021-05-01", "2021-05-05"), + _reindex_windowed(filled_time_values, 7), + ), + ).api_row_df + # fmt: off + # Test no order. + idx = ["source", "signal", "time_value"] + assert_frame_equal(df.set_index(idx).sort_index(), expected_df.set_index(idx).sort_index()) + + with self.subTest("empty iterator"): + source_signal_pairs = [SourceSignalPair("src", ["sig_base", "sig_diff", "sig_smooth"])] + _, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + assert list(row_transform_generator({})) == [] + + def test_get_day_range(self): + assert list(get_day_range([TimePair("day", [20210817])])) == [20210817] + assert list(get_day_range([TimePair("day", [20210817, (20210810, 20210815)])])) == [20210810, 20210811, 20210812, 20210813, 20210814, 20210815, 20210817] + assert list(get_day_range([TimePair("day", [(20210801, 20210805)]), TimePair("day", [(20210803, 20210807)])])) == [20210801, 20210802, 20210803, 20210804, 20210805, 20210806, 20210807] diff --git a/tests/server/endpoints/covidcast_utils/test_smooth_diff.py b/tests/server/endpoints/covidcast_utils/test_smooth_diff.py new file mode 100644 index 000000000..5ce3e0b8a --- /dev/null +++ b/tests/server/endpoints/covidcast_utils/test_smooth_diff.py @@ -0,0 +1,163 @@ +from pandas import DataFrame, date_range +from pandas.testing import assert_frame_equal +from numpy import nan, isnan +from itertools import chain +from pytest import raises +import unittest + +from delphi.epidata.acquisition.covidcast.covidcast_row import CovidcastRows +from delphi.epidata.server.endpoints.covidcast_utils.smooth_diff import generate_diffed_rows, generate_smoothed_rows, _smoother +from .test_model import _diff_rows, _smooth_rows + + +class TestStreaming(unittest.TestCase): + def test__smoother(self): + assert _smoother(list(range(1, 7)), [1] * 6) == sum(range(1, 7)) + assert _smoother([1] * 6, list(range(1, 7))) == sum(range(1, 7)) + assert isnan(_smoother([1, nan, nan])) + with raises(TypeError, match=r"unsupported operand type*"): + _smoother([1, nan, None]) + + + def test_generate_smoothed_rows(self): + with self.subTest("an empty dataframe should return an empty dataframe"): + data = DataFrame({}) + smoothed_df = CovidcastRows.from_records(generate_smoothed_rows(data.to_dict(orient='records'))).api_row_df + expected_df = CovidcastRows(rows=[]).api_row_df + assert_frame_equal(smoothed_df, expected_df) + + with self.subTest("a dataframe with not enough entries to make a single smoothed value, should return an empty dataframe"): + data = CovidcastRows.from_args( + time_value=[20210501] * 6, + value=[1.0] * 6 + ).api_row_df + + smoothed_df = CovidcastRows.from_records(generate_smoothed_rows(data.to_dict(orient='records'))).api_row_df + expected_df = CovidcastRows(rows=[]).api_row_df + assert_frame_equal(smoothed_df, expected_df) + + data = CovidcastRows.from_args( + time_value=date_range("2021-05-01", "2021-05-13"), + value=chain(range(10), [None, 2., 1.]) + ).api_row_df + + with self.subTest("regular window, nan fill"): + smoothed_df = CovidcastRows.from_records(generate_smoothed_rows(data.to_dict(orient='records'))).api_row_df + + smoothed_values = _smooth_rows(data.value.to_list()) + reduced_time_values = data.time_value.to_list()[-len(smoothed_values):] + + expected_df = CovidcastRows.from_args( + time_value=reduced_time_values, + value=smoothed_values, + stderr=[None] * len(smoothed_values), + sample_size=[None] * len(smoothed_values), + ).api_row_df + + assert_frame_equal(smoothed_df, expected_df) + + with self.subTest("regular window, 0 fill"): + smoothed_df = CovidcastRows.from_records(generate_smoothed_rows(data.to_dict(orient='records'), nan_fill_value=0.)).api_row_df + + smoothed_values = _smooth_rows([v if v is not None and not isnan(v) else 0. for v in data.value.to_list()]) + reduced_time_values = data.time_value.to_list()[-len(smoothed_values):] + + expected_df = CovidcastRows.from_args( + time_value=reduced_time_values, + value=smoothed_values, + stderr=[None] * len(smoothed_values), + sample_size=[None] * len(smoothed_values), + ).api_row_df + + assert_frame_equal(smoothed_df, expected_df) + + with self.subTest("regular window, different window length"): + smoothed_df = CovidcastRows.from_records(generate_smoothed_rows(data.to_dict(orient='records'), smoother_window_length=8)).api_row_df + + smoothed_values = _smooth_rows(data.value.to_list(), window_length=8) + reduced_time_values = data.time_value.to_list()[-len(smoothed_values):] + + expected_df = CovidcastRows.from_args( + time_value=reduced_time_values, + value=smoothed_values, + stderr=[None] * len(smoothed_values), + sample_size=[None] * len(smoothed_values), + ).api_row_df + assert_frame_equal(smoothed_df, expected_df) + + with self.subTest("regular window, different kernel"): + smoothed_df = CovidcastRows.from_records(generate_smoothed_rows(data.to_dict(orient='records'), smoother_kernel=list(range(8)))).api_row_df + + smoothed_values = _smooth_rows(data.value.to_list(), kernel=list(range(8))) + reduced_time_values = data.time_value.to_list()[-len(smoothed_values):] + + expected_df = CovidcastRows.from_args( + time_value=reduced_time_values, + value=smoothed_values, + stderr=[None] * len(smoothed_values), + sample_size=[None] * len(smoothed_values), + ).api_row_df + assert_frame_equal(smoothed_df, expected_df) + + with self.subTest("conflicting smoother args validation, smoother kernel should overwrite window length"): + smoothed_df = CovidcastRows.from_records(generate_smoothed_rows(data.to_dict(orient='records'), smoother_kernel=[1/7.]*7, smoother_window_length=10)).api_row_df + + smoothed_values = _smooth_rows(data.value.to_list(), kernel=[1/7.]*7) + reduced_time_values = data.time_value.to_list()[-len(smoothed_values):] + + expected_df = CovidcastRows.from_args( + time_value=reduced_time_values, + value=smoothed_values, + stderr=[None] * len(smoothed_values), + sample_size=[None] * len(smoothed_values), + ).api_row_df + assert_frame_equal(smoothed_df, expected_df) + + + def test_generate_diffed_rows(self): + with self.subTest("an empty dataframe should return an empty dataframe"): + data = DataFrame({}) + diffs_df = CovidcastRows.from_records(generate_diffed_rows(data.to_dict(orient='records'))).api_row_df + expected_df = CovidcastRows(rows=[]).api_row_df + assert_frame_equal(diffs_df, expected_df) + + with self.subTest("a dataframe with not enough data to make one row should return an empty dataframe"): + data = CovidcastRows.from_args(time_value=[20210501], value=[1.0]).api_row_df + diffs_df = CovidcastRows.from_records(generate_diffed_rows(data.to_dict(orient='records'))).api_row_df + expected_df = CovidcastRows(rows=[]).api_row_df + assert_frame_equal(diffs_df, expected_df) + + data = CovidcastRows.from_args( + time_value=date_range("2021-05-01", "2021-05-10"), + value=chain(range(7), [None, 2., 1.]) + ).api_row_df + + with self.subTest("no fill"): + diffs_df = CovidcastRows.from_records(generate_diffed_rows(data.to_dict(orient='records'))).api_row_df + + diffed_values = _diff_rows(data.value.to_list()) + reduced_time_values = data.time_value.to_list()[-len(diffed_values):] + + expected_df = CovidcastRows.from_args( + time_value=reduced_time_values, + value=diffed_values, + stderr=[None] * len(diffed_values), + sample_size=[None] * len(diffed_values), + ).api_row_df + + assert_frame_equal(diffs_df, expected_df) + + with self.subTest("yes fill"): + diffs_df = CovidcastRows.from_records(generate_diffed_rows(data.to_dict(orient='records'), nan_fill_value=2.)).api_row_df + + diffed_values = _diff_rows([v if v is not None and not isnan(v) else 2. for v in data.value.to_list()]) + reduced_time_values = data.time_value.to_list()[-len(diffed_values):] + + expected_df = CovidcastRows.from_args( + time_value=reduced_time_values, + value=diffed_values, + stderr=[None] * len(diffed_values), + sample_size=[None] * len(diffed_values), + ).api_row_df + + assert_frame_equal(diffs_df, expected_df) diff --git a/tests/server/test_params.py b/tests/server/test_params.py index fffea0043..d2299dd02 100644 --- a/tests/server/test_params.py +++ b/tests/server/test_params.py @@ -19,6 +19,7 @@ GeoPair, TimePair, SourceSignalPair, + _combine_source_signal_pairs ) from delphi.epidata.server._exceptions import ( ValidationFailedException, @@ -182,8 +183,7 @@ def test_parse_source_signal_arg(self): self.assertEqual( parse_source_signal_arg(), [ - SourceSignalPair("src1", ["sig1"]), - SourceSignalPair("src1", ["sig4"]), + SourceSignalPair("src1", ["sig1", "sig4"]), ], ) with self.subTest("multi list"): @@ -191,17 +191,17 @@ def test_parse_source_signal_arg(self): self.assertEqual( parse_source_signal_arg(), [ - SourceSignalPair("src1", ["sig1", "sig2"]), SourceSignalPair("county", ["sig5", "sig6"]), + SourceSignalPair("src1", ["sig1", "sig2"]), ], ) with self.subTest("hybrid"): - with app.test_request_context("/?signal=src2:*;src1:sig4;src3:sig5,sig6"): + with app.test_request_context("/?signal=src2:*;src1:sig4;src3:sig5,sig6;src1:sig5;src2:sig1"): self.assertEqual( parse_source_signal_arg(), [ + SourceSignalPair("src1", ["sig4", "sig5"]), SourceSignalPair("src2", True), - SourceSignalPair("src1", ["sig4"]), SourceSignalPair("src3", ["sig5", "sig6"]), ], ) @@ -357,3 +357,29 @@ def test_parse_day_arg(self): self.assertRaises(ValidationFailedException, parse_day_arg, "time") with app.test_request_context("/?time=week:20121010"): self.assertRaises(ValidationFailedException, parse_day_arg, "time") + + def test__combine_source_signal_pairs(self): + source_signal_pairs = [ + SourceSignalPair("src1", ["sig1", "sig2"]), + SourceSignalPair("src2", ["sig1"]), + SourceSignalPair("src1", ["sig1", "sig3"]), + SourceSignalPair("src3", ["sig1"]), + SourceSignalPair("src3", ["sig2"]), + SourceSignalPair("src3", ["sig1"]), + SourceSignalPair("src4", ["sig2"]), + SourceSignalPair("src4", True), + ] + expected_source_signal_pairs = [ + SourceSignalPair("src1", ["sig1", "sig2", "sig3"]), + SourceSignalPair("src2", ["sig1"]), + SourceSignalPair("src3", ["sig1", "sig2"]), + SourceSignalPair("src4", True), + ] + combined_pairs = _combine_source_signal_pairs(source_signal_pairs) + for i, x in enumerate(combined_pairs): + if isinstance(x, list): + sorted(x) == expected_source_signal_pairs[i] + if isinstance(x, bool): + x == expected_source_signal_pairs[i] + + assert _combine_source_signal_pairs(source_signal_pairs) == expected_source_signal_pairs diff --git a/tests/server/utils/test_dates.py b/tests/server/utils/test_dates.py index e825bbd9b..c450a9b5b 100644 --- a/tests/server/utils/test_dates.py +++ b/tests/server/utils/test_dates.py @@ -2,7 +2,7 @@ from datetime import date from epiweeks import Week -from delphi.epidata.server.utils.dates import time_value_to_date, date_to_time_value, shift_time_value, time_value_to_iso, days_in_range, weeks_in_range, week_to_time_value, week_value_to_week, time_values_to_ranges +from delphi.epidata.server.utils.dates import time_value_to_date, date_to_time_value, shift_time_value, time_value_to_iso, days_in_range, weeks_in_range, week_to_time_value, week_value_to_week, time_values_to_ranges, iterate_over_range, iterate_over_ints_and_ranges class UnitTests(unittest.TestCase): @@ -59,3 +59,19 @@ def test_time_values_to_ranges(self): self.assertEqual(time_values_to_ranges([20210228, 20210301]), [(20210228, 20210301)]) # this becomes a range because these dates are indeed consecutive # individual weeks become a range (2020 is a rare year with 53 weeks) self.assertEqual(time_values_to_ranges([202051, 202052, 202053, 202101, 202102]), [(202051, 202102)]) + + def test_iterate_over_range(self): + self.assertEqual(list(iterate_over_range(20210801, 20210805)), [20210801, 20210802, 20210803, 20210804]) + self.assertEqual(list(iterate_over_range(20210801, 20210801)), []) + self.assertEqual(list(iterate_over_range(20210801, 20210701)), []) + + def test_iterate_over_ints_and_ranges(self): + assert list(iterate_over_ints_and_ranges([0, (5, 8)], use_dates=False)) == [0, 5, 6, 7, 8] + assert list(iterate_over_ints_and_ranges([(5, 8), (4, 6), (3, 5)], use_dates=False)) == [3, 4, 5, 6, 7, 8] + assert list(iterate_over_ints_and_ranges([(7, 8), (5, 7), (3, 8), 8], use_dates=False)) == [3, 4, 5, 6, 7, 8] + assert list(iterate_over_ints_and_ranges([2, (2, 3)], use_dates=False)) == [2, 3] + assert list(iterate_over_ints_and_ranges([20, 50, 25, (21, 25), 23, 30, 31, (24, 26)], use_dates=False)) == [20, 21, 22, 23, 24, 25, 26, 30, 31, 50] + + assert list(iterate_over_ints_and_ranges([20210817])) == [20210817] + assert list(iterate_over_ints_and_ranges([20210817, (20210810, 20210815)])) == [20210810, 20210811, 20210812, 20210813, 20210814, 20210815, 20210817] + assert list(iterate_over_ints_and_ranges([(20210801, 20210905), (20210815, 20210915)])) == list(iterate_over_range(20210801, 20210916)) # right-exclusive From 9f9dfb3eef1340c4c340660063ef9271e8482616 Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Tue, 11 Oct 2022 14:46:48 -0700 Subject: [PATCH 15/29] Acquisition: update test_csv_uploading to remove Pandas warning --- integrations/acquisition/covidcast/test_csv_uploading.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/acquisition/covidcast/test_csv_uploading.py b/integrations/acquisition/covidcast/test_csv_uploading.py index de3eb5f13..f975ecfa0 100644 --- a/integrations/acquisition/covidcast/test_csv_uploading.py +++ b/integrations/acquisition/covidcast/test_csv_uploading.py @@ -213,8 +213,8 @@ def test_uploading(self): "time_value": [20200419], "signal": [signal_name], "direction": [None]})], axis=1).rename(columns=uploader_column_rename) - expected_values_df["missing_value"].iloc[0] = Nans.OTHER - expected_values_df["missing_sample_size"].iloc[0] = Nans.NOT_MISSING + expected_values_df.loc[0, "missing_value"] = Nans.OTHER + expected_values_df.loc[0, "missing_sample_size"] = Nans.NOT_MISSING expected_values = expected_values_df.to_dict(orient="records") expected_response = {'result': 1, 'epidata': self.apply_lag(expected_values), 'message': 'success'} From 0b1696a377d1dc60bbb75eca4baf67857f09a4bc Mon Sep 17 00:00:00 2001 From: Brian Clark Date: Tue, 11 Oct 2022 17:42:49 -0400 Subject: [PATCH 16/29] Build a container image from this branch --- .github/workflows/ci.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index f9d447331..e597a4dbf 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -108,7 +108,8 @@ jobs: image: needs: build # only on main and dev branch - if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev' + #if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev' + if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev' || github.ref == 'refs/heads/jit_computations' runs-on: ubuntu-latest steps: From 8ec310d55332f46a24d7cd980e77d04a6e24f55f Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Sat, 29 Oct 2022 02:31:28 -0700 Subject: [PATCH 17/29] Server: add assert_frame_equal_no_order test util --- src/acquisition/covidcast/covidcast_row.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/acquisition/covidcast/covidcast_row.py b/src/acquisition/covidcast/covidcast_row.py index e982a0250..1417a14b0 100644 --- a/src/acquisition/covidcast/covidcast_row.py +++ b/src/acquisition/covidcast/covidcast_row.py @@ -5,6 +5,7 @@ from delphi_utils import Nans from numpy import isnan from pandas import DataFrame, concat +from pandas.testing import assert_frame_equal from delphi.epidata.acquisition.covidcast.csv_importer import CsvImporter from delphi.epidata.server.utils.dates import date_to_time_value, time_value_to_date @@ -267,3 +268,11 @@ def set_df_dtypes(df: DataFrame, dtypes: Dict[str, Any]) -> DataFrame: if k in df.columns: df[k] = df[k].astype(v) return df + + +def assert_frame_equal_no_order(df1: DataFrame, df2: DataFrame, index: List[str], **kwargs: Any) -> None: + """Assert that two DataFrames are equal, ignoring the order of rows.""" + # Remove any existing index. If it wasn't named, drop it. Set a new index and sort it. + df1 = df1.reset_index().drop(columns="index").set_index(index).sort_index() + df2 = df2.reset_index().drop(columns="index").set_index(index).sort_index() + assert_frame_equal(df1, df2, **kwargs) From 8a37d82723ed4e360ed6df4512a58d7b7b7a137d Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Sat, 29 Oct 2022 02:32:29 -0700 Subject: [PATCH 18/29] Server: generalize iterate_over_range boundary inclusion --- src/server/utils/dates.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/server/utils/dates.py b/src/server/utils/dates.py index 5a2fb0205..75ab9ef88 100644 --- a/src/server/utils/dates.py +++ b/src/server/utils/dates.py @@ -8,6 +8,7 @@ ) from datetime import date, timedelta from epiweeks import Week, Year +from operator import lt, le import logging def time_value_to_date(value: int) -> date: @@ -140,16 +141,21 @@ def _to_ranges(values: Sequence[Union[Tuple[int, int], int]], value_to_date: Cal logging.info('bad input to date ranges', input=values, exception=e) return values -def iterate_over_range(start: int, end: int) -> Iterator[int]: +def iterate_over_range(start: int, end: int, inclusive: bool = False) -> Iterator[int]: """Iterate over ints corresponding to dates in a time range. - Left inclusive, right exclusive to mimic behavior of Python's built-in range. + By default left inclusive, right exclusive to mimic the behavior of the built-in range. """ + if inclusive: + op = le + else: + op = lt + if start > end: return current_date, final_date = time_value_to_date(start), time_value_to_date(end) - while current_date < final_date: + while op(current_date, final_date): yield date_to_time_value(current_date) current_date = current_date + timedelta(days=1) From e7740eeb75d57a77a8f61f2c3fcd341c7686a724 Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Tue, 1 Nov 2022 14:44:01 -0700 Subject: [PATCH 19/29] Server: convert TODO comments to #1017 --- src/server/endpoints/covidcast.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/server/endpoints/covidcast.py b/src/server/endpoints/covidcast.py index 4d9738286..dcb475c2e 100644 --- a/src/server/endpoints/covidcast.py +++ b/src/server/endpoints/covidcast.py @@ -88,16 +88,12 @@ def parse_time_pairs() -> List[TimePair]: # old version require_all("time_type", "time_values") time_values = extract_dates("time_values") - # TODO: Put a bound on the number of time_values? - # if time_values and len(time_values) > 30: - # raise ValidationFailedException("parameter value exceed: too many time pairs requested, consider using a timerange instead YYYYMMDD-YYYYMMDD") return [TimePair(time_type, time_values)] if ":" not in request.values.get("time", ""): raise ValidationFailedException("missing parameter: time or (time_type and time_values)") time_pairs = parse_time_arg() - # TODO: Put a bound on the number of time_values? (see above) return time_pairs From c54cde95200e238761f62c440fc4bbc40c97cf40 Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Tue, 1 Nov 2022 13:29:00 -0700 Subject: [PATCH 20/29] Server: remove unused Flask import in _config.py --- src/server/_config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/server/_config.py b/src/server/_config.py index 7973f9872..b6f08903d 100644 --- a/src/server/_config.py +++ b/src/server/_config.py @@ -1,6 +1,5 @@ import os from dotenv import load_dotenv -from flask import Flask import json load_dotenv() From 5ce82b506a6db43483f70b3ac295a012df723cc5 Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Tue, 1 Nov 2022 18:31:08 -0700 Subject: [PATCH 21/29] Server: remove model.py:data_signals threading and use mock --- src/server/endpoints/covidcast.py | 15 +- src/server/endpoints/covidcast_utils/model.py | 58 +--- .../endpoints/covidcast_utils/test_utils.py | 123 +++++++++ .../endpoints/covidcast_utils/test_model.py | 260 ++---------------- .../covidcast_utils/test_smooth_diff.py | 2 +- tests/server/test_params.py | 4 + 6 files changed, 180 insertions(+), 282 deletions(-) create mode 100644 src/server/endpoints/covidcast_utils/test_utils.py diff --git a/src/server/endpoints/covidcast.py b/src/server/endpoints/covidcast.py index dcb475c2e..d8863146e 100644 --- a/src/server/endpoints/covidcast.py +++ b/src/server/endpoints/covidcast.py @@ -41,7 +41,7 @@ from .._pandas import as_pandas, print_pandas from .covidcast_utils import compute_trend, compute_trends, compute_correlations, compute_trend_value, CovidcastMetaEntry from ..utils import shift_time_value, date_to_time_value, time_value_to_iso, time_value_to_date, shift_week_value, week_value_to_week, guess_time_value_is_day, week_to_time_value -from .covidcast_utils.model import TimeType, count_signal_time_types, data_sources, create_source_signal_alias_mapper, get_basename_signal_and_jit_generator, get_pad_length, pad_time_pairs, pad_time_window +from .covidcast_utils.model import TimeType, TransformType, count_signal_time_types, data_sources, data_sources_by_id, create_source_signal_alias_mapper, get_pad_length, pad_time_pairs, pad_time_window, get_basename_signal_and_jit_generator from .covidcast_utils.smooth_diff import SmootherKernelValue # first argument is the endpoint name @@ -63,7 +63,18 @@ def parse_source_signal_pairs() -> List[SourceSignalPair]: if ":" not in request.values.get("signal", ""): raise ValidationFailedException("missing parameter: signal or (data_source and signal[s])") - return parse_source_signal_arg() + # Convert source_signal_pairs with signal == True out to an explicit list of signals. + expanded_bool_source_signal_pairs = [] + for source_signal_pair in parse_source_signal_arg(): + if source_signal_pair.signal is True: + if data_source := data_sources_by_id.get(source_signal_pair.source): + expanded_bool_source_signal_pairs.append(SourceSignalPair(data_source.source, [s.signal for s in data_source.signals])) + else: + expanded_bool_source_signal_pairs.append(source_signal_pair) + else: + expanded_bool_source_signal_pairs.append(source_signal_pair) + + return expanded_bool_source_signal_pairs def parse_geo_pairs() -> List[GeoPair]: diff --git a/src/server/endpoints/covidcast_utils/model.py b/src/server/endpoints/covidcast_utils/model.py index fea526f12..424699e85 100644 --- a/src/server/endpoints/covidcast_utils/model.py +++ b/src/server/endpoints/covidcast_utils/model.py @@ -23,7 +23,7 @@ DIFF_SMOOTH: Callable = lambda rows, **kwargs: generate_smoothed_rows(generate_diffed_rows(rows, **kwargs), **kwargs) SignalTransforms = Dict[SourceSignalPair, SourceSignalPair] - +TransformType = Callable[[Iterator[Dict]], Iterator[Dict]] class HighValuesAre(str, Enum): bad = "bad" @@ -318,24 +318,6 @@ def map_row(source: str, signal: str) -> str: return transformed_pairs, map_row -def _resolve_bool_source_signals(source_signals: Union[SourceSignalPair, List[SourceSignalPair]], data_sources_by_id: Dict[str, DataSource]) -> Union[SourceSignalPair, List[SourceSignalPair]]: - """Expand a request for all signals to an explicit list of signal names. - - Example: SourceSignalPair("jhu-csse", signal=True) would return SourceSignalPair("jhu-csse", []). - """ - if isinstance(source_signals, SourceSignalPair): - if source_signals.signal == True: - source = data_sources_by_id.get(source_signals.source) - if source: - return SourceSignalPair(source.source, [s.signal for s in source.signals]) - return source_signals - - if isinstance(source_signals, list): - return [_resolve_bool_source_signals(pair, data_sources_by_id) for pair in source_signals] - - raise TypeError("source_signals is not Union[SourceSignalPair, List[SourceSignalPair]].") - - def _reindex_iterable(iterator: Iterator[Dict], time_pairs: Optional[List[TimePair]], fill_value: Optional[Number] = None) -> Iterator[Dict]: """Produces an iterator that fills in gaps in the time window of another iterator. @@ -383,7 +365,7 @@ def _reindex_iterable(iterator: Iterator[Dict], time_pairs: Optional[List[TimePa yield default_item -def _get_base_signal_transform(signal: Union[DataSignal, Tuple[str, str]], data_signals_by_key: Dict[Tuple[str, str], DataSignal] = data_signals_by_key) -> Callable: +def _get_base_signal_transform(signal: Union[DataSignal, Tuple[str, str]]) -> Callable: """Given a DataSignal, return the transformation that needs to be applied to its base signal to derive the signal.""" if isinstance(signal, DataSignal): base_signal = data_signals_by_key.get((signal.source, signal.signal_basename)) @@ -399,17 +381,13 @@ def _get_base_signal_transform(signal: Union[DataSignal, Tuple[str, str]], data_ if isinstance(signal, tuple): if signal := data_signals_by_key.get(signal): - return _get_base_signal_transform(signal, data_signals_by_key) + return _get_base_signal_transform(signal) return IDENTITY raise TypeError("signal must be either Tuple[str, str] or DataSignal.") -def get_transform_types( - source_signal_pairs: List[SourceSignalPair], - data_sources_by_id: Dict[str, DataSource] = data_sources_by_id, - data_signals_by_key: Dict[Tuple[str, str], DataSignal] = data_signals_by_key -) -> Set[Callable]: +def get_transform_types(source_signal_pairs: List[SourceSignalPair]) -> Set[Callable]: """Return a collection of the unique transforms required for transforming a given source-signal pair list. Example: @@ -417,8 +395,6 @@ def get_transform_types( Used to pad the user DB query with extra days. """ - source_signal_pairs = _resolve_bool_source_signals(source_signal_pairs, data_sources_by_id) - transform_types = set() for source_signal_pair in source_signal_pairs: source_name = source_signal_pair.source @@ -427,17 +403,12 @@ def get_transform_types( if isinstance(signal_names, bool): continue - transform_types |= {_get_base_signal_transform((source_name, signal_name), data_signals_by_key=data_signals_by_key) for signal_name in signal_names} + transform_types |= {_get_base_signal_transform((source_name, signal_name)) for signal_name in signal_names} return transform_types -def get_pad_length( - source_signal_pairs: List[SourceSignalPair], - smoother_window_length: int, - data_sources_by_id: Dict[str, DataSource] = data_sources_by_id, - data_signals_by_key: Dict[Tuple[str, str], DataSignal] = data_signals_by_key, -): +def get_pad_length(source_signal_pairs: List[SourceSignalPair], smoother_window_length: int): """Returns the size of the extra date padding needed, depending on the transformations the source-signal pair list requires. If smoothing is required, we fetch an extra smoother_window_length - 1 days (6 by default). If both diffing and smoothing is required on the same signal, @@ -445,7 +416,7 @@ def get_pad_length( Used to pad the user DB query with extra days. """ - transform_types = get_transform_types(source_signal_pairs, data_sources_by_id=data_sources_by_id, data_signals_by_key=data_signals_by_key) + transform_types = get_transform_types(source_signal_pairs) pad_length = [0] if DIFF_SMOOTH in transform_types: pad_length.append(smoother_window_length) @@ -522,7 +493,6 @@ def _generate_transformed_rows( transform_dict: Optional[SignalTransforms] = None, transform_args: Optional[Dict] = None, group_keyfunc: Optional[Callable] = None, - data_signals_by_key: Dict[Tuple[str, str], DataSignal] = data_signals_by_key, ) -> Iterator[Dict]: """Applies time-series transformations to streamed rows from a database. @@ -540,8 +510,6 @@ def _generate_transformed_rows( group_keyfunc: Optional[Callable], default None The groupby function to use to order the streamed rows. Note that Python groupby does not do any sorting, so parsed_rows are assumed to be sorted in accord with this groupby. - data_signals_by_key: Dict[Tuple[str, str], DataSignal], default data_signals_by_key - The dictionary of DataSignals which is used to find the base signal transforms. Yields: transformed rows: Dict @@ -559,7 +527,7 @@ def _generate_transformed_rows( # Extract the list of derived signals; if a signal is not in the dictionary, then use the identity map. derived_signal_transform_map: SourceSignalPair = transform_dict.get(SourceSignalPair(base_source_name, [base_signal_name]), SourceSignalPair(base_source_name, [base_signal_name])) # Create a list of source-signal pairs along with the transformation required for the signal. - signal_names_and_transforms: List[Tuple[Tuple[str, str], Callable]] = [(derived_signal, _get_base_signal_transform((base_source_name, derived_signal), data_signals_by_key)) for derived_signal in derived_signal_transform_map.signal] + signal_names_and_transforms: List[Tuple[Tuple[str, str], Callable]] = [(derived_signal, _get_base_signal_transform((base_source_name, derived_signal))) for derived_signal in derived_signal_transform_map.signal] # Put the current time series on a contiguous time index. source_signal_geo_rows = _reindex_iterable(source_signal_geo_rows, time_pairs, fill_value=transform_args.get("pad_fill_value")) # Create copies of the iterable, with smart memory usage. @@ -573,12 +541,7 @@ def _generate_transformed_rows( yield row -def get_basename_signal_and_jit_generator( - source_signal_pairs: List[SourceSignalPair], - transform_args: Optional[Dict[str, Union[str, int]]] = None, - data_sources_by_id: Dict[str, DataSource] = data_sources_by_id, - data_signals_by_key: Dict[Tuple[str, str], DataSignal] = data_signals_by_key, -) -> Tuple[List[SourceSignalPair], Generator]: +def get_basename_signal_and_jit_generator(source_signal_pairs: List[SourceSignalPair], transform_args: Optional[Dict[str, Union[str, int]]] = None) -> Tuple[List[SourceSignalPair], Generator]: """From a list of SourceSignalPairs, return the base signals required to derive them and a transformation function to take a stream of the base signals and return the transformed signals. @@ -587,7 +550,6 @@ def get_basename_signal_and_jit_generator( that will take the returned database query for "sig_base" and return both the base time series and the smoothed time series. transform_dict in this case would be {("src", "sig_base"): [("src", "sig_base"), ("src", "sig_smooth")]}. """ - source_signal_pairs = _resolve_bool_source_signals(source_signal_pairs, data_sources_by_id) base_signal_pairs: List[SourceSignalPair] = [] transform_dict: SignalTransforms = dict() @@ -608,6 +570,6 @@ def get_basename_signal_and_jit_generator( signals.append(signal.signal_basename) base_signal_pairs.append(SourceSignalPair(pair.source, signals)) - row_transform_generator = partial(_generate_transformed_rows, transform_dict=transform_dict, transform_args=transform_args, data_signals_by_key=data_signals_by_key) + row_transform_generator = partial(_generate_transformed_rows, transform_dict=transform_dict, transform_args=transform_args) return base_signal_pairs, row_transform_generator diff --git a/src/server/endpoints/covidcast_utils/test_utils.py b/src/server/endpoints/covidcast_utils/test_utils.py new file mode 100644 index 000000000..e885378d0 --- /dev/null +++ b/src/server/endpoints/covidcast_utils/test_utils.py @@ -0,0 +1,123 @@ +from numbers import Number +from typing import Iterable, List, Optional +from delphi.epidata.server.endpoints.covidcast_utils.model import DataSource, DataSignal +from more_itertools import windowed + + +# fmt: off +DATA_SIGNALS_BY_KEY = { + ("src", "sig_diff"): DataSignal( + source="src", + signal="sig_diff", + signal_basename="sig_base", + name="src", + active=True, + short_description="", + description="", + time_label="", + value_label="", + is_cumulative=False, + compute_from_base=True, + ), + ("src", "sig_smooth"): DataSignal( + source="src", + signal="sig_smooth", + signal_basename="sig_base", + name="src", + active=True, + short_description="", + description="", + time_label="", + value_label="", + is_cumulative=True, + is_smoothed=True, + compute_from_base=True, + ), + ("src", "sig_diff_smooth"): DataSignal( + source="src", + signal="sig_diff_smooth", + signal_basename="sig_base", + name="src", + active=True, + short_description="", + description="", + time_label="", + value_label="", + is_cumulative=False, + is_smoothed=True, + compute_from_base=True, + ), + ("src", "sig_base"): DataSignal( + source="src", + signal="sig_base", + signal_basename="sig_base", + name="src", + active=True, + short_description="", + description="", + time_label="", + value_label="", + is_cumulative=True, + ), + ("src2", "sig_base"): DataSignal( + source="src2", + signal="sig_base", + signal_basename="sig_base", + name="sig_base", + active=True, + short_description="", + description="", + time_label="", + value_label="", + is_cumulative=True, + ), + ("src2", "sig_diff_smooth"): DataSignal( + source="src2", + signal="sig_diff_smooth", + signal_basename="sig_base", + name="sig_smooth", + active=True, + short_description="", + description="", + time_label="", + value_label="", + is_cumulative=False, + is_smoothed=True, + compute_from_base=True, + ), +} + +DATA_SOURCES_BY_ID = { + "src": DataSource( + source="src", + db_source="src", + name="src", + description="", + reference_signal="sig_base", + signals=[DATA_SIGNALS_BY_KEY[key] for key in DATA_SIGNALS_BY_KEY if key[0] == "src"], + ), + "src2": DataSource( + source="src2", + db_source="src2", + name="src2", + description="", + reference_signal="sig_base", + signals=[DATA_SIGNALS_BY_KEY[key] for key in DATA_SIGNALS_BY_KEY if key[0] == "src2"], + ), +} +# fmt: on + + +def _diff_rows(rows: Iterable[Number]) -> List[Number]: + return [round(float(y - x), 8) if not (x is None or y is None) else None for x, y in windowed(rows, 2)] + + +def _smooth_rows(rows: Iterable[Number], window_length: int = 7, kernel: Optional[List[Number]] = None): + if not kernel: + kernel = [1.0 / window_length] * window_length + return [round(sum(x * y for x, y in zip(window, kernel)), 8) if None not in window else None for window in windowed(rows, len(kernel))] + + +def _reindex_windowed(lst: list, window_length: int) -> list: + return [max(window) if None not in window else None for window in windowed(lst, window_length)] + diff --git a/tests/server/endpoints/covidcast_utils/test_model.py b/tests/server/endpoints/covidcast_utils/test_model.py index 246f6702a..ffe05405f 100644 --- a/tests/server/endpoints/covidcast_utils/test_model.py +++ b/tests/server/endpoints/covidcast_utils/test_model.py @@ -1,7 +1,6 @@ import unittest from itertools import chain -from numbers import Number -from typing import Iterable, List, Optional +from unittest.mock import patch import pandas as pd from more_itertools import interleave_longest, windowed @@ -10,152 +9,26 @@ from delphi.epidata.acquisition.covidcast.covidcast_row import CovidcastRows from delphi.epidata.server._params import SourceSignalPair, TimePair from delphi.epidata.server.endpoints.covidcast_utils.model import ( + DIFF, + DIFF_SMOOTH, + IDENTITY, + SMOOTH, _generate_transformed_rows, _get_base_signal_transform, _reindex_iterable, - _resolve_bool_source_signals, - DataSignal, - DataSource, - DIFF_SMOOTH, - DIFF, get_basename_signal_and_jit_generator, get_day_range, get_pad_length, get_transform_types, - IDENTITY, pad_time_pairs, - SMOOTH, ) from delphi_utils.nancodes import Nans - -# fmt: off -DATA_SIGNALS_BY_KEY = { - ("src", "sig_diff"): DataSignal( - source="src", - signal="sig_diff", - signal_basename="sig_base", - name="src", - active=True, - short_description="", - description="", - time_label="", - value_label="", - is_cumulative=False, - compute_from_base=True, - ), - ("src", "sig_smooth"): DataSignal( - source="src", - signal="sig_smooth", - signal_basename="sig_base", - name="src", - active=True, - short_description="", - description="", - time_label="", - value_label="", - is_cumulative=True, - is_smoothed=True, - compute_from_base=True, - ), - ("src", "sig_diff_smooth"): DataSignal( - source="src", - signal="sig_diff_smooth", - signal_basename="sig_base", - name="src", - active=True, - short_description="", - description="", - time_label="", - value_label="", - is_cumulative=False, - is_smoothed=True, - compute_from_base=True, - ), - ("src", "sig_base"): DataSignal( - source="src", - signal="sig_base", - signal_basename="sig_base", - name="src", - active=True, - short_description="", - description="", - time_label="", - value_label="", - is_cumulative=True, - ), - ("src2", "sig_base"): DataSignal( - source="src2", - signal="sig_base", - signal_basename="sig_base", - name="sig_base", - active=True, - short_description="", - description="", - time_label="", - value_label="", - is_cumulative=True, - ), - ("src2", "sig_diff_smooth"): DataSignal( - source="src2", - signal="sig_diff_smooth", - signal_basename="sig_base", - name="sig_smooth", - active=True, - short_description="", - description="", - time_label="", - value_label="", - is_cumulative=False, - is_smoothed=True, - compute_from_base=True, - ), -} - -DATA_SOURCES_BY_ID = { - "src": DataSource( - source="src", - db_source="src", - name="src", - description="", - reference_signal="sig_base", - signals=[DATA_SIGNALS_BY_KEY[key] for key in DATA_SIGNALS_BY_KEY if key[0] == "src"], - ), - "src2": DataSource( - source="src2", - db_source="src2", - name="src2", - description="", - reference_signal="sig_base", - signals=[DATA_SIGNALS_BY_KEY[key] for key in DATA_SIGNALS_BY_KEY if key[0] == "src2"], - ), -} -# fmt: on - - -def _diff_rows(rows: Iterable[Number]) -> List[Number]: - return [round(float(y - x), 8) if not (x is None or y is None) else None for x, y in windowed(rows, 2)] - - -def _smooth_rows(rows: Iterable[Number], window_length: int = 7, kernel: Optional[List[Number]] = None): - if not kernel: - kernel = [1.0 / window_length] * window_length - return [round(sum(x * y for x, y in zip(window, kernel)), 8) if None not in window else None for window in windowed(rows, len(kernel))] - - -def _reindex_windowed(lst: list, window_length: int) -> list: - return [max(window) if None not in window else None for window in windowed(lst, window_length)] +from delphi.epidata.server.endpoints.covidcast_utils.test_utils import DATA_SOURCES_BY_ID, DATA_SIGNALS_BY_KEY, _diff_rows, _smooth_rows, _reindex_windowed +@patch("delphi.epidata.server.endpoints.covidcast_utils.model.data_sources_by_id", DATA_SOURCES_BY_ID) +@patch("delphi.epidata.server.endpoints.covidcast_utils.model.data_signals_by_key", DATA_SIGNALS_BY_KEY) class TestModel(unittest.TestCase): - def test__resolve_bool_source_signals(self): - source_signal_pair = [SourceSignalPair(source="src", signal=True), SourceSignalPair(source="src", signal=["sig_unknown"])] - resolved_source_signal_pair = _resolve_bool_source_signals(source_signal_pair, DATA_SOURCES_BY_ID) - expected_source_signal_pair = [ - SourceSignalPair(source="src", signal=["sig_diff", "sig_smooth", "sig_diff_smooth", "sig_base"]), - SourceSignalPair(source="src", signal=["sig_unknown"]), - ] - assert resolved_source_signal_pair == expected_source_signal_pair - def test__reindex_iterable(self): # Trivial test. time_pairs = [(20210503, 20210508)] @@ -194,49 +67,40 @@ def test__reindex_iterable(self): assert_frame_equal(df, expected_df) def test__get_base_signal_transform(self): - assert _get_base_signal_transform(DATA_SIGNALS_BY_KEY[("src", "sig_smooth")], DATA_SIGNALS_BY_KEY) == SMOOTH - assert _get_base_signal_transform(DATA_SIGNALS_BY_KEY[("src", "sig_diff_smooth")], DATA_SIGNALS_BY_KEY) == DIFF_SMOOTH - assert _get_base_signal_transform(DATA_SIGNALS_BY_KEY[("src", "sig_diff")], DATA_SIGNALS_BY_KEY) == DIFF - assert _get_base_signal_transform(("src", "sig_diff"), DATA_SIGNALS_BY_KEY) == DIFF - assert _get_base_signal_transform(DATA_SIGNALS_BY_KEY[("src", "sig_base")], DATA_SIGNALS_BY_KEY) == IDENTITY - assert _get_base_signal_transform(("src", "sig_unknown"), DATA_SIGNALS_BY_KEY) == IDENTITY + assert _get_base_signal_transform(("src", "sig_smooth")) == SMOOTH + assert _get_base_signal_transform(("src", "sig_diff_smooth")) == DIFF_SMOOTH + assert _get_base_signal_transform(("src", "sig_diff")) == DIFF + assert _get_base_signal_transform(("src", "sig_diff")) == DIFF + assert _get_base_signal_transform(("src", "sig_base")) == IDENTITY + assert _get_base_signal_transform(("src", "sig_unknown")) == IDENTITY def test_get_transform_types(self): - source_signal_pairs = [SourceSignalPair(source="src", signal=True)] - transform_types = get_transform_types(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) - expected_transform_types = {IDENTITY, DIFF, SMOOTH, DIFF_SMOOTH} - assert transform_types == expected_transform_types - source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_diff"])] - transform_types = get_transform_types(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + transform_types = get_transform_types(source_signal_pairs) expected_transform_types = {DIFF} assert transform_types == expected_transform_types source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_smooth"])] - transform_types = get_transform_types(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + transform_types = get_transform_types(source_signal_pairs) expected_transform_types = {SMOOTH} assert transform_types == expected_transform_types source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_diff_smooth"])] - transform_types = get_transform_types(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + transform_types = get_transform_types(source_signal_pairs) expected_transform_types = {DIFF_SMOOTH} assert transform_types == expected_transform_types def test_get_pad_length(self): - source_signal_pairs = [SourceSignalPair(source="src", signal=True)] - pad_length = get_pad_length(source_signal_pairs, smoother_window_length=7, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) - assert pad_length == 7 - source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_diff"])] - pad_length = get_pad_length(source_signal_pairs, smoother_window_length=7, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + pad_length = get_pad_length(source_signal_pairs, smoother_window_length=7) assert pad_length == 1 source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_smooth"])] - pad_length = get_pad_length(source_signal_pairs, smoother_window_length=5, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + pad_length = get_pad_length(source_signal_pairs, smoother_window_length=5) assert pad_length == 4 source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_diff_smooth"])] - pad_length = get_pad_length(source_signal_pairs, smoother_window_length=10, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + pad_length = get_pad_length(source_signal_pairs, smoother_window_length=10) assert pad_length == 10 def test_pad_time_pairs(self): @@ -287,7 +151,7 @@ def test__generate_transformed_rows(self): value=range(5) ).api_row_df transform_dict = {SourceSignalPair("src", ["sig_base"]): SourceSignalPair("src", ["sig_diff"])} - df = CovidcastRows.from_records(_generate_transformed_rows(data.to_dict(orient="records"), transform_dict=transform_dict, data_signals_by_key=DATA_SIGNALS_BY_KEY)).api_row_df + df = CovidcastRows.from_records(_generate_transformed_rows(data.to_dict(orient="records"), transform_dict=transform_dict)).api_row_df expected_df = CovidcastRows.from_args( signal=["sig_diff"] * 4, @@ -310,7 +174,7 @@ def test__generate_transformed_rows(self): sample_size=range(10) ).api_row_df transform_dict = {SourceSignalPair("src", ["sig_base"]): SourceSignalPair("src", ["sig_diff", "sig_smooth"])} - df = CovidcastRows.from_records(_generate_transformed_rows(data.to_dict(orient="records"), transform_dict=transform_dict, data_signals_by_key=DATA_SIGNALS_BY_KEY)).api_row_df + df = CovidcastRows.from_records(_generate_transformed_rows(data.to_dict(orient="records"), transform_dict=transform_dict)).api_row_df expected_df = CovidcastRows.from_args( signal=interleave_longest(["sig_diff"] * 9, ["sig_smooth"] * 4), @@ -337,7 +201,7 @@ def test__generate_transformed_rows(self): transform_dict = {SourceSignalPair("src", ["sig_base"]): SourceSignalPair("src", ["sig_diff", "sig_smooth"])} time_pairs = [TimePair("day", [(20210501, 20210520)])] df = CovidcastRows.from_records( - _generate_transformed_rows(data.to_dict(orient="records"), time_pairs=time_pairs, transform_dict=transform_dict, data_signals_by_key=DATA_SIGNALS_BY_KEY) + _generate_transformed_rows(data.to_dict(orient="records"), time_pairs=time_pairs, transform_dict=transform_dict) ).api_row_df filled_values = data.value.to_list()[:10] + [None] * 5 + data.value.to_list()[10:] @@ -361,13 +225,13 @@ def test__generate_transformed_rows(self): def test_get_basename_signals(self): with self.subTest("none to transform"): source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_base"])] - basename_pairs, _ = get_basename_signal_and_jit_generator(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + basename_pairs, _ = get_basename_signal_and_jit_generator(source_signal_pairs) expected_basename_pairs = [SourceSignalPair(source="src", signal=["sig_base"])] assert basename_pairs == expected_basename_pairs with self.subTest("unrecognized signal"): source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_unknown"])] - basename_pairs, _ = get_basename_signal_and_jit_generator(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + basename_pairs, _ = get_basename_signal_and_jit_generator(source_signal_pairs) expected_basename_pairs = [SourceSignalPair(source="src", signal=["sig_unknown"])] assert basename_pairs == expected_basename_pairs @@ -376,19 +240,13 @@ def test_get_basename_signals(self): SourceSignalPair(source="src", signal=["sig_diff", "sig_smooth", "sig_diff_smooth", "sig_base"]), SourceSignalPair(source="src2", signal=["sig"]), ] - basename_pairs, _ = get_basename_signal_and_jit_generator(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + basename_pairs, _ = get_basename_signal_and_jit_generator(source_signal_pairs) expected_basename_pairs = [ SourceSignalPair(source="src", signal=["sig_base", "sig_base", "sig_base", "sig_base"]), SourceSignalPair(source="src2", signal=["sig"]), ] assert basename_pairs == expected_basename_pairs - with self.subTest("resolve"): - source_signal_pairs = [SourceSignalPair(source="src", signal=True)] - basename_pairs, _ = get_basename_signal_and_jit_generator(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) - expected_basename_pairs = [SourceSignalPair("src", ["sig_base"] * 4)] - assert basename_pairs == expected_basename_pairs - with self.subTest("test base, diff, smooth"): # fmt: off data = CovidcastRows.from_args( @@ -399,7 +257,7 @@ def test_get_basename_signals(self): sample_size=chain(range(20), range(5)), ).api_row_df source_signal_pairs = [SourceSignalPair("src", ["sig_base", "sig_diff", "sig_other", "sig_smooth"])] - _, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + _, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs) time_pairs = [TimePair("day", [(20210501, 20210530)])] df = CovidcastRows.from_records(row_transform_generator(data.to_dict(orient="records"), time_pairs=time_pairs)).api_row_df @@ -450,7 +308,7 @@ def test_get_basename_signals(self): sample_size=chain(range(20), range(0, 40, 2)), ).api_row_df source_signal_pairs = [SourceSignalPair("src", ["sig_base", "sig_diff", "sig_other", "sig_smooth"])] - _, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + _, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs) df = CovidcastRows.from_records(row_transform_generator(data.to_dict(orient="records"))).api_row_df expected_df = CovidcastRows.from_args( @@ -486,69 +344,9 @@ def test_get_basename_signals(self): idx = ["source", "signal", "time_value"] assert_frame_equal(df.set_index(idx).sort_index(), expected_df.set_index(idx).sort_index()) - with self.subTest("resolve signals called"): - data = CovidcastRows.from_args( - signal=["sig_base"] * 20 + ["sig_other"] * 5, - time_value=chain(pd.date_range("2021-05-01", "2021-05-10"), pd.date_range("2021-05-21", "2021-05-30"), pd.date_range("2021-05-01", "2021-05-05")), - value=chain(range(20), range(5)), - stderr=chain(range(20), range(5)), - sample_size=chain(range(20), range(5)), - ).api_row_df - source_signal_pairs = [SourceSignalPair("src", True)] - _, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) - time_pairs = [TimePair("day", [(20210501, 20210530)])] - df = CovidcastRows.from_records(row_transform_generator(data.to_dict(orient="records"), time_pairs=time_pairs)).api_row_df - - filled_values = list(chain(range(10), [None] * 10, range(10, 20))) - filled_time_values = list(chain(pd.date_range("2021-05-01", "2021-05-10"), [None] * 10, pd.date_range("2021-05-21", "2021-05-30"))) - - # fmt: off - expected_df = CovidcastRows.from_args( - signal=["sig_base"] * 30 + ["sig_diff"] * 29 + ["sig_diff_smooth"] * 23 + ["sig_other"] * 5 + ["sig_smooth"] * 24, - time_value=chain( - pd.date_range("2021-05-01", "2021-05-30"), - pd.date_range("2021-05-02", "2021-05-30"), - pd.date_range("2021-05-08", "2021-05-30"), - pd.date_range("2021-05-01", "2021-05-05"), - pd.date_range("2021-05-07", "2021-05-30"), - ), - value=chain( - filled_values, - _diff_rows(filled_values), - _smooth_rows(_diff_rows(filled_values)), - range(5), - _smooth_rows(filled_values) - ), - stderr=chain( - chain(range(10), [None] * 10, range(10, 20)), - chain([None] * 29), - chain([None] * 23), - range(5), - chain([None] * 24), - ), - sample_size=chain( - chain(range(10), [None] * 10, range(10, 20)), - chain([None] * 29), - chain([None] * 23), - range(5), - chain([None] * 24), - ), - issue=chain( - filled_time_values, - _reindex_windowed(filled_time_values, 2), - _reindex_windowed(filled_time_values, 8), - pd.date_range("2021-05-01", "2021-05-05"), - _reindex_windowed(filled_time_values, 7), - ), - ).api_row_df - # fmt: off - # Test no order. - idx = ["source", "signal", "time_value"] - assert_frame_equal(df.set_index(idx).sort_index(), expected_df.set_index(idx).sort_index()) - with self.subTest("empty iterator"): source_signal_pairs = [SourceSignalPair("src", ["sig_base", "sig_diff", "sig_smooth"])] - _, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs, data_sources_by_id=DATA_SOURCES_BY_ID, data_signals_by_key=DATA_SIGNALS_BY_KEY) + _, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs) assert list(row_transform_generator({})) == [] def test_get_day_range(self): diff --git a/tests/server/endpoints/covidcast_utils/test_smooth_diff.py b/tests/server/endpoints/covidcast_utils/test_smooth_diff.py index 5ce3e0b8a..ad72e2e2f 100644 --- a/tests/server/endpoints/covidcast_utils/test_smooth_diff.py +++ b/tests/server/endpoints/covidcast_utils/test_smooth_diff.py @@ -7,7 +7,7 @@ from delphi.epidata.acquisition.covidcast.covidcast_row import CovidcastRows from delphi.epidata.server.endpoints.covidcast_utils.smooth_diff import generate_diffed_rows, generate_smoothed_rows, _smoother -from .test_model import _diff_rows, _smooth_rows +from delphi.epidata.server.endpoints.covidcast_utils.test_utils import _diff_rows, _smooth_rows class TestStreaming(unittest.TestCase): diff --git a/tests/server/test_params.py b/tests/server/test_params.py index d2299dd02..1039d5371 100644 --- a/tests/server/test_params.py +++ b/tests/server/test_params.py @@ -3,6 +3,7 @@ # standard library from math import inf import unittest +from unittest.mock import patch # from flask.testing import FlaskClient from delphi.epidata.server._common import app @@ -24,11 +25,14 @@ from delphi.epidata.server._exceptions import ( ValidationFailedException, ) +from delphi.epidata.server.endpoints.covidcast_utils.test_utils import DATA_SOURCES_BY_ID, DATA_SIGNALS_BY_KEY # py3tester coverage target __test_target__ = "delphi.epidata.server._params" +@patch("delphi.epidata.server.endpoints.covidcast_utils.model.data_sources_by_id", DATA_SOURCES_BY_ID) +@patch("delphi.epidata.server.endpoints.covidcast_utils.model.data_signals_by_key", DATA_SIGNALS_BY_KEY) class UnitTests(unittest.TestCase): """Basic unit tests.""" From d7a9219f66ffeae01d3bcef5205e49fb61ce94c6 Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Wed, 2 Nov 2022 14:59:27 -0700 Subject: [PATCH 22/29] Server: a small naming change in model.py --- src/server/endpoints/covidcast_utils/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/server/endpoints/covidcast_utils/model.py b/src/server/endpoints/covidcast_utils/model.py index 424699e85..7c3df337f 100644 --- a/src/server/endpoints/covidcast_utils/model.py +++ b/src/server/endpoints/covidcast_utils/model.py @@ -253,8 +253,8 @@ def _load_data_signals(sources: List[DataSource]): data_signals_by_key[(source.db_source, d.signal)] = d -def get_related_signals(signal: DataSignal) -> List[DataSignal]: - return [s for s in data_signals if s != signal and s.signal_basename == signal.signal_basename] +def get_related_signals(data_signal: DataSignal) -> List[DataSignal]: + return [s for s in data_signals if s != data_signal and s.signal_basename == data_signal.signal_basename] def count_signal_time_types(source_signals: List[SourceSignalPair]) -> Tuple[int, int]: From 9e06226e6c649519366e8101beb44dcf5604fe8d Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Fri, 4 Nov 2022 16:57:57 -0700 Subject: [PATCH 23/29] Server: add type hints to _query --- src/server/_query.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/server/_query.py b/src/server/_query.py index 1029c5e2c..4fae166ec 100644 --- a/src/server/_query.py +++ b/src/server/_query.py @@ -12,6 +12,7 @@ cast, Mapping, ) +from flask import Response from sqlalchemy import text from sqlalchemy.engine import Row @@ -56,7 +57,7 @@ def filter_values( param_key: str, params: Dict[str, Any], formatter=lambda x: x, -): +) -> str: if not values: return "FALSE" # builds a SQL expression to filter strings (ex: locations) @@ -71,7 +72,7 @@ def filter_strings( values: Optional[Sequence[Union[Tuple[str, str], str]]], param_key: str, params: Dict[str, Any], -): +) -> str: return filter_values(field, values, param_key, params) @@ -80,7 +81,7 @@ def filter_integers( values: Optional[Sequence[Union[Tuple[int, int], int]]], param_key: str, params: Dict[str, Any], -): +) -> str: return filter_values(field, values, param_key, params) @@ -89,7 +90,7 @@ def filter_dates( values: Optional[Sequence[Union[Tuple[int, int], int]]], param_key: str, params: Dict[str, Any], -): +) -> str: ranges = time_values_to_ranges(values) return filter_values(field, ranges, param_key, params, date_string) @@ -204,7 +205,7 @@ def parse_row( fields_string: Optional[Sequence[str]] = None, fields_int: Optional[Sequence[str]] = None, fields_float: Optional[Sequence[str]] = None, -): +) -> Dict[str, Any]: keys = set(row.keys()) parsed = dict() if fields_string: @@ -240,7 +241,7 @@ def limit_query(query: str, limit: int) -> str: return full_query -def run_query(p: APrinter, query_tuple: Tuple[str, Dict[str, Any]]): +def run_query(p: APrinter, query_tuple: Tuple[str, Dict[str, Any]]) -> Iterable[Row]: query, params = query_tuple # limit rows + 1 for detecting whether we would have more full_query = text(limit_query(query, p.remaining_rows + 1)) @@ -261,7 +262,7 @@ def execute_queries( fields_int: Sequence[str], fields_float: Sequence[str], transform: Callable[[Dict[str, Any], Row], Dict[str, Any]] = _identity_transform, -): +) -> Response: """ execute the given queries and return the response to send them """ @@ -321,14 +322,14 @@ def execute_query( fields_int: Sequence[str], fields_float: Sequence[str], transform: Callable[[Dict[str, Any], Row], Dict[str, Any]] = _identity_transform, -): +) -> Response: """ execute the given query and return the response to send it """ return execute_queries([(query, params)], fields_string, fields_int, fields_float, transform) -def _join_l(value: Union[str, List[str]]): +def _join_l(value: Union[str, List[str]]) -> str: return ", ".join(value) if isinstance(value, (list, tuple)) else value From 520dec880d8743d1254b319efa9de02944e6f8ba Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Fri, 4 Nov 2022 16:58:08 -0700 Subject: [PATCH 24/29] Server: remove unused imports in _query --- src/server/_query.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/server/_query.py b/src/server/_query.py index 4fae166ec..c36a56b0a 100644 --- a/src/server/_query.py +++ b/src/server/_query.py @@ -10,7 +10,6 @@ Tuple, Union, cast, - Mapping, ) from flask import Response @@ -18,10 +17,9 @@ from sqlalchemy.engine import Row from ._common import db, app -from ._db import metadata from ._printer import create_printer, APrinter from ._exceptions import DatabaseErrorException -from ._validate import DateRange, extract_strings +from ._validate import extract_strings from ._params import GeoPair, SourceSignalPair, TimePair from .utils import time_values_to_ranges, days_to_ranges, weeks_to_ranges From 7ce36d209c440ac18de9e3dd9da636b8fe021fae Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Fri, 4 Nov 2022 23:57:03 -0700 Subject: [PATCH 25/29] Server: tiny _params change --- src/server/_params.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/server/_params.py b/src/server/_params.py index 74ee8540d..29b72d26e 100644 --- a/src/server/_params.py +++ b/src/server/_params.py @@ -121,7 +121,7 @@ def _combine_source_signal_pairs(source_signal_pairs: List[SourceSignalPair]) -> def parse_source_signal_arg(key: str = "signal") -> List[SourceSignalPair]: - return _combine_source_signal_pairs([SourceSignalPair(source, signals) for [source, signals] in _parse_common_multi_arg(key)]) + return _combine_source_signal_pairs([SourceSignalPair(source, signals) for source, signals in _parse_common_multi_arg(key)]) def parse_single_source_signal_arg(key: str) -> SourceSignalPair: From 0586952156600efbad4a5a4a6afbfa8dd999b878 Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Fri, 4 Nov 2022 23:57:28 -0700 Subject: [PATCH 26/29] JIT: Make JIT integration tests robust (use dataframes) --- .../server/test_covidcast_endpoints.py | 138 ++++++++++-------- 1 file changed, 75 insertions(+), 63 deletions(-) diff --git a/integrations/server/test_covidcast_endpoints.py b/integrations/server/test_covidcast_endpoints.py index aa96aae3d..2744d3446 100644 --- a/integrations/server/test_covidcast_endpoints.py +++ b/integrations/server/test_covidcast_endpoints.py @@ -2,25 +2,27 @@ # standard library from copy import copy +from io import StringIO from itertools import accumulate, chain from typing import List, Sequence -from io import StringIO +from delphi.epidata.server.utils.dates import iterate_over_range # third party -from more_itertools import interleave_longest, windowed -import requests import pandas as pd +import pytest +import requests +from more_itertools import windowed from delphi.epidata.acquisition.covidcast.covidcast_meta_cache_updater import main as update_cache +from delphi.epidata.acquisition.covidcast.covidcast_row import CovidcastRow, CovidcastRows, set_df_dtypes, assert_frame_equal_no_order from delphi.epidata.acquisition.covidcast.test_utils import CovidcastBase -from delphi.epidata.acquisition.covidcast.covidcast_row import CovidcastRow, CovidcastRows, set_df_dtypes # use the local instance of the Epidata API BASE_URL = "http://delphi_web_epidata/epidata/covidcast" BASE_URL_OLD = "http://delphi_web_epidata/epidata/api.php" -def _read_csv(txt: str) -> pd.DataFrame: +def _read_csv_str(txt: str) -> pd.DataFrame: df = pd.read_csv(StringIO(txt), index_col=0).rename(columns={"data_source": "source"}) df.time_value = pd.to_datetime(df.time_value).dt.strftime("%Y%m%d").astype(int) df.issue = pd.to_datetime(df.issue).dt.strftime("%Y%m%d").astype(int) @@ -88,6 +90,7 @@ def test_compatibility(self): self._insert_rows(rows) with self.subTest("simple"): + # TODO: These tests aren't actually testing the compatibility endpoint. out = self._fetch("/", signal=first.signal_pair, geo=first.geo_pair, time="day:*") self.assertEqual(len(out["epidata"]), len(rows)) @@ -103,72 +106,81 @@ def test_compatibility(self): # JIT tests def test_derived_signals(self): - time_value_pairs = [(20200401 + i, i ** 2) for i in range(10)] - rows01 = [CovidcastRow(source="jhu-csse", signal="confirmed_cumulative_num", time_value=time_value, value=value, geo_value="01") for time_value, value in time_value_pairs] - rows02 = [CovidcastRow(source="jhu-csse", signal="confirmed_cumulative_num", time_value=time_value, value=2 * value, geo_value="02") for time_value, value in time_value_pairs] - first = rows01[0] - self._insert_rows(rows01 + rows02) + # The base signal data. + data1 = CovidcastRows.from_args( + source = ["jhu-csse"] * 10, + signal = ["confirmed_cumulative_num"] * 10, + time_value = iterate_over_range(20200401, 20200410, inclusive=True), + geo_value = ["01"] * 10, + value = [i ** 2 for i in range(10)], + ) + data2 = CovidcastRows.from_args( + source = ["jhu-csse"] * 10, + signal = ["confirmed_cumulative_num"] * 10, + time_value = iterate_over_range(20200401, 20200410, inclusive=True), + geo_value = ["02"] * 10, + value = [2 * i ** 2 for i in range(10)], + ) + # A base signal with a time gap. + data3 = CovidcastRows.from_args( + source = ["jhu-csse"] * 15, + signal = ["confirmed_cumulative_num"] * 15, + time_value = chain(iterate_over_range(20200401, 20200410, inclusive=True), iterate_over_range(20200416, 20200420, inclusive=True)), + geo_value = ["03"] * 15, + value = [i ** 2 for i in chain(range(10), range(15, 20))], + ) + self._insert_rows(data1.rows + data2.rows + data3.rows) + data3_reindexed = data3.api_row_df.set_index("time_value").reindex(iterate_over_range(20200401, 20200420, inclusive=True)).assign( + source = lambda df: df.source.fillna(method="ffill"), + signal = lambda df: df.signal.fillna(method="ffill"), + geo_value = lambda df: df.geo_value.fillna(method="ffill") + ).reset_index() + # Get the expected derived signal values. + data_df = pd.concat([data1.api_row_df, data2.api_row_df, data3_reindexed]).reset_index().set_index(["signal", "geo_value", "time_value"]) + expected_diffed_df = data_df.groupby(["geo_value"]).value.diff() + expected_diffed_df.index.set_levels(["confirmed_incidence_num"], level=0, inplace=True) + expected_smoothed_df = data_df.groupby(["geo_value"]).value.diff().rolling(7).mean() + expected_smoothed_df.index.set_levels(["confirmed_7dav_incidence_num"], level=0, inplace=True) + expected_df = pd.concat([data_df.value, expected_diffed_df, expected_smoothed_df]) with self.subTest("diffed signal"): - out = self._fetch("/", signal="jhu-csse:confirmed_incidence_num", geo=first.geo_pair, time="day:*") + out = self._fetch("/", signal="jhu-csse:confirmed_incidence_num", geo="county:01", time="day:*") + # TODO: This test will be updated when JIT can handle *. assert out['result'] == -2 - out = self._fetch("/", signal="jhu-csse:confirmed_incidence_num", geo=first.geo_pair, time="day:20200401-20200410") - out_values = [row["value"] for row in out["epidata"]] - values = [value for _, value in time_value_pairs] - expected_values = _diff_rows(values) - self.assertAlmostEqual(out_values, expected_values) + out = self._fetch("/", signal="jhu-csse:confirmed_incidence_num", geo="county:01", time="day:20200401-20200410") + out_df = pd.DataFrame.from_records(out["epidata"]).set_index(["signal", "time_value", "geo_value"]) + merged_df = pd.merge(out_df, expected_df, left_index=True, right_index=True, suffixes=["_out", "_expected"])[["value_out", "value_expected"]] + assert merged_df.empty is False + assert merged_df.value_out.to_numpy() == pytest.approx(merged_df.value_expected, nan_ok=True) with self.subTest("diffed signal, multiple geos"): out = self._fetch("/", signal="jhu-csse:confirmed_incidence_num", geo="county:01,02", time="day:20200401-20200410") - out_values = [row["value"] for row in out["epidata"]] - values1 = [value for _, value in time_value_pairs] - values2 = [2 * value for _, value in time_value_pairs] - expected_values = _diff_rows(values1) + _diff_rows(values2) - self.assertAlmostEqual(out_values, expected_values) - - with self.subTest("diffed signal, multiple geos using geo:*"): - out = self._fetch("/", signal="jhu-csse:confirmed_incidence_num", geo="county:*", time="day:20200401-20200410") - values1 = [value for _, value in time_value_pairs] - values2 = [2 * value for _, value in time_value_pairs] - expected_values = _diff_rows(values1) + _diff_rows(values2) - self.assertAlmostEqual(out_values, expected_values) + out_df = pd.DataFrame.from_records(out["epidata"]).set_index(["signal", "time_value", "geo_value"]) + merged_df = pd.merge(out_df, expected_df, left_index=True, right_index=True, suffixes=["_out", "_expected"])[["value_out", "value_expected"]] + assert merged_df.empty is False + assert merged_df.value_out.to_numpy() == pytest.approx(merged_df.value_expected, nan_ok=True) with self.subTest("smooth diffed signal"): - out = self._fetch("/", signal="jhu-csse:confirmed_7dav_incidence_num", geo=first.geo_pair, time="day:20200401-20200410") - out_values = [row["value"] for row in out["epidata"]] - values = [value for _, value in time_value_pairs] - expected_values = _smooth_rows(_diff_rows(values)) - self.assertAlmostEqual(out_values, expected_values) + out = self._fetch("/", signal="jhu-csse:confirmed_7dav_incidence_num", geo="county:01", time="day:20200401-20200410") + out_df = pd.DataFrame.from_records(out["epidata"]).set_index(["signal", "time_value", "geo_value"]) + merged_df = pd.merge(out_df, expected_df, left_index=True, right_index=True, suffixes=["_out", "_expected"])[["value_out", "value_expected"]] + assert merged_df.empty is False + assert merged_df.value_out.to_numpy() == pytest.approx(merged_df.value_expected, nan_ok=True) with self.subTest("diffed signal and smoothed signal in one request"): - out = self._fetch("/", signal="jhu-csse:confirmed_incidence_num;jhu-csse:confirmed_7dav_incidence_num", geo=first.geo_pair, time="day:20200401-20200410") - out_values = [row["value"] for row in out["epidata"]] - values = [value for _, value in time_value_pairs] - expected_diff = _diff_rows(values) - expected_smoothed = _smooth_rows(expected_diff) - expected_values = list(interleave_longest(expected_smoothed, expected_diff)) - self.assertAlmostEqual(out_values, expected_values) - - time_value_pairs = [(20200401 + i, i ** 2) for i in chain(range(10), range(15, 20))] - rows = [CovidcastRow(source="jhu-csse", signal="confirmed_cumulative_num", geo_value="03", time_value=time_value, value=value) for time_value, value in time_value_pairs] - first = rows[0] - self._insert_rows(rows) - - with self.subTest("diffing with a time gap"): - # should fetch 1 extra day - out = self._fetch("/", signal="jhu-csse:confirmed_incidence_num", geo=first.geo_pair, time="day:20200401-20200420") - out_values = [row["value"] for row in out["epidata"]] - values = [value for _, value in time_value_pairs][:10] + [None] * 5 + [value for _, value in time_value_pairs][10:] - expected_values = _diff_rows(values) - self.assertAlmostEqual(out_values, expected_values) - - with self.subTest("smoothing and diffing with a time gap"): - # should fetch 1 extra day - out = self._fetch("/", signal="jhu-csse:confirmed_7dav_incidence_num", geo=first.geo_pair, time="day:20200401-20200420") - out_values = [row["value"] for row in out["epidata"]] - values = [value for _, value in time_value_pairs][:10] + [None] * 5 + [value for _, value in time_value_pairs][10:] - expected_values = _smooth_rows(_diff_rows(values)) - self.assertAlmostEqual(out_values, expected_values) + out = self._fetch("/", signal="jhu-csse:confirmed_incidence_num;jhu-csse:confirmed_7dav_incidence_num", geo="county:01", time="day:20200401-20200410") + out_df = pd.DataFrame.from_records(out["epidata"]).set_index(["signal", "time_value", "geo_value"]) + merged_df = pd.merge(out_df, expected_df, left_index=True, right_index=True, suffixes=["_out", "_expected"])[["value_out", "value_expected"]] + assert merged_df.empty is False + assert merged_df.value_out.to_numpy() == pytest.approx(merged_df.value_expected, nan_ok=True) + + with self.subTest("smoothing and diffing with a time gap and geo=*"): + # should fetch 7 extra day + out = self._fetch("/", signal="jhu-csse:confirmed_7dav_incidence_num", geo="county:*", time="day:20200407-20200420") + out_df = pd.DataFrame.from_records(out["epidata"]).set_index(["signal", "time_value", "geo_value"]) + merged_df = pd.merge(out_df, expected_df, left_index=True, right_index=True, suffixes=["_out", "_expected"])[["value_out", "value_expected"]] + assert merged_df.empty is False + assert merged_df.value_out.to_numpy() == pytest.approx(merged_df.value_expected, nan_ok=True) def test_compatibility(self): """Request at the /api.php endpoint.""" @@ -447,7 +459,7 @@ def test_csv(self): params=dict(signal="jhu-csse:confirmed_cumulative_num", start_day="2020-04-01", end_day="2020-04-10", geo_type=first.geo_type), ) response.raise_for_status() - df = _read_csv(response.text) + df = _read_csv_str(response.text) expected_df = CovidcastRows.from_args( source=["jhu-csse"] * 10, signal=["confirmed_cumulative_num"] * 10, @@ -461,7 +473,7 @@ def test_csv(self): params=dict(signal="jhu-csse:confirmed_incidence_num", start_day="2020-04-01", end_day="2020-04-10", geo_type=first.geo_type), ) response.raise_for_status() - df_diffed = _read_csv(response.text) + df_diffed = _read_csv_str(response.text) expected_df = CovidcastRows.from_args( source=["jhu-csse"] * 10, signal=["confirmed_incidence_num"] * 10, From 487e2ce0f2569c6ee335ccb35343d9de449c580d Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Mon, 31 Oct 2022 13:47:51 -0700 Subject: [PATCH 27/29] CI: Update to build JIT Multi SQL image --- .github/workflows/ci.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index e597a4dbf..6219e05fc 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -109,7 +109,7 @@ jobs: needs: build # only on main and dev branch #if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev' - if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev' || github.ref == 'refs/heads/jit_computations' + if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev' || github.ref == 'refs/heads/jit_computations' || github.ref == 'refs/heads/ds/jit-multi-sql' runs-on: ubuntu-latest steps: From b2836a3c6ad9a089d8df73770ec166a7e99c97c4 Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Mon, 31 Oct 2022 14:40:00 -0700 Subject: [PATCH 28/29] CI: Update image --- .github/workflows/ci.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 6219e05fc..01b688bc7 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -134,6 +134,9 @@ jobs: if [ "$imageTag" = "main" ] ; then imageTag="latest" fi + if [ "$imageTag" = "ds/jit-multi-sql" ] ; then + imageTag="jit-multi-sql" + fi echo "::set-output name=tag::$imageTag" echo "::set-output name=repo::ghcr.io/${{ github.repository }}" - name: Push Dev Tag From 344d342ab271784ac99a655b253c75385d0caf87 Mon Sep 17 00:00:00 2001 From: Dmitry Shemetov Date: Fri, 4 Nov 2022 16:59:50 -0700 Subject: [PATCH 29/29] JIT: update model.py with multi-sql functions --- src/server/endpoints/covidcast.py | 453 ++++++++++-------- src/server/endpoints/covidcast_utils/model.py | 197 ++------ .../endpoints/covidcast_utils/test_model.py | 285 ++--------- .../covidcast_utils/test_smooth_diff.py | 2 +- 4 files changed, 339 insertions(+), 598 deletions(-) diff --git a/src/server/endpoints/covidcast.py b/src/server/endpoints/covidcast.py index d8863146e..e723f6ae8 100644 --- a/src/server/endpoints/covidcast.py +++ b/src/server/endpoints/covidcast.py @@ -1,14 +1,15 @@ from numbers import Number -from typing import List, Optional, Union, Tuple, Dict, Any +from typing import Callable, Iterable, List, Optional, Union, Tuple, Dict, Any from itertools import groupby from datetime import date, timedelta from bisect import bisect_right from epiweeks import Week -from flask import Blueprint, request +from flask import Blueprint, Response, request from flask.json import loads, jsonify from more_itertools import peekable from numpy import nan from sqlalchemy import text +from sqlalchemy.engine import Row from pandas import read_csv, to_datetime from .._common import is_compatibility_mode, app, db @@ -28,8 +29,9 @@ parse_single_geo_arg, ) from .._query import QueryBuilder, execute_query, run_query, parse_row, filter_fields -from .._printer import create_printer, CSVPrinter +from .._printer import create_printer, CSVPrinter, APrinter from .._validate import ( + DateRange, extract_bool, extract_date, extract_dates, @@ -41,7 +43,7 @@ from .._pandas import as_pandas, print_pandas from .covidcast_utils import compute_trend, compute_trends, compute_correlations, compute_trend_value, CovidcastMetaEntry from ..utils import shift_time_value, date_to_time_value, time_value_to_iso, time_value_to_date, shift_week_value, week_value_to_week, guess_time_value_is_day, week_to_time_value -from .covidcast_utils.model import TimeType, TransformType, count_signal_time_types, data_sources, data_sources_by_id, create_source_signal_alias_mapper, get_pad_length, pad_time_pairs, pad_time_window, get_basename_signal_and_jit_generator +from .covidcast_utils.model import TimeType, TransformType, count_signal_time_types, data_sources, data_sources_by_id, create_source_signal_alias_mapper, get_pad_length, pad_time_pairs, is_derived, get_base_signal_transform, reindex_iterable from .covidcast_utils.smooth_diff import SmootherKernelValue # first argument is the endpoint name @@ -164,83 +166,210 @@ def parse_jit_bypass(): return jit_bypass -@bp.route("/", methods=("GET", "POST")) -def handle(): - source_signal_pairs = parse_source_signal_pairs() - source_signal_pairs, alias_mapper = create_source_signal_alias_mapper(source_signal_pairs) - time_pairs = parse_time_pairs() - geo_pairs = parse_geo_pairs() - jit_bypass = parse_jit_bypass() - - as_of = extract_date("as_of") - issues = extract_dates("issues") - lag = extract_integer("lag") +def get_response( + source_signal_pairs: List[SourceSignalPair], + geo_pairs: List[GeoPair], + time_pairs: List[TimePair], + issues: List[DateRange], + lag: int, + as_of: int, + jit_bypass: bool, + fields_string: List[str], + fields_int: List[str], + fields_float: List[str], + fields_order: List[str], + extra_transform: Optional[TransformType], + printer: Optional[APrinter] = None, +) -> Response: + """Iterate over signals, get their queries and transforms, and execute them one by one to get the response for the API.""" is_time_type_week = any(time_pair.time_type == "week" for time_pair in time_pairs) is_time_value_true = any(isinstance(time_pair.time_values, bool) for time_pair in time_pairs) + use_jit_compute = not any((issues, lag, is_time_type_week, is_time_value_true)) and JIT_COMPUTE_ON and not jit_bypass - is_compatibility = is_compatibility_mode() - def alias_row(row): - if is_compatibility: - # old api returned fewer fields - remove_fields = ["geo_type", "source", "time_type"] - for field in remove_fields: - if field in row: - del row[field] - if is_compatibility or not alias_mapper or "source" not in row: - return row - row["source"] = alias_mapper(row["source"], row["signal"]) - return row + query_transforms: List[Tuple[QueryBuilder, TransformType]] = [] + for source_signal_pair in source_signal_pairs: + if isinstance(source_signal_pair.signal, bool): + query_transforms.append(base_signal_handler(SourceSignalPair(source_signal_pair.source, True), geo_pairs, time_pairs, issues, lag, as_of, fields_string, fields_int, fields_float, fields_order)) + continue + + for signal in source_signal_pair.signal: + if use_jit_compute and is_derived(source_signal_pair.source, signal): + query_transforms.append(derived_signal_handler(SourceSignalPair(source_signal_pair.source, [signal]), geo_pairs, time_pairs, issues, lag, as_of, fields_string, fields_int, fields_float, fields_order)) + else: + query_transforms.append(base_signal_handler(SourceSignalPair(source_signal_pair.source, [signal]), geo_pairs, time_pairs, issues, lag, as_of, fields_string, fields_int, fields_float, fields_order)) + + p = create_printer() if not printer else printer + + def gen() -> Iterable[Dict]: + for query, transform in query_transforms: + try: + if p.remaining_rows <= 0: + break + r = run_query(p, (str(query), query.params)) + except Exception as e: + raise DatabaseErrorException(str(e) + "\n" + str(query)) + + parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in r) + + if transform: + parsed_rows = transform(parsed_rows) + + if extra_transform: + parsed_rows = extra_transform(parsed_rows) + + yield from parsed_rows + + return p(gen()) + + +def base_signal_handler( + source_signal_pair: SourceSignalPair, + geo_pairs: List[GeoPair], + time_pairs: List[TimePair], + issues: List[DateRange], + lag: int, + as_of: int, + fields_string: List[str], + fields_int: List[str], + fields_float: List[str], + fields_order: List[str], +) -> Tuple[QueryBuilder, Optional[TransformType]]: + """Handles a base signal.""" # build query q = QueryBuilder(latest_table, "t") - fields_string = ["geo_type", "geo_value", "source", "signal", "time_type"] - fields_int = ["time_value", "direction", "issue", "lag", "missing_value", "missing_stderr", "missing_sample_size"] - fields_float = ["value", "stderr", "sample_size"] + q.set_order(*fields_order) + q.set_fields(fields_string, fields_int, fields_float) + + # basic query info + # data type of each field + # build the source, signal, time, and location (type and id) filters + q.where_source_signal_pairs("source", "signal", [source_signal_pair]) + q.where_geo_pairs("geo_type", "geo_value", geo_pairs) + q.where_time_pairs("time_type", "time_value", time_pairs) + + _handle_lag_issues_as_of(q, issues, lag, as_of) + + return q, None + + +def derived_signal_handler( + source_signal_pair: SourceSignalPair, + geo_pairs: List[GeoPair], + time_pairs: List[TimePair], + issues: List[DateRange], + lag: int, + as_of: int, + fields_string: List[str], + fields_int: List[str], + fields_float: List[str], + fields_order: List[str], +) -> Tuple[QueryBuilder, Optional[TransformType]]: + """Handles a derived signal.""" + transform_args = parse_transform_args() # TODO: JIT computations don't support time_value = *; there may be a clever way to implement this. - use_jit_compute = not any((issues, lag, is_time_type_week, is_time_value_true)) and JIT_COMPUTE_ON and not jit_bypass - if use_jit_compute: - transform_args = parse_transform_args() - pad_length = get_pad_length(source_signal_pairs, transform_args.get("smoother_window_length")) - time_pairs = pad_time_pairs(time_pairs, pad_length) - app.logger.info(f"JIT compute enabled for route '/': {source_signal_pairs}") - source_signal_pairs, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs, transform_args=transform_args) - app.logger.info(f"JIT base signals: {source_signal_pairs}") - - def gen_transform(rows): - parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows) - transformed_rows = row_transform_generator(parsed_rows=parsed_rows, time_pairs=time_pairs, transform_args=transform_args) - for row in transformed_rows: - yield alias_row(row) - else: - def gen_transform(rows): - parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows) - for row in parsed_rows: - yield alias_row(row) + app.logger.info(f"JIT compute enabled for signal: {source_signal_pair}") + try: + base_source_signal_pair, base_signal_transform = get_base_signal_transform(source_signal_pair) + app.logger.info(f"JIT transformation enabled for derived signal: {source_signal_pair} and using base signal: {base_source_signal_pair}.") + except ValueError as e: + app.logger.error("JIT transformation could not be found.", exc_info=e) + # TODO: fail gracefully, exit out with an empty iterator. + return "SELECT * FROM epimetric_latest_v LIMIT 0;", lambda x: x + app.logger.info(f"JIT base signals: {base_source_signal_pair}") + + pad_length = get_pad_length(base_signal_transform, transform_args.get("smoother_window_length")) + time_pairs = pad_time_pairs(time_pairs, pad_length) + + def gen_jit_transform(rows: Iterable[Dict]) -> Iterable[Dict]: + for _, grouped_rows in groupby(rows, lambda row: row["geo_value"]): + # Put the current time series on a contiguous time index. + grouped_rows = reindex_iterable(grouped_rows, time_pairs=time_pairs, fill_value=transform_args.get("pad_fill_value")) + for row in base_signal_transform(grouped_rows): + row["signal"] = source_signal_pair.signal[0] + yield row + + # build query + q = QueryBuilder(latest_table, "t") - q.set_order("source", "signal", "geo_type", "geo_value", "time_type", "time_value", "issue") + q.set_order(*fields_order) q.set_fields(fields_string, fields_int, fields_float) # basic query info # data type of each field # build the source, signal, time, and location (type and id) filters - q.where_source_signal_pairs("source", "signal", source_signal_pairs) + q.where_source_signal_pairs("source", "signal", [base_source_signal_pair]) q.where_geo_pairs("geo_type", "geo_value", geo_pairs) q.where_time_pairs("time_type", "time_value", time_pairs) _handle_lag_issues_as_of(q, issues, lag, as_of) - p = create_printer() + return q, gen_jit_transform - # execute first query - try: - r = run_query(p, (str(q), q.params)) - except Exception as e: - raise DatabaseErrorException(str(e)) - # now use a generator for sending the rows and execute all the other queries - return p(filter_fields(gen_transform(r))) +def compose_funcs(f: Callable, g: Callable) -> Callable: + """Compose two functions. + + Functions are applied from left to right. + """ + return lambda x: g(f(x)) + + +@bp.route("/", methods=("GET", "POST")) +def handle(): + """Base route for the API.""" + source_signal_pairs = parse_source_signal_pairs() + source_signal_pairs, alias_mapper = create_source_signal_alias_mapper(source_signal_pairs) + time_pairs = parse_time_pairs() + geo_pairs = parse_geo_pairs() + jit_bypass = parse_jit_bypass() + as_of = extract_date("as_of") + issues = extract_dates("issues") + lag = extract_integer("lag") + is_compatibility = is_compatibility_mode() + + fields_string = ["geo_type", "geo_value", "source", "signal", "time_type"] + fields_int = ["time_value", "direction", "issue", "lag", "missing_value", "missing_stderr", "missing_sample_size"] + fields_float = ["value", "stderr", "sample_size"] + fields_order = ["source", "signal", "geo_type", "geo_value", "time_type", "time_value", "issue"] + + def alias_row(row: Dict) -> Dict: + if is_compatibility: + # old api returned fewer fields + remove_fields = ["geo_type", "source", "time_type"] + for field in remove_fields: + if field in row: + del row[field] + return row + + if not alias_mapper or "source" not in row: + return row + else: + row["source"] = alias_mapper(row["source"], row["signal"]) + return row + + def gen_transform(rows: Iterable[Dict]) -> Iterable[Dict]: + for row in rows: + yield alias_row(row) + + filter_gen_transform = compose_funcs(gen_transform, filter_fields) + + return get_response( + source_signal_pairs, + geo_pairs, + time_pairs, + issues, + lag, + as_of, + jit_bypass, + fields_string, + fields_int, + fields_float, + fields_order, + filter_gen_transform + ) def _verify_argument_time_type_matches(is_day_argument: bool, count_daily_signal: int, count_weekly_signal: int) -> None: @@ -257,7 +386,6 @@ def handle_trend(): daily_signals, weekly_signals = count_signal_time_types(source_signal_pairs) source_signal_pairs, alias_mapper = create_source_signal_alias_mapper(source_signal_pairs) geo_pairs = parse_geo_pairs() - transform_args = parse_transform_args() jit_bypass = parse_jit_bypass() time_window, is_day = parse_day_or_week_range_arg("window") @@ -275,7 +403,7 @@ def handle_trend(): base_shift = 7 basis_time_value = shift_time_value(time_value, -1 * base_shift) if is_day else shift_week_value(time_value, -1 * base_shift) - def gen_trend(rows): + def gen_trend(rows: Iterable[Dict]) -> Iterable[Dict]: for key, group in groupby(rows, lambda row: (row["source"], row["signal"], row["geo_type"], row["geo_value"])): source, signal, geo_type, geo_value = key if alias_mapper: @@ -283,54 +411,28 @@ def gen_trend(rows): trend = compute_trend(geo_type, geo_value, source, signal, time_value, basis_time_value, ((row["time_value"], row["value"]) for row in group)) yield trend.asdict() - # build query - q = QueryBuilder(latest_table, "t") + filter_gen_trend = compose_funcs(gen_trend, filter_fields) fields_string = ["geo_type", "geo_value", "source", "signal"] fields_int = ["time_value"] fields_float = ["value"] + fields_order = ["source", "signal", "geo_type", "geo_value", "time_type", "time_value"] + time_pairs = [TimePair("day", [time_window])] if is_day else [TimePair("week", [time_window])] - use_jit_compute = all((is_day, is_also_day)) and JIT_COMPUTE_ON and not jit_bypass - if use_jit_compute: - pad_length = get_pad_length(source_signal_pairs, transform_args.get("smoother_window_length")) - app.logger.info(f"JIT compute enabled for route '/trend': {source_signal_pairs}") - source_signal_pairs, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs) - app.logger.info(f"JIT base signals: {source_signal_pairs}") - time_window = pad_time_window(time_window, pad_length) - - def gen_transform(rows): - parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows) - transformed_rows = row_transform_generator(parsed_rows=parsed_rows, time_pairs=[TimePair("day", [time_window])], transform_args=transform_args) - for row in transformed_rows: - yield row - else: - def gen_transform(rows): - parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows) - for row in parsed_rows: - yield row - - q.set_fields(fields_string, fields_int, fields_float) - q.set_order("source", "signal", "geo_type", "geo_value", "time_value") - - q.where_source_signal_pairs("source", "signal", source_signal_pairs) - q.where_geo_pairs("geo_type", "geo_value", geo_pairs) - q.where_time_pairs("time_type", "time_value", [TimePair("day" if is_day else "week", [time_window])]) - - # fetch most recent issue fast - _handle_lag_issues_as_of(q, None, None, None) - - p = create_printer() - - - # execute first query - try: - r = run_query(p, (str(q), q.params)) - except Exception as e: - raise DatabaseErrorException(str(e)) - - # now use a generator for sending the rows and execute all the other queries - return p(filter_fields(gen_trend(gen_transform(r)))) - + return get_response( + source_signal_pairs, + geo_pairs, + time_pairs, + None, + None, + None, + jit_bypass, + fields_string, + fields_int, + fields_float, + fields_order, + filter_gen_trend + ) @bp.route("/trendseries", methods=("GET", "POST")) def handle_trendseries(): @@ -339,7 +441,6 @@ def handle_trendseries(): daily_signals, weekly_signals = count_signal_time_types(source_signal_pairs) source_signal_pairs, alias_mapper = create_source_signal_alias_mapper(source_signal_pairs) geo_pairs = parse_geo_pairs() - transform_args = parse_transform_args() jit_bypass = parse_jit_bypass() time_window, is_day = parse_day_or_week_range_arg("window") @@ -354,7 +455,7 @@ def handle_trendseries(): if not is_day: shifter = lambda x: shift_week_value(x, -basis_shift) - def gen_trend(rows): + def gen_trend(rows: Iterable[Dict]) -> Iterable[Dict]: for key, group in groupby(rows, lambda row: (row["source"], row["signal"], row["geo_type"], row["geo_value"])): source, signal, geo_type, geo_value = key if alias_mapper: @@ -363,54 +464,31 @@ def gen_trend(rows): for t in trends: yield t.asdict() - # build query - q = QueryBuilder(latest_table, "t") + filter_gen_trend = compose_funcs(gen_trend, filter_fields) fields_string = ["geo_type", "geo_value", "source", "signal"] fields_int = ["time_value"] fields_float = ["value"] + fields_order = ["source", "signal", "geo_type", "geo_value", "time_type", "time_value"] + time_pairs = [TimePair("day", [time_window])] if is_day else [TimePair("week", [time_window])] - use_jit_compute = is_day and JIT_COMPUTE_ON and not jit_bypass - if use_jit_compute: - pad_length = get_pad_length(source_signal_pairs, transform_args.get("smoother_window_length")) - app.logger.info(f"JIT compute enabled for route '/trendseries': {source_signal_pairs}") - source_signal_pairs, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs) - app.logger.info(f"JIT base signals: {source_signal_pairs}") - time_window = pad_time_window(time_window, pad_length) - - def gen_transform(rows): - parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows) - transformed_rows = row_transform_generator(parsed_rows=parsed_rows, time_pairs=[TimePair("day", [time_window])], transform_args=transform_args) - for row in transformed_rows: - yield row - else: - def gen_transform(rows): - parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows) - for row in parsed_rows: - yield row - - q.set_fields(fields_string, fields_int, fields_float) - q.set_order("source", "signal", "geo_type", "geo_value", "time_value") - - q.where_source_signal_pairs("source", "signal", source_signal_pairs) - q.where_geo_pairs("geo_type", "geo_value", geo_pairs) - q.where_time_pairs("time_type", "time_value", [TimePair("day" if is_day else "week", [time_window])]) - - # fetch most recent issue fast - _handle_lag_issues_as_of(q, None, None, None) - - p = create_printer() - - # execute first query - try: - r = run_query(p, (str(q), q.params)) - except Exception as e: - raise DatabaseErrorException(str(e)) - - # now use a generator for sending the rows and execute all the other queries - return p(filter_fields(gen_trend(gen_transform(r)))) + return get_response( + source_signal_pairs, + geo_pairs, + time_pairs, + None, + None, + None, + jit_bypass, + fields_string, + fields_int, + fields_float, + fields_order, + filter_gen_trend + ) +# TODO: Update this endpoint if we don't use streaming in JIT. @bp.route("/correlation", methods=("GET", "POST")) def handle_correlation(): require_all("reference", "window", "others", "geo") @@ -486,6 +564,8 @@ def gen(): @bp.route("/csv", methods=("GET", "POST")) def handle_export(): source, signal = request.args.get("signal", "jhu-csse:confirmed_incidence_num").split(":") + if "," in signal: + raise ValidationFailedException("Only one signal can be exported at a time") source_signal_pairs = [SourceSignalPair(source, [signal])] daily_signals, weekly_signals = count_signal_time_types(source_signal_pairs) source_signal_pairs, alias_mapper = create_source_signal_alias_mapper(source_signal_pairs) @@ -497,7 +577,6 @@ def handle_export(): _verify_argument_time_type_matches(is_day, daily_signals, weekly_signals) - transform_args = parse_transform_args() jit_bypass = parse_jit_bypass() geo_type = request.args.get("geo_type", "county") @@ -510,46 +589,21 @@ def handle_export(): if is_day != is_as_of_day: raise ValidationFailedException("mixing weeks with day arguments") - # build query - q = QueryBuilder(latest_table, "t") - fields_string = ["geo_value", "signal", "geo_type", "source"] fields_int = ["time_value", "issue", "lag"] fields_float = ["value", "stderr", "sample_size"] - - use_jit_compute = all([is_day, is_end_day]) and JIT_COMPUTE_ON and not jit_bypass - if use_jit_compute: - pad_length = get_pad_length(source_signal_pairs, transform_args.get("smoother_window_length")) - app.logger.info(f"JIT compute enabled for route '/csv': {source_signal_pairs}") - source_signal_pairs, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs) - app.logger.info(f"JIT base signals: {source_signal_pairs}") - time_window = pad_time_window(time_window, pad_length) - - def gen_transform(rows): - parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows) - transformed_rows = row_transform_generator(parsed_rows=parsed_rows, time_pairs=[TimePair("day", [time_window])], transform_args=transform_args) - for row in transformed_rows: - yield row - else: - def gen_transform(rows): - for row in rows: - yield row - - q.set_fields(fields_string, fields_int, fields_float) - q.set_order("geo_value", "time_value") - q.where_source_signal_pairs("source", "signal", source_signal_pairs) - q.where_time_pairs("time_type", "time_value", [TimePair("day" if is_day else "week", [time_window])]) - q.where_geo_pairs("geo_type", "geo_value", [GeoPair(geo_type, True if geo_values == "*" else geo_values)]) - - _handle_lag_issues_as_of(q, None, None, as_of) + fields_order = ["source", "signal", "geo_type", "geo_value", "time_type", "time_value"] + time_pairs = [TimePair("day", [time_window])] if is_day else [TimePair("week", [time_window])] + geo_pairs = [GeoPair(geo_type, True if geo_values == "*" else geo_values)] format_date = time_value_to_iso if is_day else lambda x: week_value_to_week(x).cdcformat() # tag as_of in filename, if it was specified as_of_str = "-asof-{as_of}".format(as_of=format_date(as_of)) if as_of is not None else "" filename = "covidcast-{source}-{signal}-{start_day}-to-{end_day}{as_of}".format(source=source, signal=signal, start_day=format_date(start_day), end_day=format_date(end_day), as_of=as_of_str) + # TODO: We might have to pass printer to get_response p = CSVPrinter(filename) - def parse_csv_row(i, row): + def parse_csv_row(i: int, row: Row) -> Dict[str, Any]: # '',geo_value,signal,{time_value,issue},lag,value,stderr,sample_size,geo_type,data_source return { "": i, @@ -565,26 +619,39 @@ def parse_csv_row(i, row): "data_source": alias_mapper(row["source"], row["signal"]) if alias_mapper else row["source"], } - def gen_parse(rows): + def gen_parse(rows: Iterable[Row]) -> Iterable[Dict[str, Any]]: for i, row in enumerate(rows): yield parse_csv_row(i, row) - # execute query - try: - r = run_query(p, (str(q), q.params)) - except Exception as e: - raise DatabaseErrorException(str(e)) + def gen_peek(rows: Iterable[Dict[str, Any]]) -> Iterable[Dict[str, Any]]: + # special case for no data to be compatible with the CSV server + rows = peekable(rows) + first_row = rows.peek(None) + if not first_row: + return "No matching data found for signal {source}:{signal} " "at {geo} level from {start_day} to {end_day}, as of {as_of}.".format( + source=source, signal=signal, geo=geo_type, start_day=format_date(start_day), end_day=format_date(end_day), as_of=(date.today().isoformat() if as_of is None else format_date(as_of)) + ) - # special case for no data to be compatible with the CSV server - transformed_query = peekable(gen_transform(r)) - first_row = transformed_query.peek(None) - if not first_row: - return "No matching data found for signal {source}:{signal} " "at {geo} level from {start_day} to {end_day}, as of {as_of}.".format( - source=source, signal=signal, geo=geo_type, start_day=format_date(start_day), end_day=format_date(end_day), as_of=(date.today().isoformat() if as_of is None else format_date(as_of)) - ) + for row in rows: + yield row - # now use a generator for sending the rows and execute all the other queries - return p(gen_parse(transformed_query)) + gen_parse_gen_peek = compose_funcs(gen_parse, gen_peek) + + return get_response( + source_signal_pairs, + geo_pairs, + time_pairs, + None, + None, + None, + jit_bypass, + fields_string, + fields_int, + fields_float, + fields_order, + gen_parse_gen_peek, + p + ) @bp.route("/backfill", methods=("GET", "POST")) diff --git a/src/server/endpoints/covidcast_utils/model.py b/src/server/endpoints/covidcast_utils/model.py index 7c3df337f..002461733 100644 --- a/src/server/endpoints/covidcast_utils/model.py +++ b/src/server/endpoints/covidcast_utils/model.py @@ -1,13 +1,11 @@ from dataclasses import asdict, dataclass, field from enum import Enum -from functools import partial -from itertools import groupby, repeat, tee from numbers import Number -from typing import Callable, Generator, Iterator, Optional, Dict, List, Set, Tuple, Union +from typing import Callable, Iterator, Optional, Dict, List, Set, Tuple, Union from pathlib import Path import re -from more_itertools import flatten, interleave_longest, peekable +from more_itertools import flatten, peekable import pandas as pd import numpy as np @@ -318,7 +316,7 @@ def map_row(source: str, signal: str) -> str: return transformed_pairs, map_row -def _reindex_iterable(iterator: Iterator[Dict], time_pairs: Optional[List[TimePair]], fill_value: Optional[Number] = None) -> Iterator[Dict]: +def reindex_iterable(iterator: Iterator[Dict], time_pairs: Optional[List[TimePair]], fill_value: Optional[Number] = None) -> Iterator[Dict]: """Produces an iterator that fills in gaps in the time window of another iterator. Used to produce an iterator with a contiguous time index for time series operations. @@ -365,64 +363,69 @@ def _reindex_iterable(iterator: Iterator[Dict], time_pairs: Optional[List[TimePa yield default_item -def _get_base_signal_transform(signal: Union[DataSignal, Tuple[str, str]]) -> Callable: - """Given a DataSignal, return the transformation that needs to be applied to its base signal to derive the signal.""" - if isinstance(signal, DataSignal): - base_signal = data_signals_by_key.get((signal.source, signal.signal_basename)) - if signal.format not in [SignalFormat.raw, SignalFormat.raw_count, SignalFormat.count] or not signal.compute_from_base or not base_signal: - return IDENTITY - if signal.is_cumulative and signal.is_smoothed: - return SMOOTH - if not signal.is_cumulative and not signal.is_smoothed: - return DIFF if base_signal.is_cumulative else IDENTITY - if not signal.is_cumulative and signal.is_smoothed: - return DIFF_SMOOTH if base_signal.is_cumulative else SMOOTH - return IDENTITY +def is_derived(source: str, signal: Union[bool, str]) -> bool: + """Returns a list of derived signal pairs.""" + if isinstance(signal, bool): + return False - if isinstance(signal, tuple): - if signal := data_signals_by_key.get(signal): - return _get_base_signal_transform(signal) - return IDENTITY + data_signal = data_signals_by_key.get((source, signal)) + if data_signal and data_signal.compute_from_base and data_signal.format in [SignalFormat.raw, SignalFormat.raw_count, SignalFormat.count]: + return True + else: + return False - raise TypeError("signal must be either Tuple[str, str] or DataSignal.") +def get_base_signal_transform(derived_source_signal_pair: SourceSignalPair) -> Tuple[SourceSignalPair, TransformType]: + """Returns the transformation needed to get the derived signal. -def get_transform_types(source_signal_pairs: List[SourceSignalPair]) -> Set[Callable]: - """Return a collection of the unique transforms required for transforming a given source-signal pair list. + Assumed to have derived_source_signal_pair.signal be a singleton list with an actual derived signal. + Use is_derived beforehand to determine. + """ + derived_data_signal = data_signals_by_key.get((derived_source_signal_pair.source, derived_source_signal_pair.signal[0])) - Example: - SourceSignalPair("src", ["sig", "sig_smoothed", "sig_diff"]) would return {IDENTITY, SMOOTH, DIFF}. + if not derived_data_signal: + raise ValueError(f"Unrecognized signal {derived_data_signal}.") + elif not derived_data_signal.compute_from_base: + raise ValueError(f"A non-derived signal {derived_data_signal}. Verify signal is derived first.") - Used to pad the user DB query with extra days. - """ - transform_types = set() - for source_signal_pair in source_signal_pairs: - source_name = source_signal_pair.source - signal_names = source_signal_pair.signal + base_data_signal = data_signals_by_key.get((derived_data_signal.source, derived_data_signal.signal_basename)) - if isinstance(signal_names, bool): - continue + if not base_data_signal: + raise ValueError(f"The base signal could not be found for derived signal {derived_data_signal}.") - transform_types |= {_get_base_signal_transform((source_name, signal_name)) for signal_name in signal_names} + base_source_signal_pair = SourceSignalPair(derived_data_signal.source, [derived_data_signal.signal_basename]) - return transform_types + # Pure incidence signal, e.g. confirmed_cumulative_num -> confirmed_incidence_num + if base_data_signal.is_cumulative and not derived_data_signal.is_cumulative and not derived_data_signal.is_smoothed: + transform = DIFF + # Diffed and then smoothed signal, e.g. confirmed_cumulative_num -> confirmed_7dav_incidence_num + elif base_data_signal.is_cumulative and not derived_data_signal.is_cumulative and derived_data_signal.is_smoothed: + transform = DIFF_SMOOTH + # Smoothed signal, e.g. ageusia_raw_search -> ageusia_smoothed_search + elif not base_data_signal.is_cumulative and not derived_data_signal.is_cumulative and derived_data_signal.is_smoothed: + transform = SMOOTH + # Currently no signals fit this, but here for completeness. + elif base_data_signal.is_cumulative and derived_data_signal.is_cumulative and derived_data_signal.is_smoothed: + transform = SMOOTH + # Something went wrong, fail gracefully. + else: + raise ValueError(f"Transformation couldn't be found for derived signal: {derived_source_signal_pair} with base signal: {base_source_signal_pair}") + return base_source_signal_pair, transform -def get_pad_length(source_signal_pairs: List[SourceSignalPair], smoother_window_length: int): - """Returns the size of the extra date padding needed, depending on the transformations the source-signal pair list requires. + +def get_pad_length(transform: TransformType, smoother_window_length: int): + """Returns the size of the extra date padding needed, depending on the transform. If smoothing is required, we fetch an extra smoother_window_length - 1 days (6 by default). If both diffing and smoothing is required on the same signal, then we fetch extra smoother_window_length days (7 by default). - - Used to pad the user DB query with extra days. """ - transform_types = get_transform_types(source_signal_pairs) pad_length = [0] - if DIFF_SMOOTH in transform_types: + if DIFF_SMOOTH == transform: pad_length.append(smoother_window_length) - if SMOOTH in transform_types: + if SMOOTH == transform: pad_length.append(smoother_window_length - 1) - if DIFF in transform_types: + if DIFF == transform: pad_length.append(1) return max(pad_length) @@ -448,22 +451,6 @@ def pad_time_pairs(time_pairs: List[TimePair], pad_length: int) -> List[TimePair return time_pairs + [padded_time] -def pad_time_window(time_window: Tuple[int, int], pad_length: int) -> Tuple[int, int]: - """Extend a time window on the left by pad_length. - - Example: - (20210407, 20210413) with pad_length 6 would return (20210401, 20210413). - - Used to pad the user DB query with extra days. - """ - if pad_length < 0: - raise ValueError("pad_length should non-negative.") - if pad_length == 0: - return time_window - min_time, max_time = time_window - return (shift_time_value(min_time, -1 * pad_length), max_time) - - def get_day_range(time_pairs: List[TimePair]) -> Iterator[int]: """Iterate over a list of TimePair time_values, including the values contained in the ranges. @@ -485,91 +472,3 @@ def get_day_range(time_pairs: List[TimePair]) -> Iterator[int]: raise ValueError("get_day_range only supports int or list time_values.") return iterate_over_ints_and_ranges(time_values_flattened) - - -def _generate_transformed_rows( - parsed_rows: Iterator[Dict], - time_pairs: Optional[List[TimePair]] = None, - transform_dict: Optional[SignalTransforms] = None, - transform_args: Optional[Dict] = None, - group_keyfunc: Optional[Callable] = None, -) -> Iterator[Dict]: - """Applies time-series transformations to streamed rows from a database. - - Parameters: - parsed_rows: Iterator[Dict] - An iterator streaming rows from a database query. Assumed to be sorted by source, signal, geo_type, geo_value, time_type, and time_value. - time_pairs: Optional[List[TimePair]], default None - A list of TimePairs, which can be used to create a continguous time index for time-series operations. - The min and max dates in the TimePairs list is used. - transform_dict: Optional[SignalTransforms], default None - A dictionary mapping base sources to a list of their derived signals that the user wishes to query. - For example, transform_dict may be {("jhu-csse", "confirmed_cumulative_num): [("jhu-csse", "confirmed_incidence_num"), ("jhu-csse", "confirmed_7dav_incidence_num")]}. - transform_args: Optional[Dict], default None - A dictionary of keyword arguments for the transformer functions. - group_keyfunc: Optional[Callable], default None - The groupby function to use to order the streamed rows. Note that Python groupby does not do any sorting, so - parsed_rows are assumed to be sorted in accord with this groupby. - - Yields: - transformed rows: Dict - The transformed rows returned in an interleaved fashion. Non-transformed rows have the IDENTITY operation applied. - """ - if not transform_args: - transform_args = dict() - if not transform_dict: - transform_dict = dict() - if not group_keyfunc: - group_keyfunc = lambda row: (row["source"], row["signal"], row["geo_type"], row["geo_value"]) - - for key, source_signal_geo_rows in groupby(parsed_rows, group_keyfunc): - base_source_name, base_signal_name, _, _ = key - # Extract the list of derived signals; if a signal is not in the dictionary, then use the identity map. - derived_signal_transform_map: SourceSignalPair = transform_dict.get(SourceSignalPair(base_source_name, [base_signal_name]), SourceSignalPair(base_source_name, [base_signal_name])) - # Create a list of source-signal pairs along with the transformation required for the signal. - signal_names_and_transforms: List[Tuple[Tuple[str, str], Callable]] = [(derived_signal, _get_base_signal_transform((base_source_name, derived_signal))) for derived_signal in derived_signal_transform_map.signal] - # Put the current time series on a contiguous time index. - source_signal_geo_rows = _reindex_iterable(source_signal_geo_rows, time_pairs, fill_value=transform_args.get("pad_fill_value")) - # Create copies of the iterable, with smart memory usage. - source_signal_geo_rows_copies: Iterator[Iterator[Dict]] = tee(source_signal_geo_rows, len(signal_names_and_transforms)) - # Create a list of transformed group iterables, remembering their derived name as needed. - transformed_signals_iterator: Iterator[Tuple[str, Iterator[Dict]]] = (zip(repeat(derived_signal), transform(rows, **transform_args)) for (derived_signal, transform), rows in zip(signal_names_and_transforms, source_signal_geo_rows_copies)) - # Traverse through the transformed iterables in an interleaved fashion, which makes sure that only a small window - # of the original iterable (group) is stored in memory. - for derived_signal_name, row in interleave_longest(*transformed_signals_iterator): - row["signal"] = derived_signal_name - yield row - - -def get_basename_signal_and_jit_generator(source_signal_pairs: List[SourceSignalPair], transform_args: Optional[Dict[str, Union[str, int]]] = None) -> Tuple[List[SourceSignalPair], Generator]: - """From a list of SourceSignalPairs, return the base signals required to derive them and a transformation function to take a stream - of the base signals and return the transformed signals. - - Example: - SourceSignalPair("src", signal=["sig_base", "sig_smoothed"]) would return SourceSignalPair("src", signal=["sig_base"]) and a transformation function - that will take the returned database query for "sig_base" and return both the base time series and the smoothed time series. transform_dict in this case - would be {("src", "sig_base"): [("src", "sig_base"), ("src", "sig_smooth")]}. - """ - base_signal_pairs: List[SourceSignalPair] = [] - transform_dict: SignalTransforms = dict() - - for pair in source_signal_pairs: - # Should only occur when the SourceSignalPair was unrecognized by _resolve_bool_source_signals. Useful for testing with fake signal names. - if isinstance(pair.signal, bool): - base_signal_pairs.append(pair) - continue - - signals = [] - for signal_name in pair.signal: - signal = data_signals_by_key.get((pair.source, signal_name)) - if not signal or not signal.compute_from_base: - transform_dict.setdefault(SourceSignalPair(source=pair.source, signal=[signal_name]), SourceSignalPair(source=pair.source, signal=[])).add_signal(signal_name) - signals.append(signal_name) - else: - transform_dict.setdefault(SourceSignalPair(source=pair.source, signal=[signal.signal_basename]), SourceSignalPair(source=pair.source, signal=[])).add_signal(signal_name) - signals.append(signal.signal_basename) - base_signal_pairs.append(SourceSignalPair(pair.source, signals)) - - row_transform_generator = partial(_generate_transformed_rows, transform_dict=transform_dict, transform_args=transform_args) - - return base_signal_pairs, row_transform_generator diff --git a/tests/server/endpoints/covidcast_utils/test_model.py b/tests/server/endpoints/covidcast_utils/test_model.py index ffe05405f..c4dd89aba 100644 --- a/tests/server/endpoints/covidcast_utils/test_model.py +++ b/tests/server/endpoints/covidcast_utils/test_model.py @@ -3,7 +3,7 @@ from unittest.mock import patch import pandas as pd -from more_itertools import interleave_longest, windowed +import pytest from pandas.testing import assert_frame_equal from delphi.epidata.acquisition.covidcast.covidcast_row import CovidcastRows @@ -13,38 +13,35 @@ DIFF_SMOOTH, IDENTITY, SMOOTH, - _generate_transformed_rows, - _get_base_signal_transform, - _reindex_iterable, - get_basename_signal_and_jit_generator, + is_derived, + get_base_signal_transform, + reindex_iterable, get_day_range, get_pad_length, - get_transform_types, pad_time_pairs, ) -from delphi_utils.nancodes import Nans -from delphi.epidata.server.endpoints.covidcast_utils.test_utils import DATA_SOURCES_BY_ID, DATA_SIGNALS_BY_KEY, _diff_rows, _smooth_rows, _reindex_windowed +from delphi.epidata.server.endpoints.covidcast_utils.test_utils import DATA_SOURCES_BY_ID, DATA_SIGNALS_BY_KEY @patch("delphi.epidata.server.endpoints.covidcast_utils.model.data_sources_by_id", DATA_SOURCES_BY_ID) @patch("delphi.epidata.server.endpoints.covidcast_utils.model.data_signals_by_key", DATA_SIGNALS_BY_KEY) class TestModel(unittest.TestCase): - def test__reindex_iterable(self): + def test_reindex_iterable(self): # Trivial test. time_pairs = [(20210503, 20210508)] - assert list(_reindex_iterable([], time_pairs)) == [] + assert list(reindex_iterable([], time_pairs)) == [] data = CovidcastRows.from_args(time_value=pd.date_range("2021-05-03", "2021-05-08").to_list()).api_row_df for time_pairs in [[TimePair("day", [(20210503, 20210508)])], [], None]: with self.subTest(f"Identity operations: {time_pairs}"): - df = CovidcastRows.from_records(_reindex_iterable(data.to_dict(orient="records"), time_pairs)).api_row_df + df = CovidcastRows.from_records(reindex_iterable(data.to_dict(orient="records"), time_pairs)).api_row_df assert_frame_equal(df, data) data = CovidcastRows.from_args(time_value=pd.date_range("2021-05-03", "2021-05-08").to_list() + pd.date_range("2021-05-11", "2021-05-14").to_list()).api_row_df with self.subTest("Non-trivial operations"): time_pairs = [TimePair("day", [(20210501, 20210513)])] - df = CovidcastRows.from_records(_reindex_iterable(data.to_dict(orient="records"), time_pairs)).api_row_df + df = CovidcastRows.from_records(reindex_iterable(data.to_dict(orient="records"), time_pairs)).api_row_df expected_df = CovidcastRows.from_args( time_value=pd.date_range("2021-05-03", "2021-05-13"), issue=pd.date_range("2021-05-03", "2021-05-08").to_list() + [None] * 2 + pd.date_range("2021-05-11", "2021-05-13").to_list(), @@ -55,7 +52,7 @@ def test__reindex_iterable(self): ).api_row_df assert_frame_equal(df, expected_df) - df = CovidcastRows.from_records(_reindex_iterable(data.to_dict(orient="records"), time_pairs, fill_value=2.0)).api_row_df + df = CovidcastRows.from_records(reindex_iterable(data.to_dict(orient="records"), time_pairs, fill_value=2.0)).api_row_df expected_df = CovidcastRows.from_args( time_value=pd.date_range("2021-05-03", "2021-05-13"), issue=pd.date_range("2021-05-03", "2021-05-08").to_list() + [None] * 2 + pd.date_range("2021-05-11", "2021-05-13").to_list(), @@ -66,42 +63,27 @@ def test__reindex_iterable(self): ).api_row_df assert_frame_equal(df, expected_df) - def test__get_base_signal_transform(self): - assert _get_base_signal_transform(("src", "sig_smooth")) == SMOOTH - assert _get_base_signal_transform(("src", "sig_diff_smooth")) == DIFF_SMOOTH - assert _get_base_signal_transform(("src", "sig_diff")) == DIFF - assert _get_base_signal_transform(("src", "sig_diff")) == DIFF - assert _get_base_signal_transform(("src", "sig_base")) == IDENTITY - assert _get_base_signal_transform(("src", "sig_unknown")) == IDENTITY + def test_is_derived(self): + assert is_derived("src", "sig_smooth") is True + assert is_derived("src", True) is False + assert is_derived("src", "sig_base") is False - def test_get_transform_types(self): - source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_diff"])] - transform_types = get_transform_types(source_signal_pairs) - expected_transform_types = {DIFF} - assert transform_types == expected_transform_types + def test_get_base_signal_transform(self): + assert get_base_signal_transform(SourceSignalPair("src", ["sig_smooth"])) == (SourceSignalPair("src", ["sig_base"]), SMOOTH) + assert get_base_signal_transform(SourceSignalPair("src", ["sig_diff_smooth"])) == (SourceSignalPair("src", ["sig_base"]), DIFF_SMOOTH) + assert get_base_signal_transform(SourceSignalPair("src", ["sig_diff"])) == (SourceSignalPair("src", ["sig_base"]), DIFF) - source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_smooth"])] - transform_types = get_transform_types(source_signal_pairs) - expected_transform_types = {SMOOTH} - assert transform_types == expected_transform_types - - source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_diff_smooth"])] - transform_types = get_transform_types(source_signal_pairs) - expected_transform_types = {DIFF_SMOOTH} - assert transform_types == expected_transform_types + with pytest.raises(ValueError, match=r"A non-derived signal*"): + get_base_signal_transform(SourceSignalPair("src", ["sig_base"])) + with pytest.raises(ValueError, match=r"Unrecognized signal*"): + get_base_signal_transform(SourceSignalPair("src", ["sig_unknown"])) def test_get_pad_length(self): - source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_diff"])] - pad_length = get_pad_length(source_signal_pairs, smoother_window_length=7) - assert pad_length == 1 - - source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_smooth"])] - pad_length = get_pad_length(source_signal_pairs, smoother_window_length=5) - assert pad_length == 4 - - source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_diff_smooth"])] - pad_length = get_pad_length(source_signal_pairs, smoother_window_length=10) - assert pad_length == 10 + assert get_pad_length(IDENTITY, smoother_window_length=7) == 0 + assert get_pad_length(SMOOTH, smoother_window_length=7) == 6 + assert get_pad_length(DIFF, smoother_window_length=7) == 1 + assert get_pad_length(SMOOTH, smoother_window_length=5) == 4 + assert get_pad_length(DIFF_SMOOTH, smoother_window_length=10) == 10 def test_pad_time_pairs(self): # fmt: off @@ -111,9 +93,9 @@ def test_pad_time_pairs(self): TimePair("day", [20210816]) ] expected_padded_time_pairs = [ - TimePair("day", [20210817, (20210810, 20210815)]), - TimePair("day", True), - TimePair("day", [20210816]), + TimePair("day", [20210817, (20210810, 20210815)]), + TimePair("day", True), + TimePair("day", [20210816]), TimePair("day", [(20210803, 20210810)]) ] assert pad_time_pairs(time_pairs, pad_length=7) == expected_padded_time_pairs @@ -142,213 +124,6 @@ def test_pad_time_pairs(self): assert pad_time_pairs(time_pairs, pad_length=0) == expected_padded_time_pairs # fmt: on - def test__generate_transformed_rows(self): - # fmt: off - with self.subTest("diffed signal test"): - data = CovidcastRows.from_args( - signal=["sig_base"] * 5, - time_value=pd.date_range("2021-05-01", "2021-05-05"), - value=range(5) - ).api_row_df - transform_dict = {SourceSignalPair("src", ["sig_base"]): SourceSignalPair("src", ["sig_diff"])} - df = CovidcastRows.from_records(_generate_transformed_rows(data.to_dict(orient="records"), transform_dict=transform_dict)).api_row_df - - expected_df = CovidcastRows.from_args( - signal=["sig_diff"] * 4, - time_value=pd.date_range("2021-05-02", "2021-05-05"), - value=[1.0] * 4, - stderr=[None] * 4, - sample_size=[None] * 4, - missing_stderr=[Nans.NOT_APPLICABLE] * 4, - missing_sample_size=[Nans.NOT_APPLICABLE] * 4, - ).api_row_df - - assert_frame_equal(df, expected_df) - - with self.subTest("smoothed and diffed signals on one base test"): - data = CovidcastRows.from_args( - signal=["sig_base"] * 10, - time_value=pd.date_range("2021-05-01", "2021-05-10"), - value=range(10), - stderr=range(10), - sample_size=range(10) - ).api_row_df - transform_dict = {SourceSignalPair("src", ["sig_base"]): SourceSignalPair("src", ["sig_diff", "sig_smooth"])} - df = CovidcastRows.from_records(_generate_transformed_rows(data.to_dict(orient="records"), transform_dict=transform_dict)).api_row_df - - expected_df = CovidcastRows.from_args( - signal=interleave_longest(["sig_diff"] * 9, ["sig_smooth"] * 4), - time_value=interleave_longest(pd.date_range("2021-05-02", "2021-05-10"), pd.date_range("2021-05-07", "2021-05-10")), - value=interleave_longest(_diff_rows(data.value.to_list()), _smooth_rows(data.value.to_list())), - stderr=[None] * 13, - sample_size=[None] * 13, - ).api_row_df - - # Test no order. - idx = ["source", "signal", "time_value"] - assert_frame_equal(df.set_index(idx).sort_index(), expected_df.set_index(idx).sort_index()) - # Test order. - assert_frame_equal(df, expected_df) - - with self.subTest("smoothed and diffed signal on two non-continguous regions"): - data = CovidcastRows.from_args( - signal=["sig_base"] * 15, - time_value=chain(pd.date_range("2021-05-01", "2021-05-10"), pd.date_range("2021-05-16", "2021-05-20")), - value=range(15), - stderr=range(15), - sample_size=range(15), - ).api_row_df - transform_dict = {SourceSignalPair("src", ["sig_base"]): SourceSignalPair("src", ["sig_diff", "sig_smooth"])} - time_pairs = [TimePair("day", [(20210501, 20210520)])] - df = CovidcastRows.from_records( - _generate_transformed_rows(data.to_dict(orient="records"), time_pairs=time_pairs, transform_dict=transform_dict) - ).api_row_df - - filled_values = data.value.to_list()[:10] + [None] * 5 + data.value.to_list()[10:] - filled_time_values = list(chain(pd.date_range("2021-05-01", "2021-05-10"), [None] * 5, pd.date_range("2021-05-16", "2021-05-20"))) - - expected_df = CovidcastRows.from_args( - signal=interleave_longest(["sig_diff"] * 19, ["sig_smooth"] * 14), - time_value=interleave_longest(pd.date_range("2021-05-02", "2021-05-20"), pd.date_range("2021-05-07", "2021-05-20")), - value=interleave_longest(_diff_rows(filled_values), _smooth_rows(filled_values)), - stderr=[None] * 33, - sample_size=[None] * 33, - issue=interleave_longest(_reindex_windowed(filled_time_values, 2), _reindex_windowed(filled_time_values, 7)), - ).api_row_df - # Test no order. - idx = ["source", "signal", "time_value"] - assert_frame_equal(df.set_index(idx).sort_index(), expected_df.set_index(idx).sort_index()) - # Test order. - assert_frame_equal(df, expected_df) - # fmt: on - - def test_get_basename_signals(self): - with self.subTest("none to transform"): - source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_base"])] - basename_pairs, _ = get_basename_signal_and_jit_generator(source_signal_pairs) - expected_basename_pairs = [SourceSignalPair(source="src", signal=["sig_base"])] - assert basename_pairs == expected_basename_pairs - - with self.subTest("unrecognized signal"): - source_signal_pairs = [SourceSignalPair(source="src", signal=["sig_unknown"])] - basename_pairs, _ = get_basename_signal_and_jit_generator(source_signal_pairs) - expected_basename_pairs = [SourceSignalPair(source="src", signal=["sig_unknown"])] - assert basename_pairs == expected_basename_pairs - - with self.subTest("plain"): - source_signal_pairs = [ - SourceSignalPair(source="src", signal=["sig_diff", "sig_smooth", "sig_diff_smooth", "sig_base"]), - SourceSignalPair(source="src2", signal=["sig"]), - ] - basename_pairs, _ = get_basename_signal_and_jit_generator(source_signal_pairs) - expected_basename_pairs = [ - SourceSignalPair(source="src", signal=["sig_base", "sig_base", "sig_base", "sig_base"]), - SourceSignalPair(source="src2", signal=["sig"]), - ] - assert basename_pairs == expected_basename_pairs - - with self.subTest("test base, diff, smooth"): - # fmt: off - data = CovidcastRows.from_args( - signal=["sig_base"] * 20 + ["sig_other"] * 5, - time_value=chain(pd.date_range("2021-05-01", "2021-05-10"), pd.date_range("2021-05-21", "2021-05-30"), pd.date_range("2021-05-01", "2021-05-05")), - value=chain(range(20), range(5)), - stderr=chain(range(20), range(5)), - sample_size=chain(range(20), range(5)), - ).api_row_df - source_signal_pairs = [SourceSignalPair("src", ["sig_base", "sig_diff", "sig_other", "sig_smooth"])] - _, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs) - time_pairs = [TimePair("day", [(20210501, 20210530)])] - df = CovidcastRows.from_records(row_transform_generator(data.to_dict(orient="records"), time_pairs=time_pairs)).api_row_df - - filled_values = list(chain(range(10), [None] * 10, range(10, 20))) - filled_time_values = list(chain(pd.date_range("2021-05-01", "2021-05-10"), [None] * 10, pd.date_range("2021-05-21", "2021-05-30"))) - - expected_df = CovidcastRows.from_args( - signal=["sig_base"] * 30 + ["sig_diff"] * 29 + ["sig_other"] * 5 + ["sig_smooth"] * 24, - time_value=chain( - pd.date_range("2021-05-01", "2021-05-30"), - pd.date_range("2021-05-02", "2021-05-30"), - pd.date_range("2021-05-01", "2021-05-05"), - pd.date_range("2021-05-07", "2021-05-30") - ), - value=chain( - filled_values, - _diff_rows(filled_values), - range(5), - _smooth_rows(filled_values) - ), - stderr=chain( - chain(range(10), [None] * 10, range(10, 20)), - chain([None] * 29), - range(5), - chain([None] * 24), - ), - sample_size=chain( - chain(range(10), [None] * 10, range(10, 20)), - chain([None] * 29), - range(5), - chain([None] * 24), - ), - issue=chain(filled_time_values, _reindex_windowed(filled_time_values, 2), pd.date_range("2021-05-01", "2021-05-05"), _reindex_windowed(filled_time_values, 7)), - ).api_row_df - # fmt: on - # Test no order. - idx = ["source", "signal", "time_value"] - assert_frame_equal(df.set_index(idx).sort_index(), expected_df.set_index(idx).sort_index()) - - with self.subTest("test base, diff, smooth; multiple geos"): - # fmt: off - data = CovidcastRows.from_args( - signal=["sig_base"] * 40, - geo_value=["ak"] * 20 + ["ca"] * 20, - time_value=chain(pd.date_range("2021-05-01", "2021-05-20"), pd.date_range("2021-05-01", "2021-05-20")), - value=chain(range(20), range(0, 40, 2)), - stderr=chain(range(20), range(0, 40, 2)), - sample_size=chain(range(20), range(0, 40, 2)), - ).api_row_df - source_signal_pairs = [SourceSignalPair("src", ["sig_base", "sig_diff", "sig_other", "sig_smooth"])] - _, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs) - df = CovidcastRows.from_records(row_transform_generator(data.to_dict(orient="records"))).api_row_df - - expected_df = CovidcastRows.from_args( - signal=["sig_base"] * 40 + ["sig_diff"] * 38 + ["sig_smooth"] * 28, - geo_value=["ak"] * 20 + ["ca"] * 20 + ["ak"] * 19 + ["ca"] * 19 + ["ak"] * 14 + ["ca"] * 14, - time_value=chain( - pd.date_range("2021-05-01", "2021-05-20"), - pd.date_range("2021-05-01", "2021-05-20"), - pd.date_range("2021-05-02", "2021-05-20"), - pd.date_range("2021-05-02", "2021-05-20"), - pd.date_range("2021-05-07", "2021-05-20"), - pd.date_range("2021-05-07", "2021-05-20"), - ), - value=chain( - chain(range(20), range(0, 40, 2)), - chain([1] * 19, [2] * 19), - chain([sum(x) / len(x) for x in windowed(range(20), 7)], - [sum(x) / len(x) for x in windowed(range(0, 40, 2), 7)]) - ), - stderr=chain( - chain(range(20), range(0, 40, 2)), - chain([None] * 38), - chain([None] * 28), - ), - sample_size=chain( - chain(range(20), range(0, 40, 2)), - chain([None] * 38), - chain([None] * 28), - ), - ).api_row_df - # fmt: on - # Test no order. - idx = ["source", "signal", "time_value"] - assert_frame_equal(df.set_index(idx).sort_index(), expected_df.set_index(idx).sort_index()) - - with self.subTest("empty iterator"): - source_signal_pairs = [SourceSignalPair("src", ["sig_base", "sig_diff", "sig_smooth"])] - _, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs) - assert list(row_transform_generator({})) == [] - def test_get_day_range(self): assert list(get_day_range([TimePair("day", [20210817])])) == [20210817] assert list(get_day_range([TimePair("day", [20210817, (20210810, 20210815)])])) == [20210810, 20210811, 20210812, 20210813, 20210814, 20210815, 20210817] diff --git a/tests/server/endpoints/covidcast_utils/test_smooth_diff.py b/tests/server/endpoints/covidcast_utils/test_smooth_diff.py index ad72e2e2f..1a9fba052 100644 --- a/tests/server/endpoints/covidcast_utils/test_smooth_diff.py +++ b/tests/server/endpoints/covidcast_utils/test_smooth_diff.py @@ -58,7 +58,7 @@ def test_generate_smoothed_rows(self): with self.subTest("regular window, 0 fill"): smoothed_df = CovidcastRows.from_records(generate_smoothed_rows(data.to_dict(orient='records'), nan_fill_value=0.)).api_row_df - + smoothed_values = _smooth_rows([v if v is not None and not isnan(v) else 0. for v in data.value.to_list()]) reduced_time_values = data.time_value.to_list()[-len(smoothed_values):]