Skip to content

API server code health pass - rework parse* & extract* methods #1062

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions src/server/_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from ._exceptions import ValidationFailedException
from .utils import days_in_range, weeks_in_range, guess_time_value_is_day, guess_time_value_is_week, IntRange, TimeValues, days_to_ranges, weeks_to_ranges
from ._validate import require_any, require_all


def _parse_common_multi_arg(key: str) -> List[Tuple[str, Union[bool, Sequence[str]]]]:
Expand Down Expand Up @@ -308,3 +309,177 @@ def parse_day_or_week_range_arg(key: str) -> TimeSet:
if is_week:
return TimeSet("week", [parse_week_range_arg(key)])
return TimeSet("day", [parse_day_range_arg(key)])


def _extract_value(key: Union[str, Sequence[str]]) -> Optional[str]:
if isinstance(key, str):
return request.values.get(key)
for k in key:
if k in request.values:
return request.values[k]
return None


def _extract_list_value(key: Union[str, Sequence[str]]) -> List[str]:
if isinstance(key, str):
return request.values.getlist(key)
for k in key:
if k in request.values:
return request.values.getlist(k)
return []


def extract_strings(key: Union[str, Sequence[str]]) -> Optional[List[str]]:
s = _extract_list_value(key)
if not s:
# nothing to do
return None
# we can have multiple values
return [v for vs in s for v in vs.split(",")]


def extract_integer(key: Union[str, Sequence[str]]) -> Optional[int]:
s = _extract_value(key)
if not s:
# nothing to do
return None
try:
return int(s)
except ValueError:
raise ValidationFailedException(f"{key}: not a number: {s}")


def extract_integers(key: Union[str, Sequence[str]]) -> Optional[List[IntRange]]:
parts = extract_strings(key)
if not parts:
# nothing to do
return None

def _parse_range(part: str):
if "-" not in part:
return int(part)
r = part.split("-", 2)
first = int(r[0])
last = int(r[1])
if first == last:
# the first and last numbers are the same, just treat it as a singe value
return first
elif last > first:
# add the range as an array
return (first, last)
# the range is inverted, this is an error
raise ValidationFailedException(f"{key}: the given range is inverted")

try:
values = [_parse_range(part) for part in parts]
# check for invalid values
return None if any(v is None for v in values) else values
except ValueError as e:
raise ValidationFailedException(f"{key}: not a number: {str(e)}")


def parse_date(s: str) -> int:
# parses a given string in format YYYYMMDD or YYYY-MM-DD to a number in the form YYYYMMDD
try:
return int(s.replace("-", ""))
except ValueError:
raise ValidationFailedException(f"not a valid date: {s}")


def extract_date(key: Union[str, Sequence[str]]) -> Optional[int]:
s = _extract_value(key)
if not s:
return None
return parse_date(s)


def extract_dates(key: Union[str, Sequence[str]]) -> Optional[TimeValues]:
parts = extract_strings(key)
if not parts:
return None
values: TimeValues = []

def push_range(first: str, last: str):
first_d = parse_date(first)
last_d = parse_date(last)
if first_d == last_d:
# the first and last numbers are the same, just treat it as a singe value
return first_d
if last_d > first_d:
# add the range as an array
return (first_d, last_d)
# the range is inverted, this is an error
raise ValidationFailedException(f"{key}: the given range is inverted")

for part in parts:
if "-" not in part and ":" not in part:
# YYYYMMDD
values.append(parse_date(part))
continue
if ":" in part:
# YYYY-MM-DD:YYYY-MM-DD
range_part = part.split(":", 2)
r = push_range(range_part[0], range_part[1])
if r is None:
return None
values.append(r)
continue
# YYYY-MM-DD or YYYYMMDD-YYYYMMDD
# split on the dash
range_part = part.split("-")
if len(range_part) == 2:
# YYYYMMDD-YYYYMMDD
r = push_range(range_part[0], range_part[1])
if r is None:
return None
values.append(r)
continue
# YYYY-MM-DD
values.append(parse_date(part))
# success, return the list
return values

