Skip to content

[JIT] Implement non-streaming JIT with polars #1093

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 48 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
6d82286
System: update delphi_web_python Docker image
dshemetov Dec 5, 2022
33037dc
devtools: remove unused Docker images from Makefile
dshemetov Dec 5, 2022
9a3c07d
CI: remove unused Docker images from build
dshemetov Dec 5, 2022
96153db
System: remove stale comment from Dockerfile
dshemetov Dec 5, 2022
8da9b0f
Docker: merge requirements.txt and pin versions #1043 (#1046)
melange396 Dec 5, 2022
869ff21
Server: add CovidcastRow helper class for testing
dshemetov Oct 7, 2022
db3405c
Server: update csv_to_database to use CovidcastRow
dshemetov Oct 7, 2022
f46b7a2
Server: update test_db to use CovidcastRow
dshemetov Oct 7, 2022
9a609e9
Server: update test_delete_batch to use CovidcastRow
dshemetov Oct 7, 2022
1032e5b
Server: update test_delphi_epidata to use CovidcastRow
dshemetov Oct 7, 2022
c106430
Server: update test_covidcast_endpoints to use CovidcastRow
dshemetov Oct 7, 2022
c2bcbb0
Server: update test_covidcast to use CovidcastRow
dshemetov Oct 7, 2022
357b3af
Server: update test_utils to use CovidcastRow
dshemetov Oct 7, 2022
7e5cc0f
Server: update TimePair to auto-sort tuples
dshemetov Oct 7, 2022
d23d599
Server: minor model.py data_source_by_id name update
dshemetov Oct 7, 2022
b707a66
Server: update csv issue none handling
dshemetov Dec 1, 2022
ca4e50d
Server: add type hints to _query
dshemetov Nov 4, 2022
98e45b7
Acquisition: update test_csv_uploading to remove Pandas warning
dshemetov Oct 11, 2022
bc40d33
Server: add PANDAS_DTYPES to model.py
dshemetov Dec 5, 2022
c391a28
Docker: add more_itertools==8.4.0 to Python and API images
dshemetov Dec 5, 2022
163185f
Acquisition: update database.py to use CovidcastRow
dshemetov Dec 5, 2022
aff6036
Docker: bump API and Python pandas to 1.5.1
dshemetov Dec 5, 2022
828836d
JIT: major feature commit
dshemetov Dec 5, 2022
8f2bdaf
CI: Build a container image from this branch
korlaxxalrok Oct 11, 2022
c1d7ca2
JIT: add Pandas JIT approach with optimizations
dshemetov Dec 2, 2022
87a3a60
CI: Update to build a JIT Pandas image
dshemetov Oct 31, 2022
dc171a9
JIT: Push new approach
dshemetov Dec 7, 2022
2d9f84f
JIT: update to not groupby geo again
dshemetov Dec 8, 2022
69ccf51
JIT: improve reindex_iterable
dshemetov Dec 9, 2022
e5d4bf2
Do concatenation all together
dshemetov Jan 19, 2023
23f608a
Add reindex_iterable2
dshemetov Jan 25, 2023
67a106b
Do the row assignment step all in one
dshemetov Jan 26, 2023
dd5ddc4
Remove old code
dshemetov Jan 26, 2023
59f7e49
JIT: try new Pandas approach
dshemetov Jan 28, 2023
11680ed
CI: update to build new Docker image
dshemetov Jan 28, 2023
eaabc19
CI: update to account for typo in branch name
dshemetov Jan 28, 2023
4d2967a
Partial code cleanup, derived_signals_map bugfix
dshemetov Feb 3, 2023
d286f3a
fix bug
dshemetov Feb 7, 2023
7f9c8e8
Better handle JIT routing
dshemetov Feb 9, 2023
2b5d546
New JIT optimization:
dshemetov Feb 9, 2023
205636f
More optimizations:
dshemetov Feb 10, 2023
7bc24d3
Merge pull request #1073 from cmu-delphi/ds/jit-pandas-mutli-sql
dshemetov Feb 22, 2023
3aa2252
JIT: clean up optimizations
dshemetov Feb 22, 2023
b431b9e
Update the image name in ci.yaml
dshemetov Feb 22, 2023
13bf7a1
Update requirements with polars pyarrow
dshemetov Feb 23, 2023
7298862
Update ci.yaml with new image tag
dshemetov Feb 23, 2023
283d2cd
JIT: Switch to polars
dshemetov Feb 23, 2023
e0e0d03
Fix test
dshemetov Feb 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions integrations/server/test_covidcast_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,21 +205,22 @@ def test_derived_signals(self):
out = self._fetch("/", signal="jhu-csse:confirmed_7dav_incidence_num", geo="county:01", time="day:20200407-20200420")
out_df = CovidcastRows.from_records(out["epidata"]).api_row_df.set_index(["signal", "geo_value", "time_value"])
merged_df = pd.merge(
expected_df.query("signal == 'confirmed_7dav_incidence_num' and geo_value == '01' and time_value >= 20200407 and time_value <= 20200420"),
expected_df.query("signal == 'confirmed_7dav_incidence_num' and geo_value == '01' and time_value >= 20200401 and time_value <= 20200420"),
out_df,
how="outer",
left_index=True,
right_index=True,
suffixes=["_out", "_expected"]
)[["value_out", "value_expected"]]
expected_df.query("signal == 'confirmed_7dav_incidence_num' and geo_value == '01' and time_value >= 20200407 and time_value <= 20200420").value
assert merged_df.empty is False
assert merged_df.value_out.to_numpy() == pytest.approx(merged_df.value_expected, nan_ok=True)

