diff --git a/pandas/core/reshape.py b/pandas/core/reshape.py index 87cb088c2e91e..b76e52acb8be4 100644 --- a/pandas/core/reshape.py +++ b/pandas/core/reshape.py @@ -25,6 +25,7 @@ import pandas.core.algorithms as algos import pandas.algos as _algos +from pandas.core.missing import validate_fill_value from pandas.core.index import MultiIndex, _get_na_value @@ -405,6 +406,9 @@ def _slow_pivot(index, columns, values): def unstack(obj, level, fill_value=None): + if fill_value: + validate_fill_value(fill_value, obj.values.dtype) + if isinstance(level, (tuple, list)): return _unstack_multiple(obj, level) diff --git a/pandas/tests/types/test_missing.py b/pandas/tests/types/test_missing.py index cab44f1122ae1..bdbebfe01985b 100644 --- a/pandas/tests/types/test_missing.py +++ b/pandas/tests/types/test_missing.py @@ -12,7 +12,8 @@ DatetimeIndex, TimedeltaIndex, date_range) from pandas.types.dtypes import DatetimeTZDtype from pandas.types.missing import (array_equivalent, isnull, notnull, - na_value_for_dtype) + na_value_for_dtype, + validate_fill_value) def test_notnull(): @@ -301,3 +302,11 @@ def test_na_value_for_dtype(): for dtype in ['O']: assert np.isnan(na_value_for_dtype(np.dtype(dtype))) + + +class TestValidateFillValue(tm.TestCase): + # TODO: Fill out the test cases. + def test_validate_fill_value(self): + # validate_fill_value() + # import pdb; pdb.set_trace() + pass diff --git a/pandas/types/missing.py b/pandas/types/missing.py index e6791b79bf3bd..9d36a38b59c38 100644 --- a/pandas/types/missing.py +++ b/pandas/types/missing.py @@ -19,7 +19,10 @@ is_object_dtype, is_integer, _TD_DTYPE, - _NS_DTYPE) + _NS_DTYPE, + is_datetime64_any_dtype, is_float, + is_numeric_dtype, is_complex) +from datetime import datetime, timedelta from .inference import is_list_like @@ -391,3 +394,30 @@ def na_value_for_dtype(dtype): elif is_bool_dtype(dtype): return False return np.nan + + +def validate_fill_value(value, dtype): + """ + Make sure the fill value is appropriate for the given dtype. + """ + if not is_scalar(value): + raise TypeError('"fill_value" parameter must be ' + 'a scalar, but you passed a ' + '"{0}"'.format(type(value).__name__)) + elif not isnull(value): + if is_numeric_dtype(dtype): + if not (is_float(value) or is_integer(value) or is_complex(value)): + raise TypeError('"fill_value" parameter must be ' + 'numeric, but you passed a ' + '"{0}"'.format(type(value).__name__)) + elif is_datetime64_any_dtype(dtype): + if not isinstance(value, (np.datetime64, datetime)): + raise TypeError('"fill_value" parameter must be a ' + 'datetime, but you passed a ' + '"{0}"'.format(type(value).__name__)) + elif is_timedelta64_dtype(dtype): + if not isinstance(value, (np.timedelta64, timedelta)): + raise TypeError('"value" parameter must be ' + 'a timedelta, but you passed a ' + '"{0}"'.format(type(value).__name__)) + # if object dtype, do nothing.