Skip to content

Commit 5196993

Browse files
committed
BUG: Allow a categorical series to accept dict or Series in fillna (pandas-dev#17033)
1 parent 488db6f commit 5196993

File tree

3 files changed

+85
-11
lines changed

3 files changed

+85
-11
lines changed

pandas/core/categorical.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1689,16 +1689,53 @@ def fillna(self, value=None, method=None, limit=None):
16891689

16901690
else:
16911691

1692-
if not isna(value) and value not in self.categories:
1693-
raise ValueError("fill value must be in categories")
1694-
1695-
mask = values == -1
1696-
if mask.any():
1697-
values = values.copy()
1698-
if isna(value):
1699-
values[mask] = -1
1692+
if isinstance(value, ABCSeries):
1693+
if not value[~value.isin(self.categories)].isna().all():
1694+
raise ValueError("fill value must be in categories")
1695+
1696+
# Check if a single scalar in the value Series
1697+
# (e.g., s.fillna(pd.Series('a')))
1698+
if (len(value[value.notna()]) == 1 and
1699+
value[value.notna()].index == 0):
1700+
values_codes = _get_codes_for_values(value[value.notna()],
1701+
self.categories)
1702+
mask = values == -1
1703+
values[mask] = values_codes
17001704
else:
1701-
values[mask] = self.categories.get_loc(value)
1705+
values_codes = _get_codes_for_values(value,
1706+
self.categories)
1707+
index = np.where(values_codes != -1)
1708+
values[index] = values_codes[values_codes != -1]
1709+
# from pandas import Series
1710+
# values_codes = Series(_get_codes_for_values(value[value.notna()],
1711+
# self.categories), index=value[value.notna()].index)
1712+
# values[values_codes.index] = values_codes
1713+
1714+
elif isinstance(value, dict):
1715+
from pandas import Series
1716+
value = Series(value)
1717+
1718+
if not value[~value.isin(self.categories)].isna().all():
1719+
raise ValueError("fill value must be in categories")
1720+
1721+
# The size of value is different than in the Series case
1722+
# Convert to Series to allow use of index values
1723+
values_codes = Series(_get_codes_for_values(value,
1724+
self.categories), index=value.index)
1725+
values[values_codes.index] = values_codes
1726+
1727+
# Scalar value
1728+
else:
1729+
if not isna(value) and value not in self.categories:
1730+
raise ValueError("fill value must be in categories")
1731+
1732+
mask = values == -1
1733+
if mask.any():
1734+
values = values.copy()
1735+
if isna(value):
1736+
values[mask] = -1
1737+
else:
1738+
values[mask] = self.categories.get_loc(value)
17021739

17031740
return self._constructor(values, categories=self.categories,
17041741
ordered=self.ordered, fastpath=True)

pandas/core/generic.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from pandas.core.common import (_count_not_none,
3333
_maybe_box_datetimelike, _values_from_object,
3434
AbstractMethodError, SettingWithCopyError,
35-
SettingWithCopyWarning)
35+
SettingWithCopyWarning, is_categorical_dtype)
3636

3737
from pandas.core.base import PandasObject, SelectionMixin
3838
from pandas.core.index import (Index, MultiIndex, _ensure_index,
@@ -4298,7 +4298,9 @@ def fillna(self, value=None, method=None, axis=None, inplace=False,
42984298
return self
42994299

43004300
if self.ndim == 1:
4301-
if isinstance(value, (dict, ABCSeries)):
4301+
if isinstance(value, dict) and is_categorical_dtype(self):
4302+
pass
4303+
elif isinstance(value, (ABCSeries, dict)):
43024304
from pandas import Series
43034305
value = Series(value)
43044306
elif not is_list_like(value):

pandas/tests/test_categorical.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4603,6 +4603,41 @@ def f():
46034603
df = pd.DataFrame({'a': pd.Categorical(idx)})
46044604
tm.assert_frame_equal(df.fillna(value=pd.NaT), df)
46054605

4606+
@pytest.mark.parametrize('fill_value expected_output', [
4607+
('a', ['a', 'a', 'b', 'a', 'a']),
4608+
({1: 'a', 3: 'b', 4: 'b'}, ['a', 'a', 'b', 'b', 'b']),
4609+
({1: 'a'}, ['a', 'a', 'b', np.nan, np.nan]),
4610+
({1: 'a', 3: 'b'}, ['a', 'a', 'b', 'b', np.nan]),
4611+
(pd.Series('a'), ['a', 'a', 'b', 'a', 'a']),
4612+
(pd.Series({1: 'a', 3: 'b'}), ['a', 'a', 'b', 'b', np.nan])
4613+
])
4614+
def fillna_categorical(self, fill_value, expected_output):
4615+
# GH 17033
4616+
# Test fillna for a Categorical series
4617+
data = ['a', np.nan, 'b', np.nan, np.nan]
4618+
s = pd.Series(pd.Categorical(data, categories=['a', 'b']))
4619+
exp = pd.Series(pd.Categorical(expected_output, categories=['a', 'b']))
4620+
tm.assert_series_equal(s.fillna())
4621+
4622+
def fillna_cat_wronginput(self):
4623+
data = pd.Categorical(['a', np.nan, 'b', np.nan])
4624+
4625+
s1 = pd.Series(data, categories=['a', 'b'])
4626+
with tm.assert_raises_regex(ValueError,
4627+
"fill value must be in categories"):
4628+
s1.fillna('cat')
4629+
4630+
s2 = pd.Series(data, categories=['a', 'b'])
4631+
with tm.assert_raises_regex(ValueError,
4632+
"fill value must be in categories"):
4633+
s2.fillna(pd.Series('cat'))
4634+
4635+
s3 = pd.Series(data, categories=['a', 'b'])
4636+
with tm.assert_raises_regex(TypeError,
4637+
'"value" parameter must be a scalar or '
4638+
'dict but you passed a "list"'):
4639+
s3.fillna(['a', 'b'])
4640+
46064641
def test_astype_to_other(self):
46074642

46084643
s = self.cat['value_group']

0 commit comments

Comments
 (0)