Skip to content

Commit 4421f28

Browse files
committed
Update smoother to gracefully handle nans:
* entire array of nans is handled * left-padded nans are now ignored * a few other edge cases * add tests to match
1 parent ad77995 commit 4421f28

File tree

2 files changed

+63
-13
lines changed

2 files changed

+63
-13
lines changed

_delphi_utils_python/delphi_utils/smooth.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ def __init__(
138138
raise ValueError("Invalid impute_method given.")
139139
if self.boundary_method not in valid_boundary_methods:
140140
raise ValueError("Invalid boundary_method given.")
141+
if self.window_length <= 1:
142+
raise ValueError("Window length is too short.")
141143

142144
if smoother_name == "savgol":
143145
# The polynomial fitting is done on a past window of size window_length
@@ -165,20 +167,36 @@ def smooth(self, signal: Union[np.ndarray, pd.Series]) -> Union[np.ndarray, pd.S
165167
A smoothed 1D signal. Returns an array of the same type and length as
166168
the input.
167169
"""
170+
# If all nans, pass through
171+
if np.all(np.isnan(signal)):
172+
return signal
173+
168174
is_pandas_series = isinstance(signal, pd.Series)
169175
signal = signal.to_numpy() if is_pandas_series else signal
170176

171-
signal = self.impute(signal)
177+
# Find where the first non-nan value is located and truncate the initial nans
178+
ix = np.where(~np.isnan(signal))[0][0]
179+
signal = signal[ix:]
172180

173-
if self.smoother_name == "savgol":
174-
signal_smoothed = self.savgol_smoother(signal)
175-
elif self.smoother_name == "left_gauss_linear":
176-
signal_smoothed = self.left_gauss_linear_smoother(signal)
177-
elif self.smoother_name == "moving_average":
178-
signal_smoothed = self.moving_average_smoother(signal)
179-
else:
181+
# Don't smooth in certain edge cases
182+
if len(signal) < self.poly_fit_degree or len(signal) == 1:
180183
signal_smoothed = signal.copy()
181-
184+
else:
185+
# Impute
186+
signal = self.impute(signal)
187+
188+
# Smooth
189+
if self.smoother_name == "savgol":
190+
signal_smoothed = self.savgol_smoother(signal)
191+
elif self.smoother_name == "left_gauss_linear":
192+
signal_smoothed = self.left_gauss_linear_smoother(signal)
193+
elif self.smoother_name == "moving_average":
194+
signal_smoothed = self.moving_average_smoother(signal)
195+
elif self.smoother_name == "identity":
196+
signal_smoothed = signal
197+
198+
# Append the nans back, since we want to preserve length
199+
signal_smoothed = np.hstack([np.nan*np.ones(ix), signal_smoothed])
182200
signal_smoothed = signal_smoothed if not is_pandas_series else pd.Series(signal_smoothed)
183201
return signal_smoothed
184202

@@ -282,7 +300,7 @@ def left_gauss_linear_smoother(self, signal):
282300

283301
def savgol_predict(self, signal, poly_fit_degree, nr):
284302
"""Predict a single value using the savgol method.
285-
303+
286304
Fits a polynomial through the values given by the signal and returns the value
287305
of the polynomial at the right-most signal-value. More precisely, for a signal of length
288306
n, fits a poly_fit_degree polynomial through the points signal[-n+1+nr], signal[-n+2+nr],
@@ -311,7 +329,8 @@ def savgol_predict(self, signal, poly_fit_degree, nr):
311329
def savgol_coeffs(self, nl, nr, poly_fit_degree):
312330
"""Solve for the Savitzky-Golay coefficients.
313331
314-
The coefficients c_i give a filter so that
332+
Solves for the Savitzky-Golay coefficients. The coefficients c_i
333+
give a filter so that
315334
y = sum_{i=-{n_l}}^{n_r} c_i x_i
316335
is the value at 0 (thus the constant term) of the polynomial fit
317336
through the points {x_i}. The coefficients are c_i are calculated as
@@ -385,7 +404,7 @@ def savgol_smoother(self, signal):
385404
# - identity keeps the original signal (doesn't smooth)
386405
# - nan writes nans
387406
if self.boundary_method == "shortened_window":
388-
for ix in range(len(self.coeffs)):
407+
for ix in range(min(len(self.coeffs), len(signal))):
389408
if ix == 0:
390409
signal_smoothed[ix] = signal[ix]
391410
else:

_delphi_utils_python/tests/test_smooth.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Tests for the smoothing utility.
33
Authors: Dmitry Shemetov, Addison Hu, Maria Jahja
44
"""
5+
from numpy.lib.polynomial import poly
56
import pytest
67

78
import numpy as np
@@ -109,9 +110,39 @@ def test_causal_savgol_smoother(self):
109110
smoother_name="savgol", window_length=window_length, poly_fit_degree=1,
110111
)
111112
smoothed_signal2 = smoother.smooth(signal)
112-
113113
assert np.allclose(smoothed_signal1, smoothed_signal2)
114114

115+
# Test the all nans case
116+
signal = np.nan * np.ones(10)
117+
smoother = Smoother(window_length=9)
118+
smoothed_signal = smoother.smooth(signal)
119+
assert np.all(np.isnan(smoothed_signal))
120+
121+
# Test the case where the signal is length 1
122+
signal = np.ones(1)
123+
smoother = Smoother()
124+
smoothed_signal = smoother.smooth(signal)
125+
assert np.allclose(smoothed_signal, signal)
126+
127+
# Test the case where the signal length is less than polynomial_fit_degree
128+
signal = np.ones(2)
129+
smoother = Smoother(poly_fit_degree=3)
130+
smoothed_signal = smoother.smooth(signal)
131+
assert np.allclose(smoothed_signal, signal)
132+
133+
# Test an edge fitting case
134+
signal = np.array([np.nan, 1, np.nan])
135+
smoother = Smoother(poly_fit_degree=1, window_length=2)
136+
smoothed_signal = smoother.smooth(signal)
137+
assert np.allclose(smoothed_signal, np.array([np.nan, 1, 1]), equal_nan=True)
138+
139+
# Test a range of cases where the signal size following a sequence of nans is returned
140+
for i in range(10):
141+
signal = np.hstack([[np.nan, np.nan, np.nan], np.ones(i)])
142+
smoother = Smoother(poly_fit_degree=0, window_length=5)
143+
smoothed_signal = smoother.smooth(signal)
144+
assert np.allclose(smoothed_signal, signal, equal_nan=True)
145+
115146
def test_impute(self):
116147
# test the nan imputer
117148
signal = np.array([i if i % 3 else np.nan for i in range(1, 40)])

0 commit comments

Comments
 (0)