@@ -205,9 +205,15 @@ def _args_adjust(self):
205
205
def _setup_subplots (self ):
206
206
if self .subplots :
207
207
nrows , ncols = self ._get_layout ()
208
- fig , axes = _subplots (nrows = nrows , ncols = ncols ,
209
- sharex = self .sharex , sharey = self .sharey ,
210
- figsize = self .figsize )
208
+ if self .ax is None :
209
+ fig , axes = _subplots (nrows = nrows , ncols = ncols ,
210
+ sharex = self .sharex , sharey = self .sharey ,
211
+ figsize = self .figsize )
212
+ else :
213
+ fig , axes = _subplots (nrows = nrows , ncols = ncols ,
214
+ sharex = self .sharex , sharey = self .sharey ,
215
+ figsize = self .figsize , ax = self .ax )
216
+
211
217
else :
212
218
if self .ax is None :
213
219
fig = self .plt .figure (figsize = self .figsize )
@@ -509,10 +515,13 @@ def plot_frame(frame=None, subplots=False, sharex=True, sharey=False,
509
515
-------
510
516
ax_or_axes : matplotlib.AxesSubplot or list of them
511
517
"""
518
+ kind = kind .lower ().strip ()
512
519
if kind == 'line' :
513
520
klass = LinePlot
514
521
elif kind in ('bar' , 'barh' ):
515
522
klass = BarPlot
523
+ else :
524
+ raise ValueError ('Invalid chart type given %s' % kind )
516
525
517
526
plot_obj = klass (frame , kind = kind , subplots = subplots , rot = rot ,
518
527
legend = legend , ax = ax , fontsize = fontsize ,
@@ -691,10 +700,11 @@ def plot_group(group, ax):
691
700
ax .scatter (xvals , yvals )
692
701
693
702
if by is not None :
694
- fig = _grouped_plot (plot_group , data , by = by , figsize = figsize )
703
+ fig = _grouped_plot (plot_group , data , by = by , figsize = figsize , ax = ax )
695
704
else :
696
- fig = plt .figure ()
697
- ax = fig .add_subplot (111 )
705
+ if ax is None :
706
+ fig = plt .figure ()
707
+ ax = fig .add_subplot (111 )
698
708
plot_group (data , ax )
699
709
ax .set_ylabel (str (y ))
700
710
ax .set_xlabel (str (x ))
@@ -703,7 +713,7 @@ def plot_group(group, ax):
703
713
704
714
705
715
def hist_frame (data , grid = True , xlabelsize = None , xrot = None ,
706
- ylabelsize = None , yrot = None , ** kwds ):
716
+ ylabelsize = None , yrot = None , ax = None , ** kwds ):
707
717
"""
708
718
Draw Histogram the DataFrame's series using matplotlib / pylab.
709
719
@@ -719,6 +729,7 @@ def hist_frame(data, grid=True, xlabelsize=None, xrot=None,
719
729
If specified changes the y-axis label size
720
730
yrot : float, default None
721
731
rotation of y axis labels
732
+ ax : matplotlib axes object, default None
722
733
kwds : other plotting keyword arguments
723
734
To be passed to hist function
724
735
"""
@@ -727,7 +738,7 @@ def hist_frame(data, grid=True, xlabelsize=None, xrot=None,
727
738
k = 1
728
739
while k ** 2 < n :
729
740
k += 1
730
- _ , axes = _subplots (nrows = k , ncols = k )
741
+ _ , axes = _subplots (nrows = k , ncols = k , ax = ax )
731
742
732
743
for i , col in enumerate (com ._try_sort (data .columns )):
733
744
ax = axes [i / k ][i % k ]
@@ -797,7 +808,7 @@ def hist_series(self, ax=None, grid=True, xlabelsize=None, xrot=None,
797
808
798
809
def _grouped_plot (plotf , data , column = None , by = None , numeric_only = True ,
799
810
figsize = None , sharex = True , sharey = True , layout = None ,
800
- rot = 0 ):
811
+ rot = 0 , ax = None ):
801
812
from pandas .core .frame import DataFrame
802
813
803
814
# allow to specify mpl default with 'default'
@@ -817,7 +828,7 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
817
828
# default size
818
829
figsize = (10 , 5 )
819
830
fig , axes = _subplots (nrows = nrows , ncols = ncols , figsize = figsize ,
820
- sharex = sharex , sharey = sharey )
831
+ sharex = sharex , sharey = sharey , ax = ax )
821
832
822
833
ravel_axes = []
823
834
for row in axes :
@@ -834,7 +845,7 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
834
845
835
846
def _grouped_plot_by_column (plotf , data , columns = None , by = None ,
836
847
numeric_only = True , grid = False ,
837
- figsize = None ):
848
+ figsize = None , ax = None ):
838
849
import matplotlib .pyplot as plt
839
850
840
851
grouped = data .groupby (by )
@@ -845,7 +856,7 @@ def _grouped_plot_by_column(plotf, data, columns=None, by=None,
845
856
nrows , ncols = _get_layout (ngroups )
846
857
fig , axes = _subplots (nrows = nrows , ncols = ncols ,
847
858
sharex = True , sharey = True ,
848
- figsize = figsize )
859
+ figsize = figsize , ax = ax )
849
860
850
861
if isinstance (axes , plt .Axes ):
851
862
ravel_axes = [axes ]
@@ -890,7 +901,7 @@ def _get_layout(nplots):
890
901
# copied from matplotlib/pyplot.py for compatibility with matplotlib < 1.0
891
902
892
903
def _subplots (nrows = 1 , ncols = 1 , sharex = False , sharey = False , squeeze = True ,
893
- subplot_kw = None , ** fig_kw ):
904
+ subplot_kw = None , ax = None , ** fig_kw ):
894
905
"""Create a figure with a set of subplots already made.
895
906
896
907
This utility wrapper makes it convenient to create common layouts of
@@ -930,6 +941,8 @@ def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
930
941
Dict with keywords passed to the figure() call. Note that all keywords
931
942
not recognized above will be automatically included here.
932
943
944
+ ax : Matplotlib axis object, default None
945
+
933
946
Returns:
934
947
935
948
fig, ax : tuple
@@ -962,7 +975,10 @@ def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
962
975
if subplot_kw is None :
963
976
subplot_kw = {}
964
977
965
- fig = plt .figure (** fig_kw )
978
+ if ax is None :
979
+ fig = plt .figure (** fig_kw )
980
+ else :
981
+ fig = ax .get_figure ()
966
982
967
983
# Create empty object array to hold all axes. It's easiest to make it 1-d
968
984
# so we can just append subplots upon creation, and then
0 commit comments