Skip to content

Convert ExtensionDtype class objects to instances of that class #47108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
6 changes: 4 additions & 2 deletions pandas/core/dtypes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
from __future__ import annotations

import inspect
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -1574,9 +1575,8 @@ def get_dtype(arr_or_dtype) -> DtypeObj:
# fastpath
elif isinstance(arr_or_dtype, np.dtype):
return arr_or_dtype
elif isinstance(arr_or_dtype, type):
elif inspect.isclass(arr_or_dtype) and issubclass(arr_or_dtype, np.generic):
return np.dtype(arr_or_dtype)

# if we have an array-like
elif hasattr(arr_or_dtype, "dtype"):
arr_or_dtype = arr_or_dtype.dtype
Expand Down Expand Up @@ -1765,6 +1765,8 @@ def pandas_dtype(dtype) -> DtypeObj:
return dtype.dtype
elif isinstance(dtype, (np.dtype, ExtensionDtype)):
return dtype
elif inspect.isclass(dtype) and issubclass(dtype, ExtensionDtype):
return dtype()

# registered extension types
result = registry.find(dtype)
Expand Down
101 changes: 80 additions & 21 deletions pandas/tests/dtypes/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,48 @@ def test_period_dtype(self, dtype):
assert com.pandas_dtype(dtype) == PeriodDtype(dtype)
assert com.pandas_dtype(dtype) == dtype

@pytest.mark.parametrize(
"cls",
(
pd.BooleanDtype,
pd.Int8Dtype,
pd.Int16Dtype,
pd.Int32Dtype,
pd.Int64Dtype,
pd.UInt8Dtype,
pd.UInt16Dtype,
pd.UInt32Dtype,
pd.UInt64Dtype,
pd.Float32Dtype,
pd.Float64Dtype,
pd.SparseDtype,
pd.StringDtype,
IntervalDtype,
CategoricalDtype,
pytest.param(
DatetimeTZDtype,
marks=pytest.mark.xfail(reason="must specify TZ", raises=TypeError),
),
pytest.param(
PeriodDtype,
marks=pytest.mark.xfail(
reason="must specify frequency", raises=AttributeError
),
),
),
)
def test_pd_extension_dtype(self, cls):
"""
TODO: desired behavior?

For extension dtypes that admit no options OR can be initialized with no args
passed, convert the extension dtype class to an instance of that class.
"""
expected = cls()
result = com.pandas_dtype(cls)

assert result == expected


dtypes = {
"datetime_tz": com.pandas_dtype("datetime64[ns, US/Eastern]"),
Expand Down Expand Up @@ -558,29 +600,44 @@ def test_is_float_dtype():
assert com.is_float_dtype(pd.Index([1, 2.0]))


def test_is_bool_dtype():
assert not com.is_bool_dtype(int)
assert not com.is_bool_dtype(str)
assert not com.is_bool_dtype(pd.Series([1, 2]))
assert not com.is_bool_dtype(pd.Series(["a", "b"], dtype="category"))
assert not com.is_bool_dtype(np.array(["a", "b"]))
assert not com.is_bool_dtype(pd.Index(["a", "b"]))
assert not com.is_bool_dtype("Int64")

assert com.is_bool_dtype(bool)
assert com.is_bool_dtype(np.bool_)
assert com.is_bool_dtype(pd.Series([True, False], dtype="category"))
assert com.is_bool_dtype(np.array([True, False]))
assert com.is_bool_dtype(pd.Index([True, False]))

assert com.is_bool_dtype(pd.BooleanDtype())
assert com.is_bool_dtype(pd.array([True, False, None], dtype="boolean"))
assert com.is_bool_dtype("boolean")
@pytest.mark.parametrize(
"value",
(
True,
False,
int,
str,
"Int64",
"0 - Name", # GH39010
pd.array(("a", "b")),
pd.Index(("a", "b")),
pd.Series(("a", "b"), dtype="category"),
pd.Series((1, 2)),
),
)
def test_is_bool_dtype_returns_false(value):
assert com.is_bool_dtype(value) is False


def test_is_bool_dtype_numpy_error():
# GH39010
assert not com.is_bool_dtype("0 - Name")
@pytest.mark.parametrize(
"value",
(
bool,
np.bool_,
np.dtype(np.bool_),
pd.BooleanDtype,
pd.BooleanDtype(),
"bool",
"boolean",
pd.array((True, False)),
pd.Index((True, False)),
pd.Series((True, False)),
pd.Series((True, False), dtype="category"),
pd.Series((True, False, None), dtype="boolean"),
),
)
def test_is_bool_dtype_returns_true(value):
assert com.is_bool_dtype(value) is True


@pytest.mark.filterwarnings("ignore:'is_extension_type' is deprecated:FutureWarning")
Expand Down Expand Up @@ -674,6 +731,8 @@ def test_is_complex_dtype():
(PeriodDtype(freq="D"), PeriodDtype(freq="D")),
("period[D]", PeriodDtype(freq="D")),
(IntervalDtype(), IntervalDtype()),
(pd.BooleanDtype, pd.BooleanDtype()),
(pd.BooleanDtype(), pd.BooleanDtype()),
],
)
def test_get_dtype(input_param, result):
Expand Down