def parse_source_signal_sets() -> List[SourceSignalSet]:
ds = request.values.get("data_source")
if ds:
# old version
require_any("signal", "signals", empty=True)
signals = extract_strings(("signals", "signal"))
if len(signals) == 1 and signals[0] == "*":
return [SourceSignalSet(ds, True)]
return [SourceSignalSet(ds, signals)]

if ":" not in request.values.get("signal", ""):
raise ValidationFailedException("missing parameter: signal or (data_source and signal[s])")

return parse_source_signal_arg()


def parse_geo_sets() -> List[GeoSet]:
geo_type = request.values.get("geo_type")
if geo_type:
# old version
require_any("geo_value", "geo_values", empty=True)
geo_values = extract_strings(("geo_values", "geo_value"))
if len(geo_values) == 1 and geo_values[0] == "*":
return [GeoSet(geo_type, True)]
return [GeoSet(geo_type, geo_values)]

if ":" not in request.values.get("geo", ""):
raise ValidationFailedException("missing parameter: geo or (geo_type and geo_value[s])")

return parse_geo_arg()


def parse_time_set() -> TimeSet:
time_type = request.values.get("time_type")
if time_type:
# old version
require_all("time_type", "time_values")
time_values = extract_dates("time_values")
return TimeSet(time_type, time_values)

if ":" not in request.values.get("time", ""):
raise ValidationFailedException("missing parameter: time or (time_type and time_values)")

return parse_time_arg()
3 changes: 1 addition & 2 deletions src/server/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
from ._common import db
from ._printer import create_printer, APrinter
from ._exceptions import DatabaseErrorException
from ._validate import extract_strings
from ._params import GeoSet, SourceSignalSet, TimeSet
from ._params import extract_strings, GeoSet, SourceSignalSet, TimeSet
from .utils import time_values_to_ranges, IntRange, TimeValues


Expand Down
129 changes: 0 additions & 129 deletions src/server/_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,132 +55,3 @@ def require_any(*values: str, empty=False) -> bool:
if request.values.get(value) or (empty and value in request.values):
return True
raise ValidationFailedException(f"missing parameter: need one of [{', '.join(values)}]")


def _extract_value(key: Union[str, Sequence[str]]) -> Optional[str]:
if isinstance(key, str):
return request.values.get(key)
for k in key:
if k in request.values:
return request.values[k]
return None


def _extract_list_value(key: Union[str, Sequence[str]]) -> List[str]:
if isinstance(key, str):
return request.values.getlist(key)
for k in key:
if k in request.values:
return request.values.getlist(k)
return []


def extract_strings(key: Union[str, Sequence[str]]) -> Optional[List[str]]:
s = _extract_list_value(key)
if not s:
# nothing to do
return None
# we can have multiple values
return [v for vs in s for v in vs.split(",")]


def extract_integer(key: Union[str, Sequence[str]]) -> Optional[int]:
s = _extract_value(key)
if not s:
# nothing to do
return None
try:
return int(s)
except ValueError:
raise ValidationFailedException(f"{key}: not a number: {s}")


def extract_integers(key: Union[str, Sequence[str]]) -> Optional[List[IntRange]]:
parts = extract_strings(key)
if not parts:
# nothing to do
return None

def _parse_range(part: str):
if "-" not in part:
return int(part)
r = part.split("-", 2)
first = int(r[0])
last = int(r[1])
if first == last:
# the first and last numbers are the same, just treat it as a singe value
return first
elif last > first:
# add the range as an array
return (first, last)
# the range is inverted, this is an error
raise ValidationFailedException(f"{key}: the given range is inverted")

try:
values = [_parse_range(part) for part in parts]
# check for invalid values
return None if any(v is None for v in values) else values
except ValueError as e:
raise ValidationFailedException(f"{key}: not a number: {str(e)}")


