Skip to content

Commit d3cb7bc

Browse files
committed
fix: epiweek and date handling
1 parent 9c79289 commit d3cb7bc

File tree

3 files changed

+86
-44
lines changed

3 files changed

+86
-44
lines changed

epidatpy/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""Fetch data from Delphi's API."""
22

33
# Make the linter happy about the unused variables
4-
__all__ = ["__version__", "Epidata", "CovidcastEpidata", "EpiRange"]
4+
__all__ = ["__version__", "EpiDataContext", "CovidcastEpidata", "EpiRange"]
55
__author__ = "Delphi Research Group"
66

77

88
from ._constants import __version__
9-
from .request import CovidcastEpidata, EpiDataContext, EpiRange
9+
from ._model import EpiRange
10+
from .request import CovidcastEpidata, EpiDataContext

epidatpy/request.py

+49-28
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from os import environ
12
from typing import (
23
Any,
34
Dict,
@@ -13,7 +14,6 @@
1314
from appdirs import user_cache_dir
1415
from diskcache import Cache
1516
from pandas import CategoricalDtype, DataFrame, Series, to_datetime
16-
from os import environ
1717
from requests import Response, Session
1818
from requests.auth import HTTPBasicAuth
1919
from tenacity import retry, stop_after_attempt
@@ -27,21 +27,22 @@
2727
EpidataFieldInfo,
2828
EpidataFieldType,
2929
EpiDataResponse,
30-
EpiRange,
3130
EpiRangeParam,
3231
OnlySupportsClassicFormatException,
3332
add_endpoint_to_url,
3433
)
3534
from ._parse import fields_to_predicate
3635

3736
# Make the linter happy about the unused variables
38-
__all__ = ["Epidata", "EpiDataCall", "EpiDataContext", "EpiRange", "CovidcastEpidata"]
3937
CACHE_DIRECTORY = user_cache_dir(appname="epidatpy", appauthor="delphi")
4038

4139
if environ.get("USE_EPIDATPY_CACHE", None):
42-
print(f"diskcache is being used (unset USE_EPIDATPY_CACHE if not intended). "
43-
f"The cache directory is {CACHE_DIRECTORY}. "
44-
f"The TTL is set to {environ.get("EPIDATPY_CACHE_MAX_AGE_DAYS", "7")} days.")
40+
print(
41+
f"diskcache is being used (unset USE_EPIDATPY_CACHE if not intended). "
42+
f"The cache directory is {CACHE_DIRECTORY}. "
43+
f"The TTL is set to {environ.get('EPIDATPY_CACHE_MAX_AGE_DAYS', '7')} days."
44+
)
45+
4546

