Skip to content

Commit 136792d

Browse files
committed
feat: work on week support in covidcast
1 parent 4c6e72d commit 136792d

File tree

6 files changed

+97
-26
lines changed

6 files changed

+97
-26
lines changed

src/server/_params.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,15 @@ class TimePair:
109109
time_type: str
110110
time_values: Union[bool, Sequence[Union[int, Tuple[int, int]]]]
111111

112+
@property
113+
def is_week(self) -> bool:
114+
return self.time_type == 'week'
115+
116+
@property
117+
def is_day(self) -> bool:
118+
return self.time_type != 'week'
119+
120+
112121
def count(self) -> float:
113122
"""
114123
returns the count of items in this pair
@@ -225,3 +234,43 @@ def parse_day_arg(key: str) -> int:
225234
if not isinstance(r, int):
226235
raise ValidationFailedException(f"{key} must match YYYYMMDD or YYYY-MM-DD")
227236
return r
237+
238+
def parse_week_arg(key: str) -> int:
239+
v = request.values.get(key)
240+
if not v:
241+
raise ValidationFailedException(f"{key} param is required")
242+
r = parse_week_value(v)
243+
if not isinstance(r, int):
244+
raise ValidationFailedException(f"{key} must match YYYYWW")
245+
return r
246+
247+
248+
def parse_week_range_arg(key: str) -> Tuple[int, int]:
249+
v = request.values.get(key)
250+
if not v:
251+
raise ValidationFailedException(f"{key} param is required")
252+
r = parse_week_value(v)
253+
if not isinstance(r, tuple):
254+
raise ValidationFailedException(f"{key} must match YYYYWW-YYYYWW")
255+
return r
256+
257+
def parse_day_or_week_arg(key: str) -> Tuple[int, bool]:
258+
v = request.values.get(key)
259+
if not v:
260+
raise ValidationFailedException(f"{key} param is required")
261+
# format is either YYYY-MM-DD or YYYYMMDD or YYYYMM
262+
is_week = len(v) == 6
263+
if is_week:
264+
return parse_week_arg(key), False
265+
return parse_day_arg(key), True
266+
267+
def parse_day_or_week_range_arg(key: str) -> Tuple[Tuple[int, int], bool]:
268+
v = request.values.get(key)
269+
if not v:
270+
raise ValidationFailedException(f"{key} param is required")
271+
# format is either YYYY-MM-DD--YYYY-MM-DD or YYYYMMDD-YYYYMMDD or YYYYMM-YYYYMM
272+
# so if the first before the - has length 6, it must be a week
273+
is_week = len(v.split('-', 2)[0]) == 6
274+
if is_week:
275+
return parse_week_range_arg(key), False
276+
return parse_day_range_arg(key), True

src/server/endpoints/covidcast.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from flask.json import loads, jsonify
66
from bisect import bisect_right
77
from sqlalchemy import text
8-
from pandas import read_csv
8+
from pandas import read_csv, to_datetime
99

1010
from .._common import is_compatibility_mode, db
1111
from .._exceptions import ValidationFailedException, DatabaseErrorException
@@ -16,8 +16,9 @@
1616
parse_geo_arg,
1717
parse_source_signal_arg,
1818
parse_time_arg,
19-
parse_day_arg,
19+
parse_day_or_week_arg,
2020
parse_day_range_arg,
21+
parse_day_or_week_range_arg,
2122
parse_single_source_signal_arg,
2223
parse_single_time_arg,
2324
parse_single_geo_arg,
@@ -34,7 +35,7 @@
3435
)
3536
from .._pandas import as_pandas, print_pandas
3637
from .covidcast_utils import compute_trend, compute_trends, compute_correlations, compute_trend_value, CovidcastMetaEntry
37-
from ..utils import shift_time_value, date_to_time_value, time_value_to_iso, time_value_to_date
38+
from ..utils import shift_time_value, date_to_time_value, time_value_to_iso, time_value_to_date, shift_week_value, week_value_to_week
3839
from .covidcast_utils.model import TimeType, data_sources, create_source_signal_alias_mapper
3940

4041
# first argument is the endpoint name
@@ -172,15 +173,18 @@ def transform_row(row, proxy):
172173

173174
@bp.route("/trend", methods=("GET", "POST"))
174175
def handle_trend():
175-
require_all("date", "window")
176+
require_all("window", "date")
176177
source_signal_pairs = parse_source_signal_pairs()
177178
source_signal_pairs, alias_mapper = create_source_signal_alias_mapper(source_signal_pairs)
178-
# TODO alias
179179
geo_pairs = parse_geo_pairs()
180180

181-
time_value = parse_day_arg("date")
182-
time_window = parse_day_range_arg("window")
183-
basis_time_value = extract_date("basis") or shift_time_value(time_value, -7)
181+
time_window, is_day = parse_day_or_week_range_arg("window")
182+
time_value, is_also_day = parse_day_or_week_arg("date")
183+
if is_day != is_also_day:
184+
raise ValidationFailedException('mixing weeks with day arguments')
185+
basis_time_value = extract_date("basis")
186+
if basis_time_value is None:
187+
basis_time_value = (shift_time_value(time_value, -7) if is_day else shift_week_value(time_value, -7))
184188

185189
# build query
186190
q = QueryBuilder("covidcast", "t")
@@ -193,7 +197,7 @@ def handle_trend():
193197

194198
q.where_source_signal_pairs("source", "signal", source_signal_pairs)
195199
q.where_geo_pairs("geo_type", "geo_value", geo_pairs)
196-
q.where_time_pairs("time_type", "time_value", [TimePair("day", [time_window])])
200+
q.where_time_pairs("time_type", "time_value", [TimePair("day" if is_day else "week", [time_window])])
197201

198202
# fetch most recent issue fast
199203
_handle_lag_issues_as_of(q, None, None, None)
@@ -225,7 +229,7 @@ def handle_trendseries():
225229
source_signal_pairs, alias_mapper = create_source_signal_alias_mapper(source_signal_pairs)
226230
geo_pairs = parse_geo_pairs()
227231

228-
time_window = parse_day_range_arg("window")
232+
time_window, is_day = parse_day_or_week_range_arg("window")
229233
basis_shift = extract_integer("basis")
230234
if basis_shift is None:
231235
basis_shift = 7
@@ -241,14 +245,16 @@ def handle_trendseries():
241245

242246
q.where_source_signal_pairs("source", "signal", source_signal_pairs)
243247
q.where_geo_pairs("geo_type", "geo_value", geo_pairs)
244-
q.where_time_pairs("time_type", "time_value", [TimePair("day", [time_window])])
248+
q.where_time_pairs("time_type", "time_value", [TimePair("day" if is_day else 'week', [time_window])])
245249

246250
# fetch most recent issue fast
247251
_handle_lag_issues_as_of(q, None, None, None)
248252

249253
p = create_printer()
250254

251255
shifter = lambda x: shift_time_value(x, -basis_shift)
256+
if not is_day:
257+
shifter = lambda x: shift_week_value(x, -basis_shift)
252258

253259
def gen(rows):
254260
for key, group in groupby((parse_row(row, fields_string, fields_int, fields_float) for row in rows), lambda row: (row["geo_type"], row["geo_value"], row["source"], row["signal"])):
@@ -276,7 +282,8 @@ def handle_correlation():
276282
other_pairs = parse_source_signal_arg("others")
277283
source_signal_pairs, alias_mapper = create_source_signal_alias_mapper(other_pairs + [reference])
278284
geo_pairs = parse_geo_arg()
279-
time_window = parse_day_range_arg("window")
285+
time_window, is_day = parse_day_or_week_range_arg("window")
286+
280287
lag = extract_integer("lag")
281288
if lag is None:
282289
lag = 28
@@ -296,12 +303,17 @@ def handle_correlation():
296303
source_signal_pairs,
297304
)
298305
q.where_geo_pairs("geo_type", "geo_value", geo_pairs)
299-
q.where_time_pairs("time_type", "time_value", [TimePair("day", [time_window])])
306+
q.where_time_pairs("time_type", "time_value", [TimePair("day" if is_day else "week", [time_window])])
300307

301308
# fetch most recent issue fast
302309
q.conditions.append(f"({q.alias}.is_latest_issue IS TRUE)")
303310

304-
df = as_pandas(str(q), q.params, parse_dates={"time_value": "%Y%m%d"})
311+
df = as_pandas(str(q), q.params)
312+
if is_day:
313+
df['time_value'] = to_datetime(df['time_value'], format="%Y%m%d")
314+
else:
315+
# week but convert to date for simpler shifting
316+
df['time_value'] = to_datetime(df['time_value'].apply(lambda v: week_value_to_week(v).startdate()))
305317

306318
p = create_printer()
307319

@@ -329,7 +341,7 @@ def gen():
329341
for (source, signal), other_group in other_groups:
330342
if alias_mapper:
331343
source = alias_mapper(source, signal)
332-
for cor in compute_correlations(geo_type, geo_value, source, signal, lag, reference_group, other_group):
344+
for cor in compute_correlations(geo_type, geo_value, source, signal, lag, reference_group, other_group, is_day):
333345
yield cor.asdict()
334346

335347
# now use a generator for sending the rows and execute all the other queries
@@ -345,6 +357,8 @@ def handle_export():
345357
geo_type = request.args.get("geo_type", "county")
346358
geo_values = request.args.get("geo_values", "*")
347359

360+
# TODO weekly signals
361+
348362
if geo_values != "*":
349363
geo_values = geo_values.split(",")
350364

@@ -424,8 +438,9 @@ def handle_backfill():
424438
# don't need the alias mapper since we don't return the source
425439

426440
time_pair = parse_single_time_arg("time")
441+
is_day = time_pair.is_day
427442
geo_pair = parse_single_geo_arg("geo")
428-
reference_anchor_lag = extract_integer("anchor_lag") # in days
443+
reference_anchor_lag = extract_integer("anchor_lag") # in days or weeks
429444
if reference_anchor_lag is None:
430445
reference_anchor_lag = 60
431446

@@ -461,7 +476,8 @@ def gen(rows):
461476
for time_value, group in groupby((parse_row(row, fields_string, fields_int, fields_float) for row in rows), lambda row: row["time_value"]):
462477
# compute data per time value
463478
issues: List[Dict[str, Any]] = [r for r in group]
464-
anchor_row = find_anchor_row(issues, shift_time_value(time_value, reference_anchor_lag))
479+
shifted_time_value = shift_time_value(time_value, reference_anchor_lag) if is_day else shift_week_value(time_value, reference_anchor_lag)
480+
anchor_row = find_anchor_row(issues, shifted_time_value)
465481

466482
for i, row in enumerate(issues):
467483
if i > 0:
@@ -576,8 +592,9 @@ def handle_coverage():
576592
source_signal_pairs, alias_mapper = create_source_signal_alias_mapper(source_signal_pairs)
577593
geo_type = request.args.get("geo_type", "county")
578594
if "window" in request.values:
579-
time_window = parse_day_range_arg("window")
595+
time_window, is_day = parse_day_or_week_range_arg("window")
580596
else:
597+
is_day = False # TODO
581598
now_time = extract_date("latest")
582599
now = date.today() if now_time is None else time_value_to_date(now_time)
583600
last = extract_integer("days")
@@ -601,7 +618,7 @@ def handle_coverage():
601618
else:
602619
q.where(geo_type=geo_type)
603620
q.where_source_signal_pairs("source", "signal", source_signal_pairs)
604-
q.where_time_pairs("time_type", "time_value", [TimePair("day", [time_window])])
621+
q.where_time_pairs("time_type", "time_value", [TimePair("day" if is_day else 'week', [time_window])])
605622
q.group_by = "c.source, c.signal, c.time_value"
606623
q.set_order("source", "signal", "time_value")
607624

src/server/endpoints/covidcast_utils/correlation.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class Correlation:
4949
"""
5050

