Skip to content

Commit bf89220

Browse files
joshuastorckjreback
authored andcommitted
Implementing rolling min/max functions that can retain the original type:
* Changed the rolling min/max functions in algos.pyx so that they use a cython fused type as input instead of a float64 so that the function can accept arrays of any numeric type * Merged the functionality of rolling min/max into a common function with branches based on whether or not it's running min/max * When running rolling min/max for intergral types and there are not enough minimum periods, the output values returned are zero * Added a unit test to test_moments to make sure that rolling min/max works for all integral types and float32/64 * Updated computations and whatsnew doc closes pandas-dev#12595
1 parent 48e49ac commit bf89220

File tree

3 files changed

+97
-125
lines changed

3 files changed

+97
-125
lines changed

doc/source/whatsnew/v0.18.1.txt

+2
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ Bug Fixes
105105
- Bug in ``value_counts`` when ``normalize=True`` and ``dropna=True`` where nulls still contributed to the normalized count (:issue:`12558`)
106106
- Bug in ``Panel.fillna()`` ignoring ``inplace=True`` (:issue:`12633`)
107107
- Bug in ``Series.rename``, ``DataFrame.rename`` and ``DataFrame.rename_axis`` not treating ``Series`` as mappings to relabel (:issue:`12623`).
108+
- Clean in ``.rolling.min`` and ``.rolling.max`` to enhance dtype handling (:issue:`12373`)
109+
108110

109111

110112

pandas/algos.pyx

+79-125
Original file line numberDiff line numberDiff line change
@@ -1625,117 +1625,56 @@ def roll_median_c(ndarray[float64_t] arg, int win, int minp):
16251625
# of its Simplified BSD license
16261626
# https://github.com/kwgoodman/bottleneck
16271627

1628-
cdef struct pairs:
1629-
double value
1630-
int death
1631-
16321628
from libc cimport stdlib
16331629

16341630
@cython.boundscheck(False)
16351631
@cython.wraparound(False)
1636-
def roll_max(ndarray[float64_t] a, int window, int minp):
1637-
"Moving max of 1d array of dtype=float64 along axis=0 ignoring NaNs."
1638-
cdef np.float64_t ai, aold
1639-
cdef Py_ssize_t count
1640-
cdef pairs* ring
1641-
cdef pairs* minpair
1642-
cdef pairs* end
1643-
cdef pairs* last
1644-
cdef Py_ssize_t i0
1645-
cdef np.npy_intp *dim
1646-
dim = PyArray_DIMS(a)
1647-
cdef Py_ssize_t n0 = dim[0]
1648-
cdef np.npy_intp *dims = [n0]
1649-
cdef np.ndarray[np.float64_t, ndim=1] y = PyArray_EMPTY(1, dims,
1650-
NPY_float64, 0)
1651-
1652-
if window < 1:
1653-
raise ValueError('Invalid window size %d'
1654-
% (window))
1655-
1656-
if minp > window:
1657-
raise ValueError('Invalid min_periods size %d greater than window %d'
1658-
% (minp, window))
1659-
1660-
minp = _check_minp(window, minp, n0)
1661-
with nogil:
1662-
ring = <pairs*>stdlib.malloc(window * sizeof(pairs))
1663-
end = ring + window
1664-
last = ring
1665-
1666-
minpair = ring
1667-
ai = a[0]
1668-
if ai == ai:
1669-
minpair.value = ai
1670-
else:
1671-
minpair.value = MINfloat64
1672-
minpair.death = window
1673-
1674-
count = 0
1675-
for i0 in range(n0):
1676-
ai = a[i0]
1677-
if ai == ai:
1678-
count += 1
1679-
else:
1680-
ai = MINfloat64
1681-
if i0 >= window:
1682-
aold = a[i0 - window]
1683-
if aold == aold:
1684-
count -= 1
1685-
if minpair.death == i0:
1686-
minpair += 1
1687-
if minpair >= end:
1688-
minpair = ring
1689-
if ai >= minpair.value:
1690-
minpair.value = ai
1691-
minpair.death = i0 + window
1692-
last = minpair
1693-
else:
1694-
while last.value <= ai:
1695-
if last == ring:
1696-
last = end
1697-
last -= 1
1698-
last += 1
1699-
if last == end:
1700-
last = ring
1701-
last.value = ai
1702-
last.death = i0 + window
1703-
if count >= minp:
1704-
y[i0] = minpair.value
1705-
else:
1706-
y[i0] = NaN
1707-
1708-
for i0 in range(minp - 1):
1709-
y[i0] = NaN
1710-
1711-
stdlib.free(ring)
1712-
return y
1632+
def roll_max(ndarray[numeric] a, int window, int minp):
1633+
"""
1634+
Moving max of 1d array of any numeric type along axis=0 ignoring NaNs.
17131635
1636+
Parameters
1637+
----------
1638+
a: numpy array
1639+
window: int, size of rolling window
1640+
minp: if number of observations in window
1641+
is below this, output a NaN
1642+
"""
1643+
return _roll_min_max(a, window, minp, 1)
17141644

