Skip to content

API server code health pass - misc. refactors #1059

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 22, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/server/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
import time

from flask import Flask, g, request
from sqlalchemy import event
from sqlalchemy.engine import Connection
from sqlalchemy import create_engine, event
from sqlalchemy.engine import Connection, Engine
from werkzeug.local import LocalProxy

from .utils.logger import get_structured_logger
from ._config import SECRET
from ._db import engine
from ._config import SECRET, SQLALCHEMY_DATABASE_URI, SQLALCHEMY_ENGINE_OPTIONS
from ._exceptions import DatabaseErrorException, EpiDataException

engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI, **SQLALCHEMY_ENGINE_OPTIONS)

app = Flask("EpiData", static_url_path="")
app.config["SECRET"] = SECRET

Expand Down
26 changes: 0 additions & 26 deletions src/server/_db.py

This file was deleted.

41 changes: 26 additions & 15 deletions src/server/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ._exceptions import DatabaseErrorException
from ._validate import extract_strings
from ._params import GeoPair, SourceSignalPair, TimePair
from .utils import time_values_to_ranges, TimeValues
from .utils import time_values_to_ranges, IntRange, TimeValues


def date_string(value: int) -> str:
Expand Down Expand Up @@ -75,7 +75,7 @@ def filter_strings(

def filter_integers(
field: str,
values: Optional[Sequence[Union[Tuple[int, int], int]]],
values: Optional[Sequence[IntRange]],
param_key: str,
params: Dict[str, Any],
):
Expand Down Expand Up @@ -399,7 +399,7 @@ def _fq_field(self, field: str) -> str:
def where_integers(
self,
field: str,
values: Optional[Sequence[Union[Tuple[int, int], int]]],
values: Optional[Sequence[IntRange]],
param_key: Optional[str] = None,
) -> "QueryBuilder":
fq_field = self._fq_field(field)
Expand Down Expand Up @@ -466,26 +466,37 @@ def where_time_pair(
)
return self

def apply_lag_filter(self, history_table: str, lag: Optional[int]):
if lag is not None:
self.retable(history_table)
# history_table has full spectrum of lag values to search from whereas the latest_table does not
self.where(lag=lag)

def apply_issues_filter(self, history_table: str, issues: Optional[TimeValues]):
if issues:
self.retable(history_table)
self.where_integers("issue", issues)

def apply_as_of_filter(self, history_table: str, as_of: Optional[int]):
if as_of is not None:
self.retable(history_table)
sub_condition_asof = "(issue <= :as_of)"
self.params["as_of"] = as_of
sub_fields = "max(issue) max_issue, time_type, time_value, `source`, `signal`, geo_type, geo_value"
sub_group = "time_type, time_value, `source`, `signal`, geo_type, geo_value"
sub_condition = f"x.max_issue = {self.alias}.issue AND x.time_type = {self.alias}.time_type AND x.time_value = {self.alias}.time_value AND x.source = {self.alias}.source AND x.signal = {self.alias}.signal AND x.geo_type = {self.alias}.geo_type AND x.geo_value = {self.alias}.geo_value"
self.subquery = f"JOIN (SELECT {sub_fields} FROM {self.table} WHERE {self.conditions_clause} AND {sub_condition_asof} GROUP BY {sub_group}) x ON {sub_condition}"

def set_fields(self, *fields: Iterable[str]) -> "QueryBuilder":
self.fields = [f"{self.alias}.{field}" for field_list in fields for field in field_list]
return self

def set_order(self, *args: str, **kwargs: Union[str, bool]) -> "QueryBuilder":
def set_sort_order(self, *args: str):
"""
sets the order for the given fields (as key word arguments), True = ASC, False = DESC
"""

def to_asc(v: Union[str, bool]) -> str:
if v is True:
return "ASC"
elif v is False:
return "DESC"
return cast(str, v)

args_order = [f"{self.alias}.{k} ASC" for k in args]
kw_order = [f"{self.alias}.{k} {to_asc(v)}" for k, v in kwargs.items()]
self.order = args_order + kw_order
return self
self.order = [f"{self.alias}.{k} ASC" for k in args]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.order = [f"{self.alias}.{k} ASC" for k in args]
self.order = [f"{self.alias}.{k} ASC" for k in args]
return self

Copy link
Contributor Author

@rzats rzats Dec 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually one of the refactors I wanted to apply to query.py - the return self is intended to be used for method cascading:

q.apply_issues_filter().apply_lag_filter().apply_as_of_filter()

but it is not used that way anywhere in the codebase, so the statements serve no purpose. Do we still need to keep the return self for some other reason? (Or perhaps start using that style of cascading after all?)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can allow method cascading without requiring it. the rest of the methods of QueryBuilder do return self, and its best to keep that consistent especially in case someone does eventually choose to write something with cascading. its also not bad form to return self on mutator methods from a functional programming perspective.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it! I actually wanted to remove return self from those other methods too but you make a great point about the functional programming angle.


def with_max_issue(self, *args: str) -> "QueryBuilder":
fields: List[str] = [f for f in args]
Expand Down
5 changes: 1 addition & 4 deletions src/server/_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from flask import request

from ._exceptions import UnAuthenticatedException, ValidationFailedException
from .utils import TimeValues
from .utils import IntRange, TimeValues


def resolve_auth_token() -> Optional[str]:
Expand Down Expand Up @@ -84,9 +84,6 @@ def extract_strings(key: Union[str, Sequence[str]]) -> Optional[List[str]]:
return [v for vs in s for v in vs.split(",")]


IntRange = Union[Tuple[int, int], int]


def extract_integer(key: Union[str, Sequence[str]]) -> Optional[int]:
s = _extract_value(key)
if not s:
Expand Down
2 changes: 1 addition & 1 deletion src/server/endpoints/covid_hosp_facility.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def handle():
q.set_fields(fields_string, fields_int, fields_float)

# basic query info
q.set_order("collection_week", "hospital_pk", "publication_date")
q.set_sort_order("collection_week", "hospital_pk", "publication_date")

# build the filter
q.where_integers("collection_week", collection_weeks)
Expand Down
2 changes: 1 addition & 1 deletion src/server/endpoints/covid_hosp_facility_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def handle():
]
)
# basic query info
q.set_order("hospital_pk")
q.set_sort_order("hospital_pk")
# build the filter
# these are all fast because the table has indexes on each of these fields
if state:
Expand Down
2 changes: 1 addition & 1 deletion src/server/endpoints/covid_hosp_state_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def handle():
]

