Skip to content

Commit cdba224

Browse files
committed
Merge pull request #4113 from cpcloud/double-fig-series-hist-by-fix
BUG: make sure fig is not doubled when passing by to series.hist
2 parents 5b0db94 + f5c8db5 commit cdba224

File tree

3 files changed

+53
-54
lines changed

3 files changed

+53
-54
lines changed

doc/source/release.rst

+2
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,8 @@ Bug Fixes
435435
- Bug in getitem with a duplicate index when using where (:issue:`4879`)
436436
- Fix Type inference code coerces float column into datetime (:issue:`4601`)
437437
- Fixed ``_ensure_numeric`` does not check for complex numbers (:issue:`4902`)
438+
- Fixed a bug in ``Series.hist`` where two figures were being created when
439+
the ``by`` argument was passed (:issue:`4112`, :issue:`4113`).
438440

439441

440442
pandas 0.12.0

pandas/tests/test_graphics.py

+44-48
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,10 @@ def test_plot(self):
6262
_check_plot_works(self.series[:10].plot, kind='barh')
6363
_check_plot_works(Series(randn(10)).plot, kind='bar', color='black')
6464

65+
@slow
66+
def test_plot_figsize_and_title(self):
6567
# figsize and title
6668
import matplotlib.pyplot as plt
67-
plt.close('all')
6869
ax = self.series.plot(title='Test', figsize=(16, 8))
6970

7071
self.assertEqual(ax.title.get_text(), 'Test')
@@ -79,7 +80,6 @@ def test_bar_colors(self):
7980
default_colors = plt.rcParams.get('axes.color_cycle')
8081
custom_colors = 'rgcby'
8182

82-
plt.close('all')
8383
df = DataFrame(randn(5, 5))
8484
ax = df.plot(kind='bar')
8585

@@ -91,7 +91,7 @@ def test_bar_colors(self):
9191
rs = rect.get_facecolor()
9292
self.assertEqual(xp, rs)
9393

94-
plt.close('all')
94+
tm.close()
9595

9696
ax = df.plot(kind='bar', color=custom_colors)
9797

@@ -103,8 +103,7 @@ def test_bar_colors(self):
103103
rs = rect.get_facecolor()
104104
self.assertEqual(xp, rs)
105105

106-
plt.close('all')
107-
106+
tm.close()
108107
from matplotlib import cm
109108

110109
# Test str -> colormap functionality
@@ -118,7 +117,7 @@ def test_bar_colors(self):
118117
rs = rect.get_facecolor()
119118
self.assertEqual(xp, rs)
120119

121-
plt.close('all')
120+
tm.close()
122121

123122
# Test colormap functionality
124123
ax = df.plot(kind='bar', colormap=cm.jet)
@@ -131,8 +130,7 @@ def test_bar_colors(self):
131130
rs = rect.get_facecolor()
132131
self.assertEqual(xp, rs)
133132

134-
plt.close('all')
135-
133+
tm.close()
136134
df.ix[:, [0]].plot(kind='bar', color='DodgerBlue')
137135

138136
@slow
@@ -192,7 +190,7 @@ def test_hist(self):
192190
_check_plot_works(self.ts.hist, ax=ax)
193191
_check_plot_works(self.ts.hist, ax=ax, figure=fig)
194192
_check_plot_works(self.ts.hist, figure=fig)
195-
plt.close('all')
193+
tm.close()
196194

