Skip to content

Commit bf31c04

Browse files
ahcubjreback
authored andcommitted
ENH: fill_value argument for shift #15486 (#24128)
1 parent 1905485 commit bf31c04

File tree

16 files changed

+183
-30
lines changed

16 files changed

+183
-30
lines changed

doc/source/whatsnew/v0.24.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ New features
3131
- :func:`read_feather` now accepts ``columns`` as an argument, allowing the user to specify which columns should be read. (:issue:`24025`)
3232
- :func:`DataFrame.to_html` now accepts ``render_links`` as an argument, allowing the user to generate HTML with links to any URLs that appear in the DataFrame.
3333
See the :ref:`section on writing HTML <io.html>` in the IO docs for example usage. (:issue:`2679`)
34+
- :meth:`DataFrame.shift` :meth:`Series.shift`, :meth:`ExtensionArray.shift`, :meth:`SparseArray.shift`, :meth:`Period.shift`, :meth:`GroupBy.shift`, :meth:`Categorical.shift`, :meth:`NDFrame.shift` and :meth:`Block.shift` now accept `fill_value` as an argument, allowing the user to specify a value which will be used instead of NA/NaT in the empty periods. (:issue:`15486`)
3435

3536
.. _whatsnew_0240.values_api:
3637

pandas/core/arrays/base.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from pandas.core.dtypes.common import is_list_like
1818
from pandas.core.dtypes.generic import ABCIndexClass, ABCSeries
19+
from pandas.core.dtypes.missing import isna
1920

2021
from pandas.core import ops
2122

@@ -449,8 +450,8 @@ def dropna(self):
449450
"""
450451
return self[~self.isna()]
451452

452-
def shift(self, periods=1):
453-
# type: (int) -> ExtensionArray
453+
def shift(self, periods=1, fill_value=None):
454+
# type: (int, object) -> ExtensionArray
454455
"""
455456
Shift values by desired number.
456457
@@ -465,6 +466,12 @@ def shift(self, periods=1):
465466
The number of periods to shift. Negative values are allowed
466467
for shifting backwards.
467468
469+
fill_value : object, optional
470+
The scalar value to use for newly introduced missing values.
471+
The default is ``self.dtype.na_value``
472+
473+
.. versionadded:: 0.24.0
474+
468475
Returns
469476
-------
470477
shifted : ExtensionArray
@@ -483,8 +490,11 @@ def shift(self, periods=1):
483490
if not len(self) or periods == 0:
484491
return self.copy()
485492

493+
if isna(fill_value):
494+
fill_value = self.dtype.na_value
495+
486496
empty = self._from_sequence(
487-
[self.dtype.na_value] * min(abs(periods), len(self)),
497+
[fill_value] * min(abs(periods), len(self)),
488498
dtype=self.dtype
489499
)
490500
if periods > 0:

pandas/core/arrays/categorical.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -1257,14 +1257,18 @@ def shape(self):
12571257

12581258
return tuple([len(self._codes)])
12591259

1260-
def shift(self, periods):
1260+
def shift(self, periods, fill_value=None):
12611261
"""
12621262
Shift Categorical by desired number of periods.
12631263
12641264
Parameters
12651265
----------
12661266
periods : int
12671267
Number of periods to move, can be positive or negative
1268+
fill_value : object, optional
1269+
The scalar value to use for newly introduced missing values.
1270+
1271+
.. versionadded:: 0.24.0
12681272
12691273
Returns
12701274
-------
@@ -1277,10 +1281,18 @@ def shift(self, periods):
12771281
raise NotImplementedError("Categorical with ndim > 1.")
12781282
if np.prod(codes.shape) and (periods != 0):
12791283
codes = np.roll(codes, ensure_platform_int(periods), axis=0)
1284+
if isna(fill_value):
1285+
fill_value = -1
1286+
elif fill_value in self.categories:
1287+
fill_value = self.categories.get_loc(fill_value)
1288+
else:
1289+
raise ValueError("'fill_value={}' is not present "
1290+
"in this Categorical's "
1291+
"categories".format(fill_value))
12801292
if periods > 0:
1281-
codes[:periods] = -1
1293+
codes[:periods] = fill_value
12821294
else:
1283-
codes[periods:] = -1
1295+
codes[periods:] = fill_value
12841296

12851297
return self.from_codes(codes, categories=self.categories,
12861298
ordered=self.ordered)

pandas/core/arrays/period.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ def value_counts(self, dropna=False):
457457

458458
# --------------------------------------------------------------------
459459

460-
def shift(self, periods=1):
460+
def shift(self, periods=1, fill_value=None):
461461
"""
462462
Shift values by desired number.
463463
@@ -471,6 +471,9 @@ def shift(self, periods=1):
471471
periods : int, default 1
472472
The number of periods to shift. Negative values are allowed
473473
for shifting backwards.
474+
fill_value : optional, default NaT
475+
476+
.. versionadded:: 0.24.0
474477
475478
Returns
476479
-------
@@ -479,7 +482,7 @@ def shift(self, periods=1):
479482
# TODO(DatetimeArray): remove
480483
# The semantics for Index.shift differ from EA.shift
481484
# then just call super.
482-
return ExtensionArray.shift(self, periods)
485+
return ExtensionArray.shift(self, periods, fill_value=fill_value)
483486