with self.subTest("smoothing and diffing with a time gap and geo=* and time=*"):
out = self._fetch("/", signal="jhu-csse:confirmed_7dav_incidence_num", geo="county:*", time="day:*")
out_df = pd.DataFrame.from_records(out["epidata"]).set_index(["signal", "time_value", "geo_value"])
merged_df = pd.merge(
expected_df.query("signal == 'confirmed_7dav_incidence_num' and time_value >= 20200407 and time_value <= 20200420"),
expected_df.query("signal == 'confirmed_7dav_incidence_num' and time_value >= 20200401 and time_value <= 20200420"),
out_df,
how="outer",
left_index=True,
Expand All @@ -234,8 +235,8 @@ def test_derived_signals(self):
out_df = pd.DataFrame.from_records(out["epidata"]).set_index(["signal", "time_value", "geo_value"])
query_lines = [
"(signal == 'confirmed_cumulative_num')",
"(signal == 'confirmed_incidence_num' and time_value >= 20200402 and time_value <= 20200420)",
"(signal == 'confirmed_7dav_incidence_num' and time_value >= 20200407 and time_value <= 20200420)",
"(signal == 'confirmed_incidence_num' and time_value >= 20200401 and time_value <= 20200420)",
"(signal == 'confirmed_7dav_incidence_num' and time_value >= 20200401 and time_value <= 20200420)",
]
merged_df = pd.merge(
expected_df.query(" or ".join(query_lines)),
Expand Down
22 changes: 12 additions & 10 deletions src/server/endpoints/covidcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from flask.json import loads, jsonify
from more_itertools import peekable
from sqlalchemy import text
from pandas import read_csv, to_datetime, concat, DataFrame
from pandas import read_csv, to_datetime
import polars as pl

from .._common import is_compatibility_mode, app, db
from .._config import MAX_SMOOTHER_WINDOW, MAX_RESULTS
Expand Down Expand Up @@ -157,11 +158,11 @@ def parse_jit_bypass():
MIMETYPE_JSON = "application/json"

def df_to_response(
df: DataFrame,
df: pl.DataFrame,
filename: Optional[str] = None,
) -> Response:
is_compatibility = is_compatibility_mode()
if df.empty:
if df.is_empty():
if is_compatibility:
return Response(
"""{"result": -2, "message": "no results"}""",
Expand All @@ -174,30 +175,31 @@ def df_to_response(
)

if is_compatibility:
df.drop(columns=["source", "geo_type", "time_type"], inplace=True, errors="ignore")
columns_to_drop = [x for x in ["source", "geo_type", "time_type"] if x in df.columns]
df = df.drop(columns=columns_to_drop)

fields = request.values.get("fields")
if fields:
keep_fields = []
for field in fields.split(","):
if field.startswith("-") and field[1:] in df.columns:
df.drop(columns=[field[1:]], inplace=True)
df.drop_in_place(field[1:])
elif field in df.columns:
keep_fields.append(field)
if keep_fields:
df = df[keep_fields]
df = df.select([keep_fields])
else:
keep_fields = df.columns

return_format = request.values.get("format", "classic")
if return_format == "classic":
json_str = df.to_json(orient="records")
json_str = df.write_json(row_oriented=True)
return Response(
"""{"epidata":""" + json_str + """, "result": 1, "message": "success"}""",
mimetype=MIMETYPE_JSON
)
elif return_format == "json":
json_str = df.to_json(orient="records")
json_str = df.write_json(row_oriented=True)
return Response(json_str, mimetype=MIMETYPE_JSON)
elif return_format == "csv":
column_order = [
Expand All @@ -208,7 +210,7 @@ def df_to_response(
filename = "epidata" if not filename else filename
headers = {"Content-Disposition": f"attachment; filename={filename}.csv"}
return Response(
df[cols].to_csv(index=False),
df[cols].write_csv(),
mimetype="text/csv; charset=utf8",
headers=headers
)
Expand All @@ -224,7 +226,7 @@ def jit_request_to_df(
lag: Optional[int],
alias_mapper: Optional[Callable[[str, str], str]],
transform_args: Dict[str, Any],
) -> DataFrame:
) -> pl.DataFrame:
"""Fetches data from the database, performs JIT transformations, and returns a DataFrame.

Assumptions:
Expand Down
166 changes: 85 additions & 81 deletions src/server/endpoints/covidcast_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pathlib import Path
import re
import pandas as pd
import polars as pl
import numpy as np

from delphi_utils.nancodes import Nans
Expand Down Expand Up @@ -609,38 +610,12 @@ def pad_time_window(time_window: TimePair, pad_length: int) -> TimePair:
return TimePair("day", [(shift_day_value(min_time, -1 * pad_length), max_time)])


def to_dict_custom(df: pd.DataFrame) -> Iterable[Dict[str, Any]]:
"""This is a workaround a performance bug in Pandas.

- See this issue: https://github.com/pandas-dev/pandas/issues/46470,
- The first if branch is to avoid using reset_index(), which I found to be a good deal slower than just reading the index,
- All the dtype conversions are to avoid JSON serialization errors (e.g. numpy.int64).
"""
df = df.reset_index()
col_arr_map = {col: df[col].to_numpy(dtype=object, na_value=None) for col in df.columns}

for i in range(len(df)):
yield {col: col_arr_map[col][i] for col in df.columns}


def _set_df_dtypes(df: pd.DataFrame, dtypes: Dict[str, Any]) -> pd.DataFrame:
"""Set the dataframe column datatypes."""
for d in dtypes.values():
try:
pd.api.types.pandas_dtype(d)
except TypeError:
raise ValueError(f"Invalid dtype {d}")

sub_dtypes = {k: v for k, v in dtypes.items() if k in df.columns}
return df.astype(sub_dtypes)


def generate_transformed_rows(
rows: Iterable[Dict],
transform_dict: Optional[Dict[SourceSignal, List[SourceSignal]]] = None,
transform_args: Optional[Dict] = None,
alias_mapper: Callable = None,
) -> pd.DataFrame:
) -> pl.DataFrame:
"""Applies time-series transformations to streamed rows from a database.

This function is written for performance, so many components are very fragile. Be careful.
Expand Down Expand Up @@ -679,74 +654,103 @@ def generate_transformed_rows(
transform_type = get_base_signal_transform((base_source_name, derived_signal_name))

if transform_type == SeriesTransform.identity:
identity_row_cols = ["time_value", "geo_value", "value", "sample_size", "stderr", "missing_value", "missing_sample_size", "missing_stderr", "issue", "lag"]
derived_df = pd.DataFrame(grouped_rows_copy, columns=identity_row_cols)
derived_df["time_value"] = pd.to_datetime(derived_df["time_value"], format="%Y%m%d")

# Set dtypes. Int8/Int64 are needed to allow null values.
# TODO: Try using StringDType instead of object. Or categorical. This is mostly for memory usage. No worries about to_dict.
derived_df = _set_df_dtypes(derived_df, PANDAS_DTYPES_TIME)

derived_df = derived_df.set_index("time_value")

derived_df = derived_df.assign(
source=alias_mapper(base_source_name, derived_signal_name),
signal=derived_signal_name,
geo_type=geo_type,
time_type="day",
direction=None
)
dfs.append(derived_df)
IDENTITY_POLARS_SCHEMA = {
"time_value": pl.Utf8,
"geo_value": pl.Utf8,
"value": pl.Float64,
"stderr": pl.Float64,
"sample_size": pl.Float64,
"missing_value": pl.Int8,
"missing_stderr": pl.Int8,
"missing_sample_size": pl.Int8,
"issue": pl.Int64,
"lag": pl.Int8,
}
derived_df = pl.DataFrame(grouped_rows_copy, IDENTITY_POLARS_SCHEMA)
derived_df = derived_df.with_columns(pl.col("time_value").str.strptime(pl.Date, fmt="%Y%m%d"))

derived_df = derived_df.with_columns([
pl.lit(alias_mapper(base_source_name, derived_signal_name)).alias("source"),
pl.lit(derived_signal_name).alias("signal"),
pl.lit(geo_type).alias("geo_type"),
pl.lit("day").alias("time_type"),
pl.lit(None).alias("direction")
])
# Reorder here to match with the order of the derived schema.
dfs.append(derived_df.select([
"time_value", "geo_value", "value", "issue", "lag",
"source", "signal", "stderr", "sample_size",
"missing_value", "missing_stderr", "missing_sample_size",
"geo_type", "time_type", "direction"
]))
continue

derived_df = pd.DataFrame(grouped_rows_copy, columns=["time_value", "geo_value", "value", "issue", "lag"])
derived_df["time_value"] = pd.to_datetime(derived_df["time_value"], format="%Y%m%d")

# Set dtypes. Int8/Int64 are needed to allow null values.
# TODO: Try using StringDType instead of object. Or categorical. This is mostly for memory usage. No worries about to_dict.
derived_df = _set_df_dtypes(derived_df, PANDAS_DTYPES_TIME)

derived_df = derived_df.set_index("time_value")
# breakpoint()
DERIVED_POLARS_SCHEMA = {
"time_value": pl.Utf8,
"geo_value": pl.Utf8,
"value": pl.Float64,
"issue": pl.Int64,
"lag": pl.Int8,
}
derived_df = pl.DataFrame(grouped_rows_copy, DERIVED_POLARS_SCHEMA)
derived_df = derived_df.with_columns(pl.col("time_value").str.strptime(pl.Date, fmt="%Y%m%d"))

if transform_type == SeriesTransform.diff:
derived_df["value"] = derived_df.groupby("geo_value", sort=False)["value"].diff()
window_length = 2
derived_df = derived_df.with_columns(pl.col("value").diff().over("geo_value").alias("value"))
derived_df = derived_df.with_columns(
derived_df.groupby_rolling("time_value", period="2d", by=["geo_value"]).agg([
pl.col("issue").max().alias("issue"),
pl.col("lag").max().alias("lag")
])
)
elif transform_type == SeriesTransform.smooth:
window_length = transform_args.get("smoother_window_length", 7)
derived_df["value"] = derived_df.groupby("geo_value", sort=False)["value"].rolling(f"{window_length}D", min_periods=window_length-1).mean().droplevel(level=0)
derived_df = derived_df.with_columns(
derived_df.groupby_rolling("time_value", period=f"{window_length}d", by=["geo_value"]).agg([
pl.col("value").mean().alias("value"),
pl.col("issue").max().alias("issue"),
pl.col("lag").max().alias("lag"),
])
)
elif transform_type == SeriesTransform.diff_smooth:
window_length = transform_args.get("smoother_window_length", 7)
derived_df["value"] = derived_df.groupby("geo_value", sort=False)["value"].diff()
derived_df["value"] = derived_df.groupby("geo_value", sort=False)["value"].rolling(f"{window_length}D", min_periods=window_length-1).mean().droplevel(level=0)
window_length += 1
derived_df = derived_df.with_columns(pl.col("value").diff().over("geo_value").alias("value"))
derived_df = derived_df.with_columns(
derived_df.groupby_rolling("time_value", period="2d", by=["geo_value"]).agg([
pl.col("issue").max().alias("issue"),
pl.col("lag").max().alias("lag")
])
)
derived_df = derived_df.with_columns(
derived_df.groupby_rolling("time_value", period=f"{window_length}d", by=["geo_value"]).agg([
pl.col("value").mean().alias("value"),
pl.col("issue").max().alias("issue"),
pl.col("lag").max().alias("lag"),
])
)
else:
raise ValueError(f"Unknown transform for {derived_signal}.")

derived_df = derived_df.assign(
source=alias_mapper(base_source_name, derived_signal_name),
signal=derived_signal_name,
issue=derived_df.groupby("geo_value", sort=False)["issue"].rolling(window_length).max().droplevel(level=0).astype("Int64") if "issue" in derived_df.columns else None,
lag=derived_df.groupby("geo_value", sort=False)["lag"].rolling(window_length).max().droplevel(level=0).astype("Int64") if "lag" in derived_df.columns else None,
stderr=np.nan,
sample_size=np.nan,
missing_value=np.where(derived_df["value"].isna(), Nans.NOT_APPLICABLE, Nans.NOT_MISSING),
missing_stderr=Nans.NOT_APPLICABLE,
missing_sample_size=Nans.NOT_APPLICABLE,
time_type="day",
geo_type=geo_type,
direction=None,
)
derived_df = derived_df.with_columns([
pl.lit(alias_mapper(base_source_name, derived_signal_name)).alias("source"),
pl.lit(derived_signal_name).alias("signal"),
pl.lit(np.nan).alias("stderr"),
pl.lit(np.nan).alias("sample_size"),
pl.when(pl.col("value").is_null()).then(Nans.NOT_APPLICABLE).otherwise(Nans.NOT_MISSING).alias("missing_value").cast(pl.Int8),
pl.lit(Nans.NOT_APPLICABLE).alias("missing_stderr").cast(pl.Int8),
pl.lit(Nans.NOT_APPLICABLE).alias("missing_sample_size").cast(pl.Int8),
pl.lit(geo_type).alias("geo_type"),
pl.lit("day").alias("time_type"),
pl.lit(None).alias("direction"),
])
dfs.append(derived_df)

if not dfs:
return pd.DataFrame()

derived_df_full = pd.concat(dfs)
# Ok to do in place because nothing else depends on this memory chunk.
derived_df_full.reset_index(inplace=True)
derived_df_full["time_value"] = derived_df_full["time_value"].dt.strftime("%Y%m%d").astype("Int64")
# TODO: Testing whether we really need this. It's an expensive operation.
# derived_df_full = _set_df_dtypes(derived_df_full, PANDAS_DTYPES)
return pl.DataFrame()

derived_df_full = pl.concat(dfs)
derived_df_full = derived_df_full.with_columns(pl.col("time_value").dt.strftime("%Y%m%d").cast(pl.Int64))
return derived_df_full


Expand Down
2 changes: 1 addition & 1 deletion src/server/endpoints/covidcast_utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def smooth_df(df: pd.DataFrame, signal_name: str, nan_fill_value: float = np.nan

for key, group_df in df.groupby(["source", "signal", "geo_value"]):
group_df = group_df.set_index("time_value").sort_index()
group_df["value"] = group_df["value"].fillna(nan_fill_value).rolling(f"{window_length}D", min_periods=window_length-1).mean()
group_df["value"] = group_df["value"].fillna(nan_fill_value).rolling(f"{window_length}D").mean()
group_df["stderr"] = np.nan
group_df["sample_size"] = np.nan
group_df["missing_value"] = np.where(group_df["value"].isna(), Nans.NOT_APPLICABLE, Nans.NOT_MISSING)
Expand Down
Loading