197195
fig, (ax1, ax2) = plt.subplots(1, 2)
198196
_check_plot_works(self.ts.hist, figure=fig, ax=ax1)
@@ -204,9 +202,8 @@ def test_hist(self):
204202
@slow
205203
def test_hist_layout(self):
206204
n = 10
207-
df = DataFrame({'gender': np.array(['Male',
208-
'Female'])[random.randint(2,
209-
size=n)],
205+
gender = tm.choice(['Male', 'Female'], size=n)
206+
df = DataFrame({'gender': gender,
210207
'height': random.normal(66, 4, size=n), 'weight':
211208
random.normal(161, 32, size=n)})
212209
with tm.assertRaises(ValueError):
@@ -219,23 +216,22 @@ def test_hist_layout(self):
219216
def test_hist_layout_with_by(self):
220217
import matplotlib.pyplot as plt
221218
n = 10
222-
df = DataFrame({'gender': np.array(['Male',
223-
'Female'])[random.randint(2,
224-
size=n)],
219+
gender = tm.choice(['Male', 'Female'], size=n)
220+
df = DataFrame({'gender': gender,
225221
'height': random.normal(66, 4, size=n), 'weight':
226222
random.normal(161, 32, size=n),
227223
'category': random.randint(4, size=n)})
228224
_check_plot_works(df.height.hist, by=df.gender, layout=(2, 1))
229-
plt.close('all')
225+
tm.close()
230226

231227
_check_plot_works(df.height.hist, by=df.gender, layout=(1, 2))
232-
plt.close('all')
228+
tm.close()
233229

234230
_check_plot_works(df.weight.hist, by=df.category, layout=(1, 4))
235-
plt.close('all')
231+
tm.close()
236232

237233
_check_plot_works(df.weight.hist, by=df.category, layout=(4, 1))
238-
plt.close('all')
234+
tm.close()
239235

240236
@slow
241237
def test_hist_no_overlap(self):
@@ -256,6 +252,15 @@ def test_plot_fails_with_dupe_color_and_style(self):
256252
with tm.assertRaises(ValueError):
257253
x.plot(style='k--', color='k')
258254

255+
@slow
256+
def test_hist_by_no_extra_plots(self):
257+
import matplotlib.pyplot as plt
258+
n = 10
259+
df = DataFrame({'gender': tm.choice(['Male', 'Female'], size=n),
260+
'height': random.normal(66, 4, size=n)})
261+
axes = df.height.hist(by=df.gender)
262+
self.assertEqual(len(plt.get_fignums()), 1)
263+
259264
def test_plot_fails_when_ax_differs_from_figure(self):
260265
from pylab import figure, close
261266
fig1 = figure()
@@ -436,7 +441,6 @@ def test_plot_xy(self):
436441
self._check_data(df.plot(y=1), df[1].plot())
437442

438443
# figsize and title
439-
plt.close('all')
440444
ax = df.plot(x=1, y=2, title='Test', figsize=(16, 8))
441445

442446
self.assertEqual(ax.title.get_text(), 'Test')
@@ -456,26 +460,26 @@ def test_xcompat(self):
456460
lines = ax.get_lines()
457461
self.assert_(not isinstance(lines[0].get_xdata(), PeriodIndex))
458462

459-
plt.close('all')
463+
tm.close()
460464
pd.plot_params['xaxis.compat'] = True
461465
ax = df.plot()
462466
lines = ax.get_lines()
463467
self.assert_(not isinstance(lines[0].get_xdata(), PeriodIndex))
464468

465-
plt.close('all')
469+
tm.close()
466470
pd.plot_params['x_compat'] = False
467471
ax = df.plot()
468472
lines = ax.get_lines()
469473
tm.assert_isinstance(lines[0].get_xdata(), PeriodIndex)
470474

471-
plt.close('all')
475+
tm.close()
472476
# useful if you're plotting a bunch together
473477
with pd.plot_params.use('x_compat', True):
474478
ax = df.plot()
475479
lines = ax.get_lines()
476480
self.assert_(not isinstance(lines[0].get_xdata(), PeriodIndex))
477481

478-
plt.close('all')
482+
tm.close()
479483
ax = df.plot()
480484
lines = ax.get_lines()
481485
tm.assert_isinstance(lines[0].get_xdata(), PeriodIndex)
@@ -499,6 +503,7 @@ def check_line(xpl, rsl):
499503
assert_array_equal(xpdata, rsdata)
500504

501505
[check_line(xpl, rsl) for xpl, rsl in zip(xp_lines, rs_lines)]
506+
tm.close()
502507

503508
@slow
504509
def test_subplots(self):
@@ -537,19 +542,14 @@ def test_plot_bar(self):
537542
columns=['one', 'two', 'three', 'four'])
538543

539544
_check_plot_works(df.plot, kind='bar')
540-
close('all')
541545
_check_plot_works(df.plot, kind='bar', legend=False)
542-
close('all')
543546
_check_plot_works(df.plot, kind='bar', subplots=True)
544-
close('all')
545547
_check_plot_works(df.plot, kind='bar', stacked=True)
546-
close('all')
547548

548549
df = DataFrame(randn(10, 15),
549550
index=list(string.ascii_letters[:10]),
550551
columns=lrange(15))
551552
_check_plot_works(df.plot, kind='bar')
552-
close('all')
553553

554554
df = DataFrame({'a': [0, 1], 'b': [1, 0]})
555555
_check_plot_works(df.plot, kind='bar')
@@ -678,18 +678,18 @@ def test_hist(self):
678678
self.assertAlmostEqual(xtick.get_fontsize(), xf)
679679
self.assertAlmostEqual(xtick.get_rotation(), xrot)
680680

681-
plt.close('all')
681+
tm.close()
682682
# make sure kwargs to hist are handled
683683
ax = ser.hist(normed=True, cumulative=True, bins=4)
684684
# height of last bin (index 5) must be 1.0
685685
self.assertAlmostEqual(ax.get_children()[5].get_height(), 1.0)
686686

687-
plt.close('all')
687+
tm.close()
688688
ax = ser.hist(log=True)
689689
# scale of y must be 'log'
690690
self.assertEqual(ax.get_yscale(), 'log')
691691

692-
plt.close('all')
692+
tm.close()
693693

694694
# propagate attr exception from matplotlib.Axes.hist
695695
with tm.assertRaises(AttributeError):
@@ -698,7 +698,6 @@ def test_hist(self):
698698
@slow
699699
def test_hist_layout(self):
700700
import matplotlib.pyplot as plt
701-
plt.close('all')
702701
df = DataFrame(randn(100, 4))
703702

704703
layout_to_expected_size = (
@@ -847,15 +846,15 @@ def test_line_colors(self):
847846
tmp = sys.stderr
848847
sys.stderr = StringIO()
849848
try:
850-
plt.close('all')
849+
tm.close()
851850
ax2 = df.plot(colors=custom_colors)
852851
lines2 = ax2.get_lines()
853852
for l1, l2 in zip(lines, lines2):
854853
self.assertEqual(l1.get_color(), l2.get_color())
855854
finally:
856855
sys.stderr = tmp
857856

858-
plt.close('all')
857+
tm.close()
859858

860859
ax = df.plot(colormap='jet')
861860

@@ -867,7 +866,7 @@ def test_line_colors(self):
867866
rs = l.get_color()
868867
self.assertEqual(xp, rs)
869868

870-
plt.close('all')
869+
tm.close()
871870

872871
ax = df.plot(colormap=cm.jet)
873872

@@ -881,14 +880,13 @@ def test_line_colors(self):
881880

882881
# make color a list if plotting one column frame
883882
# handles cases like df.plot(color='DodgerBlue')
884-
plt.close('all')
883+
tm.close()
885884
df.ix[:, [0]].plot(color='DodgerBlue')
886885

887886
def test_default_color_cycle(self):
888887
import matplotlib.pyplot as plt
889888
plt.rcParams['axes.color_cycle'] = list('rgbk')
890889

891-
plt.close('all')
892890
df = DataFrame(randn(5, 3))
893891
ax = df.plot()
894892

@@ -992,15 +990,15 @@ def test_grouped_hist(self):
992990
axes = plotting.grouped_hist(df.A, by=df.C)
993991
self.assertEqual(len(axes.ravel()), 4)
994992

995-
plt.close('all')
993+
tm.close()
996994
axes = df.hist(by=df.C)
997995
self.assertEqual(axes.ndim, 2)
998996
self.assertEqual(len(axes.ravel()), 4)
999997

1000998
for ax in axes.ravel():
1001999
self.assert_(len(ax.patches) > 0)
10021000

1003-
plt.close('all')
1001+
tm.close()
10041002
# make sure kwargs to hist are handled
10051003
axes = plotting.grouped_hist(df.A, by=df.C, normed=True,
10061004
cumulative=True, bins=4)
@@ -1010,14 +1008,13 @@ def test_grouped_hist(self):
10101008
height = ax.get_children()[5].get_height()
10111009
self.assertAlmostEqual(height, 1.0)
10121010

1013-
plt.close('all')
1011+
tm.close()
10141012
axes = plotting.grouped_hist(df.A, by=df.C, log=True)
10151013
# scale of y must be 'log'
10161014
for ax in axes.ravel():
10171015
self.assertEqual(ax.get_yscale(), 'log')
10181016

1019-
plt.close('all')
1020-
1017+
tm.close()
10211018
# propagate attr exception from matplotlib.Axes.hist
10221019
with tm.assertRaises(AttributeError):
10231020
plotting.grouped_hist(df.A, by=df.C, foo='bar')
@@ -1026,9 +1023,8 @@ def test_grouped_hist(self):
10261023
def test_grouped_hist_layout(self):
10271024
import matplotlib.pyplot as plt
10281025
n = 100
1029-
df = DataFrame({'gender': np.array(['Male',
1030-
'Female'])[random.randint(2,
1031-
size=n)],
1026+
gender = tm.choice(['Male', 'Female'], size=n)
1027+
df = DataFrame({'gender': gender,
10321028
'height': random.normal(66, 4, size=n),
10331029
'weight': random.normal(161, 32, size=n),
10341030
'category': random.randint(4, size=n)})
@@ -1042,10 +1038,10 @@ def test_grouped_hist_layout(self):
10421038
layout=(2, 1))
10431039
self.assertEqual(df.hist(column='height', by=df.gender,
10441040
layout=(2, 1)).shape, (2,))
1045-
plt.close('all')
1041+
tm.close()
10461042
self.assertEqual(df.hist(column='height', by=df.category,
10471043
layout=(4, 1)).shape, (4,))
1048-
plt.close('all')
1044+
tm.close()
10491045
self.assertEqual(df.hist(column='height', by=df.category,
10501046
layout=(4, 2)).shape, (4, 2))
10511047

pandas/tools/plotting.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -2042,15 +2042,16 @@ def hist_series(self, by=None, ax=None, grid=True, xlabelsize=None,
20422042
"""
20432043
import matplotlib.pyplot as plt
20442044

2045-
fig = kwds.get('figure', _gcf()
2046-
if plt.get_fignums() else plt.figure(figsize=figsize))
2047-
if figsize is not None and tuple(figsize) != tuple(fig.get_size_inches()):
2048-
fig.set_size_inches(*figsize, forward=True)
2049-
20502045
if by is None:
2051-
if kwds.get('layout', None):
2046+
if kwds.get('layout', None) is not None:
20522047
raise ValueError("The 'layout' keyword is not supported when "
20532048
"'by' is None")
2049+
# hack until the plotting interface is a bit more unified
2050+
fig = kwds.pop('figure', plt.gcf() if plt.get_fignums() else
2051+
plt.figure(figsize=figsize))
2052+
if (figsize is not None and tuple(figsize) !=
2053+
tuple(fig.get_size_inches())):
2054+
fig.set_size_inches(*figsize, forward=True)
20542055
if ax is None:
20552056
ax = fig.gca()
20562057
elif ax.get_figure() != fig:

0 commit comments

Comments
 (0)