1715-
cdef double_t _get_max(object skiplist, int nobs, int minp):
1716-
if nobs >= minp:
1717-
return <IndexableSkiplist> skiplist.get(nobs - 1)
1718-
else:
1719-
return NaN
1645+
@cython.boundscheck(False)
1646+
@cython.wraparound(False)
1647+
def roll_min(ndarray[numeric] a, int window, int minp):
1648+
"""
1649+
Moving max of 1d array of any numeric type along axis=0 ignoring NaNs.
17201650
1651+
Parameters
1652+
----------
1653+
a: numpy array
1654+
window: int, size of rolling window
1655+
minp: if number of observations in window
1656+
is below this, output a NaN
1657+
"""
1658+
return _roll_min_max(a, window, minp, 0)
17211659

17221660
@cython.boundscheck(False)
17231661
@cython.wraparound(False)
1724-
def roll_min(np.ndarray[np.float64_t, ndim=1] a, int window, int minp):
1725-
"Moving min of 1d array of dtype=float64 along axis=0 ignoring NaNs."
1726-
cdef np.float64_t ai, aold
1662+
cdef _roll_min_max(ndarray[numeric] a, int window, int minp, bint is_max):
1663+
"Moving min/max of 1d array of any numeric type along axis=0 ignoring NaNs."
1664+
cdef numeric ai, aold
17271665
cdef Py_ssize_t count
1728-
cdef pairs* ring
1729-
cdef pairs* minpair
1730-
cdef pairs* end
1731-
cdef pairs* last
1666+
cdef Py_ssize_t* death
1667+
cdef numeric* ring
1668+
cdef numeric* minvalue
1669+
cdef numeric* end
1670+
cdef numeric* last
17321671
cdef Py_ssize_t i0
17331672
cdef np.npy_intp *dim
17341673
dim = PyArray_DIMS(a)
17351674
cdef Py_ssize_t n0 = dim[0]
17361675
cdef np.npy_intp *dims = [n0]
1737-
cdef np.ndarray[np.float64_t, ndim=1] y = PyArray_EMPTY(1, dims,
1738-
NPY_float64, 0)
1676+
cdef bint should_replace
1677+
cdef np.ndarray[numeric, ndim=1] y = PyArray_EMPTY(1, dims, PyArray_TYPE(a), 0)
17391678

