Skip to content

Commit 7b7474e

Browse files
committed
CLN: more plotting test cleanups
1 parent 1675c2a commit 7b7474e

File tree

4 files changed

+95
-81
lines changed

4 files changed

+95
-81
lines changed

pandas/tests/test_graphics.py

+37-57
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,11 @@ def _skip_if_no_scipy():
2929
raise nose.SkipTest
3030

3131

32+
@tm.mplskip
3233
class TestSeriesPlots(unittest.TestCase):
33-
@classmethod
34-
def setUpClass(cls):
35-
try:
36-
import matplotlib as mpl
37-
mpl.use('Agg', warn=False)
38-
cls.mpl_le_1_2_1 = str(mpl.__version__) <= LooseVersion('1.2.1')
39-
except ImportError:
40-
raise nose.SkipTest("matplotlib not installed")
41-
4234
def setUp(self):
35+
import matplotlib as mpl
36+
self.mpl_le_1_2_1 = str(mpl.__version__) <= LooseVersion('1.2.1')
4337
self.ts = tm.makeTimeSeries()
4438
self.ts.name = 'ts'
4539

@@ -50,8 +44,7 @@ def setUp(self):
5044
self.iseries.name = 'iseries'
5145

5246
def tearDown(self):
53-
import matplotlib.pyplot as plt
54-
plt.close('all')
47+
tm.close()
5548

5649
@slow
5750
def test_plot(self):
@@ -352,24 +345,14 @@ def test_dup_datetime_index_plot(self):
352345
_check_plot_works(s.plot)
353346

354347

348+
@tm.mplskip
355349
class TestDataFramePlots(unittest.TestCase):
356-
357-
@classmethod
358-
def setUpClass(cls):
359-
# import sys
360-
# if 'IPython' in sys.modules:
361-
# raise nose.SkipTest
362-
363-
try:
364-
import matplotlib as mpl
365-
mpl.use('Agg', warn=False)
366-
cls.mpl_le_1_2_1 = str(mpl.__version__) <= LooseVersion('1.2.1')
367-
except ImportError:
368-
raise nose.SkipTest("matplotlib not installed")
350+
def setUp(self):
351+
import matplotlib as mpl
352+
self.mpl_le_1_2_1 = str(mpl.__version__) <= LooseVersion('1.2.1')
369353

370354
def tearDown(self):
371-
import matplotlib.pyplot as plt
372-
plt.close('all')
355+
tm.close()
373356

374357
@slow
375358
def test_plot(self):
@@ -949,19 +932,10 @@ def test_invalid_kind(self):
949932
df.plot(kind='aasdf')
950933

951934

935+
@tm.mplskip
952936
class TestDataFrameGroupByPlots(unittest.TestCase):
953-
@classmethod
954-
def setUpClass(cls):
955-
try:
956-
import matplotlib as mpl
957-
mpl.use('Agg', warn=False)
958-
except ImportError:
959-
raise nose.SkipTest
960-
961937
def tearDown(self):
962-
import matplotlib.pyplot as plt
963-
for fignum in plt.get_fignums():
964-
plt.close(fignum)
938+
tm.close()
965939

966940
@slow
967941
def test_boxplot(self):
@@ -999,13 +973,16 @@ def test_time_series_plot_color_with_empty_kwargs(self):
999973
import matplotlib as mpl
1000974

1001975
def_colors = mpl.rcParams['axes.color_cycle']
976+
index = date_range('1/1/2000', periods=12)
977+
s = Series(np.arange(1, 13), index=index)
978+
979+
ncolors = 3
1002980

1003-
for i in range(3):
1004-
ax = Series(np.arange(12) + 1, index=date_range('1/1/2000',
1005-
periods=12)).plot()
981+
for i in range(ncolors):
982+
ax = s.plot()
1006983

1007984
line_colors = [l.get_color() for l in ax.get_lines()]
1008-
self.assertEqual(line_colors, def_colors[:3])
985+
self.assertEqual(line_colors, def_colors[:ncolors])
1009986

1010987
@slow
1011988
def test_grouped_hist(self):
@@ -1155,27 +1132,30 @@ def _check_plot_works(f, *args, **kwargs):
11551132
import matplotlib.pyplot as plt
11561133

11571134
try:
1158-
fig = kwargs['figure']
1159-
except KeyError:
1160-
fig = plt.gcf()
1161-
plt.clf()
1162-
ax = kwargs.get('ax', fig.add_subplot(211))
1163-
ret = f(*args, **kwargs)
1135+
try:
1136+
fig = kwargs['figure']
1137+
except KeyError:
1138+
fig = plt.gcf()
11641139

1165-
assert ret is not None
1166-
assert_is_valid_plot_return_object(ret)
1140+
plt.clf()
11671141

1168-
try:
1169-
kwargs['ax'] = fig.add_subplot(212)
1142+
ax = kwargs.get('ax', fig.add_subplot(211))
11701143
ret = f(*args, **kwargs)
1171-
except Exception:
1172-
pass
1173-
else:
1144+
11741145
assert_is_valid_plot_return_object(ret)
11751146