q.set_fields(fields_string, fields_int, fields_float)
q.set_order("date", "state", "issue")
q.set_sort_order("date", "state", "issue")

# build the filter
q.where_integers("date", dates)
Expand Down
55 changes: 12 additions & 43 deletions src/server/endpoints/covidcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,28 +91,6 @@ def parse_time_pairs() -> TimePair:
return parse_time_arg()


def _handle_lag_issues_as_of(q: QueryBuilder, issues: Optional[TimeValues] = None, lag: Optional[int] = None, as_of: Optional[int] = None):
if issues:
q.retable(history_table)
q.where_integers("issue", issues)
elif lag is not None:
q.retable(history_table)
# history_table has full spectrum of lag values to search from whereas the latest_table does not
q.where(lag=lag)
elif as_of is not None:
# fetch the most recent issue as of a certain date (not to be confused w/ plain-old "most recent issue"
q.retable(history_table)
sub_condition_asof = "(issue <= :as_of)"
q.params["as_of"] = as_of
sub_fields = "max(issue) max_issue, time_type, time_value, `source`, `signal`, geo_type, geo_value"
sub_group = "time_type, time_value, `source`, `signal`, geo_type, geo_value"
sub_condition = f"x.max_issue = {q.alias}.issue AND x.time_type = {q.alias}.time_type AND x.time_value = {q.alias}.time_value AND x.source = {q.alias}.source AND x.signal = {q.alias}.signal AND x.geo_type = {q.alias}.geo_type AND x.geo_value = {q.alias}.geo_value"
q.subquery = f"JOIN (SELECT {sub_fields} FROM {q.table} WHERE {q.conditions_clause} AND {sub_condition_asof} GROUP BY {sub_group}) x ON {sub_condition}"
else:
# else we are using the (standard/default) `latest_table`, to fetch the most recent issue quickly
pass


