Skip to content

Commit dfaf7e0

Browse files
committed
JIT: add Pandas JIT approach
1 parent 0586952 commit dfaf7e0

File tree

4 files changed

+80
-66
lines changed

4 files changed

+80
-66
lines changed

integrations/server/test_covidcast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def test_csv_format(self):
138138
**{'format':'csv'}
139139
)
140140

141-
# TODO: This is a mess because of api.php.
141+
# TODO: This is a mess because of api.php. Or maybe it's just a mess.
142142
column_order = [
143143
"geo_value", "signal", "time_value", "direction", "issue", "lag", "missing_value",
144144
"missing_stderr", "missing_sample_size", "value", "stderr", "sample_size"

src/server/endpoints/covidcast.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from flask import Blueprint, request
88
from flask.json import loads, jsonify
99
from more_itertools import peekable
10-
from numpy import nan
10+
from numpy import int64
1111
from sqlalchemy import text
1212
from pandas import read_csv, to_datetime
1313

@@ -136,22 +136,9 @@ def parse_transform_args():
136136
if smoother_window_length is None:
137137
smoother_window_length = 7
138138

139-
# TODO: Add support for floats inputs here.
140-
# The value to fill for missing date values.
141-
pad_fill_value = extract_integer("pad_fill_value")
142-
if pad_fill_value is None:
143-
pad_fill_value = nan
144-
145-
# The value to fill for None or nan values.
146-
nan_fill_value = extract_integer("nans_fill_value")
147-
if nan_fill_value is None:
148-
nan_fill_value = nan
149-
150139
smoother_args = {
151140
"smoother_kernel": SmootherKernelValue.average,
152141
"smoother_window_length": smoother_window_length if isinstance(smoother_window_length, Number) and smoother_window_length <= MAX_SMOOTHER_WINDOW else MAX_SMOOTHER_WINDOW,
153-
"pad_fill_value": pad_fill_value if isinstance(pad_fill_value, Number) else nan,
154-
"nans_fill_value": nan_fill_value if isinstance(nan_fill_value, Number) else nan
155142
}
156143
return smoother_args
157144

@@ -210,7 +197,7 @@ def alias_row(row):
210197

211198
def gen_transform(rows):
212199
parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows)
213-
transformed_rows = row_transform_generator(parsed_rows=parsed_rows, time_pairs=time_pairs, transform_args=transform_args)
200+
transformed_rows = row_transform_generator(parsed_rows=parsed_rows, transform_args=transform_args)
214201
for row in transformed_rows:
215202
yield alias_row(row)
216203
else:
@@ -300,7 +287,7 @@ def gen_trend(rows):
300287

301288
def gen_transform(rows):
302289
parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows)
303-
transformed_rows = row_transform_generator(parsed_rows=parsed_rows, time_pairs=[TimePair("day", [time_window])], transform_args=transform_args)
290+
transformed_rows = row_transform_generator(parsed_rows=parsed_rows, transform_args=transform_args)
304291
for row in transformed_rows:
305292
yield row
306293
else:
@@ -380,7 +367,7 @@ def gen_trend(rows):
380367

381368
def gen_transform(rows):
382369
parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows)
383-
transformed_rows = row_transform_generator(parsed_rows=parsed_rows, time_pairs=[TimePair("day", [time_window])], transform_args=transform_args)
370+
transformed_rows = row_transform_generator(parsed_rows=parsed_rows, transform_args=transform_args)
384371
for row in transformed_rows:
385372
yield row
386373
else:
@@ -527,7 +514,7 @@ def handle_export():
527514

528515
def gen_transform(rows):
529516
parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows)
530-
transformed_rows = row_transform_generator(parsed_rows=parsed_rows, time_pairs=[TimePair("day", [time_window])], transform_args=transform_args)
517+
transformed_rows = row_transform_generator(parsed_rows=parsed_rows, transform_args=transform_args)
531518
for row in transformed_rows:
532519
yield row
533520
else:
@@ -556,7 +543,7 @@ def parse_csv_row(i, row):
556543
"geo_value": row["geo_value"],
557544
"signal": row["signal"],
558545
"time_value": time_value_to_iso(row["time_value"]) if is_day else row["time_value"],
559-
"issue": time_value_to_iso(row["issue"]) if is_day else row["issue"],
546+
"issue": time_value_to_iso(row["issue"]) if is_day and isinstance(row["issue"], (int, int64)) else row["issue"],
560547
"lag": row["lag"],
561548
"value": row["value"],
562549
"stderr": row["stderr"],

