Skip to content

Commit fc5c8dd

Browse files
committed
Standardize weekday calc_adjustment() function
1 parent 04d6185 commit fc5c8dd

File tree

6 files changed

+51
-44
lines changed

6 files changed

+51
-44
lines changed

changehc/delphi_changehc/update_sensor.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,10 @@ def update_sensor(self,
201201
if not self.parallel:
202202
dfs = []
203203
for geo_id, sub_data in data_frame.groupby(level=0):
204-
sub_data.reset_index(level=0,inplace=True)
204+
sub_data.reset_index(inplace=True)
205205
if self.weekday:
206-
sub_data = Weekday.calc_adjustment(wd_params, sub_data)
206+
sub_data = Weekday.calc_adjustment(wd_params, sub_data, ["num"])
207+
sub_data.set_index(Config.DATE_COL, inplace=True)
207208
res = CHCSensor.fit(sub_data, self.burnindate, geo_id, self.logger)
208209
res = pd.DataFrame(res).loc[final_sensor_idxs]
209210
dfs.append(res)
@@ -213,9 +214,10 @@ def update_sensor(self,
213214
with Pool(n_cpu) as pool:
214215
pool_results = []
215216
for geo_id, sub_data in data_frame.groupby(level=0,as_index=False):
216-
sub_data.reset_index(level=0, inplace=True)
217+
sub_data.reset_index(inplace=True)
217218
if self.weekday:
218-
sub_data = Weekday.calc_adjustment(wd_params, sub_data)
219+
sub_data = Weekday.calc_adjustment(wd_params, sub_data, ["num"])
220+
sub_data.set_index(Config.DATE_COL, inplace=True)
219221
pool_results.append(
220222
pool.apply_async(
221223
CHCSensor.fit, args=(sub_data, self.burnindate, geo_id, self.logger),

changehc/delphi_changehc/weekday.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,10 @@ def get_params(data):
9090
_ = prob.solve()
9191
params = b.value
9292

93-
return params
93+
return params.reshape(1, -1)
9494

9595
@staticmethod
96-
def calc_adjustment(params, sub_data):
96+
def calc_adjustment(params, sub_data, cols):
9797
"""Apply the weekday adjustment to a specific time series.
9898
9999
Extracts the weekday fixed effects from the parameters and uses these to
@@ -112,14 +112,14 @@ def calc_adjustment(params, sub_data):
112112
-- this has the same effect.
113113
114114
"""
115-
tmp = sub_data.reset_index()
116-
117-
wd_correction = np.zeros((len(tmp["num"])))
118-
for wd in range(7):
119-
mask = tmp[Config.DATE_COL].dt.dayofweek == wd
120-
wd_correction[mask] = tmp.loc[mask, "num"] / (
121-
np.exp(params[wd]) if wd < 6 else np.exp(-np.sum(params[:6]))
122-
)
123-
tmp.loc[:, "num"] = wd_correction
124-
125-
return tmp.set_index(Config.DATE_COL)
115+
tmp = sub_data.copy()
116+
for i, c in enumerate(cols):
117+
wd_correction = np.zeros((len(tmp[c])))
118+
119+
for wd in range(7):
120+
mask = tmp[Config.DATE_COL].dt.dayofweek == wd
121+
wd_correction[mask] = tmp.loc[mask, c] / (
122+
np.exp(params[i, wd]) if wd < 6 else np.exp(-np.sum(params[i, :6]))
123+
)
124+
tmp.loc[:, c] = wd_correction
125+
return tmp

claims_hosp/delphi_claims_hosp/update_indicator.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,10 @@ def update_indicator(self, input_filepath, outpath):
160160
valid_inds = {}
161161
if not self.parallel:
162162
for geo_id, sub_data in data_frame.groupby(level=0):
163-
sub_data.reset_index(level=0, inplace=True)
163+
sub_data.reset_index(inplace=True)
164164
if self.weekday:
165-
sub_data = Weekday.calc_adjustment(wd_params, sub_data)
165+
sub_data = Weekday.calc_adjustment(wd_params, sub_data, ["num"])
166+
sub_data.set_index(Config.DATE_COL, inplace=True)
166167
res = ClaimsHospIndicator.fit(sub_data, self.burnindate, geo_id)
167168
res = pd.DataFrame(res)
168169
rates[geo_id] = np.array(res.loc[final_output_inds, "rate"])
@@ -174,9 +175,10 @@ def update_indicator(self, input_filepath, outpath):
174175
with Pool(n_cpu) as pool:
175176
pool_results = []
176177
for geo_id, sub_data in data_frame.groupby(level=0, as_index=False):
177-
sub_data.reset_index(level=0, inplace=True)
178+
sub_data.reset_index(inplace=True)
178179
if self.weekday:
179-
sub_data = Weekday.calc_adjustment(wd_params, sub_data)
180+
sub_data = Weekday.calc_adjustment(wd_params, sub_data, ["num"])
181+
sub_data.set_index(Config.DATE_COL, inplace=True)
180182
pool_results.append(
181183
pool.apply_async(
182184
ClaimsHospIndicator.fit,

claims_hosp/delphi_claims_hosp/weekday.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,10 @@ def get_params(data):
9090
_ = prob.solve()
9191
params = b.value
9292

93-
return params
93+
return params.reshape(1, -1)
9494

9595
@staticmethod
96-
def calc_adjustment(params, sub_data):
96+
def calc_adjustment(params, sub_data, cols):
9797
"""Apply the weekday adjustment to a specific time series.
9898
9999
Extracts the weekday fixed effects from the parameters and uses these to
@@ -112,14 +112,14 @@ def calc_adjustment(params, sub_data):
112112
-- this has the same effect.
113113
114114
"""
115-
tmp = sub_data.reset_index()
116-
117-
wd_correction = np.zeros((len(tmp["num"])))
118-
for wd in range(7):
119-
mask = tmp[Config.DATE_COL].dt.dayofweek == wd
120-
wd_correction[mask] = tmp.loc[mask, "num"] / (
121-
np.exp(params[wd]) if wd < 6 else np.exp(-np.sum(params[:6]))
122-
)
123-
tmp.loc[:, "num"] = wd_correction
124-
125-
return tmp.set_index(Config.DATE_COL)
115+
tmp = sub_data.copy()
116+
for i, c in enumerate(cols):
117+
wd_correction = np.zeros((len(tmp[c])))
118+
119+
for wd in range(7):
120+
mask = tmp[Config.DATE_COL].dt.dayofweek == wd
121+
wd_correction[mask] = tmp.loc[mask, c] / (
122+
np.exp(params[i, wd]) if wd < 6 else np.exp(-np.sum(params[i, :6]))
123+
)
124+
tmp.loc[:, c] = wd_correction
125+
return tmp

doctor_visits/delphi_doctor_visits/update_sensor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def update_sensor(
145145
for geo_id in unique_geo_ids:
146146
sub_data = data_groups.get_group(geo_id).copy()
147147
if weekday:
148-
sub_data = Weekday.calc_adjustment(params, sub_data)
148+
sub_data = Weekday.calc_adjustment(params, sub_data, Config.CLI_COLS + Config.FLU1_COL)
149149

150150
res = DoctorVisitsSensor.fit(
151151
sub_data,
@@ -169,7 +169,9 @@ def update_sensor(
169169
for geo_id in unique_geo_ids:
170170
sub_data = data_groups.get_group(geo_id).copy()
171171
if weekday:
172-
sub_data = Weekday.calc_adjustment(params, sub_data)
172+
sub_data = Weekday.calc_adjustment(params,
173+
sub_data,
174+
Config.CLI_COLS + Config.FLU1_COL)
173175

174176
pool_results.append(
175177
pool.apply_async(

doctor_visits/delphi_doctor_visits/weekday.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def get_params(data, logger):
103103
return params
104104

105105
@staticmethod
106-
def calc_adjustment(params, sub_data):
106+
def calc_adjustment(params, sub_data, cols):
107107
"""Apply the weekday adjustment to a specific time series.
108108
109109
Extracts the weekday fixed effects from the parameters and uses these to
@@ -122,14 +122,15 @@ def calc_adjustment(params, sub_data):
122122
-- this has the same effect.
123123
124124
"""
125-
for i, c in enumerate(Config.CLI_COLS + Config.FLU1_COL):
126-
wd_correction = np.zeros((len(sub_data[c])))
125+
tmp = sub_data.copy()
126+
127+
for i, c in enumerate(cols):
128+
wd_correction = np.zeros((len(tmp[c])))
127129

128130
for wd in range(7):
129-
mask = sub_data[Config.DATE_COL].dt.dayofweek == wd
130-
wd_correction[mask] = sub_data.loc[mask, c] / (
131+
mask = tmp[Config.DATE_COL].dt.dayofweek == wd
132+
wd_correction[mask] = tmp.loc[mask, c] / (
131133
np.exp(params[i, wd]) if wd < 6 else np.exp(-np.sum(params[i, :6]))
132134
)
133-
sub_data.loc[:, c] = wd_correction
134-
135-
return sub_data
135+
tmp.loc[:, c] = wd_correction
136+
return tmp

0 commit comments

Comments
 (0)