Skip to content

Commit 6062b25

Browse files
committed
feat: enforce data frame dtypes
closes #5
1 parent 37a7cf9 commit 6062b25

File tree

5 files changed

+63
-19
lines changed

5 files changed

+63
-19
lines changed

delphi_epidata/_model.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,25 @@
22
from enum import Enum
33
from datetime import date
44
from urllib.parse import urlencode
5-
from typing import Final, Generic, Iterable, List, Mapping, Optional, Sequence, Tuple, TypeVar, TypedDict, Union
6-
7-
from pandas import DataFrame, to_datetime
8-
9-
from ._parse import parse_api_date, parse_api_week
5+
from typing import (
6+
Any,
7+
Dict,
8+
Final,
9+
Generic,
10+
Iterable,
11+
List,
12+
Mapping,
13+
Optional,
14+
Sequence,
15+
Tuple,
16+
TypeVar,
17+
TypedDict,
18+
Union,
19+
)
20+
21+
from pandas import DataFrame, CategoricalDtype
22+
23+
from ._parse import parse_api_date, parse_api_week, fields_to_predicate
1024

1125
EpiRangeDict = TypedDict("EpiRangeDict", {"from": int, "to": int})
1226
EpiRangeLike = Union[int, str, "EpiRange", EpiRangeDict, date]
@@ -224,17 +238,29 @@ def _parse_row(
224238
def _as_df(
225239
self,
226240
rows: Sequence[Mapping[str, Union[str, float, int, date, None]]],
241+
fields: Optional[Iterable[str]] = None,
227242
disable_date_parsing: Optional[bool] = False,
228243
) -> DataFrame:
229-
# TODO define data frame dtypes for each column
230-
df = DataFrame(rows)
244+
pred = fields_to_predicate(fields)
245+
columns: List[str] = [info.name for info in self.meta if pred(info.name)]
246+
df = DataFrame(rows, columns=columns or None)
247+
248+
data_types: Dict[str, Any] = {}
231249
for info in self.meta:
232-
if (
233-
info.type in (EpidataFieldType.date, EpidataFieldType.epiweek)
234-
and info.name in df.columns
235-
and not disable_date_parsing
236-
):
237-
df[info.name] = to_datetime(df[info.name])
238-
if info.type == EpidataFieldType.categorical and info.categories and info.name in df.columns:
239-
df[info.name] = df[info.name].astype("category").cat.set_categories(info.categories, ordered=True)
250+
if not pred(info.name):
251+
continue
252+
if info.type == EpidataFieldType.bool:
253+
data_types[info.name] = bool
254+
elif info.type == EpidataFieldType.categorical:
255+
data_types[info.name] = CategoricalDtype(categories=info.categories or None, ordered=True)
256+
elif info.type == EpidataFieldType.int:
257+
data_types[info.name] = int
258+
elif info.type in (EpidataFieldType.date, EpidataFieldType.epiweek):
259+
data_types[info.name] = int if disable_date_parsing else "datetime64"
260+
elif info.type == EpidataFieldType.float:
261+
data_types[info.name] = float
262+
else:
263+
data_types[info.name] = str
264+
if data_types:
265+
df = df.astype(data_types)
240266
return df

delphi_epidata/_parse.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, cast
1+
from typing import Callable, Iterable, Optional, Set, cast
22

33
from typing import Union
44
from datetime import date, datetime
@@ -16,3 +16,16 @@ def parse_api_week(value: Union[str, int, float, None]) -> Optional[date]:
1616
if value is None:
1717
return None
1818
return cast(date, Week.fromstring(str(value)).startdate())
19+
20+
21+
def fields_to_predicate(fields: Optional[Iterable[str]] = None) -> Callable[[str], bool]:
22+
if not fields:
23+
return lambda _: True
24+
to_include: Set[str] = set()
25+
to_exclude: Set[str] = set()
26+
for f in fields:
27+
if f.startswith("-"):
28+
to_exclude.add(f[1:])
29+
else:
30+
to_include.add(f)
31+
return lambda f: (f not in to_exclude and (not to_include or f in to_include))

delphi_epidata/async_requests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ async def df(
121121
if self.only_supports_classic:
122122
raise OnlySupportsClassicFormatException()
123123
r = await self.json(fields, disable_date_parsing=disable_date_parsing)
124-
return self._as_df(r, disable_date_parsing)
124+
return self._as_df(r, fields, disable_date_parsing)
125125

126126
async def csv(self, fields: Optional[Iterable[str]] = None) -> str:
127127
"""Request and parse epidata in CSV format"""

delphi_epidata/requests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ def df(self, fields: Optional[Iterable[str]] = None, disable_date_parsing: Optio
109109
"""Request and parse epidata as a pandas data frame"""
110110
if self.only_supports_classic:
111111
raise OnlySupportsClassicFormatException()
112-
r = self.json(fields)
113-
return self._as_df(r, disable_date_parsing=disable_date_parsing)
112+
r = self.json(fields, disable_date_parsing=disable_date_parsing)
113+
return self._as_df(r, fields, disable_date_parsing=disable_date_parsing)
114114

115115
def csv(self, fields: Optional[Iterable[str]] = None) -> str:
116116
"""Request and parse epidata in CSV format"""

smoke_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717
df = apicall.df()
1818
print(df.columns)
1919
print(df.dtypes)
20+
print(df.iloc[0])
21+
df = apicall.df(disable_date_parsing=True)
22+
print(df.columns)
23+
print(df.dtypes)
24+
print(df.iloc[0])
2025

2126
for row in apicall.iter():
2227
print(row)

0 commit comments

Comments
 (0)