diff --git a/doc/source/whatsnew/v0.20.0.txt b/doc/source/whatsnew/v0.20.0.txt index 945922b5f9ba8..7b9ac2ecf0724 100644 --- a/doc/source/whatsnew/v0.20.0.txt +++ b/doc/source/whatsnew/v0.20.0.txt @@ -1604,6 +1604,7 @@ Conversion - Bug in the return type of ``pd.unique`` on a ``Categorical``, which was returning an ndarray and not a ``Categorical`` (:issue:`15903`) - Bug in ``Index.to_series()`` where the index was not copied (and so mutating later would change the original), (:issue:`15949`) - Bug in indexing with partial string indexing with a len-1 DataFrame (:issue:`16071`) +- Bug in ``pandas_dtype`` where passing invalid dtype didn't raise an error. (:issue:`15520`) Indexing ^^^^^^^^ diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index a5e12e8262579..19d3792f73de7 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -19,7 +19,7 @@ is_datetime_or_timedelta_dtype, is_bool_dtype, is_scalar, _string_dtypes, - _coerce_to_dtype, + pandas_dtype, _ensure_int8, _ensure_int16, _ensure_int32, _ensure_int64, _NS_DTYPE, _TD_DTYPE, _INT64_DTYPE, @@ -576,7 +576,7 @@ def astype_nansafe(arr, dtype, copy=True): """ return a view if copy is False, but need to be very careful as the result shape could change! """ if not isinstance(dtype, np.dtype): - dtype = _coerce_to_dtype(dtype) + dtype = pandas_dtype(dtype) if issubclass(dtype.type, text_type): # in Py3 that's str, in Py2 that's unicode diff --git a/pandas/core/dtypes/common.py b/pandas/core/dtypes/common.py index 156e43fc4e5fb..ba822071a3b72 100644 --- a/pandas/core/dtypes/common.py +++ b/pandas/core/dtypes/common.py @@ -788,4 +788,19 @@ def pandas_dtype(dtype): elif isinstance(dtype, ExtensionDtype): return dtype - return np.dtype(dtype) + try: + npdtype = np.dtype(dtype) + except (TypeError, ValueError): + raise + + # Any invalid dtype (such as pd.Timestamp) should raise an error. + # np.dtype(invalid_type).kind = 0 for such objects. However, this will + # also catch some valid dtypes such as object, np.object_ and 'object' + # which we safeguard against by catching them earlier and returning + # np.dtype(valid_dtype) before this condition is evaluated. + if dtype in [object, np.object_, 'object', 'O']: + return npdtype + elif npdtype.kind == 'O': + raise TypeError('dtype {0} not understood'.format(dtype)) + + return npdtype diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 74d3053821e39..39ffb8aae0e92 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -11,7 +11,6 @@ from pandas._libs import tslib, lib from pandas.core.dtypes.common import ( - _coerce_to_dtype, _ensure_int64, needs_i8_conversion, is_scalar, @@ -23,7 +22,8 @@ is_datetime64tz_dtype, is_list_like, is_dict_like, - is_re_compilable) + is_re_compilable, + pandas_dtype) from pandas.core.dtypes.cast import maybe_promote, maybe_upcast_putmask from pandas.core.dtypes.missing import isnull, notnull from pandas.core.dtypes.generic import ABCSeries, ABCPanel @@ -164,13 +164,14 @@ def _validate_dtype(self, dtype): """ validate the passed dtype """ if dtype is not None: - dtype = _coerce_to_dtype(dtype) + dtype = pandas_dtype(dtype) # a compound dtype if dtype.kind == 'V': raise NotImplementedError("compound dtypes are not implemented" "in the {0} constructor" .format(self.__class__.__name__)) + return dtype def _init_mgr(self, mgr, axes=None, dtype=None, copy=False): diff --git a/pandas/core/series.py b/pandas/core/series.py index d4511fb58b2f3..f03091d7e6a66 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -14,7 +14,7 @@ import numpy.ma as ma from pandas.core.dtypes.common import ( - _coerce_to_dtype, is_categorical_dtype, + is_categorical_dtype, is_bool, is_integer, is_integer_dtype, is_float_dtype, @@ -28,7 +28,8 @@ is_dict_like, is_scalar, _is_unorderable_exception, - _ensure_platform_int) + _ensure_platform_int, + pandas_dtype) from pandas.core.dtypes.generic import ABCSparseArray, ABCDataFrame from pandas.core.dtypes.cast import ( maybe_upcast, infer_dtype_from_scalar, @@ -2872,7 +2873,7 @@ def _sanitize_array(data, index, dtype=None, copy=False, """ if dtype is not None: - dtype = _coerce_to_dtype(dtype) + dtype = pandas_dtype(dtype) if isinstance(data, ma.MaskedArray): mask = ma.getmaskarray(data) diff --git a/pandas/tests/dtypes/test_common.py b/pandas/tests/dtypes/test_common.py index 86233c5d2b192..c4ef5e48b4db9 100644 --- a/pandas/tests/dtypes/test_common.py +++ b/pandas/tests/dtypes/test_common.py @@ -2,6 +2,7 @@ import pytest import numpy as np +import pandas as pd from pandas.core.dtypes.dtypes import ( DatetimeTZDtype, PeriodDtype, CategoricalDtype) @@ -13,6 +14,20 @@ class TestPandasDtype(tm.TestCase): + # Passing invalid dtype, both as a string or object, must raise TypeError + # Per issue GH15520 + def test_invalid_dtype_error(self): + msg = 'not understood' + invalid_list = [pd.Timestamp, 'pd.Timestamp', list] + for dtype in invalid_list: + with tm.assertRaisesRegexp(TypeError, msg): + pandas_dtype(dtype) + + valid_list = [object, 'float64', np.object_, np.dtype('object'), 'O', + np.float64, float, np.dtype('float64')] + for dtype in valid_list: + pandas_dtype(dtype) + def test_numpy_dtype(self): for dtype in ['M8[ns]', 'm8[ns]', 'object', 'float64', 'int64']: self.assertEqual(pandas_dtype(dtype), np.dtype(dtype)) diff --git a/pandas/tests/series/test_constructors.py b/pandas/tests/series/test_constructors.py index 57cce1d1cf199..74c2544d900ea 100644 --- a/pandas/tests/series/test_constructors.py +++ b/pandas/tests/series/test_constructors.py @@ -30,6 +30,14 @@ class TestSeriesConstructors(TestData, tm.TestCase): + def test_invalid_dtype(self): + # GH15520 + msg = 'not understood' + invalid_list = [pd.Timestamp, 'pd.Timestamp', list] + for dtype in invalid_list: + with tm.assertRaisesRegexp(TypeError, msg): + Series([], name='time', dtype=dtype) + def test_scalar_conversion(self): # Pass in scalar is disabled diff --git a/pandas/tests/test_strings.py b/pandas/tests/test_strings.py index a818bf84b8e9b..6733fbdc3b9c6 100644 --- a/pandas/tests/test_strings.py +++ b/pandas/tests/test_strings.py @@ -1208,10 +1208,9 @@ def test_extractall_same_as_extract_subject_index(self): tm.assert_frame_equal(extract_one_noname, no_match_index) def test_empty_str_methods(self): - empty_str = empty = Series(dtype=str) + empty_str = empty = Series(dtype=object) empty_int = Series(dtype=int) empty_bool = Series(dtype=bool) - empty_list = Series(dtype=list) empty_bytes = Series(dtype=object) # GH7241 @@ -1242,25 +1241,24 @@ def test_empty_str_methods(self): DataFrame(columns=[0, 1], dtype=str), empty.str.extract('()()', expand=False)) tm.assert_frame_equal(DataFrame(dtype=str), empty.str.get_dummies()) - tm.assert_series_equal(empty_str, empty_list.str.join('')) + tm.assert_series_equal(empty_str, empty_str.str.join('')) tm.assert_series_equal(empty_int, empty.str.len()) - tm.assert_series_equal(empty_list, empty_list.str.findall('a')) + tm.assert_series_equal(empty_str, empty_str.str.findall('a')) tm.assert_series_equal(empty_int, empty.str.find('a')) tm.assert_series_equal(empty_int, empty.str.rfind('a')) tm.assert_series_equal(empty_str, empty.str.pad(42)) tm.assert_series_equal(empty_str, empty.str.center(42)) - tm.assert_series_equal(empty_list, empty.str.split('a')) - tm.assert_series_equal(empty_list, empty.str.rsplit('a')) - tm.assert_series_equal(empty_list, + tm.assert_series_equal(empty_str, empty.str.split('a')) + tm.assert_series_equal(empty_str, empty.str.rsplit('a')) + tm.assert_series_equal(empty_str, empty.str.partition('a', expand=False)) - tm.assert_series_equal(empty_list, + tm.assert_series_equal(empty_str, empty.str.rpartition('a', expand=False)) tm.assert_series_equal(empty_str, empty.str.slice(stop=1)) tm.assert_series_equal(empty_str, empty.str.slice(step=1)) tm.assert_series_equal(empty_str, empty.str.strip()) tm.assert_series_equal(empty_str, empty.str.lstrip()) tm.assert_series_equal(empty_str, empty.str.rstrip()) - tm.assert_series_equal(empty_str, empty.str.rstrip()) tm.assert_series_equal(empty_str, empty.str.wrap(42)) tm.assert_series_equal(empty_str, empty.str.get(0)) tm.assert_series_equal(empty_str, empty_bytes.str.decode('ascii'))