Skip to content

Commit 7298313

Browse files
author
Chang She
committed
add ax kwd to several functions and push ax into subplots so new subplot axes is generated on the ax's figure
1 parent d3bcab7 commit 7298313

File tree

1 file changed

+30
-14
lines changed

1 file changed

+30
-14
lines changed

pandas/tools/plotting.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,15 @@ def _args_adjust(self):
205205
def _setup_subplots(self):
206206
if self.subplots:
207207
nrows, ncols = self._get_layout()
208-
fig, axes = _subplots(nrows=nrows, ncols=ncols,
209-
sharex=self.sharex, sharey=self.sharey,
210-
figsize=self.figsize)
208+
if self.ax is None:
209+
fig, axes = _subplots(nrows=nrows, ncols=ncols,
210+
sharex=self.sharex, sharey=self.sharey,
211+
figsize=self.figsize)
212+
else:
213+
fig, axes = _subplots(nrows=nrows, ncols=ncols,
214+
sharex=self.sharex, sharey=self.sharey,
215+
figsize=self.figsize, ax=self.ax)
216+
211217
else:
212218
if self.ax is None:
213219
fig = self.plt.figure(figsize=self.figsize)
@@ -509,10 +515,13 @@ def plot_frame(frame=None, subplots=False, sharex=True, sharey=False,
509515
-------
510516
ax_or_axes : matplotlib.AxesSubplot or list of them
511517
"""
518+
kind = kind.lower().strip()
512519
if kind == 'line':
513520
klass = LinePlot
514521
elif kind in ('bar', 'barh'):
515522
klass = BarPlot
523+
else:
524+
raise ValueError('Invalid chart type given %s' % kind)
516525

517526
plot_obj = klass(frame, kind=kind, subplots=subplots, rot=rot,
518527
legend=legend, ax=ax, fontsize=fontsize,
@@ -691,10 +700,11 @@ def plot_group(group, ax):
691700
ax.scatter(xvals, yvals)
692701

693702
if by is not None:
694-
fig = _grouped_plot(plot_group, data, by=by, figsize=figsize)
703+
fig = _grouped_plot(plot_group, data, by=by, figsize=figsize, ax=ax)
695704
else:
696-
fig = plt.figure()
697-
ax = fig.add_subplot(111)
705+
if ax is None:
706+
fig = plt.figure()
707+
ax = fig.add_subplot(111)
698708
plot_group(data, ax)
699709
ax.set_ylabel(str(y))
700710
ax.set_xlabel(str(x))
@@ -703,7 +713,7 @@ def plot_group(group, ax):
703713

704714

705715
def hist_frame(data, grid=True, xlabelsize=None, xrot=None,
706-
ylabelsize=None, yrot=None, **kwds):
716+
ylabelsize=None, yrot=None, ax=None, **kwds):
707717
"""
708718
Draw Histogram the DataFrame's series using matplotlib / pylab.
709719
@@ -719,6 +729,7 @@ def hist_frame(data, grid=True, xlabelsize=None, xrot=None,
719729
If specified changes the y-axis label size
720730
yrot : float, default None
721731
rotation of y axis labels
732+
ax : matplotlib axes object, default None
722733
kwds : other plotting keyword arguments
723734
To be passed to hist function
724735
"""
@@ -727,7 +738,7 @@ def hist_frame(data, grid=True, xlabelsize=None, xrot=None,
727738
k = 1
728739
while k ** 2 < n:
729740
k += 1
730-
_, axes = _subplots(nrows=k, ncols=k)
741+
_, axes = _subplots(nrows=k, ncols=k, ax=ax)
731742

732743
for i, col in enumerate(com._try_sort(data.columns)):
733744
ax = axes[i / k][i % k]
@@ -797,7 +808,7 @@ def hist_series(self, ax=None, grid=True, xlabelsize=None, xrot=None,
797808

798809
def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
799810
figsize=None, sharex=True, sharey=True, layout=None,
800-
rot=0):
811+
rot=0, ax=None):
801812
from pandas.core.frame import DataFrame
802813

803814
# allow to specify mpl default with 'default'
@@ -817,7 +828,7 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
817828
# default size
818829
figsize = (10, 5)
819830
fig, axes = _subplots(nrows=nrows, ncols=ncols, figsize=figsize,
820-
sharex=sharex, sharey=sharey)
831+
sharex=sharex, sharey=sharey, ax=ax)
821832

822833
ravel_axes = []
823834
for row in axes:
@@ -834,7 +845,7 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
834845

835846
def _grouped_plot_by_column(plotf, data, columns=None, by=None,
836847
numeric_only=True, grid=False,
837-
figsize=None):
848+
figsize=None, ax=None):
838849
import matplotlib.pyplot as plt
839850

840851
grouped = data.groupby(by)
@@ -845,7 +856,7 @@ def _grouped_plot_by_column(plotf, data, columns=None, by=None,
845856
nrows, ncols = _get_layout(ngroups)
846857
fig, axes = _subplots(nrows=nrows, ncols=ncols,
847858
sharex=True, sharey=True,
848-
figsize=figsize)
859+
figsize=figsize, ax=ax)
849860

850861
if isinstance(axes, plt.Axes):
851862
ravel_axes = [axes]
@@ -890,7 +901,7 @@ def _get_layout(nplots):
890901
# copied from matplotlib/pyplot.py for compatibility with matplotlib < 1.0
891902

892903
def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
893-
subplot_kw=None, **fig_kw):
904+
subplot_kw=None, ax=None, **fig_kw):
894905
"""Create a figure with a set of subplots already made.
895906
896907
This utility wrapper makes it convenient to create common layouts of
@@ -930,6 +941,8 @@ def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
930941
Dict with keywords passed to the figure() call. Note that all keywords
931942
not recognized above will be automatically included here.
932943
944+
ax : Matplotlib axis object, default None
945+
933946
Returns:
934947
935948
fig, ax : tuple
@@ -962,7 +975,10 @@ def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
962975
if subplot_kw is None:
963976
subplot_kw = {}
964977

965-
fig = plt.figure(**fig_kw)
978+
if ax is None:
979+
fig = plt.figure(**fig_kw)
980+
else:
981+
fig = ax.get_figure()
966982

967983
# Create empty object array to hold all axes. It's easiest to make it 1-d
968984
# so we can just append subplots upon creation, and then

0 commit comments

Comments
 (0)