def parse_date(s: str) -> int:
# parses a given string in format YYYYMMDD or YYYY-MM-DD to a number in the form YYYYMMDD
try:
return int(s.replace("-", ""))
except ValueError:
raise ValidationFailedException(f"not a valid date: {s}")


def extract_date(key: Union[str, Sequence[str]]) -> Optional[int]:
s = _extract_value(key)
if not s:
return None
return parse_date(s)


def extract_dates(key: Union[str, Sequence[str]]) -> Optional[TimeValues]:
parts = extract_strings(key)
if not parts:
return None
values: TimeValues = []

def push_range(first: str, last: str):
first_d = parse_date(first)
last_d = parse_date(last)
if first_d == last_d:
# the first and last numbers are the same, just treat it as a singe value
return first_d
if last_d > first_d:
# add the range as an array
return (first_d, last_d)
# the range is inverted, this is an error
raise ValidationFailedException(f"{key}: the given range is inverted")

for part in parts:
if "-" not in part and ":" not in part:
# YYYYMMDD
values.append(parse_date(part))
continue
if ":" in part:
# YYYY-MM-DD:YYYY-MM-DD
range_part = part.split(":", 2)
r = push_range(range_part[0], range_part[1])
if r is None:
return None
values.append(r)
continue
# YYYY-MM-DD or YYYYMMDD-YYYYMMDD
# split on the dash
range_part = part.split("-")
if len(range_part) == 2:
# YYYYMMDD-YYYYMMDD
r = push_range(range_part[0], range_part[1])
if r is None:
return None
values.append(r)
continue
# YYYY-MM-DD
values.append(parse_date(part))
# success, return the list
return values
3 changes: 2 additions & 1 deletion src/server/endpoints/afhsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from flask import Blueprint

from .._config import AUTH
from .._params import extract_integers, extract_strings
from .._query import execute_queries, filter_integers, filter_strings
from .._validate import check_auth_token, extract_integers, extract_strings, require_all
from .._validate import check_auth_token, require_all

# first argument is the endpoint name
bp = Blueprint("afhsb", __name__)
Expand Down
3 changes: 2 additions & 1 deletion src/server/endpoints/cdc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from flask import Blueprint

from .._config import AUTH, NATION_REGION, REGION_TO_STATE
from .._validate import require_all, extract_strings, extract_integers, check_auth_token
from .._params import extract_strings, extract_integers
from .._query import filter_strings, execute_queries, filter_integers
from .._validate import require_all, check_auth_token

# first argument is the endpoint name
bp = Blueprint("cdc", __name__)
Expand Down
3 changes: 2 additions & 1 deletion src/server/endpoints/covid_hosp_facility.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from flask import Blueprint

from .._params import extract_integers, extract_strings
from .._query import execute_query, QueryBuilder
from .._validate import extract_integers, extract_strings, require_all
from .._validate import require_all

# first argument is the endpoint name
bp = Blueprint("covid_hosp_facility", __name__)
Expand Down
3 changes: 2 additions & 1 deletion src/server/endpoints/covid_hosp_facility_lookup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from flask import Blueprint

from .._params import extract_strings
from .._query import execute_query, QueryBuilder
from .._validate import extract_strings, require_any
from .._validate import require_any

# first argument is the endpoint name
bp = Blueprint("covid_hosp_facility_lookup", __name__)
Expand Down
3 changes: 2 additions & 1 deletion src/server/endpoints/covid_hosp_state_timeseries.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from flask import Blueprint

from .._params import extract_integers, extract_strings, extract_date
from .._query import execute_query, QueryBuilder
from .._validate import extract_integers, extract_strings, extract_date, require_all
from .._validate import require_all

# first argument is the endpoint name
bp = Blueprint("covid_hosp_state_timeseries", __name__)
Expand Down
Loading