Skip to content

Commit 92cc1fe

Browse files
committed
BUG: Allow a categorical series to accept dict or Series in fillna (pandas-dev#17033)
cleaning up comments and tests
1 parent 63e8527 commit 92cc1fe

File tree

3 files changed

+81
-11
lines changed

3 files changed

+81
-11
lines changed

pandas/core/categorical.py

+45-9
Original file line numberDiff line numberDiff line change
@@ -1660,16 +1660,52 @@ def fillna(self, value=None, method=None, limit=None):
16601660

16611661
else:
16621662

1663-
if not isna(value) and value not in self.categories:
1664-
raise ValueError("fill value must be in categories")
1665-
1666-
mask = values == -1
1667-
if mask.any():
1668-
values = values.copy()
1669-
if isna(value):
1670-
values[mask] = -1
1663+
if isinstance(value, ABCSeries):
1664+
if not value[~value.isin(self.categories)].isna().all():
1665+
raise ValueError("fill value must be in categories")
1666+
1667+
# Check if single scalar in the value Series
1668+
# (e.g., s.fillna(pd.Series('a')))
1669+
if (len(value[value.notna()]) == 1 and
1670+
value[value.notna()].index == 0):
1671+
values_codes = _get_codes_for_values(value[value.notna()],
1672+
self.categories)
1673+
mask = values == -1
1674+
values[mask] = values_codes
16711675
else:
1672-
values[mask] = self.categories.get_loc(value)
1676+
values_codes = _get_codes_for_values(value,
1677+
self.categories)
1678+
index = np.where(values_codes != -1)
1679+
values[index] = values_codes[values_codes != -1]
1680+
# from pandas import Series
1681+
# values_codes = Series(_get_codes_for_values(value[value.notna()],
1682+
# self.categories), index=value[value.notna()].index)
1683+
# values[values_codes.index] = values_codes
1684+
1685+
elif isinstance(value, dict):
1686+
from pandas import Series
1687+
value = Series(value)
1688+
1689+
if not value[~value.isin(self.categories)].isna().all():
1690+
raise ValueError("fill value must be in categories")
1691+
1692+
# Convert to Series to allow use of index values
1693+
values_codes = Series(_get_codes_for_values(value,
1694+
self.categories), index=value.index)
1695+
values[values_codes.index] = values_codes
1696+
1697+
# Scalar value
1698+
else:
1699+
if not isna(value) and value not in self.categories:
1700+
raise ValueError("fill value must be in categories")
1701+
1702+
mask = values == -1
1703+
if mask.any():
1704+
values = values.copy()
1705+
if isna(value):
1706+
values[mask] = -1
1707+
else:
1708+
values[mask] = self.categories.get_loc(value)
16731709

16741710
return self._constructor(values, categories=self.categories,
16751711
ordered=self.ordered, fastpath=True)

pandas/core/generic.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from pandas.core.common import (_count_not_none,
3333
_maybe_box_datetimelike, _values_from_object,
3434
AbstractMethodError, SettingWithCopyError,
35-
SettingWithCopyWarning)
35+
SettingWithCopyWarning, is_categorical_dtype)
3636

3737
from pandas.core.base import PandasObject, SelectionMixin
3838
from pandas.core.index import (Index, MultiIndex, _ensure_index,
@@ -4298,7 +4298,9 @@ def fillna(self, value=None, method=None, axis=None, inplace=False,
42984298
return self
42994299

43004300
if self.ndim == 1:
4301-
if isinstance(value, (dict, ABCSeries)):
4301+
if isinstance(value, dict) and is_categorical_dtype(self):
4302+
pass
4303+
elif isinstance(value, (ABCSeries, dict)):
43024304
from pandas import Series
43034305
value = Series(value)
43044306
elif not is_list_like(value):

pandas/tests/test_categorical.py

+32
Original file line numberDiff line numberDiff line change
@@ -4601,6 +4601,38 @@ def f():
46014601
df = pd.DataFrame({'a': pd.Categorical(idx)})
46024602
tm.assert_frame_equal(df.fillna(value=pd.NaT), df)
46034603

4604+
@pytest.mark.parametrize('fill_value expected_output', [
4605+
('a', ['a', 'a', 'b', 'a', 'a']),
4606+
({1: 'a', 3: 'b', 4: 'b'}, ['a', 'a', 'b', 'b', 'b']),
4607+
({1: 'a'}, ['a', 'a', 'b', np.nan, np.nan]),
4608+
({1: 'a', 3: 'b'}, ['a', 'a', 'b', 'b', np.nan]),
4609+
(pd.Series('a'), ['a', 'a', 'b', 'a', 'a']),
4610+
(pd.Series({1: 'a', 3: 'b'}), ['a', 'a', 'b', 'b', np.nan])
4611+
])
4612+
def fillna_series_categorical(self, fill_value, expected_output):
4613+
# GH 17033
4614+
# Test fillna for a Categorical series
4615+
data = ['a', np.nan, 'b', np.nan, np.nan]
4616+
s = pd.Series(pd.Categorical(data, categories=['a', 'b']))
4617+
exp = pd.Series(pd.Categorical(expected_output, categories=['a', 'b']))
4618+
tm.assert_series_equal(s.fillna(fill_value), exp)
4619+
4620+
s1 = pd.Series(data, categories=['a', 'b'])
4621+
with tm.assert_raises_regex(ValueError,
4622+
"fill value must be in categories"):
4623+
s1.fillna('cat')
4624+
4625+
s2 = pd.Series(data, categories=['a', 'b'])
4626+
with tm.assert_raises_regex(ValueError,
4627+
"fill value must be in categories"):
4628+
s2.fillna(pd.Series('cat'))
4629+
4630+
s3 = pd.Series(data, categories=['a', 'b'])
4631+
with tm.assert_raises_regex(TypeError,
4632+
'"value" parameter must be a scalar or '
4633+
'dict but you passed a "list"'):
4634+
s3.fillna(['a', 'b'])
4635+
46044636
def test_astype_to_other(self):
46054637

46064638
s = self.cat['value_group']

0 commit comments

Comments
 (0)