Skip to content

Add caching #37

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions epidatpy/_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass, field
from datetime import date
from enum import Enum
from os import environ
from typing import (
Final,
List,
Expand Down Expand Up @@ -146,6 +147,7 @@ class AEpiDataCall:
meta: Final[Sequence[EpidataFieldInfo]]
meta_by_name: Final[Mapping[str, EpidataFieldInfo]]
only_supports_classic: Final[bool]
use_cache: Final[bool]

def __init__(
self,
Expand All @@ -154,13 +156,18 @@ def __init__(
params: Mapping[str, Optional[EpiRangeParam]],
meta: Optional[Sequence[EpidataFieldInfo]] = None,
only_supports_classic: bool = False,
use_cache: Optional[bool] = None,
) -> None:
self._base_url = base_url
self._endpoint = endpoint
self._params = params
self.only_supports_classic = only_supports_classic
self.meta = meta or []
self.meta_by_name = {k.name: k for k in self.meta}
# Set the use_cache value from the constructor if present.
# Otherwise check the USE_EPIDATPY_CACHE variable, accepting various "truthy" values.
self.use_cache = use_cache \
or (environ.get("USE_EPIDATPY_CACHE", "").lower() in ['true', 't', '1'])

def _verify_parameters(self) -> None:
# hook for verifying parameters before sending
Expand Down
34 changes: 31 additions & 3 deletions epidatpy/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
cast,
)

from appdirs import user_cache_dir
from diskcache import Cache
from pandas import CategoricalDtype, DataFrame, Series, to_datetime
from requests import Response, Session
from requests.auth import HTTPBasicAuth
Expand All @@ -33,7 +35,7 @@

# Make the linter happy about the unused variables
__all__ = ["Epidata", "EpiDataCall", "EpiDataContext", "EpiRange", "CovidcastEpidata"]

CACHE_DIRECTORY = user_cache_dir(appname="epidatpy", appauthor="delphi")

@retry(reraise=True, stop=stop_after_attempt(2))
def _request_with_retry(
Expand Down Expand Up @@ -73,8 +75,9 @@ def __init__(
params: Mapping[str, Optional[EpiRangeParam]],
meta: Optional[Sequence[EpidataFieldInfo]] = None,
only_supports_classic: bool = False,
use_cache = None,
) -> None:
super().__init__(base_url, endpoint, params, meta, only_supports_classic)
super().__init__(base_url, endpoint, params, meta, only_supports_classic, use_cache)
self._session = session

def with_base_url(self, base_url: str) -> "EpiDataCall":
Expand All @@ -100,13 +103,23 @@ def classic(
"""Request and parse epidata in CLASSIC message format."""
self._verify_parameters()
try:
if self.use_cache:
with Cache(CACHE_DIRECTORY) as cache:
cache_key = str(self._endpoint) + str(self._params)
if cache_key in cache:
return cache[cache_key]
response = self._call(fields)
r = cast(EpiDataResponse, response.json())
if disable_type_parsing:
return r
epidata = r.get("epidata")
if epidata and isinstance(epidata, list) and len(epidata) > 0 and isinstance(epidata[0], dict):
r["epidata"] = [self._parse_row(row, disable_date_parsing=disable_date_parsing) for row in epidata]
if self.use_cache:
with Cache(CACHE_DIRECTORY) as cache:
cache_key = str(self._endpoint) + str(self._params)
# Set TTL to 7 days (TODO: configurable?)
cache.set(cache_key, r, expire=7*24*60*60)
return r
except Exception as e: # pylint: disable=broad-except
return {"result": 0, "message": f"error: {e}", "epidata": []}
Expand All @@ -130,6 +143,13 @@ def df(
if self.only_supports_classic:
raise OnlySupportsClassicFormatException()
self._verify_parameters()

if self.use_cache:
with Cache(CACHE_DIRECTORY) as cache:
cache_key = str(self._endpoint) + str(self._params)
if cache_key in cache:
return cache[cache_key]

json = self.classic(fields, disable_type_parsing=True)
rows = json.get("epidata", [])
pred = fields_to_predicate(fields)
Expand Down Expand Up @@ -175,6 +195,13 @@ def df(
df[info.name] = to_datetime(df[info.name], format="%Y%m%d")
except ValueError:
pass

if self.use_cache:
with Cache(CACHE_DIRECTORY) as cache:
cache_key = str(self._endpoint) + str(self._params)
# Set TTL to 7 days (TODO: configurable?)
cache.set(cache_key, df, expire=7*24*60*60)

return df


Expand Down Expand Up @@ -203,8 +230,9 @@ def _create_call(
params: Mapping[str, Optional[EpiRangeParam]],
meta: Optional[Sequence[EpidataFieldInfo]] = None,
only_supports_classic: bool = False,
use_cache: bool = False,
) -> EpiDataCall:
return EpiDataCall(self._base_url, self._session, endpoint, params, meta, only_supports_classic)
return EpiDataCall(self._base_url, self._session, endpoint, params, meta, only_supports_classic, use_cache)


Epidata = EpiDataContext()
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ classifiers = [
requires-python = ">=3.8"
dependencies = [
"aiohttp",
"appdirs",
"diskcache",
"epiweeks>=2.1",
"pandas>=1",
"requests>=2.25",
Expand Down
Loading