7
7
from collections .abc import (
8
8
Hashable ,
9
9
Iterable ,
10
+ Iterator ,
10
11
Sequence ,
11
12
)
12
13
from typing import (
@@ -431,17 +432,15 @@ def _validate_color_args(self):
431
432
)
432
433
433
434
@final
434
- def _iter_data (self , data = None , keep_index : bool = False , fillna = None ):
435
- if data is None :
436
- data = self .data
437
- if fillna is not None :
438
- data = data .fillna (fillna )
439
-
435
+ @staticmethod
436
+ def _iter_data (
437
+ data : DataFrame | dict [Hashable , Series | DataFrame ]
438
+ ) -> Iterator [tuple [Hashable , np .ndarray ]]:
440
439
for col , values in data .items ():
441
- if keep_index is True :
442
- yield col , values
443
- else :
444
- yield col , values .values
440
+ # This was originally written to use values.values before EAs
441
+ # were implemented; adding np.asarray(...) to keep consistent
442
+ # typing.
443
+ yield col , np . asarray ( values .values )
445
444
446
445
@property
447
446
def nseries (self ) -> int :
@@ -480,7 +479,7 @@ def _has_plotted_object(ax: Axes) -> bool:
480
479
return len (ax .lines ) != 0 or len (ax .artists ) != 0 or len (ax .containers ) != 0
481
480
482
481
@final
483
- def _maybe_right_yaxis (self , ax : Axes , axes_num : int ):
482
+ def _maybe_right_yaxis (self , ax : Axes , axes_num : int ) -> Axes :
484
483
if not self .on_right (axes_num ):
485
484
# secondary axes may be passed via ax kw
486
485
return self ._get_ax_layer (ax )
@@ -656,11 +655,7 @@ def _compute_plot_data(self):
656
655
657
656
numeric_data = data .select_dtypes (include = include_type , exclude = exclude_type )
658
657
659
- try :
660
- is_empty = numeric_data .columns .empty
661
- except AttributeError :
662
- is_empty = not len (numeric_data )
663
-
658
+ is_empty = numeric_data .shape [- 1 ] == 0
664
659
# no non-numeric frames or series allowed
665
660
if is_empty :
666
661
raise TypeError ("no numeric data to plot" )
@@ -682,7 +677,7 @@ def _add_table(self) -> None:
682
677
tools .table (ax , data )
683
678
684
679
@final
685
- def _post_plot_logic_common (self , ax , data ):
680
+ def _post_plot_logic_common (self , ax : Axes , data ) -> None :
686
681
"""Common post process for each axes"""
687
682
if self .orientation == "vertical" or self .orientation is None :
688
683
self ._apply_axis_properties (ax .xaxis , rot = self .rot , fontsize = self .fontsize )
@@ -701,7 +696,7 @@ def _post_plot_logic_common(self, ax, data):
701
696
raise ValueError
702
697
703
698
@abstractmethod
704
- def _post_plot_logic (self , ax , data ) -> None :
699
+ def _post_plot_logic (self , ax : Axes , data ) -> None :
705
700
"""Post process for each axes. Overridden in child classes"""
706
701
707
702
@final
@@ -1056,7 +1051,7 @@ def _get_colors(
1056
1051
)
1057
1052
1058
1053
@final
1059
- def _parse_errorbars (self , label , err ):
1054
+ def _parse_errorbars (self , label : str , err ):
1060
1055
"""
1061
1056
Look for error keyword arguments and return the actual errorbar data
1062
1057
or return the error DataFrame/dict
@@ -1137,7 +1132,10 @@ def match_labels(data, e):
1137
1132
err = np .tile (err , (self .nseries , 1 ))
1138
1133
1139
1134
elif is_number (err ):
1140
- err = np .tile ([err ], (self .nseries , len (self .data )))
1135
+ err = np .tile (
1136
+ [err ], # pyright: ignore[reportGeneralTypeIssues]
1137
+ (self .nseries , len (self .data )),
1138
+ )
1141
1139
1142
1140
else :
1143
1141
msg = f"No valid { label } detected"
@@ -1418,14 +1416,14 @@ def _make_plot(self, fig: Figure) -> None:
1418
1416
1419
1417
x = data .index # dummy, not used
1420
1418
plotf = self ._ts_plot
1421
- it = self . _iter_data ( data = data , keep_index = True )
1419
+ it = data . items ( )
1422
1420
else :
1423
1421
x = self ._get_xticks (convert_period = True )
1424
1422
# error: Incompatible types in assignment (expression has type
1425
1423
# "Callable[[Any, Any, Any, Any, Any, Any, KwArg(Any)], Any]", variable has
1426
1424
# type "Callable[[Any, Any, Any, Any, KwArg(Any)], Any]")
1427
1425
plotf = self ._plot # type: ignore[assignment]
1428
- it = self ._iter_data ()
1426
+ it = self ._iter_data (data = self . data )
1429
1427
1430
1428
stacking_id = self ._get_stacking_id ()
1431
1429
is_errorbar = com .any_not_none (* self .errors .values ())
@@ -1434,7 +1432,12 @@ def _make_plot(self, fig: Figure) -> None:
1434
1432
for i , (label , y ) in enumerate (it ):
1435
1433
ax = self ._get_ax (i )
1436
1434
kwds = self .kwds .copy ()
1437
- style , kwds = self ._apply_style_colors (colors , kwds , i , label )
1435
+ style , kwds = self ._apply_style_colors (
1436
+ colors ,
1437
+ kwds ,
1438
+ i ,
1439
+ label , # pyright: ignore[reportGeneralTypeIssues]
1440
+ )
1438
1441
1439
1442
errors = self ._get_errorbars (label = label , index = i )
1440
1443
kwds = dict (kwds , ** errors )
@@ -1446,7 +1449,7 @@ def _make_plot(self, fig: Figure) -> None:
1446
1449
newlines = plotf (
1447
1450
ax ,
1448
1451
x ,
1449
- y ,
1452
+ y , # pyright: ignore[reportGeneralTypeIssues]
1450
1453
style = style ,
1451
1454
column_num = i ,
1452
1455
stacking_id = stacking_id ,
@@ -1465,7 +1468,14 @@ def _make_plot(self, fig: Figure) -> None:
1465
1468
# error: Signature of "_plot" incompatible with supertype "MPLPlot"
1466
1469
@classmethod
1467
1470
def _plot ( # type: ignore[override]
1468
- cls , ax : Axes , x , y , style = None , column_num = None , stacking_id = None , ** kwds
1471
+ cls ,
1472
+ ax : Axes ,
1473
+ x ,
1474
+ y : np .ndarray ,
1475
+ style = None ,
1476
+ column_num = None ,
1477
+ stacking_id = None ,
1478
+ ** kwds ,
1469
1479
):
1470
1480
# column_num is used to get the target column from plotf in line and
1471
1481
# area plots
@@ -1492,7 +1502,7 @@ def _ts_plot(self, ax: Axes, x, data: Series, style=None, **kwds):
1492
1502
decorate_axes (ax .right_ax , freq , kwds )
1493
1503
ax ._plot_data .append ((data , self ._kind , kwds ))
1494
1504
1495
- lines = self ._plot (ax , data .index , data .values , style = style , ** kwds )
1505
+ lines = self ._plot (ax , data .index , np . asarray ( data .values ) , style = style , ** kwds )
1496
1506
# set date formatter, locators and rescale limits
1497
1507
# error: Argument 3 to "format_dateaxis" has incompatible type "Index";
1498
1508
# expected "DatetimeIndex | PeriodIndex"
@@ -1520,7 +1530,9 @@ def _initialize_stacker(cls, ax: Axes, stacking_id, n: int) -> None:
1520
1530
1521
1531
@final
1522
1532
@classmethod
1523
- def _get_stacked_values (cls , ax : Axes , stacking_id , values , label ):
1533
+ def _get_stacked_values (
1534
+ cls , ax : Axes , stacking_id : int | None , values : np .ndarray , label
1535
+ ) -> np .ndarray :
1524
1536
if stacking_id is None :
1525
1537
return values
1526
1538
if not hasattr (ax , "_stacker_pos_prior" ):
@@ -1540,7 +1552,7 @@ def _get_stacked_values(cls, ax: Axes, stacking_id, values, label):
1540
1552
1541
1553
@final
1542
1554
@classmethod
1543
- def _update_stacker (cls , ax : Axes , stacking_id , values ) -> None :
1555
+ def _update_stacker (cls , ax : Axes , stacking_id : int | None , values ) -> None :
1544
1556
if stacking_id is None :
1545
1557
return
1546
1558
if (values >= 0 ).all ():
@@ -1618,7 +1630,7 @@ def _plot( # type: ignore[override]
1618
1630
cls ,
1619
1631
ax : Axes ,
1620
1632
x ,
1621
- y ,
1633
+ y : np . ndarray ,
1622
1634
style = None ,
1623
1635
column_num = None ,
1624
1636
stacking_id = None ,
@@ -1744,7 +1756,7 @@ def _plot( # type: ignore[override]
1744
1756
cls ,
1745
1757
ax : Axes ,
1746
1758
x ,
1747
- y ,
1759
+ y : np . ndarray ,
1748
1760
w ,
1749
1761
start : int | npt .NDArray [np .intp ] = 0 ,
1750
1762
log : bool = False ,
@@ -1763,7 +1775,8 @@ def _make_plot(self, fig: Figure) -> None:
1763
1775
pos_prior = neg_prior = np .zeros (len (self .data ))
1764
1776
K = self .nseries
1765
1777
1766
- for i , (label , y ) in enumerate (self ._iter_data (fillna = 0 )):
1778
+ data = self .data .fillna (0 )
1779
+ for i , (label , y ) in enumerate (self ._iter_data (data = data )):
1767
1780
ax = self ._get_ax (i )
1768
1781
kwds = self .kwds .copy ()
1769
1782
if self ._is_series :
@@ -1842,7 +1855,14 @@ def _post_plot_logic(self, ax: Axes, data) -> None:
1842
1855
1843
1856
self ._decorate_ticks (ax , self ._get_index_name (), str_index , s_edge , e_edge )
1844
1857
1845
- def _decorate_ticks (self , ax : Axes , name , ticklabels , start_edge , end_edge ) -> None :
1858
+ def _decorate_ticks (
1859
+ self ,
1860
+ ax : Axes ,
1861
+ name : str | None ,
1862
+ ticklabels : list [str ],
1863
+ start_edge : float ,
1864
+ end_edge : float ,
1865
+ ) -> None :
1846
1866
ax .set_xlim ((start_edge , end_edge ))
1847
1867
1848
1868
if self .xticks is not None :
@@ -1876,7 +1896,7 @@ def _plot( # type: ignore[override]
1876
1896
cls ,
1877
1897
ax : Axes ,
1878
1898
x ,
1879
- y ,
1899
+ y : np . ndarray ,
1880
1900
w ,
1881
1901
start : int | npt .NDArray [np .intp ] = 0 ,
1882
1902
log : bool = False ,
@@ -1887,7 +1907,14 @@ def _plot( # type: ignore[override]
1887
1907
def _get_custom_index_name (self ):
1888
1908
return self .ylabel
1889
1909
1890
- def _decorate_ticks (self , ax : Axes , name , ticklabels , start_edge , end_edge ) -> None :
1910
+ def _decorate_ticks (
1911
+ self ,
1912
+ ax : Axes ,
1913
+ name : str | None ,
1914
+ ticklabels : list [str ],
1915
+ start_edge : float ,
1916
+ end_edge : float ,
1917
+ ) -> None :
1891
1918
# horizontal bars
1892
1919
ax .set_ylim ((start_edge , end_edge ))
1893
1920
ax .set_yticks (self .tick_pos )
@@ -1921,7 +1948,7 @@ def _make_plot(self, fig: Figure) -> None:
1921
1948
colors = self ._get_colors (num_colors = len (self .data ), color_kwds = "colors" )
1922
1949
self .kwds .setdefault ("colors" , colors )
1923
1950
1924
- for i , (label , y ) in enumerate (self ._iter_data ()):
1951
+ for i , (label , y ) in enumerate (self ._iter_data (data = self . data )):
1925
1952
ax = self ._get_ax (i )
1926
1953
if label is not None :
1927
1954
label = pprint_thing (label )
0 commit comments