Skip to content

Commit 545adcd

Browse files
committed
Server-side compute: add /trendseries endpoint
1 parent 90548d5 commit 545adcd

File tree

3 files changed

+85
-3
lines changed

3 files changed

+85
-3
lines changed

integrations/server/test_covidcast_endpoints.py

+60
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ def match_row(trend, row):
412412
self.assertEqual(trend["max_date"], first.time_value)
413413
self.assertEqual(trend["max_value"], first.value)
414414
self.assertEqual(trend["max_trend"], "steady")
415+
415416
with self.subTest("trend1"):
416417
trend = trends[1]
417418
match_row(trend, rows[1])
@@ -440,6 +441,65 @@ def match_row(trend, row):
440441
self.assertEqual(trend["max_value"], first.value)
441442
self.assertEqual(trend["max_trend"], "decreasing")
442443

444+
num_rows = 3
445+
time_value_pairs = [(20200331, 0)] + [(20200401 + i, v) for i, v in enumerate(accumulate([num_rows - i for i in range(num_rows)]))]
446+
rows = [CovidcastRow(source="jhu-csse", signal="confirmed_cumulative_num", time_value=t, value=v) for t, v in time_value_pairs]
447+
self._insert_rows(rows)
448+
diffed_rows = self._diff_covidcast_rows(rows)
449+
for row in diffed_rows:
450+
row.signal = "confirmed_incidence_num"
451+
first = diffed_rows[0]
452+
last = diffed_rows[-1]
453+
454+
out = self._fetch("/trendseries", signal="jhu-csse:confirmed_incidence_num", geo=first.geo_pair, date=last.time_value, window="20200401-20200410", basis=1)
455+
456+
self.assertEqual(out["result"], 1)
457+
self.assertEqual(len(out["epidata"]), 3)
458+
trends = out["epidata"]
459+
460+
with self.subTest("trend0, server-side compute"):
461+
trend = trends[0]
462+
match_row(trend, first)
463+
self.assertEqual(trend["basis_date"], None)
464+
self.assertEqual(trend["basis_value"], None)
465+
self.assertEqual(trend["basis_trend"], "unknown")
466+
467+
self.assertEqual(trend["min_date"], last.time_value)
468+
self.assertEqual(trend["min_value"], last.value)
469+
self.assertEqual(trend["min_trend"], "increasing")
470+
self.assertEqual(trend["max_date"], first.time_value)
471+
self.assertEqual(trend["max_value"], first.value)
472+
self.assertEqual(trend["max_trend"], "steady")
473+
474+
with self.subTest("trend1"):
475+
trend = trends[1]
476+
match_row(trend, diffed_rows[1])
477+
self.assertEqual(trend["basis_date"], first.time_value)
478+
self.assertEqual(trend["basis_value"], first.value)
479+
self.assertEqual(trend["basis_trend"], "decreasing")
480+
481+
self.assertEqual(trend["min_date"], last.time_value)
482+
self.assertEqual(trend["min_value"], last.value)
483+
self.assertEqual(trend["min_trend"], "increasing")
484+
self.assertEqual(trend["max_date"], first.time_value)
485+
self.assertEqual(trend["max_value"], first.value)
486+
self.assertEqual(trend["max_trend"], "decreasing")
487+
488+
with self.subTest("trend2"):
489+
trend = trends[2]
490+
match_row(trend, last)
491+
self.assertEqual(trend["basis_date"], diffed_rows[1].time_value)
492+
self.assertEqual(trend["basis_value"], diffed_rows[1].value)
493+
self.assertEqual(trend["basis_trend"], "decreasing")
494+
495+
self.assertEqual(trend["min_date"], last.time_value)
496+
self.assertEqual(trend["min_value"], last.value)
497+
self.assertEqual(trend["min_trend"], "steady")
498+
self.assertEqual(trend["max_date"], first.time_value)
499+
self.assertEqual(trend["max_value"], first.value)
500+
self.assertEqual(trend["max_trend"], "decreasing")
501+
502+
443503
def test_correlation(self):
444504
"""Request a signal the /correlation endpoint."""
445505

src/server/endpoints/covidcast.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -307,13 +307,21 @@ def handle_trendseries():
307307
daily_signals, weekly_signals = count_signal_time_types(source_signal_pairs)
308308
source_signal_pairs, alias_mapper = create_source_signal_alias_mapper(source_signal_pairs)
309309
geo_pairs = parse_geo_pairs()
310+
transform_args = parse_transform_args()
311+
jit_bypass = parse_jit_bypass()
310312

311313
time_window, is_day = parse_day_or_week_range_arg("window")
312314
_verify_argument_time_type_matches(is_day, daily_signals, weekly_signals)
313315
basis_shift = extract_integer(("basis", "basis_shift"))
314316
if basis_shift is None:
315317
basis_shift = 7
316318

319+
use_server_side_compute = is_day and JIT_COMPUTE and not jit_bypass
320+
if use_server_side_compute:
321+
pad_length = get_pad_length(source_signal_pairs, transform_args.get("smoother_window_length"))
322+
source_signal_pairs, row_transform_generator = get_basename_signals(source_signal_pairs)
323+
time_window = pad_time_window(time_window, pad_length)
324+
317325
# build query
318326
q = QueryBuilder("covidcast", "t")
319327

@@ -336,8 +344,20 @@ def handle_trendseries():
336344
if not is_day:
337345
shifter = lambda x: shift_week_value(x, -basis_shift)
338346

339-
def gen(rows):
340-
for key, group in groupby((parse_row(row, fields_string, fields_int, fields_float) for row in rows), lambda row: (row["geo_type"], row["geo_value"], row["source"], row["signal"])):
347+
if use_server_side_compute:
348+
def gen_transform(rows):
349+
parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows)
350+
transformed_rows = row_transform_generator(parsed_rows=parsed_rows, time_pairs=[TimePair("day", [time_window])], transform_args=transform_args)
351+
for row in transformed_rows:
352+
yield row
353+
else:
354+
def gen_transform(rows):
355+
parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows)
356+
for row in parsed_rows:
357+
yield row
358+
359+
def gen_trend(rows):
360+
for key, group in groupby(rows, lambda row: (row["geo_type"], row["geo_value"], row["source"], row["signal"])):
341361
geo_type, geo_value, source, signal = key
342362
if alias_mapper:
343363
source = alias_mapper(source, signal)
@@ -352,7 +372,7 @@ def gen(rows):
352372
raise DatabaseErrorException(str(e))
353373

354374
# now use a generator for sending the rows and execute all the other queries
355-
return p(filter_fields(gen(r)))
375+
return p(filter_fields(gen_trend(gen_transform(r))))
356376

357377

358378
@bp.route("/correlation", methods=("GET", "POST"))

src/server/endpoints/covidcast_utils/trend.py

+2
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ def compute_trends(geo_type: str, geo_value: str, signal_source: str, signal_sig
7575
lookup: Dict[int, float] = OrderedDict()
7676
# find all needed rows
7777
for time, value in rows:
78+
if value is None:
79+
continue
7880
lookup[time] = value
7981
if min_value is None or min_value > value:
8082
min_date = time

0 commit comments

Comments
 (0)