@bp.route("/", methods=("GET", "POST"))
def handle():
source_signal_pairs = parse_source_signal_pairs()
Expand All @@ -132,11 +110,11 @@ def handle():
fields_float = ["value", "stderr", "sample_size"]
is_compatibility = is_compatibility_mode()
if is_compatibility:
q.set_order("signal", "time_value", "geo_value", "issue")
q.set_sort_order("signal", "time_value", "geo_value", "issue")
else:
# transfer also the new detail columns
fields_string.extend(["source", "geo_type", "time_type"])
q.set_order("source", "signal", "time_type", "time_value", "geo_type", "geo_value", "issue")
q.set_sort_order("source", "signal", "time_type", "time_value", "geo_type", "geo_value", "issue")
q.set_fields(fields_string, fields_int, fields_float)

# basic query info
Expand All @@ -147,7 +125,9 @@ def handle():
q.where_geo_pairs("geo_type", "geo_value", geo_pairs)
q.where_time_pair("time_type", "time_value", time_pair)

_handle_lag_issues_as_of(q, issues, lag, as_of)
q.apply_issues_filter(history_table, issues)
q.apply_lag_filter(history_table, lag)
q.apply_as_of_filter(history_table, as_of)

def transform_row(row, proxy):
if is_compatibility or not alias_mapper or "source" not in row:
Expand Down Expand Up @@ -195,15 +175,12 @@ def handle_trend():
fields_int = ["time_value"]
fields_float = ["value"]
q.set_fields(fields_string, fields_int, fields_float)
q.set_order("geo_type", "geo_value", "source", "signal", "time_value")
q.set_sort_order("geo_type", "geo_value", "source", "signal", "time_value")

q.where_source_signal_pairs("source", "signal", source_signal_pairs)
q.where_geo_pairs("geo_type", "geo_value", geo_pairs)
q.where_time_pair("time_type", "time_value", time_window)

# fetch most recent issue fast
_handle_lag_issues_as_of(q, None, None, None)

p = create_printer()

def gen(rows):
Expand Down Expand Up @@ -246,15 +223,12 @@ def handle_trendseries():
fields_int = ["time_value"]
fields_float = ["value"]
q.set_fields(fields_string, fields_int, fields_float)
q.set_order("geo_type", "geo_value", "source", "signal", "time_value")
q.set_sort_order("geo_type", "geo_value", "source", "signal", "time_value")

q.where_source_signal_pairs("source", "signal", source_signal_pairs)
q.where_geo_pairs("geo_type", "geo_value", geo_pairs)
q.where_time_pair("time_type", "time_value", time_window)

# fetch most recent issue fast
_handle_lag_issues_as_of(q, None, None, None)

p = create_printer()

shifter = lambda x: shift_day_value(x, -basis_shift)
Expand Down Expand Up @@ -303,7 +277,7 @@ def handle_correlation():
fields_int = ["time_value"]
fields_float = ["value"]
q.set_fields(fields_string, fields_int, fields_float)
q.set_order("geo_type", "geo_value", "source", "signal", "time_value")
q.set_sort_order("geo_type", "geo_value", "source", "signal", "time_value")

