Skip to content

Commit 647e05d

Browse files
committed
Fixes the bug where the smoother drops the pandas series index:
* restore the index after smoothing * test to match
1 parent 4421f28 commit 647e05d

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

_delphi_utils_python/delphi_utils/smooth.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def smooth(self, signal: Union[np.ndarray, pd.Series]) -> Union[np.ndarray, pd.S
172172
return signal
173173

174174
is_pandas_series = isinstance(signal, pd.Series)
175+
pandas_index = signal.index if is_pandas_series else None
175176
signal = signal.to_numpy() if is_pandas_series else signal
176177

177178
# Find where the first non-nan value is located and truncate the initial nans
@@ -197,7 +198,10 @@ def smooth(self, signal: Union[np.ndarray, pd.Series]) -> Union[np.ndarray, pd.S
197198

198199
# Append the nans back, since we want to preserve length
199200
signal_smoothed = np.hstack([np.nan*np.ones(ix), signal_smoothed])
200-
signal_smoothed = signal_smoothed if not is_pandas_series else pd.Series(signal_smoothed)
201+
# Convert back to pandas if necessary
202+
if is_pandas_series:
203+
signal_smoothed = pd.Series(signal_smoothed)
204+
signal_smoothed.index = pandas_index
201205
return signal_smoothed
202206

203207
def impute(self, signal):

_delphi_utils_python/tests/test_smooth.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,11 @@ def test_pandas_series_input(self):
254254
assert np.allclose(
255255
signal[window_length - 1 :], smoothed_signal[window_length - 1 :]
256256
)
257+
258+
# Test that the index of the series gets preserved
259+
signal = pd.Series(np.ones(30), index=np.arange(50, 80))
260+
smoother = Smoother(smoother_name="moving_average", window_length=10)
261+
smoothed_signal = signal.transform(smoother.smooth)
262+
ix1 = signal.index
263+
ix2 = smoothed_signal.index
264+
assert ix1.equals(ix2)

0 commit comments

Comments
 (0)