Skip to content

Commit 8188bf7

Browse files
authored
Merge pull request #1954 from cmu-delphi/ds/nwss
lint: minor code clean to nwss
2 parents e62f81f + f159a37 commit 8188bf7

File tree

7 files changed

+69
-152
lines changed

7 files changed

+69
-152
lines changed

nchs_mortality/delphi_nchs_mortality/constants.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,3 @@
2525
"prop"
2626
]
2727
INCIDENCE_BASE = 100000
28-
29-
# this is necessary as a delimiter in the f-string expressions we use to
30-
# construct detailed error reports
31-
# (https://www.python.org/dev/peps/pep-0498/#escape-sequences)
32-
NEWLINE = "\n"

nchs_mortality/delphi_nchs_mortality/pull.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from delphi_utils.geomap import GeoMapper
1111

12-
from .constants import METRICS, RENAME, NEWLINE
12+
from .constants import METRICS, RENAME
1313

1414
def standardize_columns(df):
1515
"""Rename columns to comply with a standard set.
@@ -90,10 +90,10 @@ def pull_nchs_mortality_data(socrata_token: str, test_file: Optional[str] = None
9090
have changed. Please investigate and amend the code.
9191
9292
Columns needed:
93-
{NEWLINE.join(type_dict.keys())}
93+
{'\n'.join(type_dict.keys())}
9494
9595
Columns available:
96-
{NEWLINE.join(df.columns)}
96+
{'\n'.join(df.columns)}
9797
""") from exc
9898

9999
df = df[keep_columns + ["timestamp", "state"]].set_index("timestamp")

nwss_wastewater/delphi_nwss/constants.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,23 @@
2222
"microbial",
2323
],
2424
}
25-
METRIC_DATES = ["date_start", "date_end"]
26-
SAMPLE_SITE_NAMES = {
27-
"wwtp_jurisdiction": "category",
28-
"wwtp_id": int,
29-
"reporting_jurisdiction": "category",
30-
"sample_location": "category",
31-
"county_names": "category",
32-
"county_fips": "category",
33-
"population_served": float,
34-
"sampling_prior": bool,
35-
"sample_location_specify": float,
36-
}
3725
SIG_DIGITS = 4
3826

39-
NEWLINE = "\n"
27+
TYPE_DICT = {key: float for key in SIGNALS}
28+
TYPE_DICT.update({"timestamp": "datetime64[ns]"})
29+
TYPE_DICT_METRIC = {key: float for key in METRIC_SIGNALS}
30+
TYPE_DICT_METRIC.update({key: "datetime64[ns]" for key in ["date_start", "date_end"]})
31+
# Sample site names
32+
TYPE_DICT_METRIC.update(
33+
{
34+
"wwtp_jurisdiction": "category",
35+
"wwtp_id": int,
36+
"reporting_jurisdiction": "category",
37+
"sample_location": "category",
38+
"county_names": "category",
39+
"county_fips": "category",
40+
"population_served": float,
41+
"sampling_prior": bool,
42+
"sample_location_specify": float,
43+
}
44+
)

nwss_wastewater/delphi_nwss/pull.py

Lines changed: 38 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,9 @@
99
SIGNALS,
1010
PROVIDER_NORMS,
1111
METRIC_SIGNALS,
12-
METRIC_DATES,
13-
SAMPLE_SITE_NAMES,
1412
SIG_DIGITS,
15-
NEWLINE,
13+
TYPE_DICT,
14+
TYPE_DICT_METRIC,
1615
)
1716

1817

@@ -35,34 +34,29 @@ def sig_digit_round(value, n_digits):
3534
return result
3635

3736