5151

52-
def lag_join(lag: int, x: pd.DataFrame, y: pd.DataFrame) -> pd.DataFrame:
52+
def lag_join(lag: int, x: pd.DataFrame, y: pd.DataFrame, is_day = True) -> pd.DataFrame:
5353
# x_t_i ~ y_t_(i-lag)
5454
# aka x_t_(i+lag) ~ y_t_i
5555

@@ -60,24 +60,24 @@ def lag_join(lag: int, x: pd.DataFrame, y: pd.DataFrame) -> pd.DataFrame:
6060
# x_t_i ~ y_shifted_t_i
6161
# shift y such that y_t(i - lag) -> y_shifted_t_i
6262
x_shifted = x
63-
y_shifted = y.shift(lag, freq="D")
63+
y_shifted = y.shift(lag, freq="D" if is_day else 'W')
6464
else: # lag < 0
6565
# x_shifted_t_i ~ y_t_i
6666
# shift x such that x_t(i+lag) -> x_shifted_t_i
6767
# lag < 0 -> - - lag = + lag
68-
x_shifted = x.shift(-lag, freq="D")
68+
x_shifted = x.shift(-lag, freq="D" if is_day else 'W')
6969
y_shifted = y
7070
# inner join to remove invalid pairs
7171
r = x_shifted.join(y_shifted, how="inner", lsuffix="_x", rsuffix="_y")
7272
return r.rename(columns=dict(value_x="x", value_y="y"))
7373

