Skip to content

Commit e679e50

Browse files
committed
Server-side compute: add /csv endpoint
1 parent 965b4fc commit e679e50

File tree

2 files changed

+73
-18
lines changed

2 files changed

+73
-18
lines changed

integrations/server/test_covidcast_endpoints.py

+42-10
Original file line numberDiff line numberDiff line change
@@ -530,19 +530,51 @@ def test_correlation(self):
530530
def test_csv(self):
531531
"""Request a signal the /csv endpoint."""
532532

533+
expected_columns = ["geo_value", "signal", "time_value", "issue", "lag", "value", "stderr", "sample_size", "geo_type", "data_source"]
533534
rows = [CovidcastRow(time_value=20200401 + i, value=i) for i in range(10)]
534-
first = rows[0]
535535
self._insert_rows(rows)
536+
first = rows[0]
537+
with self.subTest("no server-side compute"):
538+
response = requests.get(
539+
f"{BASE_URL}/csv",
540+
params=dict(signal=first.signal_pair, start_day="2020-04-01", end_day="2020-04-10", geo_type=first.geo_type),
541+
)
542+
response.raise_for_status()
543+
out = response.text
544+
df = pd.read_csv(StringIO(out), index_col=0)
536545

537-
response = requests.get(
538-
f"{BASE_URL}/csv",
539-
params=dict(signal=first.signal_pair, start_day="2020-04-01", end_day="2020-12-12", geo_type=first.geo_type),
540-
)
541-
response.raise_for_status()
542-
out = response.text
543-
df = pd.read_csv(StringIO(out), index_col=0)
544-
self.assertEqual(df.shape, (len(rows), 10))
545-
self.assertEqual(list(df.columns), ["geo_value", "signal", "time_value", "issue", "lag", "value", "stderr", "sample_size", "geo_type", "data_source"])
546+
self.assertEqual(df.shape, (len(rows), 10))
547+
self.assertEqual(list(df.columns), expected_columns)
548+
549+
num_rows = 10
550+
time_value_pairs = [(20200331, 0)] + [(20200401 + i, v) for i, v in enumerate(accumulate(range(num_rows)))]
551+
rows = [CovidcastRow(source="jhu-csse", signal="confirmed_cumulative_num", time_value=t, value=v) for t, v in time_value_pairs]
552+
self._insert_rows(rows)
553+
first = rows[0]
554+
with self.subTest("use server-side compute"):
555+
response = requests.get(
556+
f"{BASE_URL}/csv",
557+
params=dict(signal="src:sig", start_day="2020-04-01", end_day="2020-04-10", geo_type=first.geo_type),
558+
)
559+
response.raise_for_status()
560+
out = response.text
561+
df = pd.read_csv(StringIO(out), index_col=0)
562+
df.stderr = np.nan
563+
df.sample_size = np.nan
564+
565+
response = requests.get(
566+
f"{BASE_URL}/csv",
567+
params=dict(signal="jhu-csse:confirmed_incidence_num", start_day="2020-04-01", end_day="2020-04-10", geo_type=first.geo_type),
568+
)
569+
response.raise_for_status()
570+
out = response.text
571+
df_diffed = pd.read_csv(StringIO(out), index_col=0)
572+
df_diffed.signal = "sig"
573+
df_diffed.data_source = "src"
574+
575+
self.assertEqual(df_diffed.shape, (num_rows, 10))
576+
self.assertEqual(list(df_diffed.columns), expected_columns)
577+
pd.testing.assert_frame_equal(df_diffed, df)
546578

547579
def test_backfill(self):
548580
"""Request a signal the /backfill endpoint."""

src/server/endpoints/covidcast.py

+31-8
Original file line numberDiff line numberDiff line change
@@ -458,9 +458,12 @@ def handle_export():
458458
source_signal_pairs, alias_mapper = create_source_signal_alias_mapper(source_signal_pairs)
459459
start_day, is_day = parse_day_or_week_arg("start_day", 202001 if weekly_signals > 0 else 20200401)
460460
end_day, is_end_day = parse_day_or_week_arg("end_day", 202020 if weekly_signals > 0 else 20200901)
461+
time_window = (start_day, end_day)
461462
if is_day != is_end_day:
462463
raise ValidationFailedException("mixing weeks with day arguments")
463464
_verify_argument_time_type_matches(is_day, daily_signals, weekly_signals)
465+
transform_args = parse_transform_args()
466+
jit_bypass = parse_jit_bypass()
464467