38-
def construct_typedicts():
39-
"""Create the type conversion dictionary for both dataframes."""
40-
# basic type conversion
41-
type_dict = {key: float for key in SIGNALS}
42-
type_dict["timestamp"] = "datetime64[ns]"
43-
# metric type conversion
44-
signals_dict_metric = {key: float for key in METRIC_SIGNALS}
45-
metric_dates_dict = {key: "datetime64[ns]" for key in METRIC_DATES}
46-
type_dict_metric = {**metric_dates_dict, **signals_dict_metric, **SAMPLE_SITE_NAMES}
47-
return type_dict, type_dict_metric
48-
49-
50-
def warn_string(df, type_dict):
51-
"""Format the warning string."""
52-
return f"""
37+
def convert_df_type(df, type_dict, logger):
38+
"""Convert types and warn if there are unexpected columns."""
39+
try:
40+
df = df.astype(type_dict)
41+
except KeyError as exc:
42+
newline = "\n"
43+
raise KeyError(
44+
f"""
5345
Expected column(s) missed, The dataset schema may
5446
have changed. Please investigate and amend the code.
5547
56-
Columns needed:
57-
{NEWLINE.join(sorted(type_dict.keys()))}
48+
expected={newline.join(sorted(type_dict.keys()))}
5849
59-
Columns available:
60-
{NEWLINE.join(sorted(df.columns))}
50+
received={newline.join(sorted(df.columns))}
6151
"""
52+
) from exc
53+
if new_columns := set(df.columns) - set(type_dict.keys()):
54+
logger.info("New columns found in NWSS dataset.", new_columns=new_columns)
55+
return df
6256

6357

6458
def reformat(df, df_metric):
65-
"""Add columns from df_metric to df, and rename some columns.
59+
"""Add columns from df_metric to df, and rename some columns.
6660
6761
Specifically the population and METRIC_SIGNAL columns, and renames date_start to timestamp.
6862
"""
@@ -80,27 +74,16 @@ def reformat(df, df_metric):
8074
return df
8175

8276

83-
def drop_unnormalized(df):
84-
"""Drop unnormalized.
85-
86-
mutate `df` to no longer have rows where the normalization scheme isn't actually identified,
87-
as we can't classify the kind of signal
88-
"""
89-
return df[~df["normalization"].isna()]
90-
91-
9277
def add_identifier_columns(df):
9378
"""Add identifier columns.
9479
9580
Add columns to get more detail than key_plot_id gives;
9681
specifically, state, and `provider_normalization`, which gives the signal identifier
9782
"""
98-
df["state"] = df.key_plot_id.str.extract(
99-
r"_(\w\w)_"
100-
) # a pair of alphanumerics surrounded by _
101-
df["provider"] = df.key_plot_id.str.extract(
102-
r"(.*)_[a-z]{2}_"
103-
) # anything followed by state ^
83+
# a pair of alphanumerics surrounded by _
84+
df["state"] = df.key_plot_id.str.extract(r"_(\w\w)_")
85+
# anything followed by state ^
86+
df["provider"] = df.key_plot_id.str.extract(r"(.*)_[a-z]{2}_")
10487
df["signal_name"] = df.provider + "_" + df.normalization
10588

10689

@@ -120,7 +103,7 @@ def check_endpoints(df):
120103
)
121104

122105

123-
def pull_nwss_data(token: str):
106+
def pull_nwss_data(token: str, logger):
124107
"""Pull the latest NWSS Wastewater data, and conforms it into a dataset.
125108
126109
The output dataset has:
@@ -141,11 +124,6 @@ def pull_nwss_data(token: str):
141124
pd.DataFrame
142125
Dataframe as described above.
143126
"""
144-
# Constants
145-
keep_columns = [*SIGNALS, *METRIC_SIGNALS]
146-
# concentration key types
147-
type_dict, type_dict_metric = construct_typedicts()
148-
149127
# Pull data from Socrata API
150128
client = Socrata("data.cdc.gov", token)
151129
results_concentration = client.get("g653-rqe2", limit=10 ** 10)
@@ -154,19 +132,14 @@ def pull_nwss_data(token: str):
154132
df_concentration = pd.DataFrame.from_records(results_concentration)
155133
df_concentration = df_concentration.rename(columns={"date": "timestamp"})
156134

