@@ -887,25 +887,31 @@ def _validate_color_args(self):
887
887
" use one or the other or pass 'style' "
888
888
"without a color symbol" )
889
889
890
- def _iter_data (self ):
891
- from pandas .core .frame import DataFrame
892
- if isinstance (self .data , (Series , np .ndarray )):
893
- yield self .label , np .asarray (self .data )
894
- elif isinstance (self .data , DataFrame ):
895
- df = self .data
890
+ def _iter_data (self , data = None , keep_index = False ):
891
+ if data is None :
892
+ data = self .data
896
893
894
+ from pandas .core .frame import DataFrame
895
+ if isinstance (data , (Series , np .ndarray )):
896
+ if keep_index is True :
897
+ yield self .label , data
898
+ else :
899
+ yield self .label , np .asarray (data )
900
+ elif isinstance (data , DataFrame ):
897
901
if self .sort_columns :
898
- columns = com ._try_sort (df .columns )
902
+ columns = com ._try_sort (data .columns )
899
903
else :
900
- columns = df .columns
904
+ columns = data .columns
901
905
902
906
for col in columns :
903
907
# # is this right?
904
908
# empty = df[col].count() == 0
905
909
# values = df[col].values if not empty else np.zeros(len(df))
906
910
907
- values = df [col ].values
908
- yield col , values
911
+ if keep_index is True :
912
+ yield col , data [col ]
913
+ else :
914
+ yield col , data [col ].values
909
915
910
916
@property
911
917
def nseries (self ):
@@ -1593,38 +1599,26 @@ def _plot(data, col_num, ax, label, style, **kwds):
1593
1599
1594
1600
self ._add_legend_handle (newlines [0 ], label , index = col_num )
1595
1601
1596
- if isinstance (data , Series ):
1597
- ax = self . _get_ax ( 0 ) # self.axes[0]
1598
- style = self .style or ''
1599
- label = com . pprint_thing ( self .label )
1602
+ it = self . _iter_data (data = data , keep_index = True )
1603
+ for i , ( label , y ) in enumerate ( it ):
1604
+ ax = self ._get_ax ( i )
1605
+ style = self ._get_style ( i , label )
1600
1606
kwds = self .kwds .copy ()
1601
- self ._maybe_add_color (colors , kwds , style , 0 )
1602
-
1603
- if 'yerr' in kwds :
1604
- kwds ['yerr' ] = kwds ['yerr' ][0 ]
1605
1607
1606
- _plot (data , 0 , ax , label , self .style , ** kwds )
1607
-
1608
- else :
1609
- for i , col in enumerate (data .columns ):
1610
- label = com .pprint_thing (col )
1611
- ax = self ._get_ax (i )
1612
- style = self ._get_style (i , col )
1613
- kwds = self .kwds .copy ()
1614
-
1615
- self ._maybe_add_color (colors , kwds , style , i )
1608
+ self ._maybe_add_color (colors , kwds , style , i )
1616
1609
1617
- # key-matched DataFrame of errors
1618
- if 'yerr' in kwds :
1619
- yerr = kwds ['yerr' ]
1620
- if isinstance (yerr , (DataFrame , dict )):
1621
- if col in yerr .keys ():
1622
- kwds ['yerr' ] = yerr [col ]
1623
- else : del kwds ['yerr' ]
1624
- else :
1625
- kwds ['yerr' ] = yerr [i ]
1610
+ # key-matched DataFrame of errors
1611
+ if 'yerr' in kwds :
1612
+ yerr = kwds ['yerr' ]
1613
+ if isinstance (yerr , (DataFrame , dict )):
1614
+ if label in yerr .keys ():
1615
+ kwds ['yerr' ] = yerr [label ]
1616
+ else : del kwds ['yerr' ]
1617
+ else :
1618
+ kwds ['yerr' ] = yerr [i ]
1626
1619
1627
- _plot (data [col ], i , ax , label , style , ** kwds )
1620
+ label = com .pprint_thing (label )
1621
+ _plot (y , i , ax , label , style , ** kwds )
1628
1622
1629
1623
def _maybe_convert_index (self , data ):
1630
1624
# tsplot converts automatically, but don't want to convert index
@@ -1828,6 +1822,16 @@ class BoxPlot(MPLPlot):
1828
1822
class HistPlot (MPLPlot ):
1829
1823
pass
1830
1824
1825
+ # kinds supported by both dataframe and series
1826
+ _common_kinds = ['line' , 'bar' , 'barh' , 'kde' , 'density' ]
1827
+ # kinds supported by dataframe
1828
+ _dataframe_kinds = ['scatter' , 'hexbin' ]
1829
+ _all_kinds = _common_kinds + _dataframe_kinds
1830
+
1831
+ _plot_klass = {'line' : LinePlot , 'bar' : BarPlot , 'barh' : BarPlot ,
1832
+ 'kde' : KdePlot ,
1833
+ 'scatter' : ScatterPlot , 'hexbin' : HexBinPlot }
1834
+
1831
1835
1832
1836
def plot_frame (frame = None , x = None , y = None , subplots = False , sharex = True ,
1833
1837
sharey = False , use_index = True , figsize = None , grid = None ,
@@ -1921,21 +1925,14 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
1921
1925
is a function of one argument that reduces all the values in a bin to
1922
1926
a single number (e.g. `mean`, `max`, `sum`, `std`).
1923
1927
"""
1928
+
1924
1929
kind = _get_standard_kind (kind .lower ().strip ())
1925
- if kind == 'line' :
1926
- klass = LinePlot
1927
- elif kind in ('bar' , 'barh' ):
1928
- klass = BarPlot
1929
- elif kind == 'kde' :
1930
- klass = KdePlot
1931
- elif kind == 'scatter' :
1932
- klass = ScatterPlot
1933
- elif kind == 'hexbin' :
1934
- klass = HexBinPlot
1930
+ if kind in _dataframe_kinds or kind in _common_kinds :
1931
+ klass = _plot_klass [kind ]
1935
1932
else :
1936
1933
raise ValueError ('Invalid chart type given %s' % kind )
1937
1934
1938
- if kind == 'scatter' :
1935
+ if kind in _dataframe_kinds :
1939
1936
plot_obj = klass (frame , x = x , y = y , kind = kind , subplots = subplots ,
1940
1937
rot = rot ,legend = legend , ax = ax , style = style ,
1941
1938
fontsize = fontsize , use_index = use_index , sharex = sharex ,
@@ -1944,16 +1941,6 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
1944
1941
figsize = figsize , logx = logx , logy = logy ,
1945
1942
sort_columns = sort_columns , secondary_y = secondary_y ,
1946
1943
** kwds )
1947
- elif kind == 'hexbin' :
1948
- C = kwds .pop ('C' , None ) # remove from kwargs so we can set default
1949
- plot_obj = klass (frame , x = x , y = y , kind = kind , subplots = subplots ,
1950
- rot = rot ,legend = legend , ax = ax , style = style ,
1951
- fontsize = fontsize , use_index = use_index , sharex = sharex ,
1952
- sharey = sharey , xticks = xticks , yticks = yticks ,
1953
- xlim = xlim , ylim = ylim , title = title , grid = grid ,
1954
- figsize = figsize , logx = logx , logy = logy ,
1955
- sort_columns = sort_columns , secondary_y = secondary_y ,
1956
- C = C , ** kwds )
1957
1944
else :
1958
1945
if x is not None :
1959
1946
if com .is_integer (x ) and not frame .columns .holds_integer ():
@@ -2051,14 +2038,9 @@ def plot_series(series, label=None, kind='line', use_index=True, rot=None,
2051
2038
See matplotlib documentation online for more on this subject
2052
2039
"""
2053
2040
2054
- from pandas import DataFrame
2055
2041
kind = _get_standard_kind (kind .lower ().strip ())
2056
- if kind == 'line' :
2057
- klass = LinePlot
2058
- elif kind in ('bar' , 'barh' ):
2059
- klass = BarPlot
2060
- elif kind == 'kde' :
2061
- klass = KdePlot
2042
+ if kind in _common_kinds :
2043
+ klass = _plot_klass [kind ]
2062
2044
else :
2063
2045
raise ValueError ('Invalid chart type given %s' % kind )
2064
2046
0 commit comments