13
13
from appdirs import user_cache_dir
14
14
from diskcache import Cache
15
15
from pandas import CategoricalDtype , DataFrame , Series , to_datetime
16
+ from os import environ
16
17
from requests import Response , Session
17
18
from requests .auth import HTTPBasicAuth
18
19
from tenacity import retry , stop_after_attempt
37
38
__all__ = ["Epidata" , "EpiDataCall" , "EpiDataContext" , "EpiRange" , "CovidcastEpidata" ]
38
39
CACHE_DIRECTORY = user_cache_dir (appname = "epidatpy" , appauthor = "delphi" )
39
40
41
+ if environ .get ("USE_EPIDATPY_CACHE" , None ):
42
+ print (f"diskcache is being used (unset USE_EPIDATPY_CACHE if not intended). "
43
+ f"The cache directory is { CACHE_DIRECTORY } . "
44
+ f"The TTL is set to { environ .get ("EPIDATPY_CACHE_MAX_AGE_DAYS" , "7" )} days." )
45
+
40
46
@retry (reraise = True , stop = stop_after_attempt (2 ))
41
47
def _request_with_retry (
42
48
url : str ,
@@ -75,9 +81,10 @@ def __init__(
75
81
params : Mapping [str , Optional [EpiRangeParam ]],
76
82
meta : Optional [Sequence [EpidataFieldInfo ]] = None ,
77
83
only_supports_classic : bool = False ,
78
- use_cache = None ,
84
+ use_cache : Optional [bool ] = None ,
85
+ cache_max_age_days : Optional [int ] = None ,
79
86
) -> None :
80
- super ().__init__ (base_url , endpoint , params , meta , only_supports_classic , use_cache )
87
+ super ().__init__ (base_url , endpoint , params , meta , only_supports_classic , use_cache , cache_max_age_days )
81
88
self ._session = session
82
89
83
90
def with_base_url (self , base_url : str ) -> "EpiDataCall" :
@@ -94,6 +101,12 @@ def _call(
94
101
url , params = self .request_arguments (fields )
95
102
return _request_with_retry (url , params , self ._session , stream )
96
103
104
+ def _get_cache_key (self , method ) -> str :
105
+ cache_key = f"{ self ._endpoint } | { method } "
106
+ if self ._params :
107
+ cache_key += f" | { str (dict (sorted (self ._params .items ())))} "
108
+ return cache_key
109
+
97
110
def classic (
98
111
self ,
99
112
fields : Optional [Sequence [str ]] = None ,
@@ -105,7 +118,7 @@ def classic(
105
118
try :
106
119
if self .use_cache :
107
120
with Cache (CACHE_DIRECTORY ) as cache :
108
- cache_key = str ( self ._endpoint ) + str ( self . _params )
121
+ cache_key = self ._get_cache_key ( "classic" )
109
122
if cache_key in cache :
110
123
return cache [cache_key ]
111
124
response = self ._call (fields )
@@ -117,9 +130,8 @@ def classic(
117
130
r ["epidata" ] = [self ._parse_row (row , disable_date_parsing = disable_date_parsing ) for row in epidata ]
118
131
if self .use_cache :
119
132
with Cache (CACHE_DIRECTORY ) as cache :
120
- cache_key = str (self ._endpoint ) + str (self ._params )
121
- # Set TTL to 7 days (TODO: configurable?)
122
- cache .set (cache_key , r , expire = 7 * 24 * 60 * 60 )
133
+ cache_key = self ._get_cache_key ("classic" )
134
+ cache .set (cache_key , r , expire = self .cache_max_age_days * 24 * 60 * 60 )
123
135
return r
124
136
except Exception as e : # pylint: disable=broad-except
125
137
return {"result" : 0 , "message" : f"error: { e } " , "epidata" : []}
@@ -146,7 +158,7 @@ def df(
146
158
147
159
if self .use_cache :
148
160
with Cache (CACHE_DIRECTORY ) as cache :
149
- cache_key = str ( self ._endpoint ) + str ( self . _params )
161
+ cache_key = self ._get_cache_key ( "df" )
150
162
if cache_key in cache :
151
163
return cache [cache_key ]
152
164
@@ -184,7 +196,7 @@ def df(
184
196
df = df .astype (data_types )
185
197
if not disable_date_parsing :
186
198
for info in time_fields :
187
- if info .type == EpidataFieldType .epiweek :
199
+ if info .type == EpidataFieldType .epiweek or info . type == EpidataFieldType . date_or_epiweek :
188
200
continue
189
201
try :
190
202
df [info .name ] = to_datetime (df [info .name ], format = "%Y-%m-%d" )
@@ -198,9 +210,8 @@ def df(
198
210
199
211
if self .use_cache :
200
212
with Cache (CACHE_DIRECTORY ) as cache :
201
- cache_key = str (self ._endpoint ) + str (self ._params )
202
- # Set TTL to 7 days (TODO: configurable?)
203
- cache .set (cache_key , df , expire = 7 * 24 * 60 * 60 )
213
+ cache_key = self ._get_cache_key ("df" )
214
+ cache .set (cache_key , df , expire = self .cache_max_age_days * 24 * 60 * 60 )
204
215
205
216
return df
206
217
@@ -213,10 +224,18 @@ class EpiDataContext(AEpiDataEndpoints[EpiDataCall]):
213
224
_base_url : Final [str ]
214
225
_session : Final [Optional [Session ]]
215
226
216
- def __init__ (self , base_url : str = BASE_URL , session : Optional [Session ] = None ) -> None :
227
+ def __init__ (
228
+ self ,
229
+ base_url : str = BASE_URL ,
230
+ session : Optional [Session ] = None ,
231
+ use_cache : Optional [bool ] = None ,
232
+ cache_max_age_days : Optional [int ] = None ,
233
+ ) -> None :
217
234
super ().__init__ ()
218
235
self ._base_url = base_url
219
236
self ._session = session
237
+ self .use_cache = use_cache
238
+ self .cache_max_age_days = cache_max_age_days
220
239
221
240
def with_base_url (self , base_url : str ) -> "EpiDataContext" :
222
241
return EpiDataContext (base_url , self ._session )
@@ -230,15 +249,16 @@ def _create_call(
230
249
params : Mapping [str , Optional [EpiRangeParam ]],
231
250
meta : Optional [Sequence [EpidataFieldInfo ]] = None ,
232
251
only_supports_classic : bool = False ,
233
- use_cache : bool = False ,
234
- ) -> EpiDataCall :
235
- return EpiDataCall (self ._base_url , self ._session , endpoint , params , meta , only_supports_classic , use_cache )
236
-
237
-
238
- Epidata = EpiDataContext ()
239
252
240
-
241
- def CovidcastEpidata (base_url : str = BASE_URL , session : Optional [Session ] = None ) -> CovidcastDataSources [EpiDataCall ]:
253
+ ) -> EpiDataCall :
254
+ return EpiDataCall (self ._base_url , self ._session , endpoint , params , meta , only_supports_classic , self .use_cache , self .cache_max_age_days )
255
+
256
+ def CovidcastEpidata (
257
+ base_url : str = BASE_URL ,
258
+ session : Optional [Session ] = None ,
259
+ use_cache : Optional [bool ] = None ,
260
+ cache_max_age_days : Optional [int ] = None ,
261
+ ) -> CovidcastDataSources [EpiDataCall ]:
242
262
url = add_endpoint_to_url (base_url , "covidcast/meta" )
243
263
meta_data_res = _request_with_retry (url , {}, session , False )
244
264
meta_data_res .raise_for_status ()
@@ -247,6 +267,6 @@ def CovidcastEpidata(base_url: str = BASE_URL, session: Optional[Session] = None
247
267
def create_call (
248
268
params : Mapping [str , Optional [EpiRangeParam ]],
249
269
) -> EpiDataCall :
250
- return EpiDataCall (base_url , session , "covidcast" , params , define_covidcast_fields ())
270
+ return EpiDataCall (base_url , session , "covidcast" , params , define_covidcast_fields (), use_cache = use_cache , cache_max_age_days = cache_max_age_days )
251
271
252
272
return CovidcastDataSources .create (meta_data , create_call )
0 commit comments