Skip to content

Commit 9df8261

Browse files
committed
BUG: Allow a categorical series to accept dict or Series in fillna (pandas-dev#17033)
1 parent 488db6f commit 9df8261

File tree

3 files changed

+102
-11
lines changed

3 files changed

+102
-11
lines changed

pandas/core/categorical.py

+46-9
Original file line numberDiff line numberDiff line change
@@ -1689,16 +1689,53 @@ def fillna(self, value=None, method=None, limit=None):
16891689

16901690
else:
16911691

1692-
if not isna(value) and value not in self.categories:
1693-
raise ValueError("fill value must be in categories")
1694-
1695-
mask = values == -1
1696-
if mask.any():
1697-
values = values.copy()
1698-
if isna(value):
1699-
values[mask] = -1
1692+
if isinstance(value, ABCSeries):
1693+
if not value[~value.isin(self.categories)].isna().all():
1694+
raise ValueError("fill value must be in categories")
1695+
1696+
# Check if a single scalar in the value Series
1697+
# (e.g., s.fillna(pd.Series('a')))
1698+
if (len(value[value.notna()]) == 1 and
1699+
value[value.notna()].index == 0):
1700+
values_codes = _get_codes_for_values(value[value.notna()],
1701+
self.categories)
1702+
mask = values == -1
1703+
values[mask] = values_codes
17001704
else:
1701-
values[mask] = self.categories.get_loc(value)
1705+
values_codes = _get_codes_for_values(value,
1706+
self.categories)
1707+
index = np.where(values_codes != -1)
1708+
values[index] = values_codes[values_codes != -1]
1709+
# from pandas import Series
1710+
# values_codes = Series(_get_codes_for_values(value[value.notna()],
1711+
# self.categories), index=value[value.notna()].index)
1712+
# values[values_codes.index] = values_codes
1713+
1714+
elif isinstance(value, dict):
1715+
from pandas import Series
1716+
value = Series(value)
1717+
1718+
if not value[~value.isin(self.categories)].isna().all():
1719+
raise ValueError("fill value must be in categories")
1720+
1721+
# The size of value is different than in the Series case
1722+
# Convert to Series to allow use of index values
1723+
values_codes = Series(_get_codes_for_values(value,
1724+
self.categories), index=value.index)
1725+
values[values_codes.index] = values_codes
1726+
1727+
# Scalar value
1728+
else:
1729+
if not isna(value) and value not in self.categories:
1730+
raise ValueError("fill value must be in categories")
1731+
1732+
mask = values == -1
1733+
if mask.any():
1734+
values = values.copy()
1735+
if isna(value):
1736+
values[mask] = -1
1737+
else:
1738+
values[mask] = self.categories.get_loc(value)
17021739

17031740
return self._constructor(values, categories=self.categories,
17041741
ordered=self.ordered, fastpath=True)

pandas/core/generic.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from pandas.core.common import (_count_not_none,
3333
_maybe_box_datetimelike, _values_from_object,
3434
AbstractMethodError, SettingWithCopyError,
35-
SettingWithCopyWarning)
35+
SettingWithCopyWarning, is_categorical_dtype)
3636

3737
from pandas.core.base import PandasObject, SelectionMixin
3838
from pandas.core.index import (Index, MultiIndex, _ensure_index,
@@ -4298,7 +4298,9 @@ def fillna(self, value=None, method=None, axis=None, inplace=False,
42984298
return self
42994299

43004300
if self.ndim == 1:
4301-
if isinstance(value, (dict, ABCSeries)):
4301+
if isinstance(value, dict) and is_categorical_dtype(self):
4302+
pass
4303+
elif isinstance(value, (ABCSeries, dict)):
43024304
from pandas import Series
43034305
value = Series(value)
43044306
elif not is_list_like(value):

pandas/tests/test_categorical.py

+52
Original file line numberDiff line numberDiff line change
@@ -4603,6 +4603,58 @@ def f():
46034603
df = pd.DataFrame({'a': pd.Categorical(idx)})
46044604
tm.assert_frame_equal(df.fillna(value=pd.NaT), df)
46054605

4606+
def fillna_categorical(self):
4607+
# GH 17033
4608+
# Test fillna for a Categorical series
4609+
s1 = pd.Series(pd.Categorical(['a', np.nan, 'b', np.nan],
4610+
categories=['a', 'b']))
4611+
s1_exp = pd.Series(pd.Categorical(['a', 'a', 'b', 'a']))
4612+
tm.assert_series_equal(s1.fillna('a'), s1_exp)
4613+
4614+
s2 = pd.Series(pd.Categorical(['a', np.nan, 'b', np.nan],
4615+
categories=['a', 'b']))
4616+
s2_exp = pd.Series(pd.Categorical(['a', 'a', 'b', 'b']))
4617+
tm.assert_series_equal(s2.fillna({1: 'a', 3: 'b'}), s2_exp)
4618+
4619+
s3 = pd.Series(pd.Categorical(['a', np.nan, 'b', np.nan],
4620+
categories=['a', 'b']))
4621+
s3_exp = pd.Series(pd.Categorical(['a', 'a', 'b', np.nan]))
4622+
tm.assert_series_equal(s2.fillna({1: 'a'}), s3_exp)
4623+
4624+
s4 = pd.Series(pd.Categorical(['a', np.nan, 'b', np.nan, np.nan],
4625+
categories=['a', 'b']))
4626+
s4_exp = pd.Series(pd.Categorical(['a', 'a', 'b', 'b', np.nan]))
4627+
tm.assert_series_equal(s4.fillna({1: 'a', 3: 'b'}), s4_exp)
4628+
4629+
s5 = pd.Series(pd.Categorical(['a', np.nan, 'b', np.nan],
4630+
categories=['a', 'b']))
4631+
s5_exp = pd.Series(pd.Categorical(['a', 'a', 'b', 'a']))
4632+
tm.assert_series_equal(s5.fillna(pd.Series('a')), s5_exp)
4633+
4634+
s6 = pd.Series(pd.Categorical(['a', np.nan, 'b', np.nan],
4635+
categories=['a', 'b']))
4636+
s6_exp = pd.Series(pd.Categorical(['a', 'a', 'b', 'a']))
4637+
tm.assert_series_equal(s6.fillna(pd.Series({1: 'a', 3: 'b'})), s6_exp)
4638+
4639+
s7 = pd.Series(pd.Categorical(['a', np.nan, 'b', np.nan],
4640+
categories=['a', 'b']))
4641+
with tm.assert_raises_regex(ValueError,
4642+
"fill value must be in categories"):
4643+
s7.fillna('cat')
4644+
4645+
s8 = pd.Series(pd.Categorical(['a', np.nan, 'b', np.nan],
4646+
categories=['a', 'b']))
4647+
with tm.assert_raises_regex(ValueError,
4648+
"fill value must be in categories"):
4649+
s8.fillna(pd.Series('cat'))
4650+
4651+
s9 = pd.Series(pd.Categorical(['a', np.nan, 'b', np.nan],
4652+
categories=['a', 'b']))
4653+
with tm.assert_raises_regex(TypeError,
4654+
'"value" parameter must be a scalar or '
4655+
'dict but you passed a "list"'):
4656+
s9.fillna(['a', 'b'])
4657+
46064658
def test_astype_to_other(self):
46074659

46084660
s = self.cat['value_group']

0 commit comments

Comments
 (0)