Skip to content

Commit f1abd83

Browse files
raise if dtype class obj passed
1 parent 2404f3d commit f1abd83

File tree

3 files changed

+26
-51
lines changed

3 files changed

+26
-51
lines changed

pandas/core/dtypes/common.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1760,13 +1760,15 @@ def pandas_dtype(dtype) -> DtypeObj:
17601760
------
17611761
TypeError if not a dtype
17621762
"""
1763+
if inspect.isclass(dtype) and issubclass(dtype, (np.dtype, ExtensionDtype)):
1764+
msg = "Must pass dtype instance, not dtype class"
1765+
raise TypeError(msg)
1766+
17631767
# short-circuit
17641768
if isinstance(dtype, np.ndarray):
17651769
return dtype.dtype
17661770
elif isinstance(dtype, (np.dtype, ExtensionDtype)):
17671771
return dtype
1768-
elif inspect.isclass(dtype) and issubclass(dtype, ExtensionDtype):
1769-
return dtype()
17701772

17711773
# registered extension types
17721774
result = registry.find(dtype)

pandas/core/nanops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ def _mask_datetimelike_result(
648648
return result
649649

650650

651-
@disallow(PeriodDtype)
651+
@disallow(PeriodDtype())
652652
@bottleneck_switch()
653653
@_datetimelike_compat
654654
def nanmean(

pandas/tests/dtypes/test_common.py

+21-48
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
import pandas.util._test_decorators as td
99

1010
from pandas.core.dtypes.astype import astype_nansafe
11+
from pandas.core.dtypes.base import (
12+
ExtensionDtype,
13+
_registry,
14+
)
1115
import pandas.core.dtypes.common as com
1216
from pandas.core.dtypes.dtypes import (
1317
CategoricalDtype,
@@ -23,6 +27,8 @@
2327
from pandas.api.types import pandas_dtype
2428
from pandas.arrays import SparseArray
2529

30+
ALL_EA_DTYPES = _registry.dtypes
31+
2632

2733
# EA & Actual Dtypes
2834
def to_ea_dtypes(dtypes):
@@ -36,25 +42,36 @@ def to_numpy_dtypes(dtypes):
3642

3743

3844
class TestPandasDtype:
39-
4045
# Passing invalid dtype, both as a string or object, must raise TypeError
4146
# Per issue GH15520
4247
@pytest.mark.parametrize("box", [pd.Timestamp, "pd.Timestamp", list])
4348
def test_invalid_dtype_error(self, box):
44-
with pytest.raises(TypeError, match="not understood"):
49+
msg = "|".join(
50+
(
51+
"Must pass dtype instance, not dtype class",
52+
"not understood",
53+
)
54+
)
55+
with pytest.raises(TypeError, match=msg):
4556
com.pandas_dtype(box)
4657

58+
@pytest.mark.parametrize("cls", ALL_EA_DTYPES)
59+
def test_raises_for_dtype_class(self, cls: type[ExtensionDtype]):
60+
msg = "Must pass dtype instance, not dtype class"
61+
with pytest.raises(TypeError, match=msg):
62+
com.pandas_dtype(cls)
63+
4764
@pytest.mark.parametrize(
4865
"dtype",
4966
[
5067
object,
51-
"float64",
5268
np.object_,
5369
np.dtype("object"),
5470
"O",
55-
np.float64,
5671
float,
72+
np.float64,
5773
np.dtype("float64"),
74+
"float64",
5875
],
5976
)
6077
def test_pandas_dtype_valid(self, dtype):
@@ -106,48 +123,6 @@ def test_period_dtype(self, dtype):
106123
assert com.pandas_dtype(dtype) == PeriodDtype(dtype)
107124
assert com.pandas_dtype(dtype) == dtype
108125

109-
@pytest.mark.parametrize(
110-
"cls",
111-
(
112-
pd.BooleanDtype,
113-
pd.Int8Dtype,
114-
pd.Int16Dtype,
115-
pd.Int32Dtype,
116-
pd.Int64Dtype,
117-
pd.UInt8Dtype,
118-
pd.UInt16Dtype,
119-
pd.UInt32Dtype,
120-
pd.UInt64Dtype,
121-
pd.Float32Dtype,
122-
pd.Float64Dtype,
123-
pd.SparseDtype,
124-
pd.StringDtype,
125-
IntervalDtype,
126-
CategoricalDtype,
127-
pytest.param(
128-
DatetimeTZDtype,
129-
marks=pytest.mark.xfail(reason="must specify TZ", raises=TypeError),
130-
),
131-
pytest.param(
132-
PeriodDtype,
133-
marks=pytest.mark.xfail(
134-
reason="must specify frequency", raises=AttributeError
135-
),
136-
),
137-
),
138-
)
139-
def test_pd_extension_dtype(self, cls):
140-
"""
141-
TODO: desired behavior?
142-
143-
For extension dtypes that admit no options OR can be initialized with no args
144-
passed, convert the extension dtype class to an instance of that class.
145-
"""
146-
expected = cls()
147-
result = com.pandas_dtype(cls)
148-
149-
assert result == expected
150-
151126

152127
dtypes = {
153128
"datetime_tz": com.pandas_dtype("datetime64[ns, US/Eastern]"),
@@ -625,7 +600,6 @@ def test_is_bool_dtype_returns_false(value):
625600
bool,
626601
np.bool_,
627602
np.dtype(np.bool_),
628-
pd.BooleanDtype,
629603
pd.BooleanDtype(),
630604
"bool",
631605
"boolean",
@@ -731,7 +705,6 @@ def test_is_complex_dtype():
731705
(PeriodDtype(freq="D"), PeriodDtype(freq="D")),
732706
("period[D]", PeriodDtype(freq="D")),
733707
(IntervalDtype(), IntervalDtype()),
734-
(pd.BooleanDtype, pd.BooleanDtype()),
735708
(pd.BooleanDtype(), pd.BooleanDtype()),
736709
],
737710
)

0 commit comments

Comments
 (0)