17401679
if window < 1:
17411680
raise ValueError('Invalid window size %d'
@@ -1747,64 +1686,79 @@ def roll_min(np.ndarray[np.float64_t, ndim=1] a, int window, int minp):
17471686

17481687
minp = _check_minp(window, minp, n0)
17491688
with nogil:
1750-
ring = <pairs*>stdlib.malloc(window * sizeof(pairs))
1689+
ring = <numeric*>stdlib.malloc(window * sizeof(numeric))
1690+
death = <Py_ssize_t*>stdlib.malloc(window * sizeof(Py_ssize_t))
17511691
end = ring + window
17521692
last = ring
17531693

1754-
minpair = ring
1694+
minvalue = ring
17551695
ai = a[0]
1756-
if ai == ai:
1757-
minpair.value = ai
1696+
if numeric in cython.floating:
1697+
if ai == ai:
1698+
minvalue[0] = ai
1699+
elif is_max:
1700+
minvalue[0] = MINfloat64
1701+
else:
1702+
minvalue[0] = MAXfloat64
17581703
else:
1759-
minpair.value = MAXfloat64
1760-
minpair.death = window
1704+
minvalue[0] = ai
1705+
death[0] = window
17611706

17621707
count = 0
17631708
for i0 in range(n0):
17641709
ai = a[i0]
1765-
if ai == ai:
1766-
count += 1
1710+
if numeric in cython.floating:
1711+
if ai == ai:
1712+
count += 1
1713+
elif is_max:
1714+
ai = MINfloat64
1715+
else:
1716+
ai = MAXfloat64
17671717
else:
1768-
ai = MAXfloat64
1718+
count += 1
17691719
if i0 >= window:
17701720
aold = a[i0 - window]
17711721
if aold == aold:
17721722
count -= 1
1773-
if minpair.death == i0:
1774-
minpair += 1
1775-
if minpair >= end:
1776-
minpair = ring
1777-
if ai <= minpair.value:
1778-
minpair.value = ai
1779-
minpair.death = i0 + window
1780-
last = minpair
1723+
if death[minvalue-ring] == i0:
1724+
minvalue += 1
1725+
if minvalue >= end:
1726+
minvalue = ring
1727+
should_replace = ai >= minvalue[0] if is_max else ai <= minvalue[0]
1728+
if should_replace:
1729+
minvalue[0] = ai
1730+
death[minvalue-ring] = i0 + window
1731+
last = minvalue
17811732
else:
1782-
while last.value >= ai:
1733+
should_replace = last[0] <= ai if is_max else last[0] >= ai
1734+
while should_replace:
17831735
if last == ring:
17841736
last = end
17851737
last -= 1
1738+
should_replace = last[0] <= ai if is_max else last[0] >= ai
17861739
last += 1
17871740
if last == end:
17881741
last = ring
1789-
last.value = ai
1790-
last.death = i0 + window
1791-
if count >= minp:
1792-
y[i0] = minpair.value
1742+
last[0] = ai
1743+
death[last - ring] = i0 + window
1744+
if numeric in cython.floating:
1745+
if count >= minp:
1746+
y[i0] = minvalue[0]
1747+
else:
1748+
y[i0] = NaN
17931749
else:
1794-
y[i0] = NaN
1750+
y[i0] = minvalue[0]
17951751

17961752
for i0 in range(minp - 1):
1797-
y[i0] = NaN
1753+
if numeric in cython.floating:
1754+
y[i0] = NaN
1755+
else:
1756+
y[i0] = 0
17981757

17991758
stdlib.free(ring)
1759+
stdlib.free(death)
18001760
return y
18011761

1802-
cdef double_t _get_min(object skiplist, int nobs, int minp):
1803-
if nobs >= minp:
1804-
return <IndexableSkiplist> skiplist.get(0)
1805-
else:
1806-
return NaN
1807-
18081762
def roll_quantile(ndarray[float64_t, cast=True] input, int win,
18091763
int minp, double quantile):
18101764
'''

pandas/tests/test_window.py

+16
Original file line numberDiff line numberDiff line change
@@ -2697,3 +2697,19 @@ def test_rolling_median_memory_error(self):
26972697
n = 20000
26982698
Series(np.random.randn(n)).rolling(window=2, center=False).median()
26992699
Series(np.random.randn(n)).rolling(window=2, center=False).median()
2700+
2701+
def test_rolling_min_max_numeric_types(self):
2702+
# GH12373
2703+
types_test = [np.dtype("f{}".format(width)) for width in [4, 8]]
2704+
types_test.extend([np.dtype("{}{}".format(sign, width))
2705+
for width in [1, 2, 4, 8] for sign in "ui"])
2706+
for data_type in types_test:
2707+
# Just testing that these don't throw exceptions and that
2708+
# the return type is float64. Other tests will cover quantitative
2709+
# correctness
2710+
result = (DataFrame(np.arange(20, dtype=data_type))
2711+
.rolling(window=5).max())
2712+
self.assertEqual(result.dtypes[0], np.dtype("f8"))
2713+
result = (DataFrame(np.arange(20, dtype=data_type))
2714+
.rolling(window=5).min())
2715+
self.assertEqual(result.dtypes[0], np.dtype("f8"))

0 commit comments

Comments
 (0)