Skip to content

Commit 33a79bd

Browse files
committed
ensure smoothed num <= smoothed den (for synthetic data), deescalate assert to warning, remove extra print
1 parent ceae6fa commit 33a79bd

File tree

3 files changed

+10
-10
lines changed

3 files changed

+10
-10
lines changed

emr_hosp/delphi_emr_hosp/sensor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,12 @@ def fit(y_data, sensor_dates, geo_id):
105105
# the left_gauss_linear smoother is not guaranteed to return values greater than 0
106106
smoothed_total_counts = np.clip(left_gauss_linear(total_counts.values), 0, None)
107107
smoothed_total_visits = np.clip(left_gauss_linear(total_visits.values), 0, None)
108+
109+
# in smoothing, the numerator may have become more than the denominator
110+
# simple fix is to clip the max values elementwise to the denominator (note that
111+
# this has only been observed in synthetic data)
112+
smoothed_total_counts = np.clip(smoothed_total_counts, 0, smoothed_total_visits)
113+
108114
smoothed_total_rates = (
109115
(smoothed_total_counts + 0.5) / (smoothed_total_visits + 1)
110116
)

emr_hosp/delphi_emr_hosp/update_sensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def write_to_csv(output_dict, out_name, output_path="."):
5656
if all_include[geo_id][i]:
5757
assert not np.isnan(sensor), "value for included sensor is nan"
5858
assert not np.isnan(se), "se for included sensor is nan"
59-
assert sensor < 90, f"value suspiciously high, {geo_id}: {sensor}"
59+
if sensor > 90:
60+
logging.warning(f"value suspiciously high, {geo_id}: {sensor}")
6061
assert se < 5, f"se suspiciously high, {geo_id}: {se}"
6162

6263
# for privacy reasons we will not report the standard error
@@ -136,7 +137,6 @@ def update_sensor(
136137
unique_geo_ids = list(sorted(np.unique(data_frame.index.get_level_values(0))))
137138

138139
# for each location, fill in all missing dates with 0 values
139-
print(len(unique_geo_ids), len(fit_dates))
140140
multiindex = pd.MultiIndex.from_product((unique_geo_ids, fit_dates),
141141
names=[geo, "date"])
142142
assert (len(multiindex) <= (Constants.MAX_GEO[geo] * len(fit_dates))

emr_hosp/tests/test_update_sensor.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,16 +118,10 @@ def test_write_to_csv_wrong_results(self):
118118
with pytest.raises(AssertionError):
119119
write_to_csv(res2, "name_of_signal", td.name)
120120

121-
# large sensor value
121+
# large se value
122122
res3 = deepcopy(res0)
123-
res3["rates"]["a"][0] = 95
123+
res3["se"]["a"][0] = 10
124124
with pytest.raises(AssertionError):
125125
write_to_csv(res3, "name_of_signal", td.name)
126126

127-
# large se value
128-
res4 = deepcopy(res0)
129-
res4["se"]["a"][0] = 10
130-
with pytest.raises(AssertionError):
131-
write_to_csv(res4, "name_of_signal", td.name)
132-
133127
td.cleanup()

0 commit comments

Comments
 (0)