1176-
with ensure_clean() as path:
1177-
plt.savefig(path)
1178-
plt.close(fig)
1147+
try:
1148+
kwargs['ax'] = fig.add_subplot(212)
1149+
ret = f(*args, **kwargs)
1150+
except Exception:
1151+
pass
1152+
else:
1153+
assert_is_valid_plot_return_object(ret)
1154+
1155+
with ensure_clean() as path:
1156+
plt.savefig(path)
1157+
finally:
1158+
tm.close(fig)
11791159

11801160

11811161
def curpath():

pandas/tests/test_rplot.py

+29-12
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,11 @@
88
import nose
99

1010

11-
try:
12-
import matplotlib.pyplot as plt
13-
except:
14-
raise nose.SkipTest
15-
16-
1711
def curpath():
1812
pth, _ = os.path.split(os.path.abspath(__file__))
1913
return pth
2014

15+
2116
def between(a, b, x):
2217
"""Check if x is in the somewhere between a and b.
2318
@@ -36,6 +31,8 @@ def between(a, b, x):
3631
else:
3732
return x <= a and x >= b
3833

34+
35+
@tm.mplskip
3936
class TestUtilityFunctions(unittest.TestCase):
4037
"""
4138
Tests for RPlot utility functions.
@@ -74,24 +71,25 @@ def test_dictionary_union(self):
7471
self.assertTrue(2 in keys)
7572
self.assertTrue(3 in keys)
7673
self.assertTrue(4 in keys)
77-
self.assertTrue(rplot.dictionary_union(dict1, {}) == dict1)
78-
self.assertTrue(rplot.dictionary_union({}, dict1) == dict1)
79-
self.assertTrue(rplot.dictionary_union({}, {}) == {})
74+
self.assertEqual(rplot.dictionary_union(dict1, {}), dict1)
75+
self.assertEqual(rplot.dictionary_union({}, dict1), dict1)
76+
self.assertEqual(rplot.dictionary_union({}, {}), {})
8077

8178
def test_merge_aes(self):
8279
layer1 = rplot.Layer(size=rplot.ScaleSize('test'))
8380
layer2 = rplot.Layer(shape=rplot.ScaleShape('test'))
8481
rplot.merge_aes(layer1, layer2)
8582
self.assertTrue(isinstance(layer2.aes['size'], rplot.ScaleSize))
8683
self.assertTrue(isinstance(layer2.aes['shape'], rplot.ScaleShape))
87-
self.assertTrue(layer2.aes['size'] == layer1.aes['size'])
84+
self.assertEqual(layer2.aes['size'], layer1.aes['size'])
8885
for key in layer2.aes.keys():
8986
if key != 'size' and key != 'shape':
9087
self.assertTrue(layer2.aes[key] is None)
9188

9289
def test_sequence_layers(self):
9390
layer1 = rplot.Layer(self.data)
94-
layer2 = rplot.GeomPoint(x='SepalLength', y='SepalWidth', size=rplot.ScaleSize('PetalLength'))
91+
layer2 = rplot.GeomPoint(x='SepalLength', y='SepalWidth',
92+
size=rplot.ScaleSize('PetalLength'))
9593
layer3 = rplot.GeomPolyFit(2)
9694
result = rplot.sequence_layers([layer1, layer2, layer3])
9795
self.assertEqual(len(result), 3)
@@ -102,6 +100,8 @@ def test_sequence_layers(self):
102100
self.assertTrue(self.data is last.data)
103101
self.assertTrue(rplot.sequence_layers([layer1])[0] is layer1)
104102

103+
104+
@tm.mplskip
105105
class TestTrellis(unittest.TestCase):
106106
def setUp(self):
107107
path = os.path.join(curpath(), 'data/tips.csv')
@@ -148,11 +148,15 @@ def test_trellis_cols_rows(self):
148148
self.assertEqual(self.trellis3.cols, 2)
149149
self.assertEqual(self.trellis3.rows, 1)
150150

151+
152+
@tm.mplskip
151153
class TestScaleGradient(unittest.TestCase):
152154
def setUp(self):
153155
path = os.path.join(curpath(), 'data/iris.csv')
154156
self.data = read_csv(path, sep=',')
155-
self.gradient = rplot.ScaleGradient("SepalLength", colour1=(0.2, 0.3, 0.4), colour2=(0.8, 0.7, 0.6))
157+
self.gradient = rplot.ScaleGradient("SepalLength", colour1=(0.2, 0.3,
158+
0.4),
159+
colour2=(0.8, 0.7, 0.6))
156160

157161
def test_gradient(self):
158162
for index in range(len(self.data)):
@@ -164,6 +168,8 @@ def test_gradient(self):
164168
self.assertTrue(between(g1, g2, g))
165169
self.assertTrue(between(b1, b2, b))
166170