7474

75-
def compute_correlations(geo_type: str, geo_value: str, signal_source: str, signal_signal: str, lag: int, x: pd.DataFrame, y: pd.DataFrame) -> Iterable[CorrelationResult]:
75+
def compute_correlations(geo_type: str, geo_value: str, signal_source: str, signal_signal: str, lag: int, x: pd.DataFrame, y: pd.DataFrame, is_day = True) -> Iterable[CorrelationResult]:
7676
"""
7777
x,y ... DataFrame with "time_value" (Date) index and "value" (float) column
7878
"""
7979
for current_lag in range(-lag, lag + 1):
80-
xy = lag_join(current_lag, x, y)
80+
xy = lag_join(current_lag, x, y, is_day)
8181
c = compute_correlation(xy)
8282

8383
yield CorrelationResult(geo_type, geo_value, signal_source, signal_signal, current_lag, r2=c.r2, intercept=c.intercept, slope=c.slope, samples=c.samples)

src/server/endpoints/covidcast_utils/trend.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import Optional, Iterable, Tuple, Dict, List, Callable
33
from enum import Enum
44
from collections import OrderedDict
5-
from ...utils import shift_time_value
65

76

87
class TrendEnum(str, Enum):

src/server/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .dates import shift_time_value, date_to_time_value, time_value_to_iso, time_value_to_date, days_in_range, weeks_in_range
1+
from .dates import shift_time_value, date_to_time_value, time_value_to_iso, time_value_to_date, days_in_range, weeks_in_range, shift_week_value, week_to_time_value, week_value_to_week

src/server/utils/dates.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ def shift_time_value(time_value: int, days: int) -> int:
3737
shifted = d + timedelta(days=days)
3838
return date_to_time_value(shifted)
3939

40+
def shift_week_value(week_value: int, weeks: int) -> int:
41+
if weeks == 0:
42+
return week_value
43+
week = week_value_to_week(week_value)
44+
shifted = week + weeks
45+
return week_to_time_value(shifted)
4046

4147
def days_in_range(range: Tuple[int, int]) -> int:
4248
"""

0 commit comments

Comments
 (0)