Skip to content

BUG: Fixes plotting with nullable integers (#32073) #34896

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from datetime import date, datetime, timedelta
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type, Union

import numpy as np

Expand Down Expand Up @@ -76,6 +76,7 @@
if TYPE_CHECKING:
from pandas import Series
from pandas.core.arrays import ExtensionArray # noqa: F401
from pandas.core.indexes.base import Index # noqa: F401

_int8_max = np.iinfo(np.int8).max
_int16_max = np.iinfo(np.int16).max
Expand Down Expand Up @@ -1747,3 +1748,35 @@ def validate_numeric_casting(dtype: np.dtype, value):
):
if is_bool(value):
raise ValueError("Cannot assign bool to float/integer series")


def safe_convert_to_ndarray(values: Union[ArrayLike, "Index"]) -> np.ndarray:
"""
Converts values to ndarray with special handling for extension arrays.

Cast to ndarray but tries to avoid returning an array of `object` dtype.
Nullable integer and boolean arrays will be cast to float, and datetime
arrays with timezone information will lose their timezone information.

Parameters
----------
values : Union[ArrayLike, Index]
Values to be converted to ndarray.

Returns
-------
converted_values : np.ndarray
Values cast to np.ndarray.
"""
if hasattr(values, "dtype") and is_extension_array_dtype(values.dtype):
if is_integer_dtype(values.dtype):
converted_values = values.to_numpy(dtype=float, na_value=np.nan)
elif is_bool_dtype(values.dtype):
converted_values = values.to_numpy(dtype=float, na_value=np.nan)
elif is_datetime64tz_dtype(values.dtype):
converted_values = np.asarray(values.tz_localize(tz=None))
else:
converted_values = np.asarray(values)
else:
converted_values = np.asarray(values)
return converted_values
11 changes: 3 additions & 8 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,10 @@ class providing the base-class of operations.
from pandas.errors import AbstractMethodError
from pandas.util._decorators import Appender, Substitution, cache_readonly, doc

from pandas.core.dtypes.cast import maybe_cast_result
from pandas.core.dtypes.cast import maybe_cast_result, safe_convert_to_ndarray
from pandas.core.dtypes.common import (
ensure_float,
is_bool_dtype,
is_datetime64_dtype,
is_extension_array_dtype,
is_integer_dtype,
is_numeric_dtype,
is_object_dtype,
Expand Down Expand Up @@ -2052,14 +2050,11 @@ def pre_processor(vals: np.ndarray) -> Tuple[np.ndarray, Optional[Type]]:

inference = None
if is_integer_dtype(vals.dtype):
if is_extension_array_dtype(vals.dtype):
vals = vals.to_numpy(dtype=float, na_value=np.nan)
inference = np.int64
elif is_bool_dtype(vals.dtype) and is_extension_array_dtype(vals.dtype):
vals = vals.to_numpy(dtype=float, na_value=np.nan)
elif is_datetime64_dtype(vals.dtype):
inference = "datetime64[ns]"
vals = np.asarray(vals).astype(float)

vals = safe_convert_to_ndarray(vals)

return vals, inference

Expand Down
3 changes: 2 additions & 1 deletion pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pandas.errors import AbstractMethodError
from pandas.util._decorators import cache_readonly

from pandas.core.dtypes.cast import safe_convert_to_ndarray
from pandas.core.dtypes.common import (
is_float,
is_hashable,
Expand Down Expand Up @@ -421,7 +422,7 @@ def _compute_plot_data(self):
# np.ndarray before plot.
numeric_data = numeric_data.copy()
for col in numeric_data:
numeric_data[col] = np.asarray(numeric_data[col])
numeric_data[col] = safe_convert_to_ndarray(numeric_data[col].values)

self.data = numeric_data

Expand Down
51 changes: 51 additions & 0 deletions pandas/tests/dtypes/cast/test_safe_convert_to_ndarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import numpy as np
import pytest

from pandas.core.dtypes.cast import safe_convert_to_ndarray

import pandas as pd
import pandas._testing as tm


@pytest.mark.parametrize(
"values, expected",
[
(pd.Series([1, 2, 3], dtype=int), np.array([1, 2, 3], dtype=int)),
(
# Nullable integer type cast to float to handle missing values
pd.Series([1, np.NaN, 3], dtype="Int64"),
np.array([1, np.NaN, 3], dtype=float),
),
(
# Nullable boolean type cast to float to handle missing values
pd.Series([True, np.NaN, False], dtype="boolean"),
np.array([1.0, np.NaN, 0.0], dtype=float),
),
(
# Normal datetime cast not changed
pd.to_datetime([2001, None, 2003], format="%Y"),
np.array(["2001", "NaT", "2003"], dtype="datetime64").astype(
"datetime64[ns]"
),
),
(
# Extended datetime should be downcast to normal datetime
pd.to_datetime([2001, None, 2003], format="%Y", utc=True),
np.array(["2001", "NaT", "2003"], dtype="datetime64").astype(
"datetime64[ns]"
),
),
(
# Downcast to naive datetime should result in local dates, not UTC
pd.to_datetime([2001, None, 2003], format="%Y").tz_localize(
tz="US/Eastern"
),
np.array(["2001", "NaT", "2003"], dtype="datetime64").astype(
"datetime64[ns]"
),
),
],
)
def test_safe_convert_to_ndarray(values, expected):
result = safe_convert_to_ndarray(values)
tm.assert_numpy_array_equal(result, expected)
17 changes: 17 additions & 0 deletions pandas/tests/plotting/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3419,6 +3419,23 @@ def test_xlabel_ylabel_dataframe_subplots(
assert all(ax.get_ylabel() == str(new_label) for ax in axes)
assert all(ax.get_xlabel() == str(new_label) for ax in axes)

def test_nullable_int_plot(self):
# GH 32073
dates = ["2008", "2009", None, "2011", "2012"]
df = pd.DataFrame(
{
"A": [1, 2, 3, 4, 5],
"B": [7, 5, np.nan, 3, 2],
"C": pd.to_datetime(dates, format="%Y"),
"D": pd.to_datetime(dates, format="%Y", utc=True),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarcoGorelli This is the test case that hits the elif is_datetime64tz_dtype(values.dtype) branch in safe_convert_to_ndarray.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cvanweelden thanks for pointing me towards this.

Are you sure? I just tried running it and didn't hit it. Perhaps it was true before you changed

safe_convert_to_ndarray(numeric_data[col])

to

safe_convert_to_ndarray(numeric_data[col].values)

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah, you are right. I wasn't aware that the underlying values were actually an ndarray with non-extended dtype, so I didn't know this change would have that effect. Thanks!

}
)

_check_plot_works(df.plot, x="A", y="B")
_check_plot_works(df[["A", "B"]].astype("Int64").plot, x="A", y="B")
_check_plot_works(df[["A", "C"]].plot, x="A", y="C")
_check_plot_works(df[["A", "D"]].plot, x="A", y="D")


def _generate_4_axes_via_gridspec():
import matplotlib.pyplot as plt
Expand Down