diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index c3705fada724a..0fa02d54b5b78 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -1634,10 +1634,10 @@ def diff(arr, n: int, axis: int = 0, stacklevel=3): Parameters ---------- - arr : ndarray + arr : ndarray or ExtensionArray n : int number of periods - axis : int + axis : {0, 1} axis to shift on stacklevel : int The stacklevel for the lost dtype warning. @@ -1651,7 +1651,8 @@ def diff(arr, n: int, axis: int = 0, stacklevel=3): na = np.nan dtype = arr.dtype - if dtype.kind == "b": + is_bool = is_bool_dtype(dtype) + if is_bool: op = operator.xor else: op = operator.sub @@ -1677,17 +1678,15 @@ def diff(arr, n: int, axis: int = 0, stacklevel=3): dtype = arr.dtype is_timedelta = False - is_bool = False if needs_i8_conversion(arr.dtype): dtype = np.int64 arr = arr.view("i8") na = iNaT is_timedelta = True - elif is_bool_dtype(dtype): + elif is_bool: # We have to cast in order to be able to hold np.nan dtype = np.object_ - is_bool = True elif is_integer_dtype(dtype): # We have to cast in order to be able to hold np.nan @@ -1708,45 +1707,26 @@ def diff(arr, n: int, axis: int = 0, stacklevel=3): dtype = np.dtype(dtype) out_arr = np.empty(arr.shape, dtype=dtype) - na_indexer = [slice(None)] * arr.ndim + na_indexer = [slice(None)] * 2 na_indexer[axis] = slice(None, n) if n >= 0 else slice(n, None) out_arr[tuple(na_indexer)] = na - if arr.ndim == 2 and arr.dtype.name in _diff_special: + if arr.dtype.name in _diff_special: # TODO: can diff_2d dtype specialization troubles be fixed by defining # out_arr inside diff_2d? algos.diff_2d(arr, out_arr, n, axis, datetimelike=is_timedelta) else: # To keep mypy happy, _res_indexer is a list while res_indexer is # a tuple, ditto for lag_indexer. - _res_indexer = [slice(None)] * arr.ndim + _res_indexer = [slice(None)] * 2 _res_indexer[axis] = slice(n, None) if n >= 0 else slice(None, n) res_indexer = tuple(_res_indexer) - _lag_indexer = [slice(None)] * arr.ndim + _lag_indexer = [slice(None)] * 2 _lag_indexer[axis] = slice(None, -n) if n > 0 else slice(-n, None) lag_indexer = tuple(_lag_indexer) - # need to make sure that we account for na for datelike/timedelta - # we don't actually want to subtract these i8 numbers - if is_timedelta: - res = arr[res_indexer] - lag = arr[lag_indexer] - - mask = (arr[res_indexer] == na) | (arr[lag_indexer] == na) - if mask.any(): - res = res.copy() - res[mask] = 0 - lag = lag.copy() - lag[mask] = 0 - - result = res - lag - result[mask] = na - out_arr[res_indexer] = result - elif is_bool: - out_arr[res_indexer] = arr[res_indexer] ^ arr[lag_indexer] - else: - out_arr[res_indexer] = arr[res_indexer] - arr[lag_indexer] + out_arr[res_indexer] = op(arr[res_indexer], arr[lag_indexer]) if is_timedelta: out_arr = out_arr.view("timedelta64[ns]")