Skip to content

Commit 26ceaf8

Browse files
committed
BUG: Allow a categorical series to accept dict or Series in fillna (pandas-dev#17033)
cleaning up comments and tests
1 parent 488db6f commit 26ceaf8

File tree

3 files changed

+81
-11
lines changed

3 files changed

+81
-11
lines changed

pandas/core/categorical.py

+45-9
Original file line numberDiff line numberDiff line change
@@ -1689,16 +1689,52 @@ def fillna(self, value=None, method=None, limit=None):
16891689

16901690
else:
16911691

1692-
if not isna(value) and value not in self.categories:
1693-
raise ValueError("fill value must be in categories")
1694-
1695-
mask = values == -1
1696-
if mask.any():
1697-
values = values.copy()
1698-
if isna(value):
1699-
values[mask] = -1
1692+
if isinstance(value, ABCSeries):
1693+
if not value[~value.isin(self.categories)].isna().all():
1694+
raise ValueError("fill value must be in categories")
1695+
1696+
# Check if single scalar in the value Series
1697+
# (e.g., s.fillna(pd.Series('a')))
1698+
if (len(value[value.notna()]) == 1 and
1699+
value[value.notna()].index == 0):
1700+
values_codes = _get_codes_for_values(value[value.notna()],
1701+
self.categories)
1702+
mask = values == -1
1703+
values[mask] = values_codes
17001704
else:
1701-
values[mask] = self.categories.get_loc(value)
1705+
values_codes = _get_codes_for_values(value,
1706+
self.categories)
1707+
index = np.where(values_codes != -1)
1708+
values[index] = values_codes[values_codes != -1]
1709+
# from pandas import Series
1710+
# values_codes = Series(_get_codes_for_values(value[value.notna()],
1711+
# self.categories), index=value[value.notna()].index)
1712+
# values[values_codes.index] = values_codes
1713+
1714+
elif isinstance(value, dict):
1715+
from pandas import Series
1716+
value = Series(value)
1717+
1718+
if not value[~value.isin(self.categories)].isna().all():
1719+
raise ValueError("fill value must be in categories")
1720+
1721+
# Convert to Series to allow use of index values
1722+
values_codes = Series(_get_codes_for_values(value,
1723+
self.categories), index=value.index)
1724+
values[values_codes.index] = values_codes
1725+
1726+
# Scalar value
1727+
else:
1728+
if not isna(value) and value not in self.categories:
1729+
raise ValueError("fill value must be in categories")
1730+
1731+
mask = values == -1
1732+
if mask.any():
1733+
values = values.copy()
1734+
if isna(value):
1735+
values[mask] = -1
1736+
else:
1737+
values[mask] = self.categories.get_loc(value)
17021738

17031739
return self._constructor(values, categories=self.categories,
17041740
ordered=self.ordered, fastpath=True)

pandas/core/generic.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from pandas.core.common import (_count_not_none,
3333
_maybe_box_datetimelike, _values_from_object,
3434
AbstractMethodError, SettingWithCopyError,
35-
SettingWithCopyWarning)
35+
SettingWithCopyWarning, is_categorical_dtype)
3636

3737
from pandas.core.base import PandasObject, SelectionMixin
3838
from pandas.core.index import (Index, MultiIndex, _ensure_index,
@@ -4298,7 +4298,9 @@ def fillna(self, value=None, method=None, axis=None, inplace=False,
42984298
return self
42994299

43004300
if self.ndim == 1:
4301-
if isinstance(value, (dict, ABCSeries)):
4301+
if isinstance(value, dict) and is_categorical_dtype(self):
4302+
pass
4303+
elif isinstance(value, (ABCSeries, dict)):
43024304
from pandas import Series
43034305
value = Series(value)
43044306
elif not is_list_like(value):

pandas/tests/test_categorical.py

+32
Original file line numberDiff line numberDiff line change
@@ -4603,6 +4603,38 @@ def f():
46034603
df = pd.DataFrame({'a': pd.Categorical(idx)})
46044604
tm.assert_frame_equal(df.fillna(value=pd.NaT), df)
46054605

4606+
@pytest.mark.parametrize('fill_value expected_output', [
4607+
('a', ['a', 'a', 'b', 'a', 'a']),
4608+
({1: 'a', 3: 'b', 4: 'b'}, ['a', 'a', 'b', 'b', 'b']),
4609+
({1: 'a'}, ['a', 'a', 'b', np.nan, np.nan]),
4610+
({1: 'a', 3: 'b'}, ['a', 'a', 'b', 'b', np.nan]),
4611+
(pd.Series('a'), ['a', 'a', 'b', 'a', 'a']),
4612+
(pd.Series({1: 'a', 3: 'b'}), ['a', 'a', 'b', 'b', np.nan])
4613+
])
4614+
def fillna_series_categorical(self, fill_value, expected_output):
4615+
# GH 17033
4616+
# Test fillna for a Categorical series
4617+
data = ['a', np.nan, 'b', np.nan, np.nan]
4618+
s = pd.Series(pd.Categorical(data, categories=['a', 'b']))
4619+
exp = pd.Series(pd.Categorical(expected_output, categories=['a', 'b']))
4620+
tm.assert_series_equal(s.fillna(fill_value), exp)
4621+
4622+
s1 = pd.Series(data, categories=['a', 'b'])
4623+
with tm.assert_raises_regex(ValueError,
4624+
"fill value must be in categories"):
4625+
s1.fillna('cat')
4626+
4627+
s2 = pd.Series(data, categories=['a', 'b'])
4628+
with tm.assert_raises_regex(ValueError,
4629+
"fill value must be in categories"):
4630+
s2.fillna(pd.Series('cat'))
4631+
4632+
s3 = pd.Series(data, categories=['a', 'b'])
4633+
with tm.assert_raises_regex(TypeError,
4634+
'"value" parameter must be a scalar or '
4635+
'dict but you passed a "list"'):
4636+
s3.fillna(['a', 'b'])
4637+
46064638
def test_astype_to_other(self):
46074639

46084640
s = self.cat['value_group']

0 commit comments

Comments
 (0)