diff --git a/doc/source/v0.15.0.txt b/doc/source/v0.15.0.txt index d15a48535f1eb..bbf665b574409 100644 --- a/doc/source/v0.15.0.txt +++ b/doc/source/v0.15.0.txt @@ -303,6 +303,9 @@ Enhancements ~~~~~~~~~~~~ - Added support for bool, uint8, uint16 and uint32 datatypes in ``to_stata`` (:issue:`7097`, :issue:`7365`) +- Added ``layout`` keyword to ``DataFrame.plot`` (:issue:`6667`) +- Allow to pass multiple axes to ``DataFrame.plot``, ``hist`` and ``boxplot`` (:issue:`5353`, :issue:`6970`, :issue:`7069`) + - ``PeriodIndex`` supports ``resolution`` as the same as ``DatetimeIndex`` (:issue:`7708`) - ``pandas.tseries.holiday`` has added support for additional holidays and ways to observe holidays (:issue:`7070`) diff --git a/doc/source/visualization.rst b/doc/source/visualization.rst index 40b5d7c1599c1..e8d3d147479c2 100644 --- a/doc/source/visualization.rst +++ b/doc/source/visualization.rst @@ -946,10 +946,41 @@ with the ``subplots`` keyword: @savefig frame_plot_subplots.png df.plot(subplots=True, figsize=(6, 6)); -Targeting Different Subplots -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Using Layout and Targetting Multiple Axes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -You can pass an ``ax`` argument to :meth:`Series.plot` to plot on a particular axis: +The layout of subplots can be specified by ``layout`` keyword. It can accept +``(rows, columns)``. The ``layout`` keyword can be used in +``hist`` and ``boxplot`` also. If input is invalid, ``ValueError`` will be raised. + +The number of axes which can be contained by rows x columns specified by ``layout`` must be +larger than the number of required subplots. If layout can contain more axes than required, +blank axes are not drawn. + +.. ipython:: python + + @savefig frame_plot_subplots_layout.png + df.plot(subplots=True, layout=(2, 3), figsize=(6, 6)); + +Also, you can pass multiple axes created beforehand as list-like via ``ax`` keyword. +This allows to use more complicated layout. +The passed axes must be the same number as the subplots being drawn. + +When multiple axes are passed via ``ax`` keyword, ``layout``, ``sharex`` and ``sharey`` keywords are ignored. +These must be configured when creating axes. + +.. ipython:: python + + fig, axes = plt.subplots(4, 4, figsize=(6, 6)); + plt.adjust_subplots(wspace=0.5, hspace=0.5); + target1 = [axes[0][0], axes[1][1], axes[2][2], axes[3][3]] + target2 = [axes[3][0], axes[2][1], axes[1][2], axes[0][3]] + + df.plot(subplots=True, ax=target1, legend=False); + @savefig frame_plot_subplots_multi_ax.png + (-df).plot(subplots=True, ax=target2, legend=False); + +Another option is passing an ``ax`` argument to :meth:`Series.plot` to plot on a particular axis: .. ipython:: python :suppress: @@ -964,12 +995,12 @@ You can pass an ``ax`` argument to :meth:`Series.plot` to plot on a particular a .. ipython:: python fig, axes = plt.subplots(nrows=2, ncols=2) - df['A'].plot(ax=axes[0,0]); axes[0,0].set_title('A') - df['B'].plot(ax=axes[0,1]); axes[0,1].set_title('B') - df['C'].plot(ax=axes[1,0]); axes[1,0].set_title('C') + df['A'].plot(ax=axes[0,0]); axes[0,0].set_title('A'); + df['B'].plot(ax=axes[0,1]); axes[0,1].set_title('B'); + df['C'].plot(ax=axes[1,0]); axes[1,0].set_title('C'); @savefig series_plot_multi.png - df['D'].plot(ax=axes[1,1]); axes[1,1].set_title('D') + df['D'].plot(ax=axes[1,1]); axes[1,1].set_title('D'); .. ipython:: python :suppress: diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index b3a92263370e8..1560b78a2f5e0 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -670,7 +670,7 @@ def test_hist_layout_with_by(self): axes = _check_plot_works(df.height.hist, by=df.classroom, layout=(2, 2)) self._check_axes_shape(axes, axes_num=3, layout=(2, 2)) - axes = _check_plot_works(df.height.hist, by=df.category, layout=(4, 2), figsize=(12, 7)) + axes = df.height.hist(by=df.category, layout=(4, 2), figsize=(12, 7)) self._check_axes_shape(axes, axes_num=4, layout=(4, 2), figsize=(12, 7)) @slow @@ -1071,6 +1071,7 @@ def test_subplots(self): for kind in ['bar', 'barh', 'line', 'area']: axes = df.plot(kind=kind, subplots=True, sharex=True, legend=True) self._check_axes_shape(axes, axes_num=3, layout=(3, 1)) + self.assertEqual(axes.shape, (3, )) for ax, column in zip(axes, df.columns): self._check_legend_labels(ax, labels=[com.pprint_thing(column)]) @@ -1133,6 +1134,77 @@ def test_subplots_timeseries(self): self._check_visible(ax.get_yticklabels()) self._check_ticks_props(ax, xlabelsize=7, xrot=45) + def test_subplots_layout(self): + # GH 6667 + df = DataFrame(np.random.rand(10, 3), + index=list(string.ascii_letters[:10])) + + axes = df.plot(subplots=True, layout=(2, 2)) + self._check_axes_shape(axes, axes_num=3, layout=(2, 2)) + self.assertEqual(axes.shape, (2, 2)) + + axes = df.plot(subplots=True, layout=(1, 4)) + self._check_axes_shape(axes, axes_num=3, layout=(1, 4)) + self.assertEqual(axes.shape, (1, 4)) + + with tm.assertRaises(ValueError): + axes = df.plot(subplots=True, layout=(1, 1)) + + # single column + df = DataFrame(np.random.rand(10, 1), + index=list(string.ascii_letters[:10])) + axes = df.plot(subplots=True) + self._check_axes_shape(axes, axes_num=1, layout=(1, 1)) + self.assertEqual(axes.shape, (1, )) + + axes = df.plot(subplots=True, layout=(3, 3)) + self._check_axes_shape(axes, axes_num=1, layout=(3, 3)) + self.assertEqual(axes.shape, (3, 3)) + + @slow + def test_subplots_multiple_axes(self): + # GH 5353, 6970, GH 7069 + fig, axes = self.plt.subplots(2, 3) + df = DataFrame(np.random.rand(10, 3), + index=list(string.ascii_letters[:10])) + + returned = df.plot(subplots=True, ax=axes[0]) + self._check_axes_shape(returned, axes_num=3, layout=(1, 3)) + self.assertEqual(returned.shape, (3, )) + self.assertIs(returned[0].figure, fig) + # draw on second row + returned = df.plot(subplots=True, ax=axes[1]) + self._check_axes_shape(returned, axes_num=3, layout=(1, 3)) + self.assertEqual(returned.shape, (3, )) + self.assertIs(returned[0].figure, fig) + self._check_axes_shape(axes, axes_num=6, layout=(2, 3)) + tm.close() + + with tm.assertRaises(ValueError): + fig, axes = self.plt.subplots(2, 3) + # pass different number of axes from required + df.plot(subplots=True, ax=axes) + + # pass 2-dim axes and invalid layout + # invalid lauout should not affect to input and return value + # (show warning is tested in + # TestDataFrameGroupByPlots.test_grouped_box_multiple_axes + fig, axes = self.plt.subplots(2, 2) + df = DataFrame(np.random.rand(10, 4), + index=list(string.ascii_letters[:10])) + + returned = df.plot(subplots=True, ax=axes, layout=(2, 1)) + self._check_axes_shape(returned, axes_num=4, layout=(2, 2)) + self.assertEqual(returned.shape, (4, )) + + # single column + fig, axes = self.plt.subplots(1, 1) + df = DataFrame(np.random.rand(10, 1), + index=list(string.ascii_letters[:10])) + axes = df.plot(subplots=True, ax=[axes]) + self._check_axes_shape(axes, axes_num=1, layout=(1, 1)) + self.assertEqual(axes.shape, (1, )) + def test_negative_log(self): df = - DataFrame(rand(6, 4), index=list(string.ascii_letters[:6]), @@ -1718,7 +1790,7 @@ def test_hist_df_coord(self): normal_df = DataFrame({'A': np.repeat(np.array([1, 2, 3, 4, 5]), np.array([10, 9, 8, 7, 6])), 'B': np.repeat(np.array([1, 2, 3, 4, 5]), - np.array([8, 8, 8, 8, 8])), + np.array([8, 8, 8, 8, 8])), 'C': np.repeat(np.array([1, 2, 3, 4, 5]), np.array([6, 7, 8, 9, 10]))}, columns=['A', 'B', 'C']) @@ -1726,7 +1798,7 @@ def test_hist_df_coord(self): nan_df = DataFrame({'A': np.repeat(np.array([np.nan, 1, 2, 3, 4, 5]), np.array([3, 10, 9, 8, 7, 6])), 'B': np.repeat(np.array([1, np.nan, 2, 3, 4, 5]), - np.array([8, 3, 8, 8, 8, 8])), + np.array([8, 3, 8, 8, 8, 8])), 'C': np.repeat(np.array([1, 2, 3, np.nan, 4, 5]), np.array([6, 7, 8, 3, 9, 10]))}, columns=['A', 'B', 'C']) @@ -2712,6 +2784,41 @@ def test_grouped_box_layout(self): return_type='dict') self._check_axes_shape(self.plt.gcf().axes, axes_num=3, layout=(1, 4)) + @slow + def test_grouped_box_multiple_axes(self): + # GH 6970, GH 7069 + df = self.hist_df + + # check warning to ignore sharex / sharey + # this check should be done in the first function which + # passes multiple axes to plot, hist or boxplot + # location should be changed if other test is added + # which has earlier alphabetical order + with tm.assert_produces_warning(UserWarning): + fig, axes = self.plt.subplots(2, 2) + df.groupby('category').boxplot(column='height', return_type='axes', ax=axes) + self._check_axes_shape(self.plt.gcf().axes, axes_num=4, layout=(2, 2)) + + fig, axes = self.plt.subplots(2, 3) + returned = df.boxplot(column=['height', 'weight', 'category'], by='gender', + return_type='axes', ax=axes[0]) + returned = np.array(returned.values()) + self._check_axes_shape(returned, axes_num=3, layout=(1, 3)) + self.assert_numpy_array_equal(returned, axes[0]) + self.assertIs(returned[0].figure, fig) + # draw on second row + returned = df.groupby('classroom').boxplot(column=['height', 'weight', 'category'], + return_type='axes', ax=axes[1]) + returned = np.array(returned.values()) + self._check_axes_shape(returned, axes_num=3, layout=(1, 3)) + self.assert_numpy_array_equal(returned, axes[1]) + self.assertIs(returned[0].figure, fig) + + with tm.assertRaises(ValueError): + fig, axes = self.plt.subplots(2, 3) + # pass different number of axes from required + axes = df.groupby('classroom').boxplot(ax=axes) + @slow def test_grouped_hist_layout(self): @@ -2724,12 +2831,12 @@ def test_grouped_hist_layout(self): axes = _check_plot_works(df.hist, column='height', by=df.gender, layout=(2, 1)) self._check_axes_shape(axes, axes_num=2, layout=(2, 1)) - axes = _check_plot_works(df.hist, column='height', by=df.category, layout=(4, 1)) + axes = df.hist(column='height', by=df.category, layout=(4, 1)) self._check_axes_shape(axes, axes_num=4, layout=(4, 1)) - axes = _check_plot_works(df.hist, column='height', by=df.category, - layout=(4, 2), figsize=(12, 8)) + axes = df.hist(column='height', by=df.category, layout=(4, 2), figsize=(12, 8)) self._check_axes_shape(axes, axes_num=4, layout=(4, 2), figsize=(12, 8)) + tm.close() # GH 6769 axes = _check_plot_works(df.hist, column='height', by='classroom', layout=(2, 2)) @@ -2739,13 +2846,32 @@ def test_grouped_hist_layout(self): axes = _check_plot_works(df.hist, by='classroom') self._check_axes_shape(axes, axes_num=3, layout=(2, 2)) - axes = _check_plot_works(df.hist, by='gender', layout=(3, 5)) + axes = df.hist(by='gender', layout=(3, 5)) self._check_axes_shape(axes, axes_num=2, layout=(3, 5)) - axes = _check_plot_works(df.hist, column=['height', 'weight', 'category']) + axes = df.hist(column=['height', 'weight', 'category']) self._check_axes_shape(axes, axes_num=3, layout=(2, 2)) @slow + def test_grouped_hist_multiple_axes(self): + # GH 6970, GH 7069 + df = self.hist_df + + fig, axes = self.plt.subplots(2, 3) + returned = df.hist(column=['height', 'weight', 'category'], ax=axes[0]) + self._check_axes_shape(returned, axes_num=3, layout=(1, 3)) + self.assert_numpy_array_equal(returned, axes[0]) + self.assertIs(returned[0].figure, fig) + returned = df.hist(by='classroom', ax=axes[1]) + self._check_axes_shape(returned, axes_num=3, layout=(1, 3)) + self.assert_numpy_array_equal(returned, axes[1]) + self.assertIs(returned[0].figure, fig) + + with tm.assertRaises(ValueError): + fig, axes = self.plt.subplots(2, 3) + # pass different number of axes from required + axes = df.hist(column='height', ax=axes) + @slow def test_axis_share_x(self): df = self.hist_df # GH4089 diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 7d0eaea5b36d6..18fc2bead02ec 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -246,7 +246,8 @@ def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False, df = frame._get_numeric_data() n = df.columns.size - fig, axes = _subplots(nrows=n, ncols=n, figsize=figsize, ax=ax, + naxes = n * n + fig, axes = _subplots(naxes=naxes, figsize=figsize, ax=ax, squeeze=False) # no gaps between subplots @@ -752,6 +753,7 @@ class MPLPlot(object): data : """ + _layout_type = 'vertical' _default_rot = 0 orientation = None @@ -767,7 +769,7 @@ def __init__(self, data, kind=None, by=None, subplots=False, sharex=True, xticks=None, yticks=None, sort_columns=False, fontsize=None, secondary_y=False, colormap=None, - table=False, **kwds): + table=False, layout=None, **kwds): self.data = data self.by = by @@ -780,6 +782,7 @@ def __init__(self, data, kind=None, by=None, subplots=False, sharex=True, self.sharex = sharex self.sharey = sharey self.figsize = figsize + self.layout = layout self.xticks = xticks self.yticks = yticks @@ -932,22 +935,22 @@ def _maybe_right_yaxis(self, ax): def _setup_subplots(self): if self.subplots: - nrows, ncols = self._get_layout() - fig, axes = _subplots(nrows=nrows, ncols=ncols, + fig, axes = _subplots(naxes=self.nseries, sharex=self.sharex, sharey=self.sharey, - figsize=self.figsize, ax=self.ax) - if not com.is_list_like(axes): - axes = np.array([axes]) + figsize=self.figsize, ax=self.ax, + layout=self.layout, + layout_type=self._layout_type) else: if self.ax is None: fig = self.plt.figure(figsize=self.figsize) - ax = fig.add_subplot(111) + axes = fig.add_subplot(111) else: fig = self.ax.get_figure() if self.figsize is not None: fig.set_size_inches(self.figsize) - ax = self.ax - axes = [ax] + axes = self.ax + + axes = _flatten(axes) if self.logx or self.loglog: [a.set_xscale('log') for a in axes] @@ -957,12 +960,18 @@ def _setup_subplots(self): self.fig = fig self.axes = axes - def _get_layout(self): - from pandas.core.frame import DataFrame - if isinstance(self.data, DataFrame): - return (len(self.data.columns), 1) + @property + def result(self): + """ + Return result axes + """ + if self.subplots: + if self.layout is not None and not com.is_list_like(self.ax): + return self.axes.reshape(*self.layout) + else: + return self.axes else: - return (1, 1) + return self.axes[0] def _compute_plot_data(self): numeric_data = self.data.convert_objects()._get_numeric_data() @@ -1360,6 +1369,8 @@ def _get_errorbars(self, label=None, index=None, xerr=True, yerr=True): class ScatterPlot(MPLPlot): + _layout_type = 'single' + def __init__(self, data, x, y, **kwargs): MPLPlot.__init__(self, data, **kwargs) self.kwds.setdefault('c', self.plt.rcParams['patch.facecolor']) @@ -1372,8 +1383,9 @@ def __init__(self, data, x, y, **kwargs): self.x = x self.y = y - def _get_layout(self): - return (1, 1) + @property + def nseries(self): + return 1 def _make_plot(self): x, y, data = self.x, self.y, self.data @@ -1404,6 +1416,8 @@ def _post_plot_logic(self): class HexBinPlot(MPLPlot): + _layout_type = 'single' + def __init__(self, data, x, y, C=None, **kwargs): MPLPlot.__init__(self, data, **kwargs) @@ -1421,8 +1435,9 @@ def __init__(self, data, x, y, C=None, **kwargs): self.y = y self.C = C - def _get_layout(self): - return (1, 1) + @property + def nseries(self): + return 1 def _make_plot(self): import matplotlib.pyplot as plt @@ -1966,6 +1981,8 @@ def _post_plot_logic(self): class PiePlot(MPLPlot): + _layout_type = 'horizontal' + def __init__(self, data, kind=None, **kwargs): data = data.fillna(value=0) if (data < 0).any().any(): @@ -1978,13 +1995,6 @@ def _args_adjust(self): self.logx = False self.loglog = False - def _get_layout(self): - from pandas import DataFrame - if isinstance(self.data, DataFrame): - return (1, len(self.data.columns)) - else: - return (1, 1) - def _validate_color_args(self): pass @@ -2044,7 +2054,7 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, legend=True, rot=None, ax=None, style=None, title=None, xlim=None, ylim=None, logx=False, logy=False, xticks=None, yticks=None, kind='line', sort_columns=False, fontsize=None, - secondary_y=False, **kwds): + secondary_y=False, layout=None, **kwds): """ Make line, bar, or scatter plots of DataFrame series with the index on the x-axis @@ -2116,6 +2126,8 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, position : float Specify relative alignments for bar plot layout. From 0 (left/bottom-end) to 1 (right/top-end). Default is 0.5 (center) + layout : tuple (optional) + (rows, columns) for the layout of the plot table : boolean, Series or DataFrame, default False If True, draw a table using the data in the DataFrame and the data will be transposed to meet matplotlib's default layout. @@ -2153,7 +2165,7 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, xlim=xlim, ylim=ylim, title=title, grid=grid, figsize=figsize, logx=logx, logy=logy, sort_columns=sort_columns, secondary_y=secondary_y, - **kwds) + layout=layout, **kwds) elif kind in _series_kinds: if y is None and subplots is False: msg = "{0} requires either y column or 'subplots=True'" @@ -2169,9 +2181,8 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, fontsize=fontsize, use_index=use_index, sharex=sharex, sharey=sharey, xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim, title=title, grid=grid, - figsize=figsize, - sort_columns=sort_columns, - **kwds) + figsize=figsize, layout=layout, + sort_columns=sort_columns, **kwds) else: if x is not None: if com.is_integer(x) and not frame.columns.holds_integer(): @@ -2209,14 +2220,11 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim, title=title, grid=grid, figsize=figsize, logx=logx, logy=logy, sort_columns=sort_columns, - secondary_y=secondary_y, **kwds) + secondary_y=secondary_y, layout=layout, **kwds) plot_obj.generate() plot_obj.draw() - if subplots: - return plot_obj.axes - else: - return plot_obj.axes[0] + return plot_obj.result def plot_series(series, label=None, kind='line', use_index=True, rot=None, @@ -2311,7 +2319,7 @@ def plot_series(series, label=None, kind='line', use_index=True, rot=None, plot_obj.draw() # plot_obj.ax is None if we created the first figure - return plot_obj.axes[0] + return plot_obj.result _shared_docs['boxplot'] = """ @@ -2551,12 +2559,13 @@ def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None, data = data._get_numeric_data() naxes = len(data.columns) - nrows, ncols = _get_layout(naxes, layout=layout) - fig, axes = _subplots(nrows=nrows, ncols=ncols, naxes=naxes, ax=ax, squeeze=False, - sharex=sharex, sharey=sharey, figsize=figsize) + fig, axes = _subplots(naxes=naxes, ax=ax, squeeze=False, + sharex=sharex, sharey=sharey, figsize=figsize, + layout=layout) + _axes = _flatten(axes) for i, col in enumerate(com._try_sort(data.columns)): - ax = axes[i // ncols, i % ncols] + ax = _axes[i] ax.hist(data[col].dropna().values, bins=bins, **kwds) ax.set_title(col) ax.grid(grid) @@ -2672,7 +2681,7 @@ def plot_group(group, ax): xrot = xrot or rot fig, axes = _grouped_plot(plot_group, data, column=column, - by=by, sharex=sharex, sharey=sharey, + by=by, sharex=sharex, sharey=sharey, ax=ax, figsize=figsize, layout=layout, rot=rot) _set_ticks_props(axes, xlabelsize=xlabelsize, xrot=xrot, @@ -2730,9 +2739,9 @@ def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None, """ if subplots is True: naxes = len(grouped) - nrows, ncols = _get_layout(naxes, layout=layout) - fig, axes = _subplots(nrows=nrows, ncols=ncols, naxes=naxes, squeeze=False, - ax=ax, sharex=False, sharey=True, figsize=figsize) + fig, axes = _subplots(naxes=naxes, squeeze=False, + ax=ax, sharex=False, sharey=True, figsize=figsize, + layout=layout) axes = _flatten(axes) ret = compat.OrderedDict() @@ -2773,14 +2782,14 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True, grouped = grouped[column] naxes = len(grouped) - nrows, ncols = _get_layout(naxes, layout=layout) - fig, axes = _subplots(nrows=nrows, ncols=ncols, naxes=naxes, - figsize=figsize, sharex=sharex, sharey=sharey, ax=ax) + fig, axes = _subplots(naxes=naxes, figsize=figsize, + sharex=sharex, sharey=sharey, ax=ax, + layout=layout) - ravel_axes = _flatten(axes) + _axes = _flatten(axes) for i, (key, group) in enumerate(grouped): - ax = ravel_axes[i] + ax = _axes[i] if numeric_only and isinstance(group, DataFrame): group = group._get_numeric_data() plotf(group, ax, **kwargs) @@ -2799,16 +2808,14 @@ def _grouped_plot_by_column(plotf, data, columns=None, by=None, by = [by] columns = data._get_numeric_data().columns - by naxes = len(columns) - nrows, ncols = _get_layout(naxes, layout=layout) - fig, axes = _subplots(nrows=nrows, ncols=ncols, naxes=naxes, - sharex=True, sharey=True, - figsize=figsize, ax=ax) + fig, axes = _subplots(naxes=naxes, sharex=True, sharey=True, + figsize=figsize, ax=ax, layout=layout) - ravel_axes = _flatten(axes) + _axes = _flatten(axes) result = compat.OrderedDict() for i, col in enumerate(columns): - ax = ravel_axes[i] + ax = _axes[i] gp_col = grouped[col] keys, values = zip(*gp_col) re_plotf = plotf(keys, values, ax, **kwargs) @@ -2869,7 +2876,7 @@ def table(ax, data, rowLabels=None, colLabels=None, return table -def _get_layout(nplots, layout=None): +def _get_layout(nplots, layout=None, layout_type='box'): if layout is not None: if not isinstance(layout, (tuple, list)) or len(layout) != 2: raise ValueError('Layout must be a tuple of (rows, columns)') @@ -2881,27 +2888,31 @@ def _get_layout(nplots, layout=None): return layout - if nplots == 1: + if layout_type == 'single': return (1, 1) - elif nplots == 2: - return (1, 2) - elif nplots < 4: - return (2, 2) + elif layout_type == 'horizontal': + return (1, nplots) + elif layout_type == 'vertical': + return (nplots, 1) - k = 1 - while k ** 2 < nplots: - k += 1 - - if (k - 1) * k >= nplots: - return k, (k - 1) - else: - return k, k + layouts = {1: (1, 1), 2: (1, 2), 3: (2, 2), 4: (2, 2)} + try: + return layouts[nplots] + except KeyError: + k = 1 + while k ** 2 < nplots: + k += 1 + + if (k - 1) * k >= nplots: + return k, (k - 1) + else: + return k, k -# copied from matplotlib/pyplot.py for compatibility with matplotlib < 1.0 +# copied from matplotlib/pyplot.py and modified for pandas.plotting -def _subplots(nrows=1, ncols=1, naxes=None, sharex=False, sharey=False, squeeze=True, - subplot_kw=None, ax=None, **fig_kw): +def _subplots(naxes=None, sharex=False, sharey=False, squeeze=True, + subplot_kw=None, ax=None, layout=None, layout_type='box', **fig_kw): """Create a figure with a set of subplots already made. This utility wrapper makes it convenient to create common layouts of @@ -2909,12 +2920,6 @@ def _subplots(nrows=1, ncols=1, naxes=None, sharex=False, sharey=False, squeeze= Keyword arguments: - nrows : int - Number of rows of the subplot grid. Defaults to 1. - - ncols : int - Number of columns of the subplot grid. Defaults to 1. - naxes : int Number of required axes. Exceeded axes are set invisible. Default is nrows * ncols. @@ -2942,11 +2947,17 @@ def _subplots(nrows=1, ncols=1, naxes=None, sharex=False, sharey=False, squeeze= ax : Matplotlib axis object, optional + layout : tuple + Number of rows and columns of the subplot grid. + If not specified, calculated from naxes and layout_type + + layout_type : {'box', 'horziontal', 'vertical'}, default 'box' + Specify how to layout the subplot grid. + fig_kw : Other keyword arguments to be passed to the figure() call. Note that all keywords not recognized above will be automatically included here. - Returns: fig, ax : tuple @@ -2975,23 +2986,27 @@ def _subplots(nrows=1, ncols=1, naxes=None, sharex=False, sharey=False, squeeze= plt.subplots(2, 2, subplot_kw=dict(polar=True)) """ import matplotlib.pyplot as plt - from pandas.core.frame import DataFrame if subplot_kw is None: subplot_kw = {} - # Create empty object array to hold all axes. It's easiest to make it 1-d - # so we can just append subplots upon creation, and then - nplots = nrows * ncols - - if naxes is None: - naxes = nrows * ncols - elif nplots < naxes: - raise ValueError("naxes {0} is larger than layour size defined by nrows * ncols".format(naxes)) - if ax is None: fig = plt.figure(**fig_kw) else: + if com.is_list_like(ax): + ax = _flatten(ax) + if layout is not None: + warnings.warn("When passing multiple axes, layout keyword is ignored", UserWarning) + if sharex or sharey: + warnings.warn("When passing multiple axes, sharex and sharey are ignored." + "These settings must be specified when creating axes", UserWarning) + if len(ax) == naxes: + fig = ax[0].get_figure() + return fig, ax + else: + raise ValueError("The number of passed axes must be {0}, the same as " + "the output plot".format(naxes)) + fig = ax.get_figure() # if ax is passed and a number of subplots is 1, return ax as it is if naxes == 1: @@ -3004,6 +3019,11 @@ def _subplots(nrows=1, ncols=1, naxes=None, sharex=False, sharey=False, squeeze= "is being cleared", UserWarning) fig.clear() + nrows, ncols = _get_layout(naxes, layout=layout, layout_type=layout_type) + nplots = nrows * ncols + + # Create empty object array to hold all axes. It's easiest to make it 1-d + # so we can just append subplots upon creation, and then axarr = np.empty(nplots, dtype=object) # Create first subplot separately, so we can share it if requested @@ -3074,10 +3094,10 @@ def _subplots(nrows=1, ncols=1, naxes=None, sharex=False, sharey=False, squeeze= def _flatten(axes): if not com.is_list_like(axes): - axes = [axes] + return np.array([axes]) elif isinstance(axes, (np.ndarray, Index)): - axes = axes.ravel() - return axes + return axes.ravel() + return np.array(axes) def _get_all_lines(ax):