171+
172+
@tm.mplskip
167173
class TestScaleGradient2(unittest.TestCase):
168174
def setUp(self):
169175
path = os.path.join(curpath(), 'data/iris.csv')
@@ -190,6 +196,8 @@ def test_gradient2(self):
190196
self.assertTrue(between(g2, g3, g))
191197
self.assertTrue(between(b2, b3, b))
192198

199+
200+
@tm.mplskip
193201
class TestScaleRandomColour(unittest.TestCase):
194202
def setUp(self):
195203
path = os.path.join(curpath(), 'data/iris.csv')
@@ -208,13 +216,16 @@ def test_random_colour(self):
208216
self.assertTrue(g <= 1.0)
209217
self.assertTrue(b <= 1.0)
210218

219+
220+
@tm.mplskip
211221
class TestScaleConstant(unittest.TestCase):
212222
def test_scale_constant(self):
213223
scale = rplot.ScaleConstant(1.0)
214224
self.assertEqual(scale(None, None), 1.0)
215225
scale = rplot.ScaleConstant("test")
216226
self.assertEqual(scale(None, None), "test")
217227

228+
218229
class TestScaleSize(unittest.TestCase):
219230
def setUp(self):
220231
path = os.path.join(curpath(), 'data/iris.csv')
@@ -235,8 +246,10 @@ def f():
235246
self.assertRaises(ValueError, f)
236247

237248

249+
@tm.mplskip
238250
class TestRPlot(unittest.TestCase):
239251
def test_rplot1(self):
252+
import matplotlib.pyplot as plt
240253
path = os.path.join(curpath(), 'data/tips.csv')
241254
plt.figure()
242255
self.data = read_csv(path, sep=',')
@@ -247,6 +260,7 @@ def test_rplot1(self):
247260
self.plot.render(self.fig)
248261

249262
def test_rplot2(self):
263+
import matplotlib.pyplot as plt
250264
path = os.path.join(curpath(), 'data/tips.csv')
251265
plt.figure()
252266
self.data = read_csv(path, sep=',')
@@ -257,6 +271,7 @@ def test_rplot2(self):
257271
self.plot.render(self.fig)
258272

259273
def test_rplot3(self):
274+
import matplotlib.pyplot as plt
260275
path = os.path.join(curpath(), 'data/tips.csv')
261276
plt.figure()
262277
self.data = read_csv(path, sep=',')
@@ -267,6 +282,7 @@ def test_rplot3(self):
267282
self.plot.render(self.fig)
268283

269284
def test_rplot_iris(self):
285+
import matplotlib.pyplot as plt
270286
path = os.path.join(curpath(), 'data/iris.csv')
271287
plt.figure()
272288
self.data = read_csv(path, sep=',')
@@ -277,5 +293,6 @@ def test_rplot_iris(self):
277293
self.fig = plt.gcf()
278294
plot.render(self.fig)
279295

296+
280297
if __name__ == '__main__':
281298
unittest.main()

pandas/tseries/tests/test_plotting.py

+2-12
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,8 @@ def _skip_if_no_scipy():
2626
raise nose.SkipTest
2727

2828

29+
@tm.mplskip
2930
class TestTSPlot(unittest.TestCase):
30-
31-
@classmethod
32-
def setUpClass(cls):
33-
try:
34-
import matplotlib as mpl
35-
mpl.use('Agg', warn=False)
36-
except ImportError:
37-
raise nose.SkipTest
38-
3931
def setUp(self):
4032
freq = ['S', 'T', 'H', 'D', 'W', 'M', 'Q', 'Y']
4133
idx = [period_range('12/31/1999', freq=x, periods=100) for x in freq]
@@ -52,9 +44,7 @@ def setUp(self):
5244
for x in idx]
5345

5446
def tearDown(self):
55-
import matplotlib.pyplot as plt
56-
for fignum in plt.get_fignums():
57-
plt.close(fignum)
47+
tm.close()
5848

5949
@slow
6050
def test_ts_plot_with_tz(self):

pandas/util/testing.py

+27
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from numpy.random import randn, rand
1919
import numpy as np
2020

21+
import nose
22+
2123
from pandas.core.common import isnull, _is_sequence
2224
import pandas.core.index as index
2325
import pandas.core.series as series
@@ -70,6 +72,31 @@ def choice(x, size=10):
7072
except AttributeError:
7173
return np.random.randint(len(x), size=size).choose(x)
7274

75+
76+
def close(fignum=None):
77+
from matplotlib.pyplot import get_fignums, close as _close
78+
79+
if fignum is None:
80+
for fignum in get_fignums():
81+
_close(fignum)
82+
else:
83+
_close(fignum)
84+
85+
86+
def mplskip(cls):
87+
"""Skip a TestCase instance if matplotlib isn't installed"""
88+
@classmethod
89+
def setUpClass(cls):
90+
try:
91+
import matplotlib as mpl
92+
mpl.use("Agg", warn=False)
93+
except ImportError:
94+
raise nose.SkipTest("matplotlib not installed")
95+
96+
cls.setUpClass = setUpClass
97+
return cls
98+
99+
73100
#------------------------------------------------------------------------------
74101
# Console debugging tools
75102

0 commit comments

Comments
 (0)