Skip to content

Commit 112e178

Browse files
authored
Merge pull request #37 from cmu-delphi/rzatserkovnyi/caching
Add caching
2 parents 0d519b7 + d3cb7bc commit 112e178

File tree

6 files changed

+179
-69
lines changed

6 files changed

+179
-69
lines changed

epidatpy/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""Fetch data from Delphi's API."""
22

33
# Make the linter happy about the unused variables
4-
__all__ = ["__version__", "Epidata", "CovidcastEpidata", "EpiRange"]
4+
__all__ = ["__version__", "EpiDataContext", "CovidcastEpidata", "EpiRange"]
55
__author__ = "Delphi Research Group"
66

77

88
from ._constants import __version__
9-
from .request import CovidcastEpidata, Epidata, EpiRange
9+
from ._model import EpiRange
10+
from .request import CovidcastEpidata, EpiDataContext

epidatpy/_model.py

+17
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from dataclasses import dataclass, field
22
from datetime import date
33
from enum import Enum
4+
from os import environ
45
from typing import (
56
Final,
67
List,
@@ -146,6 +147,7 @@ class AEpiDataCall:
146147
meta: Final[Sequence[EpidataFieldInfo]]
147148
meta_by_name: Final[Mapping[str, EpidataFieldInfo]]
148149
only_supports_classic: Final[bool]
150+
use_cache: Final[bool]
149151

150152
def __init__(
151153
self,
@@ -154,13 +156,28 @@ def __init__(
154156
params: Mapping[str, Optional[EpiRangeParam]],
155157
meta: Optional[Sequence[EpidataFieldInfo]] = None,
156158
only_supports_classic: bool = False,
159+
use_cache: Optional[bool] = None,
160+
cache_max_age_days: Optional[int] = None,
157161
) -> None:
158162
self._base_url = base_url
159163
self._endpoint = endpoint
160164
self._params = params
161165
self.only_supports_classic = only_supports_classic
162166
self.meta = meta or []
163167
self.meta_by_name = {k.name: k for k in self.meta}
168+
# Set the use_cache value from the constructor if present.
169+
# Otherwise check the USE_EPIDATPY_CACHE variable, accepting various "truthy" values.
170+
self.use_cache = use_cache if use_cache is not None \
171+
else (environ.get("USE_EPIDATPY_CACHE", "").lower() in ['true', 't', '1'])
172+
# Set cache_max_age_days from the constructor, fall back to environment variable.
173+
if cache_max_age_days:
174+
self.cache_max_age_days = cache_max_age_days
175+
else:
176+
env_days = environ.get("EPIDATPY_CACHE_MAX_AGE_DAYS", "7")
177+
if env_days.isdigit():
178+
self.cache_max_age_days = int(env_days)
179+
else: # handle string / negative / invalid enviromment variable
180+
self.cache_max_age_days = 7
164181

165182
def _verify_parameters(self) -> None:
166183
# hook for verifying parameters before sending

epidatpy/request.py

+89-20
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from os import environ
12
from typing import (
23
Any,
34
Dict,
@@ -10,6 +11,8 @@
1011
cast,
1112
)
1213

14+
from appdirs import user_cache_dir
15+
from diskcache import Cache
1316
from pandas import CategoricalDtype, DataFrame, Series, to_datetime
1417
from requests import Response, Session
1518
from requests.auth import HTTPBasicAuth
@@ -24,15 +27,21 @@
2427
EpidataFieldInfo,
2528
EpidataFieldType,
2629
EpiDataResponse,
27-
EpiRange,
2830
EpiRangeParam,
2931
OnlySupportsClassicFormatException,
3032
add_endpoint_to_url,
3133
)
3234
from ._parse import fields_to_predicate
3335

3436
# Make the linter happy about the unused variables
35-
__all__ = ["Epidata", "EpiDataCall", "EpiDataContext", "EpiRange", "CovidcastEpidata"]
37+
CACHE_DIRECTORY = user_cache_dir(appname="epidatpy", appauthor="delphi")
38+
39+
if environ.get("USE_EPIDATPY_CACHE", None):
40+
print(
41+
f"diskcache is being used (unset USE_EPIDATPY_CACHE if not intended). "
42+
f"The cache directory is {CACHE_DIRECTORY}. "
43+
f"The TTL is set to {environ.get('EPIDATPY_CACHE_MAX_AGE_DAYS', '7')} days."
44+
)
3645

3746

3847
@retry(reraise=True, stop=stop_after_attempt(2))
@@ -59,9 +68,7 @@ def call_impl(s: Session) -> Response:
5968

6069

6170
class EpiDataCall(AEpiDataCall):
62-
"""
63-
epidata call representation
64-
"""
71+
"""epidata call representation"""
6572

6673
_session: Final[Optional[Session]]
6774

@@ -73,8 +80,10 @@ def __init__(
7380
params: Mapping[str, Optional[EpiRangeParam]],
7481
meta: Optional[Sequence[EpidataFieldInfo]] = None,
7582
only_supports_classic: bool = False,
83+
use_cache: Optional[bool] = None,
84+
cache_max_age_days: Optional[int] = None,
7685
) -> None:
77-
super().__init__(base_url, endpoint, params, meta, only_supports_classic)
86+
super().__init__(base_url, endpoint, params, meta, only_supports_classic, use_cache, cache_max_age_days)
7887
self._session = session
7988

8089
def with_base_url(self, base_url: str) -> "EpiDataCall":
@@ -91,6 +100,12 @@ def _call(
91100
url, params = self.request_arguments(fields)
92101
return _request_with_retry(url, params, self._session, stream)
93102

103+
def _get_cache_key(self, method: str) -> str:
104+
cache_key = f"{self._endpoint} | {method}"
105+
if self._params:
106+
cache_key += f" | {str(dict(sorted(self._params.items())))}"
107+
return cache_key
108+
94109
def classic(
95110
self,
96111
fields: Optional[Sequence[str]] = None,
@@ -100,13 +115,22 @@ def classic(
100115
"""Request and parse epidata in CLASSIC message format."""
101116
self._verify_parameters()
102117
try:
118+
if self.use_cache:
119+
with Cache(CACHE_DIRECTORY) as cache:
120+
cache_key = self._get_cache_key("classic")
121+
if cache_key in cache:
122+
return cast(EpiDataResponse, cache[cache_key])
103123
response = self._call(fields)
104124
r = cast(EpiDataResponse, response.json())
105125
if disable_type_parsing:
106126
return r
107127
epidata = r.get("epidata")
108128
if epidata and isinstance(epidata, list) and len(epidata) > 0 and isinstance(epidata[0], dict):
109129
r["epidata"] = [self._parse_row(row, disable_date_parsing=disable_date_parsing) for row in epidata]
130+
if self.use_cache:
131+
with Cache(CACHE_DIRECTORY) as cache:
132+
cache_key = self._get_cache_key("classic")
133+
cache.set(cache_key, r, expire=self.cache_max_age_days * 24 * 60 * 60)
110134
return r
111135
except Exception as e: # pylint: disable=broad-except
112136
return {"result": 0, "message": f"error: {e}", "epidata": []}
@@ -118,7 +142,11 @@ def __call__(
118142
) -> Union[EpiDataResponse, DataFrame]:
119143
"""Request and parse epidata in df message format."""
120144
if self.only_supports_classic:
121-
return self.classic(fields, disable_date_parsing=disable_date_parsing, disable_type_parsing=False)
145+
return self.classic(
146+
fields,
147+
disable_date_parsing=disable_date_parsing,
148+
disable_type_parsing=False,
149+
)
122150
return self.df(fields, disable_date_parsing=disable_date_parsing)
123151

124152
def df(
@@ -130,6 +158,13 @@ def df(
130158
if self.only_supports_classic:
131159
raise OnlySupportsClassicFormatException()
132160
self._verify_parameters()
161+
162+
if self.use_cache:
163+
with Cache(CACHE_DIRECTORY) as cache:
164+
cache_key = self._get_cache_key("df")
165+
if cache_key in cache:
166+
return cast(DataFrame, cache[cache_key])
167+
133168
json = self.classic(fields, disable_type_parsing=True)
134169
rows = json.get("epidata", [])
135170
pred = fields_to_predicate(fields)
@@ -145,7 +180,8 @@ def df(
145180
data_types[info.name] = bool
146181
elif info.type == EpidataFieldType.categorical:
147182
data_types[info.name] = CategoricalDtype(
148-
categories=Series(info.categories) if info.categories else None, ordered=True
183+
categories=Series(info.categories) if info.categories else None,
184+
ordered=True,
149185
)
150186
elif info.type == EpidataFieldType.int:
151187
data_types[info.name] = "Int64"
@@ -166,6 +202,8 @@ def df(
166202
for info in time_fields:
167203
if info.type == EpidataFieldType.epiweek:
168204
continue
205+
# Try two date foramts, otherwise keep as string. The try except
206+
# is needed because the time field might be date_or_epiweek.
169207
try:
170208
df[info.name] = to_datetime(df[info.name], format="%Y-%m-%d")
171209
continue
@@ -175,21 +213,33 @@ def df(
175213
df[info.name] = to_datetime(df[info.name], format="%Y%m%d")
176214
except ValueError:
177215
pass
216+
217+
if self.use_cache:
218+
with Cache(CACHE_DIRECTORY) as cache:
219+
cache_key = self._get_cache_key("df")
220+
cache.set(cache_key, df, expire=self.cache_max_age_days * 24 * 60 * 60)
221+
178222
return df
179223

180224

181225
class EpiDataContext(AEpiDataEndpoints[EpiDataCall]):
182-
"""
183-
sync epidata call class
184-
"""
226+
"""sync epidata call class"""
185227

186228
_base_url: Final[str]
187229
_session: Final[Optional[Session]]
188230

189-
def __init__(self, base_url: str = BASE_URL, session: Optional[Session] = None) -> None:
231+
def __init__(
232+
self,
233+
base_url: str = BASE_URL,
234+
session: Optional[Session] = None,
235+
use_cache: Optional[bool] = None,
236+
cache_max_age_days: Optional[int] = None,
237+
) -> None:
190238
super().__init__()
191239
self._base_url = base_url
192240
self._session = session
241+
self.use_cache = use_cache
242+
self.cache_max_age_days = cache_max_age_days
193243

194244
def with_base_url(self, base_url: str) -> "EpiDataContext":
195245
return EpiDataContext(base_url, self._session)
@@ -204,13 +254,24 @@ def _create_call(
204254
meta: Optional[Sequence[EpidataFieldInfo]] = None,
205255
only_supports_classic: bool = False,
206256
) -> EpiDataCall:
207-
return EpiDataCall(self._base_url, self._session, endpoint, params, meta, only_supports_classic)
208-
209-
210-
Epidata = EpiDataContext()
211-
212-
213-
def CovidcastEpidata(base_url: str = BASE_URL, session: Optional[Session] = None) -> CovidcastDataSources[EpiDataCall]:
257+
return EpiDataCall(
258+
self._base_url,
259+
self._session,
260+
endpoint,
261+
params,
262+
meta,
263+
only_supports_classic,
264+
self.use_cache,
265+
self.cache_max_age_days,
266+
)
267+
268+
269+
def CovidcastEpidata(
270+
base_url: str = BASE_URL,
271+
session: Optional[Session] = None,
272+
use_cache: Optional[bool] = None,
273+
cache_max_age_days: Optional[int] = None,
274+
) -> CovidcastDataSources[EpiDataCall]:
214275
url = add_endpoint_to_url(base_url, "covidcast/meta")
215276
meta_data_res = _request_with_retry(url, {}, session, False)
216277
meta_data_res.raise_for_status()
@@ -219,6 +280,14 @@ def CovidcastEpidata(base_url: str = BASE_URL, session: Optional[Session] = None
219280
def create_call(
220281
params: Mapping[str, Optional[EpiRangeParam]],
221282
) -> EpiDataCall:
222-
return EpiDataCall(base_url, session, "covidcast", params, define_covidcast_fields())
283+
return EpiDataCall(
284+
base_url,
285+
session,
286+
"covidcast",
287+
params,
288+
define_covidcast_fields(),
289+
use_cache=use_cache,
290+
cache_max_age_days=cache_max_age_days,
291+
)
223292

224293
return CovidcastDataSources.create(meta_data, create_call)

pyproject.toml

+2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ classifiers = [
3131
requires-python = ">=3.8"
3232
dependencies = [
3333
"aiohttp",
34+
"appdirs",
35+
"diskcache",
3436
"epiweeks>=2.1",
3537
"pandas>=1",
3638
"requests>=2.25",

smoke_test.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from datetime import date
22

3-
from epidatpy import CovidcastEpidata, Epidata, EpiRange
3+
from epidatpy import CovidcastEpidata, EpiDataContext, EpiRange
44

55
print("Epidata Test")
6-
apicall = Epidata.pub_covidcast("fb-survey", "smoothed_cli", "nation", "day", "us", EpiRange(20210405, 20210410))
6+
epidata = EpiDataContext(use_cache=True, cache_max_age_days=1)
7+
apicall = epidata.pub_covidcast("fb-survey", "smoothed_cli", "nation", "day", "us", EpiRange(20210405, 20210410))
78

89
# Call info
910
print(apicall)
@@ -27,17 +28,17 @@
2728
print(df.iloc[0])
2829

2930

30-
StagingEpidata = Epidata.with_base_url("https://staging.delphi.cmu.edu/epidata/")
31+
staging_epidata = epidata.with_base_url("https://staging.delphi.cmu.edu/epidata/")
3132

32-
epicall = StagingEpidata.pub_covidcast(
33+
epicall = staging_epidata.pub_covidcast(
3334
"fb-survey", "smoothed_cli", "nation", "day", "*", EpiRange(date(2021, 4, 5), date(2021, 4, 10))
3435
)
3536
print(epicall._base_url)
3637

3738

3839
# Covidcast test
3940
print("Covidcast Test")
40-
epidata = CovidcastEpidata()
41+
epidata = CovidcastEpidata(use_cache=True, cache_max_age_days=1)
4142
print(epidata.source_names())
4243
print(epidata.signal_names("fb-survey"))
4344
epidata["fb-survey"].signal_df

0 commit comments

Comments
 (0)