Skip to content

Commit 0d4dc67

Browse files
committed
TST: copy matplotlib subplots function for compat with mpl < 1.0
1 parent 6e1bac7 commit 0d4dc67

File tree

2 files changed

+119
-10
lines changed

2 files changed

+119
-10
lines changed

pandas/core/frame.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -3571,11 +3571,12 @@ def plot(self, subplots=False, sharex=True, sharey=False, use_index=True,
35713571
and will error.
35723572
"""
35733573
import matplotlib.pyplot as plt
3574+
import pandas.tools.plotting as gfx
35743575

35753576
if subplots:
3576-
fig, axes = plt.subplots(nrows=len(self.columns),
3577-
sharex=sharex, sharey=sharey,
3578-
figsize=figsize)
3577+
fig, axes = gfx.subplots(nrows=len(self.columns),
3578+
sharex=sharex, sharey=sharey,
3579+
figsize=figsize)
35793580
else:
35803581
if ax is None:
35813582
fig = plt.figure(figsize=figsize)
@@ -3676,13 +3677,14 @@ def hist(self, grid=True, **kwds):
36763677
kwds : other plotting keyword arguments
36773678
To be passed to hist function
36783679
"""
3680+
import pandas.tools.plotting as gfx
36793681
import matplotlib.pyplot as plt
36803682

36813683
n = len(self.columns)
36823684
k = 1
36833685
while k ** 2 < n:
36843686
k += 1
3685-
_, axes = plt.subplots(nrows=k, ncols=k)
3687+
_, axes = gfx.subplots(nrows=k, ncols=k)
36863688

36873689
for i, col in enumerate(_try_sort(self.columns)):
36883690
ax = axes[i / k][i % k]

pandas/tools/plotting.py

+113-6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import numpy as np
2+
13
def scatter_matrix(data):
24
pass
35

@@ -123,14 +125,12 @@ def plot_group(group, ax):
123125

124126
def _grouped_plot(plotf, data, by=None, numeric_only=True, figsize=(10, 5),
125127
sharex=True, sharey=True):
126-
import matplotlib.pyplot as plt
127-
128128
grouped = data.groupby(by)
129129
ngroups = len(grouped)
130130

131131
nrows, ncols = _get_layout(ngroups)
132-
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize,
133-
sharex=sharex, sharey=sharey)
132+
fig, axes = subplots(nrows=nrows, ncols=ncols, figsize=figsize,
133+
sharex=sharex, sharey=sharey)
134134

135135
ravel_axes = []
136136
for row in axes:
@@ -155,8 +155,8 @@ def _grouped_plot_by_column(plotf, data, columns=None, by=None,
155155
ngroups = len(columns)
156156

157157
nrows, ncols = _get_layout(ngroups)
158-
fig, axes = plt.subplots(nrows=nrows, ncols=ncols,
159-
sharex=True, sharey=True)
158+
fig, axes = subplots(nrows=nrows, ncols=ncols,
159+
sharex=True, sharey=True)
160160

161161
if isinstance(axes, plt.Axes):
162162
ravel_axes = [axes]
@@ -198,6 +198,113 @@ def _get_layout(nplots):
198198
else:
199199
return k, k
200200

201+
# copied from matplotlib/pyplot.py for compatibility with matplotlib < 1.0
202+
203+
def subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
204+
subplot_kw=None, **fig_kw):
205+
"""Create a figure with a set of subplots already made.
206+
207+
This utility wrapper makes it convenient to create common layouts of
208+
subplots, including the enclosing figure object, in a single call.
209+
210+
Keyword arguments:
211+
212+
nrows : int
213+
Number of rows of the subplot grid. Defaults to 1.
214+
215+
ncols : int
216+
Number of columns of the subplot grid. Defaults to 1.
217+
218+
sharex : bool
219+
If True, the X axis will be shared amongst all subplots.
220+
221+
sharex : bool
222+
If True, the Y axis will be shared amongst all subplots.
223+
224+
squeeze : bool
225+
226+
If True, extra dimensions are squeezed out from the returned axis object:
227+
- if only one subplot is constructed (nrows=ncols=1), the resulting
228+
single Axis object is returned as a scalar.
229+
- for Nx1 or 1xN subplots, the returned object is a 1-d numpy object
230+
array of Axis objects are returned as numpy 1-d arrays.
231+
- for NxM subplots with N>1 and M>1 are returned as a 2d array.
232+
233+
If False, no squeezing at all is done: the returned axis object is always
234+
a 2-d array contaning Axis instances, even if it ends up being 1x1.
235+
236+
subplot_kw : dict
237+
Dict with keywords passed to the add_subplot() call used to create each
238+
subplots.
239+
240+
fig_kw : dict
241+
Dict with keywords passed to the figure() call. Note that all keywords
242+
not recognized above will be automatically included here.
243+
244+
Returns:
245+
246+
fig, ax : tuple
247+
- fig is the Matplotlib Figure object
248+
- ax can be either a single axis object or an array of axis objects if
249+
more than one supblot was created. The dimensions of the resulting array
250+
can be controlled with the squeeze keyword, see above.
251+
252+
**Examples:**
253+
254+
x = np.linspace(0, 2*np.pi, 400)
255+
y = np.sin(x**2)
256+
257+
# Just a figure and one subplot
258+
f, ax = plt.subplots()
259+
ax.plot(x, y)
260+
ax.set_title('Simple plot')
261+
262+
# Two subplots, unpack the output array immediately
263+
f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
264+
ax1.plot(x, y)
265+
ax1.set_title('Sharing Y axis')
266+
ax2.scatter(x, y)
267+
268+
# Four polar axes
269+
plt.subplots(2, 2, subplot_kw=dict(polar=True))
270+
"""
271+
import matplotlib.pyplot as plt
272+
273+
if subplot_kw is None:
274+
subplot_kw = {}
275+
276+
fig = plt.figure(**fig_kw)
277+
278+
# Create empty object array to hold all axes. It's easiest to make it 1-d
279+
# so we can just append subplots upon creation, and then
280+
nplots = nrows*ncols
281+
axarr = np.empty(nplots, dtype=object)
282+
283+
# Create first subplot separately, so we can share it if requested
284+
ax0 = fig.add_subplot(nrows, ncols, 1, **subplot_kw)
285+
if sharex:
286+
subplot_kw['sharex'] = ax0
287+
if sharey:
288+
subplot_kw['sharey'] = ax0
289+
axarr[0] = ax0
290+
291+
# Note off-by-one counting because add_subplot uses the MATLAB 1-based
292+
# convention.
293+
for i in range(1, nplots):
294+
axarr[i] = fig.add_subplot(nrows, ncols, i+1, **subplot_kw)
295+
296+
if squeeze:
297+
# Reshape the array to have the final desired dimension (nrow,ncol),
298+
# though discarding unneeded dimensions that equal 1. If we only have
299+
# one subplot, just return it instead of a 1-element array.
300+
if nplots==1:
301+
return fig, axarr[0]
302+
else:
303+
return fig, axarr.reshape(nrows, ncols).squeeze()
304+
else:
305+
# returned axis array will be always 2-d, even if nrows=ncols=1
306+
return fig, axarr.reshape(nrows, ncols)
307+
201308
if __name__ == '__main__':
202309
import pandas.rpy.common as com
203310
sales = com.load_data('sanfrancisco.home.sales', package='nutshell')

0 commit comments

Comments
 (0)