1
+ from os import environ
1
2
from typing import (
2
3
Any ,
3
4
Dict ,
10
11
cast ,
11
12
)
12
13
14
+ from appdirs import user_cache_dir
15
+ from diskcache import Cache
13
16
from pandas import CategoricalDtype , DataFrame , Series , to_datetime
14
17
from requests import Response , Session
15
18
from requests .auth import HTTPBasicAuth
24
27
EpidataFieldInfo ,
25
28
EpidataFieldType ,
26
29
EpiDataResponse ,
27
- EpiRange ,
28
30
EpiRangeParam ,
29
31
OnlySupportsClassicFormatException ,
30
32
add_endpoint_to_url ,
31
33
)
32
34
from ._parse import fields_to_predicate
33
35
34
36
# Make the linter happy about the unused variables
35
- __all__ = ["Epidata" , "EpiDataCall" , "EpiDataContext" , "EpiRange" , "CovidcastEpidata" ]
37
+ CACHE_DIRECTORY = user_cache_dir (appname = "epidatpy" , appauthor = "delphi" )
38
+
39
+ if environ .get ("USE_EPIDATPY_CACHE" , None ):
40
+ print (
41
+ f"diskcache is being used (unset USE_EPIDATPY_CACHE if not intended). "
42
+ f"The cache directory is { CACHE_DIRECTORY } . "
43
+ f"The TTL is set to { environ .get ('EPIDATPY_CACHE_MAX_AGE_DAYS' , '7' )} days."
44
+ )
36
45
37
46
38
47
@retry (reraise = True , stop = stop_after_attempt (2 ))
@@ -59,9 +68,7 @@ def call_impl(s: Session) -> Response:
59
68
60
69
61
70
class EpiDataCall (AEpiDataCall ):
62
- """
63
- epidata call representation
64
- """
71
+ """epidata call representation"""
65
72
66
73
_session : Final [Optional [Session ]]
67
74
@@ -73,8 +80,10 @@ def __init__(
73
80
params : Mapping [str , Optional [EpiRangeParam ]],
74
81
meta : Optional [Sequence [EpidataFieldInfo ]] = None ,
75
82
only_supports_classic : bool = False ,
83
+ use_cache : Optional [bool ] = None ,
84
+ cache_max_age_days : Optional [int ] = None ,
76
85
) -> None :
77
- super ().__init__ (base_url , endpoint , params , meta , only_supports_classic )
86
+ super ().__init__ (base_url , endpoint , params , meta , only_supports_classic , use_cache , cache_max_age_days )
78
87
self ._session = session
79
88
80
89
def with_base_url (self , base_url : str ) -> "EpiDataCall" :
@@ -91,6 +100,12 @@ def _call(
91
100
url , params = self .request_arguments (fields )
92
101
return _request_with_retry (url , params , self ._session , stream )
93
102
103
+ def _get_cache_key (self , method : str ) -> str :
104
+ cache_key = f"{ self ._endpoint } | { method } "
105
+ if self ._params :
106
+ cache_key += f" | { str (dict (sorted (self ._params .items ())))} "
107
+ return cache_key
108
+
94
109
def classic (
95
110
self ,
96
111
fields : Optional [Sequence [str ]] = None ,
@@ -100,13 +115,22 @@ def classic(
100
115
"""Request and parse epidata in CLASSIC message format."""
101
116
self ._verify_parameters ()
102
117
try :
118
+ if self .use_cache :
119
+ with Cache (CACHE_DIRECTORY ) as cache :
120
+ cache_key = self ._get_cache_key ("classic" )
121
+ if cache_key in cache :
122
+ return cast (EpiDataResponse , cache [cache_key ])
103
123
response = self ._call (fields )
104
124
r = cast (EpiDataResponse , response .json ())
105
125
if disable_type_parsing :
106
126
return r
107
127
epidata = r .get ("epidata" )
108
128
if epidata and isinstance (epidata , list ) and len (epidata ) > 0 and isinstance (epidata [0 ], dict ):
109
129
r ["epidata" ] = [self ._parse_row (row , disable_date_parsing = disable_date_parsing ) for row in epidata ]
130
+ if self .use_cache :
131
+ with Cache (CACHE_DIRECTORY ) as cache :
132
+ cache_key = self ._get_cache_key ("classic" )
133
+ cache .set (cache_key , r , expire = self .cache_max_age_days * 24 * 60 * 60 )
110
134
return r
111
135
except Exception as e : # pylint: disable=broad-except
112
136
return {"result" : 0 , "message" : f"error: { e } " , "epidata" : []}
@@ -118,7 +142,11 @@ def __call__(
118
142
) -> Union [EpiDataResponse , DataFrame ]:
119
143
"""Request and parse epidata in df message format."""
120
144
if self .only_supports_classic :
121
- return self .classic (fields , disable_date_parsing = disable_date_parsing , disable_type_parsing = False )
145
+ return self .classic (
146
+ fields ,
147
+ disable_date_parsing = disable_date_parsing ,
148
+ disable_type_parsing = False ,
149
+ )
122
150
return self .df (fields , disable_date_parsing = disable_date_parsing )
123
151
124
152
def df (
@@ -130,6 +158,13 @@ def df(
130
158
if self .only_supports_classic :
131
159
raise OnlySupportsClassicFormatException ()
132
160
self ._verify_parameters ()
161
+
162
+ if self .use_cache :
163
+ with Cache (CACHE_DIRECTORY ) as cache :
164
+ cache_key = self ._get_cache_key ("df" )
165
+ if cache_key in cache :
166
+ return cast (DataFrame , cache [cache_key ])
167
+
133
168
json = self .classic (fields , disable_type_parsing = True )
134
169
rows = json .get ("epidata" , [])
135
170
pred = fields_to_predicate (fields )
@@ -145,7 +180,8 @@ def df(
145
180
data_types [info .name ] = bool
146
181
elif info .type == EpidataFieldType .categorical :
147
182
data_types [info .name ] = CategoricalDtype (
148
- categories = Series (info .categories ) if info .categories else None , ordered = True
183
+ categories = Series (info .categories ) if info .categories else None ,
184
+ ordered = True ,
149
185
)
150
186
elif info .type == EpidataFieldType .int :
151
187
data_types [info .name ] = "Int64"
@@ -166,6 +202,8 @@ def df(
166
202
for info in time_fields :
167
203
if info .type == EpidataFieldType .epiweek :
168
204
continue
205
+ # Try two date foramts, otherwise keep as string. The try except
206
+ # is needed because the time field might be date_or_epiweek.
169
207
try :
170
208
df [info .name ] = to_datetime (df [info .name ], format = "%Y-%m-%d" )
171
209
continue
@@ -175,21 +213,33 @@ def df(
175
213
df [info .name ] = to_datetime (df [info .name ], format = "%Y%m%d" )
176
214
except ValueError :
177
215
pass
216
+
217
+ if self .use_cache :
218
+ with Cache (CACHE_DIRECTORY ) as cache :
219
+ cache_key = self ._get_cache_key ("df" )
220
+ cache .set (cache_key , df , expire = self .cache_max_age_days * 24 * 60 * 60 )
221
+
178
222
return df
179
223
180
224
181
225
class EpiDataContext (AEpiDataEndpoints [EpiDataCall ]):
182
- """
183
- sync epidata call class
184
- """
226
+ """sync epidata call class"""
185
227
186
228
_base_url : Final [str ]
187
229
_session : Final [Optional [Session ]]
188
230
189
- def __init__ (self , base_url : str = BASE_URL , session : Optional [Session ] = None ) -> None :
231
+ def __init__ (
232
+ self ,
233
+ base_url : str = BASE_URL ,
234
+ session : Optional [Session ] = None ,
235
+ use_cache : Optional [bool ] = None ,
236
+ cache_max_age_days : Optional [int ] = None ,
237
+ ) -> None :
190
238
super ().__init__ ()
191
239
self ._base_url = base_url
192
240
self ._session = session
241
+ self .use_cache = use_cache
242
+ self .cache_max_age_days = cache_max_age_days
193
243
194
244
def with_base_url (self , base_url : str ) -> "EpiDataContext" :
195
245
return EpiDataContext (base_url , self ._session )
@@ -204,13 +254,24 @@ def _create_call(
204
254
meta : Optional [Sequence [EpidataFieldInfo ]] = None ,
205
255
only_supports_classic : bool = False ,
206
256
) -> EpiDataCall :
207
- return EpiDataCall (self ._base_url , self ._session , endpoint , params , meta , only_supports_classic )
208
-
209
-
210
- Epidata = EpiDataContext ()
211
-
212
-
213
- def CovidcastEpidata (base_url : str = BASE_URL , session : Optional [Session ] = None ) -> CovidcastDataSources [EpiDataCall ]:
257
+ return EpiDataCall (
258
+ self ._base_url ,
259
+ self ._session ,
260
+ endpoint ,
261
+ params ,
262
+ meta ,
263
+ only_supports_classic ,
264
+ self .use_cache ,
265
+ self .cache_max_age_days ,
266
+ )
267
+
268
+
269
+ def CovidcastEpidata (
270
+ base_url : str = BASE_URL ,
271
+ session : Optional [Session ] = None ,
272
+ use_cache : Optional [bool ] = None ,
273
+ cache_max_age_days : Optional [int ] = None ,
274
+ ) -> CovidcastDataSources [EpiDataCall ]:
214
275
url = add_endpoint_to_url (base_url , "covidcast/meta" )
215
276
meta_data_res = _request_with_retry (url , {}, session , False )
216
277
meta_data_res .raise_for_status ()
@@ -219,6 +280,14 @@ def CovidcastEpidata(base_url: str = BASE_URL, session: Optional[Session] = None
219
280
def create_call (
220
281
params : Mapping [str , Optional [EpiRangeParam ]],
221
282
) -> EpiDataCall :
222
- return EpiDataCall (base_url , session , "covidcast" , params , define_covidcast_fields ())
283
+ return EpiDataCall (
284
+ base_url ,
285
+ session ,
286
+ "covidcast" ,
287
+ params ,
288
+ define_covidcast_fields (),
289
+ use_cache = use_cache ,
290
+ cache_max_age_days = cache_max_age_days ,
291
+ )
223
292
224
293
return CovidcastDataSources .create (meta_data , create_call )
0 commit comments