Skip to content

Add caching #37

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions epidatpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Fetch data from Delphi's API."""

# Make the linter happy about the unused variables
__all__ = ["__version__", "Epidata", "CovidcastEpidata", "EpiRange"]
__all__ = ["__version__", "EpiDataContext", "CovidcastEpidata", "EpiRange"]
__author__ = "Delphi Research Group"


from ._constants import __version__
from .request import CovidcastEpidata, Epidata, EpiRange
from ._model import EpiRange
from .request import CovidcastEpidata, EpiDataContext
17 changes: 17 additions & 0 deletions epidatpy/_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass, field
from datetime import date
from enum import Enum
from os import environ
from typing import (
Final,
List,
Expand Down Expand Up @@ -146,6 +147,7 @@ class AEpiDataCall:
meta: Final[Sequence[EpidataFieldInfo]]
meta_by_name: Final[Mapping[str, EpidataFieldInfo]]
only_supports_classic: Final[bool]
use_cache: Final[bool]

def __init__(
self,
Expand All @@ -154,13 +156,28 @@ def __init__(
params: Mapping[str, Optional[EpiRangeParam]],
meta: Optional[Sequence[EpidataFieldInfo]] = None,
only_supports_classic: bool = False,
use_cache: Optional[bool] = None,
cache_max_age_days: Optional[int] = None,
) -> None:
self._base_url = base_url
self._endpoint = endpoint
self._params = params
self.only_supports_classic = only_supports_classic
self.meta = meta or []
self.meta_by_name = {k.name: k for k in self.meta}
# Set the use_cache value from the constructor if present.
# Otherwise check the USE_EPIDATPY_CACHE variable, accepting various "truthy" values.
self.use_cache = use_cache if use_cache is not None \
else (environ.get("USE_EPIDATPY_CACHE", "").lower() in ['true', 't', '1'])
# Set cache_max_age_days from the constructor, fall back to environment variable.
if cache_max_age_days:
self.cache_max_age_days = cache_max_age_days
else:
env_days = environ.get("EPIDATPY_CACHE_MAX_AGE_DAYS", "7")
if env_days.isdigit():
self.cache_max_age_days = int(env_days)
else: # handle string / negative / invalid enviromment variable
self.cache_max_age_days = 7

def _verify_parameters(self) -> None:
# hook for verifying parameters before sending
Expand Down
109 changes: 89 additions & 20 deletions epidatpy/request.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from os import environ
from typing import (
Any,
Dict,
Expand All @@ -10,6 +11,8 @@
cast,
)

from appdirs import user_cache_dir
from diskcache import Cache
from pandas import CategoricalDtype, DataFrame, Series, to_datetime
from requests import Response, Session
from requests.auth import HTTPBasicAuth
Expand All @@ -24,15 +27,21 @@
EpidataFieldInfo,
EpidataFieldType,
EpiDataResponse,
EpiRange,
EpiRangeParam,
OnlySupportsClassicFormatException,
add_endpoint_to_url,
)
from ._parse import fields_to_predicate

# Make the linter happy about the unused variables
__all__ = ["Epidata", "EpiDataCall", "EpiDataContext", "EpiRange", "CovidcastEpidata"]
CACHE_DIRECTORY = user_cache_dir(appname="epidatpy", appauthor="delphi")

if environ.get("USE_EPIDATPY_CACHE", None):
print(
f"diskcache is being used (unset USE_EPIDATPY_CACHE if not intended). "
f"The cache directory is {CACHE_DIRECTORY}. "
f"The TTL is set to {environ.get('EPIDATPY_CACHE_MAX_AGE_DAYS', '7')} days."
)


@retry(reraise=True, stop=stop_after_attempt(2))
Expand All @@ -59,9 +68,7 @@ def call_impl(s: Session) -> Response:


class EpiDataCall(AEpiDataCall):
"""
epidata call representation
"""
"""epidata call representation"""

_session: Final[Optional[Session]]

Expand All @@ -73,8 +80,10 @@ def __init__(
params: Mapping[str, Optional[EpiRangeParam]],
meta: Optional[Sequence[EpidataFieldInfo]] = None,
only_supports_classic: bool = False,
use_cache: Optional[bool] = None,
cache_max_age_days: Optional[int] = None,
) -> None:
super().__init__(base_url, endpoint, params, meta, only_supports_classic)
super().__init__(base_url, endpoint, params, meta, only_supports_classic, use_cache, cache_max_age_days)
self._session = session

def with_base_url(self, base_url: str) -> "EpiDataCall":
Expand All @@ -91,6 +100,12 @@ def _call(
url, params = self.request_arguments(fields)
return _request_with_retry(url, params, self._session, stream)

def _get_cache_key(self, method: str) -> str:
cache_key = f"{self._endpoint} | {method}"
if self._params:
cache_key += f" | {str(dict(sorted(self._params.items())))}"
return cache_key

def classic(
self,
fields: Optional[Sequence[str]] = None,
Expand All @@ -100,13 +115,22 @@ def classic(
"""Request and parse epidata in CLASSIC message format."""
self._verify_parameters()
try:
if self.use_cache:
with Cache(CACHE_DIRECTORY) as cache:
cache_key = self._get_cache_key("classic")
if cache_key in cache:
return cast(EpiDataResponse, cache[cache_key])
response = self._call(fields)
r = cast(EpiDataResponse, response.json())
if disable_type_parsing:
return r
epidata = r.get("epidata")
if epidata and isinstance(epidata, list) and len(epidata) > 0 and isinstance(epidata[0], dict):
r["epidata"] = [self._parse_row(row, disable_date_parsing=disable_date_parsing) for row in epidata]
if self.use_cache:
with Cache(CACHE_DIRECTORY) as cache:
cache_key = self._get_cache_key("classic")
cache.set(cache_key, r, expire=self.cache_max_age_days * 24 * 60 * 60)
return r
except Exception as e: # pylint: disable=broad-except
return {"result": 0, "message": f"error: {e}", "epidata": []}
Expand All @@ -118,7 +142,11 @@ def __call__(
) -> Union[EpiDataResponse, DataFrame]:
"""Request and parse epidata in df message format."""
if self.only_supports_classic:
return self.classic(fields, disable_date_parsing=disable_date_parsing, disable_type_parsing=False)
return self.classic(
fields,
disable_date_parsing=disable_date_parsing,
disable_type_parsing=False,
)
return self.df(fields, disable_date_parsing=disable_date_parsing)

def df(
Expand All @@ -130,6 +158,13 @@ def df(
if self.only_supports_classic:
raise OnlySupportsClassicFormatException()
self._verify_parameters()

if self.use_cache:
with Cache(CACHE_DIRECTORY) as cache:
cache_key = self._get_cache_key("df")
if cache_key in cache:
return cast(DataFrame, cache[cache_key])

json = self.classic(fields, disable_type_parsing=True)
rows = json.get("epidata", [])
pred = fields_to_predicate(fields)
Expand All @@ -145,7 +180,8 @@ def df(
data_types[info.name] = bool
elif info.type == EpidataFieldType.categorical:
data_types[info.name] = CategoricalDtype(
categories=Series(info.categories) if info.categories else None, ordered=True
categories=Series(info.categories) if info.categories else None,
ordered=True,
)
elif info.type == EpidataFieldType.int:
data_types[info.name] = "Int64"
Expand All @@ -166,6 +202,8 @@ def df(
for info in time_fields:
if info.type == EpidataFieldType.epiweek:
continue
# Try two date foramts, otherwise keep as string. The try except
# is needed because the time field might be date_or_epiweek.
try:
df[info.name] = to_datetime(df[info.name], format="%Y-%m-%d")
continue
Expand All @@ -175,21 +213,33 @@ def df(
df[info.name] = to_datetime(df[info.name], format="%Y%m%d")
except ValueError:
pass

if self.use_cache:
with Cache(CACHE_DIRECTORY) as cache:
cache_key = self._get_cache_key("df")
cache.set(cache_key, df, expire=self.cache_max_age_days * 24 * 60 * 60)

return df


class EpiDataContext(AEpiDataEndpoints[EpiDataCall]):
"""
sync epidata call class
"""
"""sync epidata call class"""

_base_url: Final[str]
_session: Final[Optional[Session]]

def __init__(self, base_url: str = BASE_URL, session: Optional[Session] = None) -> None:
def __init__(
self,
base_url: str = BASE_URL,
session: Optional[Session] = None,
use_cache: Optional[bool] = None,
cache_max_age_days: Optional[int] = None,
) -> None:
super().__init__()
self._base_url = base_url
self._session = session
self.use_cache = use_cache
self.cache_max_age_days = cache_max_age_days

def with_base_url(self, base_url: str) -> "EpiDataContext":
return EpiDataContext(base_url, self._session)
Expand All @@ -204,13 +254,24 @@ def _create_call(
meta: Optional[Sequence[EpidataFieldInfo]] = None,
only_supports_classic: bool = False,
) -> EpiDataCall:
return EpiDataCall(self._base_url, self._session, endpoint, params, meta, only_supports_classic)


Epidata = EpiDataContext()


def CovidcastEpidata(base_url: str = BASE_URL, session: Optional[Session] = None) -> CovidcastDataSources[EpiDataCall]:
return EpiDataCall(
self._base_url,
self._session,
endpoint,
params,
meta,
only_supports_classic,
self.use_cache,
self.cache_max_age_days,
)


def CovidcastEpidata(
base_url: str = BASE_URL,
session: Optional[Session] = None,
use_cache: Optional[bool] = None,
cache_max_age_days: Optional[int] = None,
) -> CovidcastDataSources[EpiDataCall]:
url = add_endpoint_to_url(base_url, "covidcast/meta")
meta_data_res = _request_with_retry(url, {}, session, False)
meta_data_res.raise_for_status()
Expand All @@ -219,6 +280,14 @@ def CovidcastEpidata(base_url: str = BASE_URL, session: Optional[Session] = None
def create_call(
params: Mapping[str, Optional[EpiRangeParam]],
) -> EpiDataCall:
return EpiDataCall(base_url, session, "covidcast", params, define_covidcast_fields())
return EpiDataCall(
base_url,
session,
"covidcast",
params,
define_covidcast_fields(),
use_cache=use_cache,
cache_max_age_days=cache_max_age_days,
)

return CovidcastDataSources.create(meta_data, create_call)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ classifiers = [
requires-python = ">=3.8"
dependencies = [
"aiohttp",
"appdirs",
"diskcache",
"epiweeks>=2.1",
"pandas>=1",
"requests>=2.25",
Expand Down
11 changes: 6 additions & 5 deletions smoke_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from datetime import date

from epidatpy import CovidcastEpidata, Epidata, EpiRange
from epidatpy import CovidcastEpidata, EpiDataContext, EpiRange

print("Epidata Test")
apicall = Epidata.pub_covidcast("fb-survey", "smoothed_cli", "nation", "day", "us", EpiRange(20210405, 20210410))
epidata = EpiDataContext(use_cache=True, cache_max_age_days=1)
apicall = epidata.pub_covidcast("fb-survey", "smoothed_cli", "nation", "day", "us", EpiRange(20210405, 20210410))

# Call info
print(apicall)
Expand All @@ -27,17 +28,17 @@
print(df.iloc[0])


StagingEpidata = Epidata.with_base_url("https://staging.delphi.cmu.edu/epidata/")
staging_epidata = epidata.with_base_url("https://staging.delphi.cmu.edu/epidata/")

epicall = StagingEpidata.pub_covidcast(
epicall = staging_epidata.pub_covidcast(
"fb-survey", "smoothed_cli", "nation", "day", "*", EpiRange(date(2021, 4, 5), date(2021, 4, 10))
)
print(epicall._base_url)


# Covidcast test
print("Covidcast Test")
epidata = CovidcastEpidata()
epidata = CovidcastEpidata(use_cache=True, cache_max_age_days=1)
print(epidata.source_names())
print(epidata.signal_names("fb-survey"))
epidata["fb-survey"].signal_df
Expand Down
Loading
Loading