diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index a0983ecf8578d..4fe40a6ebe4ad 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -142,6 +142,30 @@ def test_hist(self): df = DataFrame(np.random.randn(100, 6)) _check_plot_works(df.hist) + #make sure kwargs are handled + ser = df[0] + xf, yf = 20, 20 + xrot, yrot = 30, 30 + ax = ser.hist(xlabelsize=xf, xrot=30, ylabelsize=yf, yrot=30) + ytick = ax.get_yticklabels()[0] + xtick = ax.get_xticklabels()[0] + self.assertAlmostEqual(ytick.get_fontsize(), yf) + self.assertAlmostEqual(ytick.get_rotation(), yrot) + self.assertAlmostEqual(xtick.get_fontsize(), xf) + self.assertAlmostEqual(xtick.get_rotation(), xrot) + + xf, yf = 20, 20 + xrot, yrot = 30, 30 + axes = df.hist(xlabelsize=xf, xrot=30, ylabelsize=yf, yrot=30) + for i, ax in enumerate(axes.ravel()): + if i < len(df.columns): + ytick = ax.get_yticklabels()[0] + xtick = ax.get_xticklabels()[0] + self.assertAlmostEqual(ytick.get_fontsize(), yf) + self.assertAlmostEqual(ytick.get_rotation(), yrot) + self.assertAlmostEqual(xtick.get_fontsize(), xf) + self.assertAlmostEqual(xtick.get_rotation(), xrot) + @slow def test_scatter(self): df = DataFrame(np.random.randn(100, 4)) diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index f80b69f2104ad..ea30787d0f172 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -205,9 +205,15 @@ def _args_adjust(self): def _setup_subplots(self): if self.subplots: nrows, ncols = self._get_layout() - fig, axes = _subplots(nrows=nrows, ncols=ncols, - sharex=self.sharex, sharey=self.sharey, - figsize=self.figsize) + if self.ax is None: + fig, axes = _subplots(nrows=nrows, ncols=ncols, + sharex=self.sharex, sharey=self.sharey, + figsize=self.figsize) + else: + fig, axes = _subplots(nrows=nrows, ncols=ncols, + sharex=self.sharex, sharey=self.sharey, + figsize=self.figsize, ax=self.ax) + else: if self.ax is None: fig = self.plt.figure(figsize=self.figsize) @@ -509,10 +515,13 @@ def plot_frame(frame=None, subplots=False, sharex=True, sharey=False, ------- ax_or_axes : matplotlib.AxesSubplot or list of them """ + kind = kind.lower().strip() if kind == 'line': klass = LinePlot elif kind in ('bar', 'barh'): klass = BarPlot + else: + raise ValueError('Invalid chart type given %s' % kind) plot_obj = klass(frame, kind=kind, subplots=subplots, rot=rot, legend=legend, ax=ax, fontsize=fontsize, @@ -691,10 +700,13 @@ def plot_group(group, ax): ax.scatter(xvals, yvals) if by is not None: - fig = _grouped_plot(plot_group, data, by=by, figsize=figsize) + fig = _grouped_plot(plot_group, data, by=by, figsize=figsize, ax=ax) else: - fig = plt.figure() - ax = fig.add_subplot(111) + if ax is None: + fig = plt.figure() + ax = fig.add_subplot(111) + else: + fig = ax.get_figure() plot_group(data, ax) ax.set_ylabel(str(y)) ax.set_xlabel(str(x)) @@ -702,20 +714,33 @@ def plot_group(group, ax): return fig -def hist_frame(data, grid=True, **kwds): +def hist_frame(data, grid=True, xlabelsize=None, xrot=None, + ylabelsize=None, yrot=None, ax=None, **kwds): """ Draw Histogram the DataFrame's series using matplotlib / pylab. Parameters ---------- + grid : boolean, default True + Whether to show axis grid lines + xlabelsize : int, default None + If specified changes the x-axis label size + xrot : float, default None + rotation of x axis labels + ylabelsize : int, default None + If specified changes the y-axis label size + yrot : float, default None + rotation of y axis labels + ax : matplotlib axes object, default None kwds : other plotting keyword arguments To be passed to hist function """ + import matplotlib.pyplot as plt n = len(data.columns) k = 1 while k ** 2 < n: k += 1 - _, axes = _subplots(nrows=k, ncols=k) + _, axes = _subplots(nrows=k, ncols=k, ax=ax) for i, col in enumerate(com._try_sort(data.columns)): ax = axes[i / k][i % k] @@ -723,10 +748,19 @@ def hist_frame(data, grid=True, **kwds): ax.set_title(col) ax.grid(grid) - return axes + if xlabelsize is not None: + plt.setp(ax.get_xticklabels(), fontsize=xlabelsize) + if xrot is not None: + plt.setp(ax.get_xticklabels(), rotation=xrot) + if ylabelsize is not None: + plt.setp(ax.get_yticklabels(), fontsize=ylabelsize) + if yrot is not None: + plt.setp(ax.get_yticklabels(), rotation=yrot) + return axes -def hist_series(self, ax=None, grid=True, **kwds): +def hist_series(self, ax=None, grid=True, xlabelsize=None, xrot=None, + ylabelsize=None, yrot=None, **kwds): """ Draw histogram of the input series using matplotlib @@ -734,6 +768,16 @@ def hist_series(self, ax=None, grid=True, **kwds): ---------- ax : matplotlib axis object If not passed, uses gca() + grid : boolean, default True + Whether to show axis grid lines + xlabelsize : int, default None + If specified changes the x-axis label size + xrot : float, default None + rotation of x axis labels + ylabelsize : int, default None + If specified changes the y-axis label size + yrot : float, default None + rotation of y axis labels kwds : keywords To be passed to the actual plotting function @@ -752,12 +796,21 @@ def hist_series(self, ax=None, grid=True, **kwds): ax.hist(values, **kwds) ax.grid(grid) + if xlabelsize is not None: + plt.setp(ax.get_xticklabels(), fontsize=xlabelsize) + if xrot is not None: + plt.setp(ax.get_xticklabels(), rotation=xrot) + if ylabelsize is not None: + plt.setp(ax.get_yticklabels(), fontsize=ylabelsize) + if yrot is not None: + plt.setp(ax.get_yticklabels(), rotation=yrot) + return ax def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True, figsize=None, sharex=True, sharey=True, layout=None, - rot=0): + rot=0, ax=None): from pandas.core.frame import DataFrame # allow to specify mpl default with 'default' @@ -777,7 +830,7 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True, # default size figsize = (10, 5) fig, axes = _subplots(nrows=nrows, ncols=ncols, figsize=figsize, - sharex=sharex, sharey=sharey) + sharex=sharex, sharey=sharey, ax=ax) ravel_axes = [] for row in axes: @@ -794,7 +847,7 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True, def _grouped_plot_by_column(plotf, data, columns=None, by=None, numeric_only=True, grid=False, - figsize=None): + figsize=None, ax=None): import matplotlib.pyplot as plt grouped = data.groupby(by) @@ -805,7 +858,7 @@ def _grouped_plot_by_column(plotf, data, columns=None, by=None, nrows, ncols = _get_layout(ngroups) fig, axes = _subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=True, - figsize=figsize) + figsize=figsize, ax=ax) if isinstance(axes, plt.Axes): ravel_axes = [axes] @@ -850,7 +903,7 @@ def _get_layout(nplots): # copied from matplotlib/pyplot.py for compatibility with matplotlib < 1.0 def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True, - subplot_kw=None, **fig_kw): + subplot_kw=None, ax=None, **fig_kw): """Create a figure with a set of subplots already made. This utility wrapper makes it convenient to create common layouts of @@ -890,6 +943,8 @@ def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True, Dict with keywords passed to the figure() call. Note that all keywords not recognized above will be automatically included here. + ax : Matplotlib axis object, default None + Returns: fig, ax : tuple @@ -922,7 +977,10 @@ def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True, if subplot_kw is None: subplot_kw = {} - fig = plt.figure(**fig_kw) + if ax is None: + fig = plt.figure(**fig_kw) + else: + fig = ax.get_figure() # Create empty object array to hold all axes. It's easiest to make it 1-d # so we can just append subplots upon creation, and then