Skip to content

REF: Mess with data less outside of MPLPlot.__init__ #55889

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 56 additions & 40 deletions pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,15 @@

from pandas._typing import (
IndexLabel,
NDFrameT,
PlottingOrientation,
npt,
)

from pandas import Series
from pandas import (
PeriodIndex,
Series,
)


def _color_in_style(style: str) -> bool:
Expand Down Expand Up @@ -161,8 +165,6 @@ def __init__(
) -> None:
import matplotlib.pyplot as plt

self.data = data

# if users assign an empty list or tuple, raise `ValueError`
# similar to current `df.box` and `df.hist` APIs.
if by in ([], ()):
Expand Down Expand Up @@ -193,9 +195,11 @@ def __init__(

self.kind = kind

self.subplots = self._validate_subplots_kwarg(subplots)
self.subplots = type(self)._validate_subplots_kwarg(
subplots, data, kind=self._kind
)

self.sharex = self._validate_sharex(sharex, ax, by)
self.sharex = type(self)._validate_sharex(sharex, ax, by)
self.sharey = sharey
self.figsize = figsize
self.layout = layout
Expand Down Expand Up @@ -245,10 +249,11 @@ def __init__(
# parse errorbar input if given
xerr = kwds.pop("xerr", None)
yerr = kwds.pop("yerr", None)
self.errors = {
kw: self._parse_errorbars(kw, err)
for kw, err in zip(["xerr", "yerr"], [xerr, yerr])
}
nseries = self._get_nseries(data)
xerr, data = type(self)._parse_errorbars("xerr", xerr, data, nseries)
yerr, data = type(self)._parse_errorbars("yerr", yerr, data, nseries)
self.errors = {"xerr": xerr, "yerr": yerr}
self.data = data

if not isinstance(secondary_y, (bool, tuple, list, np.ndarray, ABCIndex)):
secondary_y = [secondary_y]
Expand All @@ -271,7 +276,8 @@ def __init__(
self._validate_color_args()

@final
def _validate_sharex(self, sharex: bool | None, ax, by) -> bool:
@staticmethod
def _validate_sharex(sharex: bool | None, ax, by) -> bool:
if sharex is None:
# if by is defined, subplots are used and sharex should be False
if ax is None and by is None: # pylint: disable=simplifiable-if-statement
Expand All @@ -285,8 +291,9 @@ def _validate_sharex(self, sharex: bool | None, ax, by) -> bool:
return bool(sharex)

@final
@staticmethod
def _validate_subplots_kwarg(
self, subplots: bool | Sequence[Sequence[str]]
subplots: bool | Sequence[Sequence[str]], data: Series | DataFrame, kind: str
) -> bool | list[tuple[int, ...]]:
"""
Validate the subplots parameter
Expand Down Expand Up @@ -323,18 +330,18 @@ def _validate_subplots_kwarg(
"area",
"pie",
)
if self._kind not in supported_kinds:
if kind not in supported_kinds:
raise ValueError(
"When subplots is an iterable, kind must be "
f"one of {', '.join(supported_kinds)}. Got {self._kind}."
f"one of {', '.join(supported_kinds)}. Got {kind}."
)

if isinstance(self.data, ABCSeries):
if isinstance(data, ABCSeries):
raise NotImplementedError(
"An iterable subplots for a Series is not supported."
)

columns = self.data.columns
columns = data.columns
if isinstance(columns, ABCMultiIndex):
raise NotImplementedError(
"An iterable subplots for a DataFrame with a MultiIndex column "
Expand Down Expand Up @@ -442,18 +449,22 @@ def _iter_data(
# typing.
yield col, np.asarray(values.values)

@property
def nseries(self) -> int:
def _get_nseries(self, data: Series | DataFrame) -> int:
# When `by` is explicitly assigned, grouped data size will be defined, and
# this will determine number of subplots to have, aka `self.nseries`
if self.data.ndim == 1:
if data.ndim == 1:
return 1
elif self.by is not None and self._kind == "hist":
return len(self._grouped)
elif self.by is not None and self._kind == "box":
return len(self.columns)
else:
return self.data.shape[1]
return data.shape[1]

@final
@property
def nseries(self) -> int:
return self._get_nseries(self.data)

@final
def draw(self) -> None:
Expand Down Expand Up @@ -880,10 +891,12 @@ def _get_xticks(self, convert_period: bool = False):
index = self.data.index
is_datetype = index.inferred_type in ("datetime", "date", "datetime64", "time")

x: list[int] | np.ndarray
if self.use_index:
if convert_period and isinstance(index, ABCPeriodIndex):
self.data = self.data.reindex(index=index.sort_values())
x = self.data.index.to_timestamp()._mpl_repr()
index = cast("PeriodIndex", self.data.index)
x = index.to_timestamp()._mpl_repr()
elif is_any_real_numeric_dtype(index.dtype):
# Matplotlib supports numeric values or datetime objects as
# xaxis values. Taking LBYL approach here, by the time
Expand Down Expand Up @@ -1050,8 +1063,12 @@ def _get_colors(
color=self.kwds.get(color_kwds),
)

# TODO: tighter typing for first return?
@final
def _parse_errorbars(self, label: str, err):
@staticmethod
def _parse_errorbars(
label: str, err, data: NDFrameT, nseries: int
) -> tuple[Any, NDFrameT]:
"""
Look for error keyword arguments and return the actual errorbar data
or return the error DataFrame/dict
Expand All @@ -1071,32 +1088,32 @@ def _parse_errorbars(self, label: str, err):
should be in a ``Mx2xN`` array.
"""
if err is None:
return None
return None, data

def match_labels(data, e):
e = e.reindex(data.index)
return e

# key-matched DataFrame
if isinstance(err, ABCDataFrame):
err = match_labels(self.data, err)
err = match_labels(data, err)
# key-matched dict
elif isinstance(err, dict):
pass

# Series of error values
elif isinstance(err, ABCSeries):
# broadcast error series across data
err = match_labels(self.data, err)
err = match_labels(data, err)
err = np.atleast_2d(err)
err = np.tile(err, (self.nseries, 1))
err = np.tile(err, (nseries, 1))

# errors are a column in the dataframe
elif isinstance(err, str):
evalues = self.data[err].values
self.data = self.data[self.data.columns.drop(err)]
evalues = data[err].values
data = data[data.columns.drop(err)]
err = np.atleast_2d(evalues)
err = np.tile(err, (self.nseries, 1))
err = np.tile(err, (nseries, 1))

elif is_list_like(err):
if is_iterator(err):
Expand All @@ -1108,40 +1125,40 @@ def match_labels(data, e):
err_shape = err.shape

# asymmetrical error bars
if isinstance(self.data, ABCSeries) and err_shape[0] == 2:
if isinstance(data, ABCSeries) and err_shape[0] == 2:
err = np.expand_dims(err, 0)
err_shape = err.shape
if err_shape[2] != len(self.data):
if err_shape[2] != len(data):
raise ValueError(
"Asymmetrical error bars should be provided "
f"with the shape (2, {len(self.data)})"
f"with the shape (2, {len(data)})"
)
elif isinstance(self.data, ABCDataFrame) and err.ndim == 3:
elif isinstance(data, ABCDataFrame) and err.ndim == 3:
if (
(err_shape[0] != self.nseries)
(err_shape[0] != nseries)
or (err_shape[1] != 2)
or (err_shape[2] != len(self.data))
or (err_shape[2] != len(data))
):
raise ValueError(
"Asymmetrical error bars should be provided "
f"with the shape ({self.nseries}, 2, {len(self.data)})"
f"with the shape ({nseries}, 2, {len(data)})"
)

# broadcast errors to each data series
if len(err) == 1:
err = np.tile(err, (self.nseries, 1))
err = np.tile(err, (nseries, 1))

elif is_number(err):
err = np.tile(
[err], # pyright: ignore[reportGeneralTypeIssues]
(self.nseries, len(self.data)),
(nseries, len(data)),
)

else:
msg = f"No valid {label} detected"
raise ValueError(msg)

return err
return err, data # pyright: ignore[reportGeneralTypeIssues]

@final
def _get_errorbars(
Expand Down Expand Up @@ -1215,8 +1232,7 @@ def __init__(self, data, x, y, **kwargs) -> None:
self.y = y

@final
@property
def nseries(self) -> int:
def _get_nseries(self, data: Series | DataFrame) -> int:
return 1

@final
Expand Down
7 changes: 5 additions & 2 deletions pandas/plotting/_matplotlib/hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@

from pandas._typing import PlottingOrientation

from pandas import DataFrame
from pandas import (
DataFrame,
Series,
)


class HistPlot(LinePlot):
Expand Down Expand Up @@ -87,7 +90,7 @@ def _adjust_bins(self, bins: int | np.ndarray | list[np.ndarray]):
bins = self._calculate_bins(self.data, bins)
return bins

def _calculate_bins(self, data: DataFrame, bins) -> np.ndarray:
def _calculate_bins(self, data: Series | DataFrame, bins) -> np.ndarray:
"""Calculate bins given data"""
nd_values = data.infer_objects(copy=False)._get_numeric_data()
values = np.ravel(nd_values)
Expand Down