484487
def _time_shift(self, n, freq=None):
485488
"""

pandas/core/arrays/sparse.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -889,12 +889,15 @@ def fillna(self, value=None, method=None, limit=None):
889889

890890
return self._simple_new(new_values, self._sparse_index, new_dtype)
891891

892-
def shift(self, periods=1):
892+
def shift(self, periods=1, fill_value=None):
893893

894894
if not len(self) or periods == 0:
895895
return self.copy()
896896

897-
subtype = np.result_type(np.nan, self.dtype.subtype)
897+
if isna(fill_value):
898+
fill_value = self.dtype.na_value
899+
900+
subtype = np.result_type(fill_value, self.dtype.subtype)
898901

899902
if subtype != self.dtype.subtype:
900903
# just coerce up front
@@ -903,7 +906,7 @@ def shift(self, periods=1):
903906
arr = self
904907

905908
empty = self._from_sequence(
906-
[self.dtype.na_value] * min(abs(periods), len(self)),
909+
[fill_value] * min(abs(periods), len(self)),
907910
dtype=arr.dtype
908911
)
909912

pandas/core/frame.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3938,9 +3938,9 @@ def replace(self, to_replace=None, value=None, inplace=False, limit=None,
39383938
method=method)
39393939

39403940
@Appender(_shared_docs['shift'] % _shared_doc_kwargs)
3941-
def shift(self, periods=1, freq=None, axis=0):
3941+
def shift(self, periods=1, freq=None, axis=0, fill_value=None):
39423942
return super(DataFrame, self).shift(periods=periods, freq=freq,
3943-
axis=axis)
3943+
axis=axis, fill_value=fill_value)
39443944

39453945
def set_index(self, keys, drop=True, append=False, inplace=False,
39463946
verify_integrity=False):

pandas/core/generic.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -8849,6 +8849,14 @@ def mask(self, cond, other=np.nan, inplace=False, axis=None, level=None,
88498849
extend the index when shifting and preserve the original data.
88508850
axis : {0 or 'index', 1 or 'columns', None}, default None
88518851
Shift direction.
8852+
fill_value : object, optional
8853+
The scalar value to use for newly introduced missing values.
8854+
the default depends on the dtype of `self`.
8855+
For numeric data, ``np.nan`` is used.
8856+
For datetime, timedelta, or period data, etc. :attr:`NaT` is used.
8857+
For extension dtypes, ``self.dtype.na_value`` is used.
8858+
8859+
.. versionchanged:: 0.24.0
88528860
88538861
Returns
88548862
-------
@@ -8884,16 +8892,25 @@ def mask(self, cond, other=np.nan, inplace=False, axis=None, level=None,
88848892
2 NaN 15.0 18.0
88858893
3 NaN 30.0 33.0
88868894
4 NaN 45.0 48.0
8895+
8896+
>>> df.shift(periods=3, fill_value=0)
8897+
Col1 Col2 Col3
8898+
0 0 0 0
8899+
1 0 0 0
8900+
2 0 0 0
8901+
3 10 13 17
8902+
4 20 23 27
88878903
""")
88888904

88898905
@Appender(_shared_docs['shift'] % _shared_doc_kwargs)
8890-
def shift(self, periods=1, freq=None, axis=0):
8906+
def shift(self, periods=1, freq=None, axis=0, fill_value=None):
88918907
if periods == 0:
88928908
return self.copy()
88938909

88948910
block_axis = self._get_block_manager_axis(axis)
88958911
if freq is None:
8896-
new_data = self._data.shift(periods=periods, axis=block_axis)
8912+
new_data = self._data.shift(periods=periods, axis=block_axis,
8913+
fill_value=fill_value)
88978914
else:
88988915
return self.tshift(periods, freq)
88998916