157-
try:
158-
df_concentration = df_concentration.astype(type_dict)
159-
except KeyError as exc:
160-
raise ValueError(warn_string(df_concentration, type_dict)) from exc
135+
# Schema checks.
136+
df_concentration = convert_df_type(df_concentration, TYPE_DICT, logger)
137+
df_metric = convert_df_type(df_metric, TYPE_DICT_METRIC, logger)
161138

162-
try:
163-
df_metric = df_metric.astype(type_dict_metric)
164-
except KeyError as exc:
165-
raise ValueError(warn_string(df_metric, type_dict_metric)) from exc
139+
# Drop sites without a normalization scheme.
140+
df = df_concentration[~df_concentration["normalization"].isna()]
166141

167-
# if the normalization scheme isn't recorded, why is it even included as a sample site?
168-
df = drop_unnormalized(df_concentration)
169-
# pull 2 letter state labels out of the key_plot_id labels
142+
# Pull 2 letter state labels out of the key_plot_id labels.
170143
add_identifier_columns(df)
171144

172145
# move population and metric signals over to df
@@ -180,13 +153,14 @@ def pull_nwss_data(token: str):
180153
# otherwise, best to assume some value rather than break the data)
181154
df.population_served = df.population_served.ffill()
182155
check_endpoints(df)
183-
keep_columns.extend(
184-
[
185-
"timestamp",
186-
"state",
187-
"population_served",
188-
"normalization",
189-
"provider",
190-
]
191-
)
156+
157+
keep_columns = [
158+
*SIGNALS,
159+
*METRIC_SIGNALS,
160+
"timestamp",
161+
"state",
162+
"population_served",
163+
"normalization",
164+
"provider",
165+
]
192166
return df[keep_columns]

