diff --git a/integrations/server/test_covidcast.py b/integrations/server/test_covidcast.py index 01d81bf29..e96488fb6 100644 --- a/integrations/server/test_covidcast.py +++ b/integrations/server/test_covidcast.py @@ -81,6 +81,22 @@ def _insert_placeholder_set_five(self): ] self._insert_rows(rows) return rows + + def _insert_placeholder_set_six(self): + rows = [ + CovidcastTestRow.make_default_row(time_value=2000_01_01+i, value=i*1., stderr=i*10., sample_size=i*100., issue=2000_01_03) + for i in [1, 2, 3, 4, 5, 6] + ] + self._insert_rows(rows) + return rows + + def _insert_placeholder_set_seven(self): + rows = [ + CovidcastTestRow.make_default_row(time_value=2000_01_01, value=i*1., stderr=i*10., sample_size=i*100., issue=2000_01_01+i) + for i in [1, 2, 3, 4, 5, 6] + ] + self._insert_rows(rows) + return rows def test_round_trip(self): """Make a simple round-trip with some sample data.""" @@ -167,7 +183,7 @@ def test_csv_format(self): .assign(direction = None) .to_csv(columns=column_order, index=False) ) - + # assert that the right data came back self.assertEqual(response, expected) @@ -231,7 +247,6 @@ def test_location_wildcard(self): # make the request response = self.request_based_on_row(rows[0], geo_value="*") - self.maxDiff = None # assert that the right data came back self.assertEqual(response, { 'result': 1, @@ -249,7 +264,6 @@ def test_time_values_wildcard(self): # make the request response = self.request_based_on_row(rows[0], time_values="*") - self.maxDiff = None # assert that the right data came back self.assertEqual(response, { 'result': 1, @@ -257,6 +271,67 @@ def test_time_values_wildcard(self): 'message': 'success', }) + def helper_test_inequality(self, rows, field): + expected = [row.as_api_compatibility_row_dict() for row in rows] + + def fetch(datecode): + # make the request + response = self.request_based_on_row(rows[0], **{field: datecode}) + return response + + # test fetch time_value with < + r = fetch('<20000104') + self.assertEqual(r['message'], 'success') + self.assertEqual(r['epidata'], expected[:2]) + # test fetch time_value with <= + r = fetch('<=20000104') + self.assertEqual(r['message'], 'success') + self.assertEqual(r['epidata'], expected[:3]) + # test fetch time_value with > + r = fetch('>20000104') + self.assertEqual(r['message'], 'success') + self.assertEqual(r['epidata'], expected[3:]) + # test fetch time_value with >= + r = fetch('>=20000104') + self.assertEqual(r['message'], 'success') + self.assertEqual(r['epidata'], expected[2:]) + # test fetch multiple inequalities + r = fetch('<20000104,>20000104') + self.assertEqual(r['message'], 'success') + self.assertEqual(r['epidata'], expected[:2] + expected[3:]) + # test overlapped inequalities, pick the more extreme one + r = fetch('<20000104,<20000105') + self.assertEqual(r['message'], 'success') + self.assertEqual(r['epidata'], expected[:3]) + # test fetch inequalities that has no results + r = fetch('>20000107') + self.assertEqual(r['message'], 'no results') + # test fetch empty value + r = fetch('') + if field == 'time_values': + self.assertEqual(r['message'], 'missing parameter: need [time_type, time_values]') + else: + self.assertEqual(r['message'], 'not a valid date: (empty)') + # test fetch invalid time_value + r = fetch('>') + self.assertEqual(r['message'], 'missing parameter: date after the inequality operator') + # test if extra operators provided + r = fetch('>>') + self.assertEqual(r['message'], 'not a valid date: >') + r = fetch('>>20000103') + self.assertEqual(r['message'], 'not a valid date: >20000103') + # test invalid operator + r = fetch('#') + self.assertEqual(r['message'], 'not a valid date: #') + + def test_time_values_inequality(self): + rows = self._insert_placeholder_set_six() + self.helper_test_inequality(rows, "time_values") + + def test_issues_inequality(self): + rows = self._insert_placeholder_set_seven() + self.helper_test_inequality(rows, "issues") + def test_issues_wildcard(self): """Select all issues with a wildcard query.""" @@ -267,7 +342,6 @@ def test_issues_wildcard(self): # make the request response = self.request_based_on_row(rows[0], issues="*") - self.maxDiff = None # assert that the right data came back self.assertEqual(response, { 'result': 1, @@ -285,7 +359,6 @@ def test_signal_wildcard(self): # make the request response = self.request_based_on_row(rows[0], signals="*") - self.maxDiff = None # assert that the right data came back self.assertEqual(response, { 'result': 1, diff --git a/src/server/_params.py b/src/server/_params.py index 39c25ce1e..02defab70 100644 --- a/src/server/_params.py +++ b/src/server/_params.py @@ -403,11 +403,12 @@ def _parse_range(part: str): def parse_date(s: str) -> int: # parses a given string in format YYYYMMDD or YYYY-MM-DD to a number in the form YYYYMMDD + if s == "*": + return s + if not s: + raise ValidationFailedException("not a valid date: (empty)") try: - if s == "*": - return s - else: - return int(s.replace("-", "")) + return int(s.replace("-", "")) except ValueError: raise ValidationFailedException(f"not a valid date: {s}") @@ -425,6 +426,17 @@ def extract_dates(key: Union[str, Sequence[str]]) -> Optional[TimeValues]: return None values: TimeValues = [] + def handle_inequality(part: str): + inequality_operator = None + for operator in ['<=', '>=', '<', '>']: + if part.startswith(operator): + inequality_operator = operator + part = part[len(operator):] + if not part: + raise ValidationFailedException("missing parameter: date after the inequality operator") + return (inequality_operator,), part + return None, part + def push_range(first: str, last: str): first_d = parse_date(first) last_d = parse_date(last) @@ -438,6 +450,12 @@ def push_range(first: str, last: str): raise ValidationFailedException(f"{key}: the given range is inverted") for part in parts: + # Handle inequality operators + inequality_operator, part = handle_inequality(part) + # >/=/<= YYYYMMDD/YYYY-MM-DD + if inequality_operator is not None: + values.append((inequality_operator, parse_date(part))) + continue if "-" not in part and ":" not in part: # YYYYMMDD values.append(parse_date(part)) diff --git a/src/server/_query.py b/src/server/_query.py index 267a78eb1..0fa1f6ff3 100644 --- a/src/server/_query.py +++ b/src/server/_query.py @@ -41,6 +41,11 @@ def to_condition( formatter=lambda x: x, ) -> str: if isinstance(value, (list, tuple)): + # Check if the first element is a tuple with an inequality operator + if isinstance(value[0], tuple): + inequality_operator, date_value = value[0][0], value[1] + params[param_key] = formatter(date_value) + return f"{field} {inequality_operator} :{param_key}" params[param_key] = formatter(value[0]) params[f"{param_key}_2"] = formatter(value[1]) return f"{field} BETWEEN :{param_key} AND :{param_key}_2" @@ -476,10 +481,8 @@ def apply_lag_filter(self, history_table: str, lag: Optional[int]) -> "QueryBuil def apply_issues_filter(self, history_table: str, issues: Optional[TimeValues]) -> "QueryBuilder": if issues: - if issues == ["*"]: - self.retable(history_table) - else: - self.retable(history_table) + self.retable(history_table) + if issues != ["*"]: self.where_integers("issue", issues) return self diff --git a/src/server/utils/dates.py b/src/server/utils/dates.py index e810e2146..4f1195473 100644 --- a/src/server/utils/dates.py +++ b/src/server/utils/dates.py @@ -116,16 +116,18 @@ def weeks_to_ranges(values: TimeValues) -> TimeValues: def _to_ranges(values: TimeValues, value_to_date: Callable, date_to_value: Callable, time_unit: Union[int, timedelta]) -> TimeValues: try: intervals = [] - # populate list of intervals based on original date/week values for v in values: if isinstance(v, int): # 20200101 -> [20200101, 20200101] intervals.append([value_to_date(v), value_to_date(v)]) else: # tuple + if isinstance(v[0], tuple): # inequality operator + intervals.append([value_to_date(v[1])]) + else: # (20200101, 20200102) -> [20200101, 20200102] - intervals.append([value_to_date(v[0]), value_to_date(v[1])]) - + intervals.append([value_to_date(v[0]), value_to_date(v[1])]) + intervals.sort() # merge overlapping intervals https://leetcode.com/problems/merge-intervals/ @@ -142,10 +144,14 @@ def _to_ranges(values: TimeValues, value_to_date: Callable, date_to_value: Calla # convert intervals from dates/weeks back to integers ranges = [] for m in merged: - if m[0] == m[1]: - ranges.append(date_to_value(m[0])) - else: - ranges.append((date_to_value(m[0]), date_to_value(m[1]))) + # if inequality operator, length of m is only 1 + if len(m) == 1: + ranges.append((v[0], date_to_value(m[0]))) + else: + if m[0] == m[1]: + ranges.append(date_to_value(m[0])) + else: + ranges.append((date_to_value(m[0]), date_to_value(m[1]))) get_structured_logger('server_utils').info("Optimized list of date values", original=values, optimized=ranges, original_length=len(values), optimized_length=len(ranges))