pandas/core/groupby/groupby.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1994,7 +1994,7 @@ def _get_cythonized_result(self, how, grouper, aggregate=False,
19941994

19951995
@Substitution(name='groupby')
19961996
@Appender(_common_see_also)
1997-
def shift(self, periods=1, freq=None, axis=0):
1997+
def shift(self, periods=1, freq=None, axis=0, fill_value=None):
19981998
"""
19991999
Shift each group by periods observations.
20002000
@@ -2004,10 +2004,14 @@ def shift(self, periods=1, freq=None, axis=0):
20042004
number of periods to shift
20052005
freq : frequency string
20062006
axis : axis to shift, default 0
2007+
fill_value : optional
2008+
2009+
.. versionadded:: 0.24.0
20072010
"""
20082011

2009-
if freq is not None or axis != 0:
2010-
return self.apply(lambda x: x.shift(periods, freq, axis))
2012+
if freq is not None or axis != 0 or not isna(fill_value):
2013+
return self.apply(lambda x: x.shift(periods, freq,
2014+
axis, fill_value))
20112015

20122016
return self._get_cythonized_result('group_shift_indexer',
20132017
self.grouper, cython_dtype=np.int64,

pandas/core/internals/blocks.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -1261,12 +1261,12 @@ def diff(self, n, axis=1):
12611261
new_values = algos.diff(self.values, n, axis=axis)
12621262
return [self.make_block(values=new_values)]
12631263

1264-
def shift(self, periods, axis=0):
1264+
def shift(self, periods, axis=0, fill_value=None):
12651265
""" shift the block by periods, possibly upcast """
12661266

12671267
# convert integer to float if necessary. need to do a lot more than
12681268
# that, handle boolean etc also
1269-
new_values, fill_value = maybe_upcast(self.values)
1269+
new_values, fill_value = maybe_upcast(self.values, fill_value)
12701270

12711271
# make sure array sent to np.roll is c_contiguous
12721272
f_ordered = new_values.flags.f_contiguous
@@ -1955,17 +1955,19 @@ def interpolate(self, method='pad', axis=0, inplace=False, limit=None,
19551955
limit=limit),
19561956
placement=self.mgr_locs)
19571957

1958-
def shift(self, periods, axis=0):
1958+
def shift(self, periods, axis=0, fill_value=None):
19591959
"""
19601960
Shift the block by `periods`.
19611961
19621962
Dispatches to underlying ExtensionArray and re-boxes in an
19631963
ExtensionBlock.
19641964
"""
19651965
# type: (int, Optional[BlockPlacement]) -> List[ExtensionBlock]
1966-
return [self.make_block_same_class(self.values.shift(periods=periods),
1967-
placement=self.mgr_locs,
1968-
ndim=self.ndim)]
1966+
return [
1967+
self.make_block_same_class(
1968+
self.values.shift(periods=periods, fill_value=fill_value),
1969+
placement=self.mgr_locs, ndim=self.ndim)
1970+
]
19691971

19701972
def where(self, other, cond, align=True, errors='raise',
19711973
try_cast=False, axis=0, transpose=False):
@@ -3023,7 +3025,7 @@ def _try_coerce_result(self, result):
30233025
def _box_func(self):
30243026
return lambda x: tslibs.Timestamp(x, tz=self.dtype.tz)
30253027

3026-
def shift(self, periods, axis=0):
3028+
def shift(self, periods, axis=0, fill_value=None):
30273029
""" shift the block by periods """
30283030

30293031
# think about moving this to the DatetimeIndex. This is a non-freq
@@ -3038,10 +3040,12 @@ def shift(self, periods, axis=0):
30383040

30393041
new_values = self.values.asi8.take(indexer)
30403042

3043+
if isna(fill_value):
3044+
fill_value = tslibs.iNaT
30413045
if periods > 0:
3042-
new_values[:periods] = tslibs.iNaT
3046+
new_values[:periods] = fill_value
30433047
else:
3044-
new_values[periods:] = tslibs.iNaT
3048+
new_values[periods:] = fill_value
30453049

30463050
new_values = self.values._shallow_copy(new_values)
30473051
return [self.make_block_same_class(new_values,

pandas/core/series.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -3765,8 +3765,9 @@ def replace(self, to_replace=None, value=None, inplace=False, limit=None,
37653765
regex=regex, method=method)
37663766

37673767
@Appender(generic._shared_docs['shift'] % _shared_doc_kwargs)
3768-
def shift(self, periods=1, freq=None, axis=0):
3769-
return super(Series, self).shift(periods=periods, freq=freq, axis=axis)
3768+
def shift(self, periods=1, freq=None, axis=0, fill_value=None):
3769+
return super(Series, self).shift(periods=periods, freq=freq, axis=axis,
3770+
fill_value=fill_value)
37703771

37713772
def reindex_axis(self, labels, axis=0, **kwargs):
37723773
"""

pandas/tests/arrays/sparse/test_array.py

+13
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pandas.util._test_decorators as td
1111

1212
import pandas as pd
13+
from pandas import isna
1314
from pandas.core.sparse.api import SparseArray, SparseDtype, SparseSeries
1415
import pandas.util.testing as tm
1516
from pandas.util.testing import assert_almost_equal
@@ -262,6 +263,18 @@ def test_take_negative(self):
262263
exp = SparseArray(np.take(self.arr_data, [-4, -3, -2]))
263264
tm.assert_sp_array_equal(self.arr.take([-4, -3, -2]), exp)
264265

266+
@pytest.mark.parametrize('fill_value', [0, None, np.nan])
267+
def test_shift_fill_value(self, fill_value):
268+
# GH #24128
269+
sparse = SparseArray(np.array([1, 0, 0, 3, 0]),
270+
fill_value=8.0)
271+
res = sparse.shift(1, fill_value=fill_value)
272+
if isna(fill_value):
273+
fill_value = res.dtype.na_value
274+
exp = SparseArray(np.array([fill_value, 1, 0, 0, 3]),
275+
fill_value=8.0)
276+
tm.assert_sp_array_equal(res, exp)
277+
265278
def test_bad_take(self):
266279
with pytest.raises(IndexError, match="bounds"):
267280
self.arr.take([11])

pandas/tests/extension/base/methods.py

+11
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,17 @@ def test_shift_empty_array(self, data, periods):
221221
expected = empty
222222
self.assert_extension_array_equal(result, expected)
223223

224+
def test_shift_fill_value(self, data):
225+
arr = data[:4]
226+
fill_value = data[0]
227+
result = arr.shift(1, fill_value=fill_value)
228+
expected = data.take([0, 0, 1, 2])
229+
self.assert_extension_array_equal(result, expected)
230+
231+
result = arr.shift(-2, fill_value=fill_value)
232+
expected = data.take([2, 3, 0, 0])
233+
self.assert_extension_array_equal(result, expected)
234+
224235
@pytest.mark.parametrize("as_frame", [True, False])
225236
def test_hash_pandas_object_works(self, data, as_frame):
226237
# https://github.com/pandas-dev/pandas/issues/23066

pandas/tests/frame/test_timeseries.py

+14
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,20 @@ def test_shift_categorical(self):
320320
xp = DataFrame({'one': s1.shift(1), 'two': s2.shift(1)})
321321
assert_frame_equal(rs, xp)
322322

323+
def test_shift_fill_value(self):
324+
# GH #24128
325+
df = DataFrame([1, 2, 3, 4, 5],
326+
index=date_range('1/1/2000', periods=5, freq='H'))
327+
exp = DataFrame([0, 1, 2, 3, 4],
328+
index=date_range('1/1/2000', periods=5, freq='H'))
329+
result = df.shift(1, fill_value=0)
330+
assert_frame_equal(result, exp)
331+
332+
exp = DataFrame([0, 0, 1, 2, 3],
333+
index=date_range('1/1/2000', periods=5, freq='H'))
334+
result = df.shift(2, fill_value=0)
335+
assert_frame_equal(result, exp)
336+
323337
def test_shift_empty(self):
324338
# Regression test for #8019
325339
df = DataFrame({'foo': []})

pandas/tests/groupby/test_categorical.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from pandas.compat import PY37
1010
from pandas import (Index, MultiIndex, CategoricalIndex,
1111
DataFrame, Categorical, Series, qcut)
12-
from pandas.util.testing import assert_frame_equal, assert_series_equal
12+
from pandas.util.testing import (assert_equal,
13+
assert_frame_equal, assert_series_equal)
1314
import pandas.util.testing as tm
1415

1516

@@ -860,3 +861,13 @@ def test_groupby_multiindex_categorical_datetime():
860861
expected = pd.DataFrame(
861862
{'values': [0, 4, 8, 3, 4, 5, 6, np.nan, 2]}, index=idx)
862863
assert_frame_equal(result, expected)
864+
865+
866+
@pytest.mark.parametrize('fill_value', [None, np.nan, pd.NaT])
867+
def test_shift(fill_value):
868+
ct = pd.Categorical(['a', 'b', 'c', 'd'],
869+
categories=['a', 'b', 'c', 'd'], ordered=False)
870+
expected = pd.Categorical([None, 'a', 'b', 'c'],
871+
categories=['a', 'b', 'c', 'd'], ordered=False)
872+
res = ct.shift(1, fill_value=fill_value)
873+
assert_equal(res, expected)

0 commit comments

Comments
 (0)