src/server/endpoints/covidcast_utils/model.py

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
from dataclasses import asdict, dataclass, field
22
from enum import Enum
33
from functools import partial
4-
from itertools import groupby, repeat, tee
4+
from itertools import groupby
55
from numbers import Number
66
from typing import Callable, Generator, Iterator, Optional, Dict, List, Set, Tuple, Union
77

88
from pathlib import Path
99
import re
10-
from more_itertools import flatten, interleave_longest, peekable
10+
from more_itertools import flatten, peekable
1111
import pandas as pd
1212
import numpy as np
1313

1414
from delphi_utils.nancodes import Nans
1515
from ..._params import SourceSignalPair, TimePair
1616
from .smooth_diff import generate_smoothed_rows, generate_diffed_rows
17-
from ...utils import shift_time_value, iterate_over_ints_and_ranges
17+
from ...utils import shift_time_value, iterate_over_ints_and_ranges, iterate_over_range
1818

1919

2020
IDENTITY: Callable = lambda rows, **kwargs: rows
@@ -489,7 +489,6 @@ def get_day_range(time_pairs: List[TimePair]) -> Iterator[int]:
489489

490490
def _generate_transformed_rows(
491491
parsed_rows: Iterator[Dict],
492-
time_pairs: Optional[List[TimePair]] = None,
493492
transform_dict: Optional[SignalTransforms] = None,
494493
transform_args: Optional[Dict] = None,
495494
group_keyfunc: Optional[Callable] = None,
@@ -499,9 +498,6 @@ def _generate_transformed_rows(
499498
Parameters:
500499
parsed_rows: Iterator[Dict]
501500
An iterator streaming rows from a database query. Assumed to be sorted by source, signal, geo_type, geo_value, time_type, and time_value.
502-
time_pairs: Optional[List[TimePair]], default None
503-
A list of TimePairs, which can be used to create a continguous time index for time-series operations.
504-
The min and max dates in the TimePairs list is used.
505501
transform_dict: Optional[SignalTransforms], default None
506502
A dictionary mapping base sources to a list of their derived signals that the user wishes to query.
507503
For example, transform_dict may be {("jhu-csse", "confirmed_cumulative_num): [("jhu-csse", "confirmed_incidence_num"), ("jhu-csse", "confirmed_7dav_incidence_num")]}.
@@ -527,18 +523,56 @@ def _generate_transformed_rows(
527523
# Extract the list of derived signals; if a signal is not in the dictionary, then use the identity map.
528524
derived_signal_transform_map: SourceSignalPair = transform_dict.get(SourceSignalPair(base_source_name, [base_signal_name]), SourceSignalPair(base_source_name, [base_signal_name]))
529525
# Create a list of source-signal pairs along with the transformation required for the signal.
530-
signal_names_and_transforms: List[Tuple[Tuple[str, str], Callable]] = [(derived_signal, _get_base_signal_transform((base_source_name, derived_signal))) for derived_signal in derived_signal_transform_map.signal]
531-
# Put the current time series on a contiguous time index.
532-
source_signal_geo_rows = _reindex_iterable(source_signal_geo_rows, time_pairs, fill_value=transform_args.get("pad_fill_value"))
533-
# Create copies of the iterable, with smart memory usage.
534-
source_signal_geo_rows_copies: Iterator[Iterator[Dict]] = tee(source_signal_geo_rows, len(signal_names_and_transforms))
535-
# Create a list of transformed group iterables, remembering their derived name as needed.
536-
transformed_signals_iterator: Iterator[Tuple[str, Iterator[Dict]]] = (zip(repeat(derived_signal), transform(rows, **transform_args)) for (derived_signal, transform), rows in zip(signal_names_and_transforms, source_signal_geo_rows_copies))
537-
# Traverse through the transformed iterables in an interleaved fashion, which makes sure that only a small window
538-
# of the original iterable (group) is stored in memory.
539-
for derived_signal_name, row in interleave_longest(*transformed_signals_iterator):
540-
row["signal"] = derived_signal_name
541-
yield row
526+
signal_names_and_transforms: List[Tuple[str, Callable]] = [(derived_signal, _get_base_signal_transform((base_source_name, derived_signal))) for derived_signal in derived_signal_transform_map.signal]
527+
528+
# TODO: Fix these to come as an argument.
529+
fields_string = ["geo_type", "geo_value", "source", "signal", "time_type"]
530+
fields_int = ["time_value", "direction", "issue", "lag", "missing_value", "missing_stderr", "missing_sample_size"]
531+
fields_float = ["value", "stderr", "sample_size"]
532+
columns = fields_string + fields_int + fields_float
533+
df = pd.DataFrame.from_records(source_signal_geo_rows, columns=columns)
534+
for derived_signal, transform in signal_names_and_transforms:
535+
if transform == IDENTITY:
536+
yield from df.to_dict(orient="records")
537+
continue
538+
539+
df2 = df.set_index(["time_value"])
540+
df2 = df2.reindex(iterate_over_range(df2.index.min(), df2.index.max(), inclusive=True))
541+
542+
if transform == DIFF:
543+
df2["value"] = df2["value"].diff()
544+
window_length = 2
545+
elif transform == SMOOTH:
546+
df2["value"] = df2["value"].rolling(7).mean()
547+
window_length = 7
548+
elif transform == DIFF_SMOOTH:
549+
df2["value"] = df2["value"].diff().rolling(7).mean()
550+
window_length = 8
551+
else:
552+
raise ValueError(f"Unknown transform for {derived_signal}.")
553+
554+
df2 = df2.assign(
555+
geo_type = df2["geo_type"].fillna(method="ffill"),
556+
geo_value = df2["geo_value"].fillna(method="ffill"),
557+
source = df2["source"].fillna(method="ffill"),
558+
signal = derived_signal,
559+
time_type = df2["time_type"].fillna(method="ffill"),
560+
direction = df2["direction"].fillna(method="ffill"),
561+
issue = df2["issue"].rolling(window_length).max(),
562+
lag = df2["lag"].rolling(window_length).max(),
563+
missing_value=np.where(df2["value"].isna(), Nans.NOT_APPLICABLE, Nans.NOT_MISSING),
564+
missing_stderr=Nans.NOT_APPLICABLE,
565+
missing_sample_size=Nans.NOT_APPLICABLE,
566+
stderr=np.nan,
567+
sample_size=np.nan,
568+
)
569+
df2 = df2.iloc[window_length - 1:]
570+
for row in df2.reset_index().to_dict(orient="records"):
571+
row.update({
572+
"issue": int(row["issue"]) if not np.isnan(row["issue"]) else row["issue"],
573+
"lag": int(row["lag"]) if not np.isnan(row["lag"]) else row["lag"]
574+
})
575+
yield row
542576

543577

544578
def get_basename_signal_and_jit_generator(source_signal_pairs: List[SourceSignalPair], transform_args: Optional[Dict[str, Union[str, int]]] = None) -> Tuple[List[SourceSignalPair], Generator]:

tests/server/endpoints/covidcast_utils/test_model.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from more_itertools import interleave_longest, windowed
77
from pandas.testing import assert_frame_equal
88

9-
from delphi.epidata.acquisition.covidcast.covidcast_row import CovidcastRows
9+
from delphi.epidata.acquisition.covidcast.covidcast_row import CovidcastRows, assert_frame_equal_no_order
1010
from delphi.epidata.server._params import SourceSignalPair, TimePair
1111
from delphi.epidata.server.endpoints.covidcast_utils.model import (
1212
DIFF,
@@ -163,7 +163,7 @@ def test__generate_transformed_rows(self):
163163
missing_sample_size=[Nans.NOT_APPLICABLE] * 4,
164164
).api_row_df
165165

166-
assert_frame_equal(df, expected_df)
166+
assert_frame_equal_no_order(df, expected_df, index=["signal", "geo_value", "time_value"])
167167

168168
with self.subTest("smoothed and diffed signals on one base test"):
169169
data = CovidcastRows.from_args(
@@ -184,11 +184,7 @@ def test__generate_transformed_rows(self):
184184
sample_size=[None] * 13,
185185
).api_row_df
186186

187-
# Test no order.
188-
idx = ["source", "signal", "time_value"]
189-
assert_frame_equal(df.set_index(idx).sort_index(), expected_df.set_index(idx).sort_index())
190-
# Test order.
191-
assert_frame_equal(df, expected_df)
187+
assert_frame_equal_no_order(df, expected_df, index=["signal", "geo_value", "time_value"])
192188

193189
with self.subTest("smoothed and diffed signal on two non-continguous regions"):
194190
data = CovidcastRows.from_args(
@@ -199,9 +195,8 @@ def test__generate_transformed_rows(self):
199195
sample_size=range(15),
200196
).api_row_df
201197
transform_dict = {SourceSignalPair("src", ["sig_base"]): SourceSignalPair("src", ["sig_diff", "sig_smooth"])}
202-
time_pairs = [TimePair("day", [(20210501, 20210520)])]
203198
df = CovidcastRows.from_records(
204-
_generate_transformed_rows(data.to_dict(orient="records"), time_pairs=time_pairs, transform_dict=transform_dict)
199+
_generate_transformed_rows(data.to_dict(orient="records"), transform_dict=transform_dict)
205200
).api_row_df
206201

207202
filled_values = data.value.to_list()[:10] + [None] * 5 + data.value.to_list()[10:]
@@ -215,11 +210,8 @@ def test__generate_transformed_rows(self):
215210
sample_size=[None] * 33,
216211
issue=interleave_longest(_reindex_windowed(filled_time_values, 2), _reindex_windowed(filled_time_values, 7)),
217212
).api_row_df
218-
# Test no order.
219-
idx = ["source", "signal", "time_value"]
220-
assert_frame_equal(df.set_index(idx).sort_index(), expected_df.set_index(idx).sort_index())
221-
# Test order.
222-
assert_frame_equal(df, expected_df)
213+
214+
assert_frame_equal_no_order(df, expected_df, index=["signal", "geo_value", "time_value"])
223215
# fmt: on
224216

225217
def test_get_basename_signals(self):
@@ -258,44 +250,46 @@ def test_get_basename_signals(self):
258250
).api_row_df
259251
source_signal_pairs = [SourceSignalPair("src", ["sig_base", "sig_diff", "sig_other", "sig_smooth"])]
260252
_, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs)
261-
time_pairs = [TimePair("day", [(20210501, 20210530)])]
262-
df = CovidcastRows.from_records(row_transform_generator(data.to_dict(orient="records"), time_pairs=time_pairs)).api_row_df
253+
df = CovidcastRows.from_records(row_transform_generator(data.to_dict(orient="records"))).api_row_df
263254

264255
filled_values = list(chain(range(10), [None] * 10, range(10, 20)))
265256
filled_time_values = list(chain(pd.date_range("2021-05-01", "2021-05-10"), [None] * 10, pd.date_range("2021-05-21", "2021-05-30")))
266257

267258
expected_df = CovidcastRows.from_args(
268-
signal=["sig_base"] * 30 + ["sig_diff"] * 29 + ["sig_other"] * 5 + ["sig_smooth"] * 24,
259+
signal=["sig_base"] * 20 + ["sig_diff"] * 29 + ["sig_other"] * 5 + ["sig_smooth"] * 24,
269260
time_value=chain(
270-
pd.date_range("2021-05-01", "2021-05-30"),
261+
chain(pd.date_range("2021-05-01", "2021-05-10"), pd.date_range("2021-05-21", "2021-05-30")),
271262
pd.date_range("2021-05-02", "2021-05-30"),
272263
pd.date_range("2021-05-01", "2021-05-05"),
273264
pd.date_range("2021-05-07", "2021-05-30")
274265
),
275266
value=chain(
276-
filled_values,
267+
range(20),
277268
_diff_rows(filled_values),
278269
range(5),
279270
_smooth_rows(filled_values)
280271
),
281272
stderr=chain(
282-
chain(range(10), [None] * 10, range(10, 20)),
273+
range(20),
283274
chain([None] * 29),
284275
range(5),
285276
chain([None] * 24),
286277
),
287278
sample_size=chain(
288-
chain(range(10), [None] * 10, range(10, 20)),
279+
range(20),
289280
chain([None] * 29),
290281
range(5),
291282
chain([None] * 24),
292283
),
293-
issue=chain(filled_time_values, _reindex_windowed(filled_time_values, 2), pd.date_range("2021-05-01", "2021-05-05"), _reindex_windowed(filled_time_values, 7)),
284+
issue=chain(
285+
chain(pd.date_range("2021-05-01", "2021-05-10"), pd.date_range("2021-05-21", "2021-05-30")),
286+
_reindex_windowed(filled_time_values, 2),
287+
pd.date_range("2021-05-01", "2021-05-05"),
288+
_reindex_windowed(filled_time_values, 7)
289+
),
294290
).api_row_df
295291
# fmt: on
296-
# Test no order.
297-
idx = ["source", "signal", "time_value"]
298-
assert_frame_equal(df.set_index(idx).sort_index(), expected_df.set_index(idx).sort_index())
292+
assert_frame_equal_no_order(df, expected_df, index=["signal", "geo_value", "time_value"])
299293

300294
with self.subTest("test base, diff, smooth; multiple geos"):
301295
# fmt: off
@@ -341,8 +335,7 @@ def test_get_basename_signals(self):
341335
).api_row_df
342336
# fmt: on
343337
# Test no order.
344-
idx = ["source", "signal", "time_value"]
345-
assert_frame_equal(df.set_index(idx).sort_index(), expected_df.set_index(idx).sort_index())
338+
assert_frame_equal_no_order(df, expected_df, index=["signal", "geo_value", "time_value"])
346339

347340
with self.subTest("empty iterator"):
348341
source_signal_pairs = [SourceSignalPair("src", ["sig_base", "sig_diff", "sig_smooth"])]

0 commit comments

Comments
 (0)