@@ -205,9 +205,15 @@ def _args_adjust(self):
205
205
def _setup_subplots (self ):
206
206
if self .subplots :
207
207
nrows , ncols = self ._get_layout ()
208
- fig , axes = _subplots (nrows = nrows , ncols = ncols ,
209
- sharex = self .sharex , sharey = self .sharey ,
210
- figsize = self .figsize )
208
+ if self .ax is None :
209
+ fig , axes = _subplots (nrows = nrows , ncols = ncols ,
210
+ sharex = self .sharex , sharey = self .sharey ,
211
+ figsize = self .figsize )
212
+ else :
213
+ fig , axes = _subplots (nrows = nrows , ncols = ncols ,
214
+ sharex = self .sharex , sharey = self .sharey ,
215
+ figsize = self .figsize , ax = self .ax )
216
+
211
217
else :
212
218
if self .ax is None :
213
219
fig = self .plt .figure (figsize = self .figsize )
@@ -509,10 +515,13 @@ def plot_frame(frame=None, subplots=False, sharex=True, sharey=False,
509
515
-------
510
516
ax_or_axes : matplotlib.AxesSubplot or list of them
511
517
"""
518
+ kind = kind .lower ().strip ()
512
519
if kind == 'line' :
513
520
klass = LinePlot
514
521
elif kind in ('bar' , 'barh' ):
515
522
klass = BarPlot
523
+ else :
524
+ raise ValueError ('Invalid chart type given %s' % kind )
516
525
517
526
plot_obj = klass (frame , kind = kind , subplots = subplots , rot = rot ,
518
527
legend = legend , ax = ax , fontsize = fontsize ,
@@ -691,49 +700,84 @@ def plot_group(group, ax):
691
700
ax .scatter (xvals , yvals )
692
701
693
702
if by is not None :
694
- fig = _grouped_plot (plot_group , data , by = by , figsize = figsize )
703
+ fig = _grouped_plot (plot_group , data , by = by , figsize = figsize , ax = ax )
695
704
else :
696
- fig = plt .figure ()
697
- ax = fig .add_subplot (111 )
705
+ if ax is None :
706
+ fig = plt .figure ()
707
+ ax = fig .add_subplot (111 )
708
+ else :
709
+ fig = ax .get_figure ()
698
710
plot_group (data , ax )
699
711
ax .set_ylabel (str (y ))
700
712
ax .set_xlabel (str (x ))
701
713
702
714
return fig
703
715
704
716
705
- def hist_frame (data , grid = True , ** kwds ):
717
+ def hist_frame (data , grid = True , xlabelsize = None , xrot = None ,
718
+ ylabelsize = None , yrot = None , ax = None , ** kwds ):
706
719
"""
707
720
Draw Histogram the DataFrame's series using matplotlib / pylab.
708
721
709
722
Parameters
710
723
----------
724
+ grid : boolean, default True
725
+ Whether to show axis grid lines
726
+ xlabelsize : int, default None
727
+ If specified changes the x-axis label size
728
+ xrot : float, default None
729
+ rotation of x axis labels
730
+ ylabelsize : int, default None
731
+ If specified changes the y-axis label size
732
+ yrot : float, default None
733
+ rotation of y axis labels
734
+ ax : matplotlib axes object, default None
711
735
kwds : other plotting keyword arguments
712
736
To be passed to hist function
713
737
"""
738
+ import matplotlib .pyplot as plt
714
739
n = len (data .columns )
715
740
k = 1
716
741
while k ** 2 < n :
717
742
k += 1
718
- _ , axes = _subplots (nrows = k , ncols = k )
743
+ _ , axes = _subplots (nrows = k , ncols = k , ax = ax )
719
744
720
745
for i , col in enumerate (com ._try_sort (data .columns )):
721
746
ax = axes [i / k ][i % k ]
722
747
ax .hist (data [col ].dropna ().values , ** kwds )
723
748
ax .set_title (col )
724
749
ax .grid (grid )
725
750
726
- return axes
751
+ if xlabelsize is not None :
752
+ plt .setp (ax .get_xticklabels (), fontsize = xlabelsize )
753
+ if xrot is not None :
754
+ plt .setp (ax .get_xticklabels (), rotation = xrot )
755
+ if ylabelsize is not None :
756
+ plt .setp (ax .get_yticklabels (), fontsize = ylabelsize )
757
+ if yrot is not None :
758
+ plt .setp (ax .get_yticklabels (), rotation = yrot )
727
759
760
+ return axes
728
761
729
- def hist_series (self , ax = None , grid = True , ** kwds ):
762
+ def hist_series (self , ax = None , grid = True , xlabelsize = None , xrot = None ,
763
+ ylabelsize = None , yrot = None , ** kwds ):
730
764
"""
731
765
Draw histogram of the input series using matplotlib
732
766
733
767
Parameters
734
768
----------
735
769
ax : matplotlib axis object
736
770
If not passed, uses gca()
771
+ grid : boolean, default True
772
+ Whether to show axis grid lines
773
+ xlabelsize : int, default None
774
+ If specified changes the x-axis label size
775
+ xrot : float, default None
776
+ rotation of x axis labels
777
+ ylabelsize : int, default None
778
+ If specified changes the y-axis label size
779
+ yrot : float, default None
780
+ rotation of y axis labels
737
781
kwds : keywords
738
782
To be passed to the actual plotting function
739
783
@@ -752,12 +796,21 @@ def hist_series(self, ax=None, grid=True, **kwds):
752
796
ax .hist (values , ** kwds )
753
797
ax .grid (grid )
754
798
799
+ if xlabelsize is not None :
800
+ plt .setp (ax .get_xticklabels (), fontsize = xlabelsize )
801
+ if xrot is not None :
802
+ plt .setp (ax .get_xticklabels (), rotation = xrot )
803
+ if ylabelsize is not None :
804
+ plt .setp (ax .get_yticklabels (), fontsize = ylabelsize )
805
+ if yrot is not None :
806
+ plt .setp (ax .get_yticklabels (), rotation = yrot )
807
+
755
808
return ax
756
809
757
810
758
811
def _grouped_plot (plotf , data , column = None , by = None , numeric_only = True ,
759
812
figsize = None , sharex = True , sharey = True , layout = None ,
760
- rot = 0 ):
813
+ rot = 0 , ax = None ):
761
814
from pandas .core .frame import DataFrame
762
815
763
816
# allow to specify mpl default with 'default'
@@ -777,7 +830,7 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
777
830
# default size
778
831
figsize = (10 , 5 )
779
832
fig , axes = _subplots (nrows = nrows , ncols = ncols , figsize = figsize ,
780
- sharex = sharex , sharey = sharey )
833
+ sharex = sharex , sharey = sharey , ax = ax )
781
834
782
835
ravel_axes = []
783
836
for row in axes :
@@ -794,7 +847,7 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
794
847
795
848
def _grouped_plot_by_column (plotf , data , columns = None , by = None ,
796
849
numeric_only = True , grid = False ,
797
- figsize = None ):
850
+ figsize = None , ax = None ):
798
851
import matplotlib .pyplot as plt
799
852
800
853
grouped = data .groupby (by )
@@ -805,7 +858,7 @@ def _grouped_plot_by_column(plotf, data, columns=None, by=None,
805
858
nrows , ncols = _get_layout (ngroups )
806
859
fig , axes = _subplots (nrows = nrows , ncols = ncols ,
807
860
sharex = True , sharey = True ,
808
- figsize = figsize )
861
+ figsize = figsize , ax = ax )
809
862
810
863
if isinstance (axes , plt .Axes ):
811
864
ravel_axes = [axes ]
@@ -850,7 +903,7 @@ def _get_layout(nplots):
850
903
# copied from matplotlib/pyplot.py for compatibility with matplotlib < 1.0
851
904
852
905
def _subplots (nrows = 1 , ncols = 1 , sharex = False , sharey = False , squeeze = True ,
853
- subplot_kw = None , ** fig_kw ):
906
+ subplot_kw = None , ax = None , ** fig_kw ):
854
907
"""Create a figure with a set of subplots already made.
855
908
856
909
This utility wrapper makes it convenient to create common layouts of
@@ -890,6 +943,8 @@ def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
890
943
Dict with keywords passed to the figure() call. Note that all keywords
891
944
not recognized above will be automatically included here.
892
945
946
+ ax : Matplotlib axis object, default None
947
+
893
948
Returns:
894
949
895
950
fig, ax : tuple
@@ -922,7 +977,10 @@ def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
922
977
if subplot_kw is None :
923
978
subplot_kw = {}
924
979
925
- fig = plt .figure (** fig_kw )
980
+ if ax is None :
981
+ fig = plt .figure (** fig_kw )
982
+ else :
983
+ fig = ax .get_figure ()
926
984
927
985
# Create empty object array to hold all axes. It's easiest to make it 1-d
928
986
# so we can just append subplots upon creation, and then
0 commit comments