Skip to content

Commit 74614cb

Browse files
committed
ENH: ExtensionArray.fillna
1 parent 1e4c50a commit 74614cb

File tree

6 files changed

+225
-65
lines changed

6 files changed

+225
-65
lines changed

pandas/core/arrays/base.py

+84
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
"""An interface for extending pandas with custom arrays."""
2+
import itertools
3+
24
import numpy as np
35

46
from pandas.errors import AbstractMethodError
@@ -216,6 +218,88 @@ def isna(self):
216218
"""
217219
raise AbstractMethodError(self)
218220

221+
def tolist(self):
222+
# type: () -> list
223+
"""Convert the array to a list of scalars."""
224+
return list(self)
225+
226+
def fillna(self, value=None, method=None, limit=None):
227+
""" Fill NA/NaN values using the specified method.
228+
229+
Parameters
230+
----------
231+
method : {'backfill', 'bfill', 'pad', 'ffill', None}, default None
232+
Method to use for filling holes in reindexed Series
233+
pad / ffill: propagate last valid observation forward to next valid
234+
backfill / bfill: use NEXT valid observation to fill gap
235+
value : scalar, array-like
236+
If a scalar value is passed it is used to fill all missing values.
237+
Alternatively, an array-like 'value' can be given. It's expected
238+
that the array-like have the same length as 'self'.
239+
limit : int, default None
240+
(Not implemented yet for ExtensionArray!)
241+
If method is specified, this is the maximum number of consecutive
242+
NaN values to forward/backward fill. In other words, if there is
243+
a gap with more than this number of consecutive NaNs, it will only
244+
be partially filled. If method is not specified, this is the
245+
maximum number of entries along the entire axis where NaNs will be
246+
filled.
247+
248+
Returns
249+
-------
250+
filled : ExtensionArray with NA/NaN filled
251+
"""
252+
from pandas.api.types import is_scalar
253+
from pandas.util._validators import validate_fillna_kwargs
254+
255+
value, method = validate_fillna_kwargs(value, method)
256+
257+
if not is_scalar(value):
258+
if len(value) != len(self):
259+
raise ValueError("Length of 'value' does not match. Got ({}) "
260+
" expected {}".format(len(value), len(self)))
261+
else:
262+
value = itertools.cycle([value])
263+
264+
if limit is not None:
265+
msg = ("Specifying 'limit' for 'fillna' has not been implemented "
266+
"yet for {} typed data".format(self.dtype))
267+
raise NotImplementedError(msg)
268+
269+
mask = self.isna()
270+
271+
if mask.any():
272+
# ffill / bfill
273+
if method is not None:
274+
if method == 'backfill':
275+
data = reversed(self)
276+
mask = reversed(mask)
277+
last_valid = self[len(self) - 1]
278+
else:
279+
last_valid = self[0]
280+
data = self
281+
282+
new_values = []
283+
284+
for is_na, val in zip(mask, data):
285+
if is_na:
286+
new_values.append(last_valid)
287+
else:
288+
new_values.append(val)
289+
last_valid = val
290+
291+
if method in {'bfill', 'backfill'}:
292+
new_values = list(reversed(new_values))
293+
else:
294+
# fill with value
295+
new_values = [
296+
val if is_na else original
297+
for is_na, original, val in zip(mask, self, value)
298+
]
299+
else:
300+
new_values = self
301+
return type(self)(new_values)
302+
219303
# ------------------------------------------------------------------------
220304
# Indexing methods
221305
# ------------------------------------------------------------------------

pandas/core/internals.py

+17-21
Original file line numberDiff line numberDiff line change
@@ -1963,6 +1963,23 @@ def concat_same_type(self, to_concat, placement=None):
19631963
return self.make_block_same_class(values, ndim=self.ndim,
19641964
placement=placement)
19651965

1966+
def fillna(self, value, limit=None, inplace=False, downcast=None,
1967+
mgr=None):
1968+
values = self.values if inplace else self.values.copy()
1969+
values = values.fillna(value=value, limit=limit)
1970+
return [self.make_block_same_class(values=values,
1971+
placement=self.mgr_locs,
1972+
ndim=self.ndim)]
1973+
1974+
def interpolate(self, method='pad', axis=0, inplace=False, limit=None,
1975+
fill_value=None, **kwargs):
1976+
1977+
values = self.values if inplace else self.values.copy()
1978+
return self.make_block_same_class(
1979+
values=values.fillna(value=fill_value, method=method,
1980+
limit=limit),
1981+
placement=self.mgr_locs)
1982+
19661983

19671984
class NumericBlock(Block):
19681985
__slots__ = ()
@@ -2522,27 +2539,6 @@ def _try_coerce_result(self, result):
25222539

25232540
return result
25242541

2525-
def fillna(self, value, limit=None, inplace=False, downcast=None,
2526-
mgr=None):
2527-
# we may need to upcast our fill to match our dtype
2528-
if limit is not None:
2529-
raise NotImplementedError("specifying a limit for 'fillna' has "
2530-
"not been implemented yet")
2531-
2532-
values = self.values if inplace else self.values.copy()
2533-
values = self._try_coerce_result(values.fillna(value=value,
2534-
limit=limit))
2535-
return [self.make_block(values=values)]
2536-
2537-
def interpolate(self, method='pad', axis=0, inplace=False, limit=None,
2538-
fill_value=None, **kwargs):
2539-
2540-
values = self.values if inplace else self.values.copy()
2541-
return self.make_block_same_class(
2542-
values=values.fillna(fill_value=fill_value, method=method,
2543-
limit=limit),
2544-
placement=self.mgr_locs)
2545-
25462542
def shift(self, periods, axis=0, mgr=None):
25472543
return self.make_block_same_class(values=self.values.shift(periods),
25482544
placement=self.mgr_locs)

pandas/tests/extension/base/missing.py

+80
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pytest
23

34
import pandas as pd
45
import pandas.util.testing as tm
@@ -45,3 +46,82 @@ def test_dropna_frame(self, data_missing):
4546
result = df.dropna()
4647
expected = df.iloc[:0]
4748
self.assert_frame_equal(result, expected)
49+
50+
def test_fillna_limit_raises(self, data_missing):
51+
ser = pd.Series(data_missing)
52+
fill_value = data_missing[1]
53+
xpr = "Specifying 'limit' for 'fillna'.*{}".format(data_missing.dtype)
54+
55+
with tm.assert_raises_regex(NotImplementedError, xpr):
56+
ser.fillna(fill_value, limit=2)
57+
58+
def test_fillna_series(self, data_missing):
59+
fill_value = data_missing[1]
60+
ser = pd.Series(data_missing)
61+
62+
result = ser.fillna(fill_value)
63+
expected = pd.Series(type(data_missing)([fill_value, fill_value]))
64+
self.assert_series_equal(result, expected)
65+
66+
# Fill with a series
67+
result = ser.fillna(expected)
68+
self.assert_series_equal(result, expected)
69+
70+
# Fill with a series not affecting the missing values
71+
result = ser.fillna(ser)
72+
self.assert_series_equal(result, ser)
73+
74+
@pytest.mark.xfail(reason="Too magical?")
75+
def test_fillna_series_with_dict(self, data_missing):
76+
fill_value = data_missing[1]
77+
ser = pd.Series(data_missing)
78+
expected = pd.Series(type(data_missing)([fill_value, fill_value]))
79+
80+
# Fill with a dict
81+
result = ser.fillna({0: fill_value})
82+
self.assert_series_equal(result, expected)
83+
84+
# Fill with a dict not affecting the missing values
85+
result = ser.fillna({1: fill_value})
86+
ser = pd.Series(data_missing)
87+
self.assert_series_equal(result, ser)
88+
89+
@pytest.mark.parametrize('method', ['ffill', 'bfill'])
90+
def test_fillna_series_method(self, data_missing, method):
91+
fill_value = data_missing[1]
92+
93+
if method == 'ffill':
94+
data_missing = type(data_missing)(data_missing[::-1])
95+
96+
result = pd.Series(data_missing).fillna(method=method)
97+
expected = pd.Series(type(data_missing)([fill_value, fill_value]))
98+
99+
self.assert_series_equal(result, expected)
100+
101+
def test_fillna_frame(self, data_missing):
102+
fill_value = data_missing[1]
103+
104+
result = pd.DataFrame({
105+
"A": data_missing,
106+
"B": [1, 2]
107+
}).fillna(fill_value)
108+
109+
expected = pd.DataFrame({
110+
"A": type(data_missing)([fill_value, fill_value]),
111+
"B": [1, 2],
112+
})
113+
114+
self.assert_frame_equal(result, expected)
115+
116+
def test_fillna_fill_other(self, data):
117+
result = pd.DataFrame({
118+
"A": data,
119+
"B": [np.nan] * len(data)
120+
}).fillna({"B": 0.0})
121+
122+
expected = pd.DataFrame({
123+
"A": data,
124+
"B": [0.0] * len(result),
125+
})
126+
127+
self.assert_frame_equal(result, expected)

pandas/tests/extension/category/test_categorical.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,10 @@ def test_getitem_scalar(self):
6969

7070

7171
class TestMissing(base.BaseMissingTests):
72-
pass
72+
73+
@pytest.mark.skip(reason="Backwards compatability")
74+
def test_fillna_limit_raises(self):
75+
"""Has a different error message."""
7376

7477

7578
class TestMethods(base.BaseMethodsTests):

pandas/tests/extension/decimal/test_decimal.py

+33-42
Original file line numberDiff line numberDiff line change
@@ -35,68 +35,59 @@ def na_value():
3535
return decimal.Decimal("NaN")
3636

3737

38-
class TestDtype(base.BaseDtypeTests):
39-
pass
38+
class BaseDecimal(object):
39+
@staticmethod
40+
def assert_series_equal(left, right, *args, **kwargs):
41+
# tm.assert_series_equal doesn't handle Decimal('NaN').
42+
# We will ensure that the NA values match, and then
43+
# drop those values before moving on.
4044

45+
left_na = left.isna()
46+
right_na = right.isna()
4147

42-
class TestInterface(base.BaseInterfaceTests):
43-
pass
48+
tm.assert_series_equal(left_na, right_na)
49+
tm.assert_series_equal(left[~left_na], right[~right_na],
50+
*args, **kwargs)
4451

52+
@staticmethod
53+
def assert_frame_equal(left, right, *args, **kwargs):
54+
# TODO(EA): select_dtypes
55+
decimals = (left.dtypes == 'decimal').index
4556

46-
class TestConstructors(base.BaseConstructorsTests):
47-
pass
57+
for col in decimals:
58+
BaseDecimal.assert_series_equal(left[col], right[col],
59+
*args, **kwargs)
4860

61+
left = left.drop(columns=decimals)
62+
right = right.drop(columns=decimals)
63+
tm.assert_frame_equal(left, right, *args, **kwargs)
4964

50-
class TestReshaping(base.BaseReshapingTests):
5165

52-
def test_align(self, data, na_value):
53-
# Have to override since assert_series_equal doesn't
54-
# compare Decimal(NaN) properly.
55-
a = data[:3]
56-
b = data[2:5]
57-
r1, r2 = pd.Series(a).align(pd.Series(b, index=[1, 2, 3]))
66+
class TestDtype(BaseDecimal, base.BaseDtypeTests):
67+
pass
5868

59-
# NaN handling
60-
e1 = pd.Series(type(data)(list(a) + [na_value]))
61-
e2 = pd.Series(type(data)([na_value] + list(b)))
62-
tm.assert_series_equal(r1.iloc[:3], e1.iloc[:3])
63-
assert r1[3].is_nan()
64-
assert e1[3].is_nan()
6569

66-
tm.assert_series_equal(r2.iloc[1:], e2.iloc[1:])
67-
assert r2[0].is_nan()
68-
assert e2[0].is_nan()
70+
class TestInterface(BaseDecimal, base.BaseInterfaceTests):
71+
pass
6972

70-
def test_align_frame(self, data, na_value):
71-
# Override for Decimal(NaN) comparison
72-
a = data[:3]
73-
b = data[2:5]
74-
r1, r2 = pd.DataFrame({'A': a}).align(
75-
pd.DataFrame({'A': b}, index=[1, 2, 3])
76-
)
7773

78-
# Assumes that the ctor can take a list of scalars of the type
79-
e1 = pd.DataFrame({'A': type(data)(list(a) + [na_value])})
80-
e2 = pd.DataFrame({'A': type(data)([na_value] + list(b))})
74+
class TestConstructors(BaseDecimal, base.BaseConstructorsTests):
75+
pass
8176

82-
tm.assert_frame_equal(r1.iloc[:3], e1.iloc[:3])
83-
assert r1.loc[3, 'A'].is_nan()
84-
assert e1.loc[3, 'A'].is_nan()
8577

86-
tm.assert_frame_equal(r2.iloc[1:], e2.iloc[1:])
87-
assert r2.loc[0, 'A'].is_nan()
88-
assert e2.loc[0, 'A'].is_nan()
78+
class TestReshaping(BaseDecimal, base.BaseReshapingTests):
79+
pass
8980

9081

91-
class TestGetitem(base.BaseGetitemTests):
82+
class TestGetitem(BaseDecimal, base.BaseGetitemTests):
9283
pass
9384

9485

95-
class TestMissing(base.BaseMissingTests):
86+
class TestMissing(BaseDecimal, base.BaseMissingTests):
9687
pass
9788

9889

99-
class TestMethods(base.BaseMethodsTests):
90+
class TestMethods(BaseDecimal, base.BaseMethodsTests):
10091
@pytest.mark.parametrize('dropna', [True, False])
10192
@pytest.mark.xfail(reason="value_counts not implemented yet.")
10293
def test_value_counts(self, all_data, dropna):
@@ -112,7 +103,7 @@ def test_value_counts(self, all_data, dropna):
112103
tm.assert_series_equal(result, expected)
113104

114105

115-
class TestCasting(base.BaseCastingTests):
106+
class TestCasting(BaseDecimal, base.BaseCastingTests):
116107
pass
117108

118109

pandas/tests/extension/json/test_json.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,13 @@ class TestGetitem(base.BaseGetitemTests):
6060

6161

6262
class TestMissing(base.BaseMissingTests):
63-
pass
63+
@pytest.mark.xfail(reason="Setting a dict as a scalar")
64+
def test_fillna_series(self):
65+
"""We treat dictionaries as a mapping in fillna, not a scalar."""
66+
67+
@pytest.mark.xfail(reason="Setting a dict as a scalar")
68+
def test_fillna_frame(self):
69+
"""We treat dictionaries as a mapping in fillna, not a scalar."""
6470

6571

6672
class TestMethods(base.BaseMethodsTests):

0 commit comments

Comments
 (0)