4647
@retry(reraise=True, stop=stop_after_attempt(2))
4748
def _request_with_retry(
@@ -67,9 +68,7 @@ def call_impl(s: Session) -> Response:
6768

6869

6970
class EpiDataCall(AEpiDataCall):
70-
"""
71-
epidata call representation
72-
"""
71+
"""epidata call representation"""
7372

7473
_session: Final[Optional[Session]]
7574

@@ -101,7 +100,7 @@ def _call(
101100
url, params = self.request_arguments(fields)
102101
return _request_with_retry(url, params, self._session, stream)
103102

104-
def _get_cache_key(self, method) -> str:
103+
def _get_cache_key(self, method: str) -> str:
105104
cache_key = f"{self._endpoint} | {method}"
106105
if self._params:
107106
cache_key += f" | {str(dict(sorted(self._params.items())))}"
@@ -120,7 +119,7 @@ def classic(
120119
with Cache(CACHE_DIRECTORY) as cache:
121120
cache_key = self._get_cache_key("classic")
122121
if cache_key in cache:
123-
return cache[cache_key]
122+
return cast(EpiDataResponse, cache[cache_key])
124123
response = self._call(fields)
125124
r = cast(EpiDataResponse, response.json())
126125
if disable_type_parsing:
@@ -131,7 +130,7 @@ def classic(
131130
if self.use_cache:
132131
with Cache(CACHE_DIRECTORY) as cache:
133132
cache_key = self._get_cache_key("classic")
134-
cache.set(cache_key, r, expire=self.cache_max_age_days*24*60*60)
133+
cache.set(cache_key, r, expire=self.cache_max_age_days * 24 * 60 * 60)
135134
return r
136135
except Exception as e: # pylint: disable=broad-except
137136
return {"result": 0, "message": f"error: {e}", "epidata": []}
@@ -143,7 +142,11 @@ def __call__(
143142
) -> Union[EpiDataResponse, DataFrame]:
144143
"""Request and parse epidata in df message format."""
145144
if self.only_supports_classic:
146-
return self.classic(fields, disable_date_parsing=disable_date_parsing, disable_type_parsing=False)
145+
return self.classic(
146+
fields,
147+
disable_date_parsing=disable_date_parsing,
148+
disable_type_parsing=False,
149+
)
147150
return self.df(fields, disable_date_parsing=disable_date_parsing)
148151

149152
def df(
@@ -160,7 +163,7 @@ def df(
160163
with Cache(CACHE_DIRECTORY) as cache:
161164
cache_key = self._get_cache_key("df")
162165
if cache_key in cache:
163-
return cache[cache_key]
166+
return cast(DataFrame, cache[cache_key])
164167

165168
json = self.classic(fields, disable_type_parsing=True)
166169
rows = json.get("epidata", [])
@@ -177,7 +180,8 @@ def df(
177180
data_types[info.name] = bool
178181
elif info.type == EpidataFieldType.categorical:
179182
data_types[info.name] = CategoricalDtype(
180-
categories=Series(info.categories) if info.categories else None, ordered=True
183+
categories=Series(info.categories) if info.categories else None,
184+
ordered=True,
181185
)
182186
elif info.type == EpidataFieldType.int:
183187
data_types[info.name] = "Int64"
@@ -196,8 +200,10 @@ def df(
196200
df = df.astype(data_types)
197201
if not disable_date_parsing:
198202
for info in time_fields:
199-
if info.type == EpidataFieldType.epiweek or info.type == EpidataFieldType.date_or_epiweek:
203+
if info.type == EpidataFieldType.epiweek:
200204
continue
205+
# Try two date foramts, otherwise keep as string. The try except
206+
# is needed because the time field might be date_or_epiweek.
201207
try:
202208
df[info.name] = to_datetime(df[info.name], format="%Y-%m-%d")
203209
continue
@@ -211,15 +217,13 @@ def df(
211217
if self.use_cache:
212218
with Cache(CACHE_DIRECTORY) as cache:
213219
cache_key = self._get_cache_key("df")
214-
cache.set(cache_key, df, expire=self.cache_max_age_days*24*60*60)
220+
cache.set(cache_key, df, expire=self.cache_max_age_days * 24 * 60 * 60)
215221

216222
return df
217223

218224

219225
class EpiDataContext(AEpiDataEndpoints[EpiDataCall]):
220-
"""
221-
sync epidata call class
222-
"""
226+
"""sync epidata call class"""
223227

224228
_base_url: Final[str]
225229
_session: Final[Optional[Session]]
@@ -249,16 +253,25 @@ def _create_call(
249253
params: Mapping[str, Optional[EpiRangeParam]],
250254
meta: Optional[Sequence[EpidataFieldInfo]] = None,
251255
only_supports_classic: bool = False,
252-
253256
) -> EpiDataCall:
254-
return EpiDataCall(self._base_url, self._session, endpoint, params, meta, only_supports_classic, self.use_cache, self.cache_max_age_days)
257+
return EpiDataCall(
258+
self._base_url,
259+
self._session,
260+
endpoint,
261+
params,
262+
meta,
263+
only_supports_classic,
264+
self.use_cache,
265+
self.cache_max_age_days,
266+
)
267+
255268

256269
def CovidcastEpidata(
257-
base_url: str = BASE_URL,
258-
session: Optional[Session] = None,
259-
use_cache: Optional[bool] = None,
260-
cache_max_age_days: Optional[int] = None,
261-
) -> CovidcastDataSources[EpiDataCall]:
270+
base_url: str = BASE_URL,
271+
session: Optional[Session] = None,
272+
use_cache: Optional[bool] = None,
273+
cache_max_age_days: Optional[int] = None,
274+
) -> CovidcastDataSources[EpiDataCall]:
262275
url = add_endpoint_to_url(base_url, "covidcast/meta")
263276
meta_data_res = _request_with_retry(url, {}, session, False)
264277
meta_data_res.raise_for_status()
@@ -267,6 +280,14 @@ def CovidcastEpidata(
267280
def create_call(
268281
params: Mapping[str, Optional[EpiRangeParam]],
269282
) -> EpiDataCall:
270-
return EpiDataCall(base_url, session, "covidcast", params, define_covidcast_fields(), use_cache=use_cache, cache_max_age_days=cache_max_age_days)
283+
return EpiDataCall(
284+
base_url,
285+
session,
286+
"covidcast",
287+
params,
288+
define_covidcast_fields(),
289+
use_cache=use_cache,
290+
cache_max_age_days=cache_max_age_days,
291+
)
271292

272293
return CovidcastDataSources.create(meta_data, create_call)

tests/test_epidata_calls.py

+34-14
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import pytest
1010

11-
from epidatpy.request import EpiDataContext, EpiRange
11+
from epidatpy import EpiDataContext, EpiRange
1212

1313
auth = os.environ.get("DELPHI_EPIDATA_KEY", "")
1414
secret_cdc = os.environ.get("SECRET_API_AUTH_CDC", "")
@@ -63,7 +63,9 @@ def test_pub_covid_hosp_facility_lookup(self) -> None:
6363

6464
@pytest.mark.filterwarnings("ignore:`collection_weeks` is in week format")
6565
def test_pub_covid_hosp_facility(self) -> None:
66-
apicall = EpiDataContext().pub_covid_hosp_facility(hospital_pks="100075", collection_weeks=EpiRange(20200101, 20200501))
66+
apicall = EpiDataContext().pub_covid_hosp_facility(
67+
hospital_pks="100075", collection_weeks=EpiRange(20200101, 20200501)
68+
)
6769
data = apicall.df()
6870
assert len(data) > 0
6971
assert str(data["hospital_pk"].dtype) == "string"
@@ -79,7 +81,9 @@ def test_pub_covid_hosp_facility(self) -> None:
7981
assert str(data["collection_week"].dtype) == "datetime64[ns]"
8082
assert str(data["is_metro_micro"].dtype) == "bool"
8183

82-
apicall2 = EpiDataContext().pub_covid_hosp_facility(hospital_pks="100075", collection_weeks=EpiRange(202001, 202030))
84+
apicall2 = EpiDataContext().pub_covid_hosp_facility(
85+
hospital_pks="100075", collection_weeks=EpiRange(202001, 202030)
86+
)
8387
data2 = apicall2.df()
8488
assert len(data2) > 0
8589

@@ -107,7 +111,7 @@ def test_pub_covidcast_meta(self) -> None:
107111
assert str(data["mean_value"].dtype) == "Float64"
108112
assert str(data["stdev_value"].dtype) == "Float64"
109113
assert str(data["last_update"].dtype) == "Int64"
110-
assert str(data["max_issue"].dtype) == "string"
114+
assert str(data["max_issue"].dtype) == "datetime64[ns]"
111115
assert str(data["min_lag"].dtype) == "Int64"
112116
assert str(data["max_lag"].dtype) == "Int64"
113117

@@ -167,7 +171,10 @@ def test_pub_dengue_nowcast(self) -> None:
167171
@pytest.mark.skipif(not secret_sensors, reason="Dengue sensors key not available.")
168172
def test_pvt_dengue_sensors(self) -> None:
169173
apicall = EpiDataContext().pvt_dengue_sensors(
170-
auth=secret_sensors, names="ght", locations="ag", epiweeks=EpiRange(201501, 202001)
174+
auth=secret_sensors,
175+
names="ght",
176+
locations="ag",
177+
epiweeks=EpiRange(201501, 202001),
171178
)
172179
data = apicall.df()
173180

@@ -225,7 +232,7 @@ def test_pub_fluview_meta(self) -> None:
225232

226233
assert len(data) > 0
227234
assert str(data["latest_update"].dtype) == "datetime64[ns]"
228-
assert str(data["latest_issue"].dtype) == "datetime64[ns]"
235+
assert str(data["latest_issue"].dtype) == "string"
229236
assert str(data["table_rows"].dtype) == "Int64"
230237

231238
def test_pub_fluview(self) -> None:
@@ -255,7 +262,10 @@ def test_pub_gft(self) -> None:
255262
@pytest.mark.skipif(not secret_ght, reason="GHT key not available.")
256263
def test_pvt_ght(self) -> None:
257264
apicall = EpiDataContext().pvt_ght(
258-
auth=secret_ght, locations="ma", epiweeks=EpiRange(199301, 202304), query="how to get over the flu"
265+
auth=secret_ght,
266+
locations="ma",
267+
epiweeks=EpiRange(199301, 202304),
268+
query="how to get over the flu",
259269
)
260270
data = apicall.df()
261271

@@ -315,10 +325,10 @@ def test_pvt_norostat(self) -> None:
315325
data = apicall.df()
316326

317327
# TODO: Need a non-trivial query for Norostat
318-
assert len(data) > 0
319-
assert str(data["release_date"].dtype) == "datetime64[ns]"
320-
assert str(data["epiweek"].dtype) == "string"
321-
assert str(data["value"].dtype) == "Int64"
328+
# assert len(data) > 0
329+
# assert str(data["release_date"].dtype) == "datetime64[ns]"
330+
# assert str(data["epiweek"].dtype) == "string"
331+
# assert str(data["value"].dtype) == "Int64"
322332

323333
def test_pub_nowcast(self) -> None:
324334
apicall = EpiDataContext().pub_nowcast(locations="ca", epiweeks=EpiRange(201201, 201301))
@@ -360,7 +370,10 @@ def test_pvt_quidel(self) -> None:
360370
@pytest.mark.skipif(not secret_sensors, reason="Sensors key not available.")
361371
def test_pvt_sensors(self) -> None:
362372
apicall = EpiDataContext().pvt_sensors(
363-
auth=secret_sensors, names="sar3", locations="nat", epiweeks=EpiRange(201501, 202001)
373+
auth=secret_sensors,
374+
names="sar3",
375+
locations="nat",
376+
epiweeks=EpiRange(201501, 202001),
364377
)
365378
data = apicall.df()
366379

@@ -373,7 +386,10 @@ def test_pvt_sensors(self) -> None:
373386
@pytest.mark.skipif(not secret_twitter, reason="Twitter key not available.")
374387
def test_pvt_twitter(self) -> None:
375388
apicall = EpiDataContext().pvt_twitter(
376-
auth=secret_twitter, locations="CA", time_type="week", time_values=EpiRange(201501, 202001)
389+
auth=secret_twitter,
390+
locations="CA",
391+
time_type="week",
392+
time_values=EpiRange(201501, 202001),
377393
)
378394
data = apicall.df()
379395

@@ -385,7 +401,11 @@ def test_pvt_twitter(self) -> None:
385401
assert str(data["percent"].dtype) == "Float64"
386402

387403
def test_pub_wiki(self) -> None:
388-
apicall = EpiDataContext().pub_wiki(articles="avian_influenza", time_type="week", time_values=EpiRange(201501, 201601))
404+
apicall = EpiDataContext().pub_wiki(
405+
articles="avian_influenza",
406+
time_type="week",
407+
time_values=EpiRange(201501, 201601),
408+
)
389409
data = apicall.df()
390410

391411
assert len(data) > 0

0 commit comments

Comments
 (0)