Skip to content

Commit 727c025

Browse files
committed
Address review comments
1 parent de30579 commit 727c025

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

src/server/_params.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
from ._exceptions import ValidationFailedException
10-
from .utils import days_in_range, weeks_in_range, guess_time_value_is_day, guess_time_value_is_week, TimeValues, days_to_ranges, weeks_to_ranges
10+
from .utils import days_in_range, weeks_in_range, guess_time_value_is_day, guess_time_value_is_week, IntRange, TimeValues, days_to_ranges, weeks_to_ranges
1111

1212

1313
def _parse_common_multi_arg(key: str) -> List[Tuple[str, Union[bool, Sequence[str]]]]:
@@ -140,7 +140,7 @@ def to_ranges(self):
140140
return TimePair(self.time_type, days_to_ranges(self.time_values))
141141

142142

143-
def _verify_range(start: int, end: int) -> Union[int, Tuple[int, int]]:
143+
def _verify_range(start: int, end: int) -> IntRange:
144144
if start == end:
145145
# the first and last numbers are the same, just treat it as a singe value
146146
return start
@@ -151,7 +151,7 @@ def _verify_range(start: int, end: int) -> Union[int, Tuple[int, int]]:
151151
raise ValidationFailedException(f"the given range {start}-{end} is inverted")
152152

153153

154-
def parse_week_value(time_value: str) -> Union[int, Tuple[int, int]]:
154+
def parse_week_value(time_value: str) -> IntRange:
155155
count_dashes = time_value.count("-")
156156
msg = f"{time_value} does not match a known format YYYYWW or YYYYWW-YYYYWW"
157157

@@ -171,7 +171,7 @@ def parse_week_value(time_value: str) -> Union[int, Tuple[int, int]]:
171171
raise ValidationFailedException(msg)
172172

173173

174-
def parse_day_value(time_value: str) -> Union[int, Tuple[int, int]]:
174+
def parse_day_value(time_value: str) -> IntRange:
175175
count_dashes = time_value.count("-")
176176
msg = f"{time_value} does not match a known format YYYYMMDD, YYYY-MM-DD, YYYYMMDD-YYYYMMDD, or YYYY-MM-DD--YYYY-MM-DD"
177177

src/server/_query.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def date_string(value: int) -> str:
3434

3535
def to_condition(
3636
field: str,
37-
value: Union[str, Tuple[int, int], int],
37+
value: Union[str, IntRange],
3838
param_key: str,
3939
params: Dict[str, Any],
4040
formatter=lambda x: x,
@@ -50,7 +50,7 @@ def to_condition(
5050

5151
def filter_values(
5252
field: str,
53-
values: Optional[Sequence[Union[str, Tuple[int, int], int]]],
53+
values: Optional[Sequence[Union[str, IntRange]]],
5454
param_key: str,
5555
params: Dict[str, Any],
5656
formatter=lambda x: x,
@@ -471,11 +471,13 @@ def apply_lag_filter(self, history_table: str, lag: Optional[int]):
471471
self.retable(history_table)
472472
# history_table has full spectrum of lag values to search from whereas the latest_table does not
473473
self.where(lag=lag)
474+
return self
474475

475476
def apply_issues_filter(self, history_table: str, issues: Optional[TimeValues]):
476477
if issues:
477478
self.retable(history_table)
478479
self.where_integers("issue", issues)
480+
return self
479481

480482
def apply_as_of_filter(self, history_table: str, as_of: Optional[int]):
481483
if as_of is not None:
@@ -484,8 +486,10 @@ def apply_as_of_filter(self, history_table: str, as_of: Optional[int]):
484486
self.params["as_of"] = as_of
485487
sub_fields = "max(issue) max_issue, time_type, time_value, `source`, `signal`, geo_type, geo_value"
486488
sub_group = "time_type, time_value, `source`, `signal`, geo_type, geo_value"
487-
sub_condition = f"x.max_issue = {self.alias}.issue AND x.time_type = {self.alias}.time_type AND x.time_value = {self.alias}.time_value AND x.source = {self.alias}.source AND x.signal = {self.alias}.signal AND x.geo_type = {self.alias}.geo_type AND x.geo_value = {self.alias}.geo_value"
489+
alias = self.alias
490+
sub_condition = f"x.max_issue = {alias}.issue AND x.time_type = {alias}.time_type AND x.time_value = {alias}.time_value AND x.source = {alias}.source AND x.signal = {alias}.signal AND x.geo_type = {alias}.geo_type AND x.geo_value = {alias}.geo_value"
488491
self.subquery = f"JOIN (SELECT {sub_fields} FROM {self.table} WHERE {self.conditions_clause} AND {sub_condition_asof} GROUP BY {sub_group}) x ON {sub_condition}"
492+
return self
489493

490494
def set_fields(self, *fields: Iterable[str]) -> "QueryBuilder":
491495
self.fields = [f"{self.alias}.{field}" for field_list in fields for field in field_list]
@@ -497,6 +501,7 @@ def set_sort_order(self, *args: str):
497501
"""
498502

499503
self.order = [f"{self.alias}.{k} ASC" for k in args]
504+
return self
500505

501506
def with_max_issue(self, *args: str) -> "QueryBuilder":
502507
fields: List[str] = [f for f in args]

0 commit comments

Comments
 (0)