q.where_source_signal_pairs(
"source",
Expand Down Expand Up @@ -381,12 +355,12 @@ def handle_export():
q = QueryBuilder(latest_table, "t")

q.set_fields(["geo_value", "signal", "time_value", "issue", "lag", "value", "stderr", "sample_size", "geo_type", "source"], [], [])
q.set_order("time_value", "geo_value")
q.set_sort_order("time_value", "geo_value")
q.where_source_signal_pairs("source", "signal", source_signal_pairs)
q.where_time_pair("time_type", "time_value", TimePair("day" if is_day else "week", [(start_day, end_day)]))
q.where_geo_pairs("geo_type", "geo_value", [GeoPair(geo_type, True if geo_values == "*" else geo_values)])

_handle_lag_issues_as_of(q, None, None, as_of)
q.apply_as_of_filter(history_table, as_of)

format_date = time_value_to_iso if is_day else lambda x: time_value_to_week(x).cdcformat()
# tag as_of in filename, if it was specified
Expand Down Expand Up @@ -459,16 +433,13 @@ def handle_backfill():
fields_int = ["time_value", "issue"]
fields_float = ["value", "sample_size"]
# sort by time value and issue asc
q.set_order(time_value=True, issue=True)
q.set_sort_order("time_value", "issue")
q.set_fields(fields_string, fields_int, fields_float, ["is_latest_issue"])

q.where_source_signal_pairs("source", "signal", source_signal_pairs)
q.where_geo_pairs("geo_type", "geo_value", [geo_pair])
q.where_time_pair("time_type", "time_value", time_pair)

# no restriction of issues or dates since we want all issues
# _handle_lag_issues_as_of(q, issues, lag, as_of)

p = create_printer()

def find_anchor_row(rows: List[Dict[str, Any]], issue: int) -> Optional[Dict[str, Any]]:
Expand Down Expand Up @@ -642,9 +613,7 @@ def handle_coverage():
q.where_source_signal_pairs("source", "signal", source_signal_pairs)
q.where_time_pair("time_type", "time_value", time_window)
q.group_by = "c.source, c.signal, c.time_value"
q.set_order("source", "signal", "time_value")

_handle_lag_issues_as_of(q, None, None, None)
q.set_sort_order("source", "signal", "time_value")

def transform_row(row, proxy):
if not alias_mapper or "source" not in row:
Expand Down
5 changes: 0 additions & 5 deletions src/server/endpoints/covidcast_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,6 @@ def _load_data_signals(sources: List[DataSource]):
data_signals_by_key[(source.db_source, d.signal)] = d



def get_related_signals(signal: DataSignal) -> List[DataSignal]:
return [s for s in data_signals if s != signal and s.signal_basename == signal.signal_basename]


def count_signal_time_types(source_signals: List[SourceSignalPair]) -> Tuple[int, int]:
"""
count the number of signals in this query for each time type
Expand Down
2 changes: 1 addition & 1 deletion src/server/endpoints/dengue_nowcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def handle():
fields_float = ["value", "std"]
q.set_fields(fields_string, fields_int, fields_float)

q.set_order("epiweek", "location")
q.set_sort_order("epiweek", "location")

# build the filter
q.where_strings("location", locations)
Expand Down
2 changes: 1 addition & 1 deletion src/server/endpoints/dengue_sensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def handle():
fields_float = ["value"]
q.set_fields(fields_string, fields_int, fields_float)

q.set_order('epiweek', 'name', 'location')
q.set_sort_order('epiweek', 'name', 'location')

q.where_strings('name', names)
q.where_strings('location', locations)
Expand Down
2 changes: 1 addition & 1 deletion src/server/endpoints/ecdc_ili.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def handle():
fields_float = ["incidence_rate"]
q.set_fields(fields_string, fields_int, fields_float)

q.set_order("epiweek", "region", "issue")
q.set_sort_order("epiweek", "region", "issue")

q.where_integers("epiweek", epiweeks)
q.where_strings("region", regions)
Expand Down
2 changes: 1 addition & 1 deletion src/server/endpoints/flusurv.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def handle():
"rate_overall",
]
q.set_fields(fields_string, fields_int, fields_float)
q.set_order("epiweek", "location", "issue")
q.set_sort_order("epiweek", "location", "issue")

q.where_integers("epiweek", epiweeks)
q.where_strings("location", locations)
Expand Down
2 changes: 1 addition & 1 deletion src/server/endpoints/fluview_clinicial.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def handle():
fields_int = ["issue", "epiweek", "lag", "total_specimens", "total_a", "total_b"]
fields_float = ["percent_positive", "percent_a", "percent_b"]
q.set_fields(fields_string, fields_int, fields_float)
q.set_order("epiweek", "region", "issue")
q.set_sort_order("epiweek", "region", "issue")

q.where_integers("epiweek", epiweeks)
q.where_strings("region", regions)
Expand Down
2 changes: 1 addition & 1 deletion src/server/endpoints/gft.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def handle():
fields_int = ["epiweek", "num"]
fields_float = []
q.set_fields(fields_string, fields_int, fields_float)
q.set_order("epiweek", "location")
q.set_sort_order("epiweek", "location")

# build the filter
q.where_integers("epiweek", epiweeks)
Expand Down
2 changes: 1 addition & 1 deletion src/server/endpoints/ght.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def handle():
fields_float = ["value"]
q.set_fields(fields_string, fields_int, fields_float)

q.set_order("epiweek", "location")
q.set_sort_order("epiweek", "location")

# build the filter
q.where_strings("location", locations)
Expand Down
2 changes: 1 addition & 1 deletion src/server/endpoints/kcdc_ili.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def handle():
fields_float = ["ili"]
q.set_fields(fields_string, fields_int, fields_float)

q.set_order("epiweek", "region", "issue")
q.set_sort_order("epiweek", "region", "issue")
# build the filter
q.where_integers("epiweek", epiweeks)
q.where_strings("region", regions)
Expand Down
Loading