@@ -1634,10 +1634,10 @@ def diff(arr, n: int, axis: int = 0, stacklevel=3):
1634
1634
1635
1635
Parameters
1636
1636
----------
1637
- arr : ndarray
1637
+ arr : ndarray or ExtensionArray
1638
1638
n : int
1639
1639
number of periods
1640
- axis : int
1640
+ axis : {0, 1}
1641
1641
axis to shift on
1642
1642
stacklevel : int
1643
1643
The stacklevel for the lost dtype warning.
@@ -1651,7 +1651,8 @@ def diff(arr, n: int, axis: int = 0, stacklevel=3):
1651
1651
na = np .nan
1652
1652
dtype = arr .dtype
1653
1653
1654
- if dtype .kind == "b" :
1654
+ is_bool = is_bool_dtype (dtype )
1655
+ if is_bool :
1655
1656
op = operator .xor
1656
1657
else :
1657
1658
op = operator .sub
@@ -1677,17 +1678,15 @@ def diff(arr, n: int, axis: int = 0, stacklevel=3):
1677
1678
dtype = arr .dtype
1678
1679
1679
1680
is_timedelta = False
1680
- is_bool = False
1681
1681
if needs_i8_conversion (arr .dtype ):
1682
1682
dtype = np .int64
1683
1683
arr = arr .view ("i8" )
1684
1684
na = iNaT
1685
1685
is_timedelta = True
1686
1686
1687
- elif is_bool_dtype ( dtype ) :
1687
+ elif is_bool :
1688
1688
# We have to cast in order to be able to hold np.nan
1689
1689
dtype = np .object_
1690
- is_bool = True
1691
1690
1692
1691
elif is_integer_dtype (dtype ):
1693
1692
# We have to cast in order to be able to hold np.nan
@@ -1708,45 +1707,26 @@ def diff(arr, n: int, axis: int = 0, stacklevel=3):
1708
1707
dtype = np .dtype (dtype )
1709
1708
out_arr = np .empty (arr .shape , dtype = dtype )
1710
1709
1711
- na_indexer = [slice (None )] * arr . ndim
1710
+ na_indexer = [slice (None )] * 2
1712
1711
na_indexer [axis ] = slice (None , n ) if n >= 0 else slice (n , None )
1713
1712
out_arr [tuple (na_indexer )] = na
1714
1713
1715
- if arr .ndim == 2 and arr . dtype .name in _diff_special :
1714
+ if arr .dtype .name in _diff_special :
1716
1715
# TODO: can diff_2d dtype specialization troubles be fixed by defining
1717
1716
# out_arr inside diff_2d?
1718
1717
algos .diff_2d (arr , out_arr , n , axis , datetimelike = is_timedelta )
1719
1718
else :
1720
1719
# To keep mypy happy, _res_indexer is a list while res_indexer is
1721
1720
# a tuple, ditto for lag_indexer.
1722
- _res_indexer = [slice (None )] * arr . ndim
1721
+ _res_indexer = [slice (None )] * 2
1723
1722
_res_indexer [axis ] = slice (n , None ) if n >= 0 else slice (None , n )
1724
1723
res_indexer = tuple (_res_indexer )
1725
1724
1726
- _lag_indexer = [slice (None )] * arr . ndim
1725
+ _lag_indexer = [slice (None )] * 2
1727
1726
_lag_indexer [axis ] = slice (None , - n ) if n > 0 else slice (- n , None )
1728
1727
lag_indexer = tuple (_lag_indexer )
1729
1728
1730
- # need to make sure that we account for na for datelike/timedelta
1731
- # we don't actually want to subtract these i8 numbers
1732
- if is_timedelta :
1733
- res = arr [res_indexer ]
1734
- lag = arr [lag_indexer ]
1735
-
1736
- mask = (arr [res_indexer ] == na ) | (arr [lag_indexer ] == na )
1737
- if mask .any ():
1738
- res = res .copy ()
1739
- res [mask ] = 0
1740
- lag = lag .copy ()
1741
- lag [mask ] = 0
1742
-
1743
- result = res - lag
1744
- result [mask ] = na
1745
- out_arr [res_indexer ] = result
1746
- elif is_bool :
1747
- out_arr [res_indexer ] = arr [res_indexer ] ^ arr [lag_indexer ]
1748
- else :
1749
- out_arr [res_indexer ] = arr [res_indexer ] - arr [lag_indexer ]
1729
+ out_arr [res_indexer ] = op (arr [res_indexer ], arr [lag_indexer ])
1750
1730
1751
1731
if is_timedelta :
1752
1732
out_arr = out_arr .view ("timedelta64[ns]" )
0 commit comments