nwss_wastewater/delphi_nwss/run.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
- "bucket_name: str, name of S3 bucket to read/write
2222
- "cache_dir": str, directory of locally cached data
2323
"""
24+
2425
import time
2526
from datetime import datetime
2627

@@ -138,10 +139,10 @@ def run_module(params):
138139
run_stats = []
139140
## build the base version of the signal at the most detailed geo level you can get.
140141
## compute stuff here or farm out to another function or file
141-
df_pull = pull_nwss_data(socrata_token)
142+
df_pull = pull_nwss_data(socrata_token, logger)
142143
## aggregate
143144
# iterate over the providers and the normalizations that they specifically provide
144-
for (provider, normalization) in zip(
145+
for provider, normalization in zip(
145146
PROVIDER_NORMS["provider"], PROVIDER_NORMS["normalization"]
146147
):
147148
# copy by only taking the relevant subsection

nwss_wastewater/tests/test_pull.py

Lines changed: 5 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,12 @@
1-
from datetime import datetime, date
2-
import json
3-
from unittest.mock import patch
4-
import tempfile
5-
import os
6-
import time
7-
from datetime import datetime
8-
91
import pandas as pd
102
import pandas.api.types as ptypes
113

124
from delphi_nwss.pull import (
135
add_identifier_columns,
14-
check_endpoints,
15-
construct_typedicts,
166
sig_digit_round,
177
reformat,
18-
warn_string,
198
)
9+
from delphi_nwss.constants import TYPE_DICT, TYPE_DICT_METRIC
2010
import numpy as np
2111

2212

@@ -31,32 +21,10 @@ def test_sig_digit():
3121
).all()
3222

3323

34-
def test_column_type_dicts():
35-
type_dict, type_dict_metric = construct_typedicts()
36-
assert type_dict == {"pcr_conc_smoothed": float, "timestamp": "datetime64[ns]"}
37-
assert type_dict_metric == {
38-
"date_start": "datetime64[ns]",
39-
"date_end": "datetime64[ns]",
40-
"detect_prop_15d": float,
41-
"percentile": float,
42-
"ptc_15d": float,
43-
"wwtp_jurisdiction": "category",
44-
"wwtp_id": int,
45-
"reporting_jurisdiction": "category",
46-
"sample_location": "category",
47-
"county_names": "category",
48-
"county_fips": "category",
49-
"population_served": float,
50-
"sampling_prior": bool,
51-
"sample_location_specify": float,
52-
}
53-
54-
5524
def test_column_conversions_concentration():
56-
type_dict, type_dict_metric = construct_typedicts()
5725
df = pd.read_csv("test_data/conc_data.csv", index_col=0)
5826
df = df.rename(columns={"date": "timestamp"})
59-
converted = df.astype(type_dict)
27+
converted = df.astype(TYPE_DICT)
6028
assert all(
6129
converted.columns
6230
== pd.Index(["key_plot_id", "timestamp", "pcr_conc_smoothed", "normalization"])
@@ -66,9 +34,8 @@ def test_column_conversions_concentration():
6634

6735

6836
def test_column_conversions_metric():
69-
type_dict, type_dict_metric = construct_typedicts()
7037
df = pd.read_csv("test_data/metric_data.csv", index_col=0)
71-
converted = df.astype(type_dict_metric)
38+
converted = df.astype(TYPE_DICT_METRIC)
7239
assert all(
7340
converted.columns
7441
== pd.Index(
@@ -113,24 +80,13 @@ def test_column_conversions_metric():
11380
assert all(ptypes.is_numeric_dtype(converted[flo].dtype) for flo in float_typed)
11481

11582

116-
def test_warn_string():
117-
type_dict, type_dict_metric = construct_typedicts()
118-
df_conc = pd.read_csv("test_data/conc_data.csv")
119-
assert (
120-
warn_string(df_conc, type_dict)
121-
== "\nExpected column(s) missed, The dataset schema may\nhave changed. Please investigate and amend the code.\n\nColumns needed:\npcr_conc_smoothed\ntimestamp\n\nColumns available:\nUnnamed: 0\ndate\nkey_plot_id\nnormalization\npcr_conc_smoothed\n"
122-
)
123-
124-
12583
def test_formatting():
126-
type_dict, type_dict_metric = construct_typedicts()
12784
df_metric = pd.read_csv("test_data/metric_data.csv", index_col=0)
128-
df_metric = df_metric.astype(type_dict_metric)
85+
df_metric = df_metric.astype(TYPE_DICT_METRIC)
12986

130-
type_dict, type_dict_metric = construct_typedicts()
13187
df = pd.read_csv("test_data/conc_data.csv", index_col=0)
13288
df = df.rename(columns={"date": "timestamp"})
133-
df = df.astype(type_dict)
89+
df = df.astype(TYPE_DICT)
13490

13591
df_formatted = reformat(df, df_metric)
13692

nwss_wastewater/tests/test_run.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,7 @@
1-
from datetime import datetime, date
2-
import json
3-
from unittest.mock import patch
4-
import tempfile
5-
import os
6-
import time
7-
from datetime import datetime
8-
91
import numpy as np
102
import pandas as pd
113
from pandas.testing import assert_frame_equal
12-
from delphi_utils import S3ArchiveDiffer, get_structured_logger, create_export_csv, Nans
134

14-
from delphi_nwss.constants import GEOS, SIGNALS
155
from delphi_nwss.run import (
166
add_needed_columns,
177
generate_weights,
@@ -23,13 +13,9 @@
2313

2414
def test_sum_all_nan():
2515
"""Check that sum_all_nan returns NaN iff everything is a NaN"""
26-
no_nans = np.array([3, 5])
27-
assert sum_all_nan(no_nans) == 8
28-
partial_nan = np.array([np.nan, 3, 5])
16+
assert sum_all_nan(np.array([3, 5])) == 8
2917
assert np.isclose(sum_all_nan([np.nan, 3, 5]), 8)
30-
31-
oops_all_nans = np.array([np.nan, np.nan])
32-
assert np.isnan(oops_all_nans).all()
18+
assert np.isnan(np.array([np.nan, np.nan])).all()
3319

3420

3521
def test_weight_generation():

0 commit comments

Comments
 (0)