Skip to content

Commit 90548d5

Browse files
committed
Server-side compute: add /trend endpoint
1 parent 5f72ebe commit 90548d5

File tree

3 files changed

+92
-25
lines changed

3 files changed

+92
-25
lines changed

integrations/server/test_covidcast_endpoints.py

+67-22
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,14 @@ def test_compatibility(self):
296296
out = self._fetch(is_compatibility=True, source=first.source, signal=first.signal, geo=first.geo_pair, time="day:*")
297297
self.assertEqual(len(out["epidata"]), len(rows))
298298

299+
def _diff_covidcast_rows(self, rows: List[CovidcastRow]) -> List[CovidcastRow]:
300+
new_rows = list()
301+
for x, y in zip(rows[1:], rows[:-1]):
302+
new_row = copy(x)
303+
new_row.value = x.value - y.value
304+
new_rows.append(new_row)
305+
return new_rows
306+
299307
def test_trend(self):
300308
"""Request a signal the /trend endpoint."""
301309

@@ -306,29 +314,66 @@ def test_trend(self):
306314
ref = rows[num_rows // 2]
307315
self._insert_rows(rows)
308316

309-
out = self._fetch("/trend", signal=first.signal_pair, geo=first.geo_pair, date=last.time_value, window="20200401-20201212", basis=ref.time_value)
317+
with self.subTest("no server-side compute"):
318+
out = self._fetch("/trend", signal=first.signal_pair, geo=first.geo_pair, date=last.time_value, window="20200401-20201212", basis=ref.time_value)
319+
320+
self.assertEqual(out["result"], 1)
321+
self.assertEqual(len(out["epidata"]), 1)
322+
trend = out["epidata"][0]
323+
self.assertEqual(trend["geo_type"], last.geo_type)
324+
self.assertEqual(trend["geo_value"], last.geo_value)
325+
self.assertEqual(trend["signal_source"], last.source)
326+
self.assertEqual(trend["signal_signal"], last.signal)
327+
328+
self.assertEqual(trend["date"], last.time_value)
329+
self.assertEqual(trend["value"], last.value)
330+
331+
self.assertEqual(trend["basis_date"], ref.time_value)
332+
self.assertEqual(trend["basis_value"], ref.value)
333+
self.assertEqual(trend["basis_trend"], "increasing")
334+
335+
self.assertEqual(trend["min_date"], first.time_value)
336+
self.assertEqual(trend["min_value"], first.value)
337+
self.assertEqual(trend["min_trend"], "increasing")
338+
self.assertEqual(trend["max_date"], last.time_value)
339+
self.assertEqual(trend["max_value"], last.value)
340+
self.assertEqual(trend["max_trend"], "steady")
341+
342+
num_rows = 30
343+
time_value_pairs = [(20200331, 0)] + [(20200401 + i, v) for i, v in enumerate(accumulate(range(num_rows)))]
344+
rows = [CovidcastRow(source="jhu-csse", signal="confirmed_cumulative_num", time_value=t, value=v) for t, v in time_value_pairs]
345+
self._insert_rows(rows)
346+
diffed_rows = self._diff_covidcast_rows(rows)
347+
for row in diffed_rows:
348+
row.signal = "confirmed_incidence_num"
349+
first = diffed_rows[0]
350+
last = diffed_rows[-1]
351+
ref = diffed_rows[num_rows // 2]
352+
with self.subTest("use server-side compute"):
353+
out = self._fetch("/trend", signal="jhu-csse:confirmed_incidence_num", geo=first.geo_pair, date=last.time_value, window="20200401-20201212", basis=ref.time_value)
354+
355+
self.assertEqual(out["result"], 1)
356+
self.assertEqual(len(out["epidata"]), 1)
357+
trend = out["epidata"][0]
358+
self.assertEqual(trend["geo_type"], last.geo_type)
359+
self.assertEqual(trend["geo_value"], last.geo_value)
360+
self.assertEqual(trend["signal_source"], last.source)
361+
self.assertEqual(trend["signal_signal"], last.signal)
362+
363+
self.assertEqual(trend["date"], last.time_value)
364+
self.assertEqual(trend["value"], last.value)
365+
366+
self.assertEqual(trend["basis_date"], ref.time_value)
367+
self.assertEqual(trend["basis_value"], ref.value)
368+
self.assertEqual(trend["basis_trend"], "increasing")
369+
370+
self.assertEqual(trend["min_date"], first.time_value)
371+
self.assertEqual(trend["min_value"], first.value)
372+
self.assertEqual(trend["min_trend"], "increasing")
373+
self.assertEqual(trend["max_date"], last.time_value)
374+
self.assertEqual(trend["max_value"], last.value)
375+
self.assertEqual(trend["max_trend"], "steady")
310376

311-
self.assertEqual(out["result"], 1)
312-
self.assertEqual(len(out["epidata"]), 1)
313-
trend = out["epidata"][0]
314-
self.assertEqual(trend["geo_type"], last.geo_type)
315-
self.assertEqual(trend["geo_value"], last.geo_value)
316-
self.assertEqual(trend["signal_source"], last.source)
317-
self.assertEqual(trend["signal_signal"], last.signal)
318-
319-
self.assertEqual(trend["date"], last.time_value)
320-
self.assertEqual(trend["value"], last.value)
321-
322-
self.assertEqual(trend["basis_date"], ref.time_value)
323-
self.assertEqual(trend["basis_value"], ref.value)
324-
self.assertEqual(trend["basis_trend"], "increasing")
325-
326-
self.assertEqual(trend["min_date"], first.time_value)
327-
self.assertEqual(trend["min_value"], first.value)
328-
self.assertEqual(trend["min_trend"], "increasing")
329-
self.assertEqual(trend["max_date"], last.time_value)
330-
self.assertEqual(trend["max_value"], last.value)
331-
self.assertEqual(trend["max_trend"], "steady")
332377

333378
def test_trendseries(self):
334379
"""Request a signal the /trendseries endpoint."""

src/server/endpoints/covidcast.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,8 @@ def handle_trend():
231231
daily_signals, weekly_signals = count_signal_time_types(source_signal_pairs)
232232
source_signal_pairs, alias_mapper = create_source_signal_alias_mapper(source_signal_pairs)
233233
geo_pairs = parse_geo_pairs()
234+
transform_args = parse_transform_args()
235+
jit_bypass = parse_jit_bypass()
234236

235237
time_window, is_day = parse_day_or_week_range_arg("window")
236238
time_value, is_also_day = parse_day_or_week_arg("date")
@@ -244,6 +246,12 @@ def handle_trend():
244246
base_shift = 7
245247
basis_time_value = shift_time_value(time_value, -1 * base_shift) if is_day else shift_week_value(time_value, -1 * base_shift)
246248

249+
use_server_side_compute = not any((not is_day, not is_also_day)) and JIT_COMPUTE and not jit_bypass
250+
if use_server_side_compute:
251+
pad_length = get_pad_length(source_signal_pairs, transform_args.get("smoother_window_length"))
252+
source_signal_pairs, row_transform_generator = get_basename_signals(source_signal_pairs)
253+
time_window = pad_time_window(time_window, pad_length)
254+
247255
# build query
248256
q = QueryBuilder("covidcast", "t")
249257

@@ -262,8 +270,20 @@ def handle_trend():
262270

263271
p = create_printer()
264272

265-
def gen(rows):
266-
for key, group in groupby((parse_row(row, fields_string, fields_int, fields_float) for row in rows), lambda row: (row["geo_type"], row["geo_value"], row["source"], row["signal"])):
273+
if use_server_side_compute:
274+
def gen_transform(rows):
275+
parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows)
276+
transformed_rows = row_transform_generator(parsed_rows=parsed_rows, time_pairs=[TimePair("day", [time_window])], transform_args=transform_args)
277+
for row in transformed_rows:
278+
yield row
279+
else:
280+
def gen_transform(rows):
281+
parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows)
282+
for row in parsed_rows:
283+
yield row
284+
285+
def gen_trend(rows):
286+
for key, group in groupby(rows, lambda row: (row["geo_type"], row["geo_value"], row["source"], row["signal"])):
267287
geo_type, geo_value, source, signal = key
268288
if alias_mapper:
269289
source = alias_mapper(source, signal)
@@ -277,7 +297,7 @@ def gen(rows):
277297
raise DatabaseErrorException(str(e))
278298

279299
# now use a generator for sending the rows and execute all the other queries
280-
return p(filter_fields(gen(r)))
300+
return p(filter_fields(gen_trend(gen_transform(r))))
281301

282302

283303
@bp.route("/trendseries", methods=("GET", "POST"))

src/server/endpoints/covidcast_utils/trend.py

+2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def compute_trend(geo_type: str, geo_value: str, signal_source: str, signal_sign
4242

4343
# find all needed rows
4444
for time, value in rows:
45+
if value is None:
46+
continue
4547
if time == current_time:
4648
t.value = value
4749
if time == basis_time:

0 commit comments

Comments
 (0)