465468
geo_type = request.args.get("geo_type", "county")
466469
geo_values = request.args.get("geo_values", "*")
@@ -472,13 +475,22 @@ def handle_export():
472475
if is_day != is_as_of_day:
473476
raise ValidationFailedException("mixing weeks with day arguments")
474477

478+
use_server_side_compute = all([is_day, is_end_day]) and JIT_COMPUTE and not jit_bypass
479+
if use_server_side_compute:
480+
pad_length = get_pad_length(source_signal_pairs, transform_args.get("smoother_window_length"))
481+
source_signal_pairs, row_transform_generator = get_basename_signals(source_signal_pairs)
482+
time_window = pad_time_window(time_window, pad_length)
483+
475484
# build query
476485
q = QueryBuilder("covidcast", "t")
477486

478-
q.set_fields(["geo_value", "signal", "time_value", "issue", "lag", "value", "stderr", "sample_size", "geo_type", "source"], [], [])
487+
fields_string = ["geo_value", "signal", "geo_type", "source"]
488+
fields_int = ["time_value", "issue", "lag"]
489+
fields_float = ["value", "stderr", "sample_size"]
490+
q.set_fields(fields_string + fields_int + fields_float, [], [])
479491
q.set_order("time_value", "geo_value")
480492
q.where_source_signal_pairs("source", "signal", source_signal_pairs)
481-
q.where_time_pairs("time_type", "time_value", [TimePair("day" if is_day else "week", [(start_day, end_day)])])
493+
q.where_time_pairs("time_type", "time_value", [TimePair("day" if is_day else "week", [time_window])])
482494
q.where_geo_pairs("geo_type", "geo_value", [GeoPair(geo_type, True if geo_values == "*" else geo_values)])
483495

484496
_handle_lag_issues_as_of(q, None, None, as_of)
@@ -489,7 +501,7 @@ def handle_export():
489501
filename = "covidcast-{source}-{signal}-{start_day}-to-{end_day}{as_of}".format(source=source, signal=signal, start_day=format_date(start_day), end_day=format_date(end_day), as_of=as_of_str)
490502
p = CSVPrinter(filename)
491503

492-
def parse_row(i, row):
504+
def parse_csv_row(i, row):
493505
# '',geo_value,signal,{time_value,issue},lag,value,stderr,sample_size,geo_type,data_source
494506
return {
495507
"": i,
@@ -505,10 +517,20 @@ def parse_row(i, row):
505517
"data_source": alias_mapper(row["source"], row["signal"]) if alias_mapper else row["source"],
506518
}
507519

508-
def gen(first_row, rows):
509-
yield parse_row(0, first_row)
520+
if use_server_side_compute:
521+
def gen_transform(rows):
522+
parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows)
523+
transformed_rows = row_transform_generator(parsed_rows=parsed_rows, time_pairs=[TimePair("day", [time_window])], transform_args=transform_args)
524+
for row in transformed_rows:
525+
yield row
526+
else:
527+
def gen_transform(rows):
528+
for row in rows:
529+
yield row
530+
531+
def gen_parse(rows):
510532
for i, row in enumerate(rows):
511-
yield parse_row(i + 1, row)
533+
yield parse_csv_row(i, row)
512534

513535
# execute query
514536
try:
@@ -517,14 +539,15 @@ def gen(first_row, rows):
517539
raise DatabaseErrorException(str(e))
518540

519541
# special case for no data to be compatible with the CSV server
520-
first_row = next(r, None)
542+
transformed_query = peekable(gen_transform(r))
543+
first_row = transformed_query.peek(None)
521544
if not first_row:
522545
return "No matching data found for signal {source}:{signal} " "at {geo} level from {start_day} to {end_day}, as of {as_of}.".format(
523546
source=source, signal=signal, geo=geo_type, start_day=format_date(start_day), end_day=format_date(end_day), as_of=(date.today().isoformat() if as_of is None else format_date(as_of))
524547
)
525548

526549
# now use a generator for sending the rows and execute all the other queries
527-
return p(gen(first_row, r))
550+
return p(gen_parse(transformed_query))
528551

529552

530553
@bp.route("/backfill", methods=("GET", "POST"))

0 commit comments

Comments
 (0)