Skip to content

Commit 6f40d3d

Browse files
make pandas_dtype more consistent
1 parent c835dd8 commit 6f40d3d

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

pandas/core/dtypes/common.py

+4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Common type operations.
33
"""
44
from __future__ import annotations
5+
import inspect
56

67
from typing import (
78
Any,
@@ -30,6 +31,7 @@
3031
DatetimeTZDtype,
3132
ExtensionDtype,
3233
IntervalDtype,
34+
PandasExtensionDtype,
3335
PeriodDtype,
3436
)
3537
from pandas.core.dtypes.generic import (
@@ -1765,6 +1767,8 @@ def pandas_dtype(dtype) -> DtypeObj:
17651767
return dtype.dtype
17661768
elif isinstance(dtype, (np.dtype, ExtensionDtype)):
17671769
return dtype
1770+
elif inspect.isclass(dtype) and issubclass(dtype, ExtensionDtype):
1771+
return dtype()
17681772

17691773
# registered extension types
17701774
result = registry.find(dtype)

pandas/tests/dtypes/test_common.py

+44
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,48 @@ def test_period_dtype(self, dtype):
106106
assert com.pandas_dtype(dtype) == PeriodDtype(dtype)
107107
assert com.pandas_dtype(dtype) == dtype
108108

109+
@pytest.mark.parametrize(
110+
"cls",
111+
(
112+
pd.BooleanDtype,
113+
pd.Int8Dtype,
114+
pd.Int16Dtype,
115+
pd.Int32Dtype,
116+
pd.Int64Dtype,
117+
pd.UInt8Dtype,
118+
pd.UInt16Dtype,
119+
pd.UInt32Dtype,
120+
pd.UInt64Dtype,
121+
pd.Float32Dtype,
122+
pd.Float64Dtype,
123+
pd.SparseDtype,
124+
pd.StringDtype,
125+
IntervalDtype,
126+
CategoricalDtype,
127+
pytest.param(
128+
DatetimeTZDtype,
129+
marks=pytest.mark.xfail(reason="must specify TZ", raises=TypeError),
130+
),
131+
pytest.param(
132+
PeriodDtype,
133+
marks=pytest.mark.xfail(
134+
reason="must specify frequency", raises=AttributeError
135+
),
136+
),
137+
),
138+
)
139+
def test_pd_extension_dtype(self, cls):
140+
"""
141+
TODO: desired behavior?
142+
143+
For extension dtypes that admit no options OR can be initialized with no args
144+
passed, convert the extension dtype class to an instance of that class.
145+
"""
146+
expected = cls()
147+
result = com.pandas_dtype(cls)
148+
149+
assert result == expected
150+
109151

110152
dtypes = {
111153
"datetime_tz": com.pandas_dtype("datetime64[ns, US/Eastern]"),
@@ -689,6 +731,8 @@ def test_is_complex_dtype():
689731
(PeriodDtype(freq="D"), PeriodDtype(freq="D")),
690732
("period[D]", PeriodDtype(freq="D")),
691733
(IntervalDtype(), IntervalDtype()),
734+
(pd.BooleanDtype, pd.BooleanDtype()),
735+
(pd.BooleanDtype(), pd.BooleanDtype()),
692736
],
693737
)
694738
def test_get_dtype(input_param, result):

0 commit comments

Comments
 (0)